1
0
Fork 0
mirror of https://github.com/prise6/smart-iss-posts synced 2024-05-19 22:06:33 +02:00

fix for colab

This commit is contained in:
prise6 2019-03-10 17:50:45 +01:00
parent 7c8922940a
commit a6969908fd

View file

@ -27,19 +27,33 @@ class ModelTrainer:
print(keras.__version__) print(keras.__version__)
self.model.model.fit_generator( if self.validation_freq is not None:
generator = self.data_loader.get_train_generator(), self.model.model.fit_generator(
steps_per_epoch = self.steps_per_epoch, generator = self.data_loader.get_train_generator(),
epochs = self.epochs, steps_per_epoch = self.steps_per_epoch,
verbose = self.verbose, epochs = self.epochs,
initial_epoch = self.initial_epoch, verbose = self.verbose,
callbacks = self.callbacks, initial_epoch = self.initial_epoch,
workers = self.workers, callbacks = self.callbacks,
use_multiprocessing = self.use_multiprocessing, workers = self.workers,
validation_data = self.data_loader.get_test_generator(), use_multiprocessing = self.use_multiprocessing,
validation_steps = self.validation_steps, validation_data = self.data_loader.get_test_generator(),
validation_freq = self.validation_freq validation_steps = self.validation_steps,
) validation_freq = self.validation_freq
)
else:
self.model.model.fit_generator(
generator = self.data_loader.get_train_generator(),
steps_per_epoch = self.steps_per_epoch,
epochs = self.epochs,
verbose = self.verbose,
initial_epoch = self.initial_epoch,
callbacks = self.callbacks,
workers = self.workers,
use_multiprocessing = self.use_multiprocessing,
validation_data = self.data_loader.get_test_generator(),
validation_steps = self.validation_steps
)
self.model.save() self.model.save()