mirror of
https://github.com/prise6/smart-iss-posts
synced 2024-04-25 10:40:26 +02:00
name of model
This commit is contained in:
parent
d1b24f4891
commit
3e2147f146
|
@ -14,7 +14,7 @@ class AbstractModel:
|
|||
if not os.path.exists(self.save_directory):
|
||||
os.makedirs(self.save_directory)
|
||||
|
||||
self.model.save('{}/final_model.hdf5'.format(self.save_directory))
|
||||
self.model.save('{}/final_{}.hdf5'.format(self.save_directory, self.model_name))
|
||||
|
||||
|
||||
def load(self, which = 'final_model'):
|
||||
|
@ -34,3 +34,24 @@ class AbstractAutoEncoderModel(AbstractModel):
|
|||
super().__init__(save_directory, model_name)
|
||||
self.encoder_model = None
|
||||
self.decoder_model = None
|
||||
|
||||
def get_encoded_prediction(self, pictures):
|
||||
return self.encoder_model.predict(pictures)
|
||||
|
||||
def get_full_encoded_prediction(self, generator, nb_batch = None):
|
||||
|
||||
generator.reset()
|
||||
div = np.divmod(generator.n, generator.batch_size)
|
||||
|
||||
if nb_batch is None:
|
||||
nb_batch = div[0] + 1 * (div[1] != 0) - 1
|
||||
|
||||
if nb_batch <= 0:
|
||||
return
|
||||
|
||||
predictions = self.get_encoded_prediction(generator.next()[1])
|
||||
while generator.batch_index <= (nb_batch - 1):
|
||||
predictions = np.concatenate((predictions, self.get_encoded_prediction(generator.next()[1]) ), axis = 0)
|
||||
|
||||
return predictions
|
||||
|
||||
|
|
|
@ -64,7 +64,7 @@ class ModelTrainer:
|
|||
Tools.create_dir_if_not_exists(log_dir)
|
||||
|
||||
self.csv_logger = CSVLogger(
|
||||
filename = '{}/training.log'.format(log_dir),
|
||||
filename = '{}/{}training.log'.format(log_dir, self.model.model_name),
|
||||
append = config['callbacks']['csv_logger']['append']
|
||||
)
|
||||
self.callbacks.extend([self.csv_logger])
|
||||
|
|
Loading…
Reference in a new issue