diff --git a/iss/models/AbstractModel.py b/iss/models/AbstractModel.py index 0af2c5f..b443854 100644 --- a/iss/models/AbstractModel.py +++ b/iss/models/AbstractModel.py @@ -14,7 +14,7 @@ class AbstractModel: if not os.path.exists(self.save_directory): os.makedirs(self.save_directory) - self.model.save('{}/final_model.hdf5'.format(self.save_directory)) + self.model.save('{}/final_{}.hdf5'.format(self.save_directory, self.model_name)) def load(self, which = 'final_model'): @@ -34,3 +34,24 @@ class AbstractAutoEncoderModel(AbstractModel): super().__init__(save_directory, model_name) self.encoder_model = None self.decoder_model = None + + def get_encoded_prediction(self, pictures): + return self.encoder_model.predict(pictures) + + def get_full_encoded_prediction(self, generator, nb_batch = None): + + generator.reset() + div = np.divmod(generator.n, generator.batch_size) + + if nb_batch is None: + nb_batch = div[0] + 1 * (div[1] != 0) - 1 + + if nb_batch <= 0: + return + + predictions = self.get_encoded_prediction(generator.next()[1]) + while generator.batch_index <= (nb_batch - 1): + predictions = np.concatenate((predictions, self.get_encoded_prediction(generator.next()[1]) ), axis = 0) + + return predictions + diff --git a/iss/models/ModelTrainer.py b/iss/models/ModelTrainer.py index 58c5e62..92b2d91 100644 --- a/iss/models/ModelTrainer.py +++ b/iss/models/ModelTrainer.py @@ -64,7 +64,7 @@ class ModelTrainer: Tools.create_dir_if_not_exists(log_dir) self.csv_logger = CSVLogger( - filename = '{}/training.log'.format(log_dir), + filename = '{}/{}training.log'.format(log_dir, self.model.model_name), append = config['callbacks']['csv_logger']['append'] ) self.callbacks.extend([self.csv_logger])