1
0
Fork 0
mirror of https://github.com/prise6/smart-iss-posts synced 2024-05-03 06:03:10 +02:00
smart-iss-posts/iss/models/AbstractModel.py
2019-03-13 15:25:46 +01:00

36 lines
1,017 B
Python

# -*- coding: utf-8 -*-
from keras.models import load_model
import numpy as np
import os
class AbstractModel:
def __init__(self, save_directory, model_name):
self.save_directory = save_directory
self.model = None
self.model_name = model_name
def save(self):
if not os.path.exists(self.save_directory):
os.makedirs(self.save_directory)
self.model.save('{}/final_model.hdf5'.format(self.save_directory))
def load(self, which = 'final_model'):
self.model = load_model('{}/{}.hdf5'.format(self.save_directory, which))
def predict(self, x, batch_size = None, verbose = 0, steps = None, callbacks = None):
return self.model.predict(x, batch_size, verbose, steps)
def predict_one(self, x, batch_size = 1, verbose = 0, steps = None):
x = np.expand_dims(x, axis = 0)
return self.predict(x, batch_size, verbose, steps)
class AbstractAutoEncoderModel(AbstractModel):
def __init__(self, save_directory, model_name):
super().__init__(save_directory, model_name)
self.encoded_layer = None