diff --git a/iss/models/ModelTrainer.py b/iss/models/ModelTrainer.py index 7af8a1a..73fd30a 100644 --- a/iss/models/ModelTrainer.py +++ b/iss/models/ModelTrainer.py @@ -3,7 +3,7 @@ from keras.callbacks import ModelCheckpoint, CSVLogger from iss.models.Callbacks import DisplayPictureCallback from iss.tools.tools import Tools - +import keras class ModelTrainer: @@ -25,6 +25,8 @@ class ModelTrainer: def train(self): + print(keras.__version__) + self.model.model.fit_generator( generator = self.data_loader.get_train_generator(), steps_per_epoch = self.steps_per_epoch,