mirror of
https://github.com/prise6/smart-iss-posts
synced 2024-05-04 22:53:09 +02:00
name of model
This commit is contained in:
parent
d1b24f4891
commit
3e2147f146
|
@ -14,7 +14,7 @@ class AbstractModel:
|
||||||
if not os.path.exists(self.save_directory):
|
if not os.path.exists(self.save_directory):
|
||||||
os.makedirs(self.save_directory)
|
os.makedirs(self.save_directory)
|
||||||
|
|
||||||
self.model.save('{}/final_model.hdf5'.format(self.save_directory))
|
self.model.save('{}/final_{}.hdf5'.format(self.save_directory, self.model_name))
|
||||||
|
|
||||||
|
|
||||||
def load(self, which = 'final_model'):
|
def load(self, which = 'final_model'):
|
||||||
|
@ -34,3 +34,24 @@ class AbstractAutoEncoderModel(AbstractModel):
|
||||||
super().__init__(save_directory, model_name)
|
super().__init__(save_directory, model_name)
|
||||||
self.encoder_model = None
|
self.encoder_model = None
|
||||||
self.decoder_model = None
|
self.decoder_model = None
|
||||||
|
|
||||||
|
def get_encoded_prediction(self, pictures):
|
||||||
|
return self.encoder_model.predict(pictures)
|
||||||
|
|
||||||
|
def get_full_encoded_prediction(self, generator, nb_batch = None):
|
||||||
|
|
||||||
|
generator.reset()
|
||||||
|
div = np.divmod(generator.n, generator.batch_size)
|
||||||
|
|
||||||
|
if nb_batch is None:
|
||||||
|
nb_batch = div[0] + 1 * (div[1] != 0) - 1
|
||||||
|
|
||||||
|
if nb_batch <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
predictions = self.get_encoded_prediction(generator.next()[1])
|
||||||
|
while generator.batch_index <= (nb_batch - 1):
|
||||||
|
predictions = np.concatenate((predictions, self.get_encoded_prediction(generator.next()[1]) ), axis = 0)
|
||||||
|
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
|
|
@ -64,7 +64,7 @@ class ModelTrainer:
|
||||||
Tools.create_dir_if_not_exists(log_dir)
|
Tools.create_dir_if_not_exists(log_dir)
|
||||||
|
|
||||||
self.csv_logger = CSVLogger(
|
self.csv_logger = CSVLogger(
|
||||||
filename = '{}/training.log'.format(log_dir),
|
filename = '{}/{}training.log'.format(log_dir, self.model.model_name),
|
||||||
append = config['callbacks']['csv_logger']['append']
|
append = config['callbacks']['csv_logger']['append']
|
||||||
)
|
)
|
||||||
self.callbacks.extend([self.csv_logger])
|
self.callbacks.extend([self.csv_logger])
|
||||||
|
|
Loading…
Reference in a new issue