From a6969908fdba0cf6c3d9e872828df64211cdd512 Mon Sep 17 00:00:00 2001 From: prise6 Date: Sun, 10 Mar 2019 17:50:45 +0100 Subject: [PATCH] fix for colab --- iss/models/ModelTrainer.py | 40 +++++++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/iss/models/ModelTrainer.py b/iss/models/ModelTrainer.py index 73fd30a..58c5e62 100644 --- a/iss/models/ModelTrainer.py +++ b/iss/models/ModelTrainer.py @@ -27,19 +27,33 @@ class ModelTrainer: print(keras.__version__) - self.model.model.fit_generator( - generator = self.data_loader.get_train_generator(), - steps_per_epoch = self.steps_per_epoch, - epochs = self.epochs, - verbose = self.verbose, - initial_epoch = self.initial_epoch, - callbacks = self.callbacks, - workers = self.workers, - use_multiprocessing = self.use_multiprocessing, - validation_data = self.data_loader.get_test_generator(), - validation_steps = self.validation_steps, - validation_freq = self.validation_freq - ) + if self.validation_freq is not None: + self.model.model.fit_generator( + generator = self.data_loader.get_train_generator(), + steps_per_epoch = self.steps_per_epoch, + epochs = self.epochs, + verbose = self.verbose, + initial_epoch = self.initial_epoch, + callbacks = self.callbacks, + workers = self.workers, + use_multiprocessing = self.use_multiprocessing, + validation_data = self.data_loader.get_test_generator(), + validation_steps = self.validation_steps, + validation_freq = self.validation_freq + ) + else: + self.model.model.fit_generator( + generator = self.data_loader.get_train_generator(), + steps_per_epoch = self.steps_per_epoch, + epochs = self.epochs, + verbose = self.verbose, + initial_epoch = self.initial_epoch, + callbacks = self.callbacks, + workers = self.workers, + use_multiprocessing = self.use_multiprocessing, + validation_data = self.data_loader.get_test_generator(), + validation_steps = self.validation_steps + ) self.model.save()