1
0
Fork 0
mirror of https://github.com/prise6/smart-iss-posts synced 2024-04-25 10:40:26 +02:00

name of model

This commit is contained in:
Francois Vieille 2019-03-24 17:57:37 +01:00
parent d1b24f4891
commit 3e2147f146
2 changed files with 23 additions and 2 deletions

View file

@ -14,7 +14,7 @@ class AbstractModel:
if not os.path.exists(self.save_directory):
os.makedirs(self.save_directory)
self.model.save('{}/final_model.hdf5'.format(self.save_directory))
self.model.save('{}/final_{}.hdf5'.format(self.save_directory, self.model_name))
def load(self, which = 'final_model'):
@ -34,3 +34,24 @@ class AbstractAutoEncoderModel(AbstractModel):
super().__init__(save_directory, model_name)
self.encoder_model = None
self.decoder_model = None
def get_encoded_prediction(self, pictures):
return self.encoder_model.predict(pictures)
def get_full_encoded_prediction(self, generator, nb_batch = None):
generator.reset()
div = np.divmod(generator.n, generator.batch_size)
if nb_batch is None:
nb_batch = div[0] + 1 * (div[1] != 0) - 1
if nb_batch <= 0:
return
predictions = self.get_encoded_prediction(generator.next()[1])
while generator.batch_index <= (nb_batch - 1):
predictions = np.concatenate((predictions, self.get_encoded_prediction(generator.next()[1]) ), axis = 0)
return predictions

View file

@ -64,7 +64,7 @@ class ModelTrainer:
Tools.create_dir_if_not_exists(log_dir)
self.csv_logger = CSVLogger(
filename = '{}/training.log'.format(log_dir),
filename = '{}/{}training.log'.format(log_dir, self.model.model_name),
append = config['callbacks']['csv_logger']['append']
)
self.callbacks.extend([self.csv_logger])