From 88acd0d60d2c74e5158a67b96b88e536147c8cbe Mon Sep 17 00:00:00 2001 From: Francois Date: Sat, 9 Mar 2019 22:41:37 +0100 Subject: [PATCH] class pour entrainer un autoencoder --- iss/models/AbstractModel.py | 28 +++++++++++++ iss/models/Callbacks.py | 31 ++++++++++++++ iss/models/DataLoader.py | 53 +++++++++++++++++++++++ iss/models/ModelTrainer.py | 74 +++++++++++++++++++++++++++++++++ iss/models/SimpleAutoEncoder.py | 49 ++++++++++++++++++++++ 5 files changed, 235 insertions(+) create mode 100644 iss/models/AbstractModel.py create mode 100644 iss/models/Callbacks.py create mode 100644 iss/models/DataLoader.py create mode 100644 iss/models/ModelTrainer.py create mode 100644 iss/models/SimpleAutoEncoder.py diff --git a/iss/models/AbstractModel.py b/iss/models/AbstractModel.py new file mode 100644 index 0000000..553a031 --- /dev/null +++ b/iss/models/AbstractModel.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- + +from keras.models import load_model +import numpy as np +import os + +class AbstractModel: + def __init__(self, save_directory, model_name): + self.save_directory = save_directory + self.model = None + self.model_name = model_name + + def save(self): + if not os.path.exists(self.save_directory): + os.makedirs(self.save_directory) + + self.model.save('{}/final_model.hdf5'.format(self.save_directory)) + + + def load(self, which = 'final_model'): + self.model = load_model('{}/{}.hdf5'.format(self.save_directory, which)) + + def predict(self, x, batch_size = None, verbose = 0, steps = None, callbacks = None): + return self.model.predict(x, batch_size, verbose, steps) + + def predict_one(self, x, batch_size = 1, verbose = 0, steps = None): + x = np.expand_dims(x, axis = 0) + return self.predict(x, batch_size, verbose, steps) diff --git a/iss/models/Callbacks.py b/iss/models/Callbacks.py new file mode 100644 index 0000000..fd6c51b --- /dev/null +++ b/iss/models/Callbacks.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- + +from keras.callbacks import Callback +import numpy as np +from iss.tools.tools import Tools +from IPython.display import display + +class DisplayPictureCallback(Callback): + + def __init__(self, model, epoch_laps, data_loader): + + self.model_class = model + self.epoch_laps = epoch_laps + self.data_loader = data_loader + super(DisplayPictureCallback, self).__init__() + + + def on_epoch_end(self, epoch, logs): + if epoch % self.epoch_laps == 0: + + print("ok") + + input_pict = self.data_loader.next()[0][1] + output_pict = self.model_class.predict_one(input_pict) + + display(Tools.display_one_picture_scaled(input_pict)) + display(Tools.display_index_picture_scaled(output_pict)) + + return self + + diff --git a/iss/models/DataLoader.py b/iss/models/DataLoader.py new file mode 100644 index 0000000..3c1bd47 --- /dev/null +++ b/iss/models/DataLoader.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- + +from keras.preprocessing.image import ImageDataGenerator + +class ImageDataGeneratorWrapper: + + def __init__(self, config): + + self.config = config + self.datagen = None + self.train_generator = None + self.test_generator = None + + self.image_data_generator(config) + + self.set_train_generator() + self.set_test_generator() + + def image_data_generator(self, config): + self.datagen = ImageDataGenerator( + rescale = 1./255 + ) + return self + + def build_generator(self, directory): + # voir plus tars si besoin de parametrer + return self.datagen.flow_from_directory( + directory, + target_size = (self.config.get('models')['simple']['input_height'], self.config.get('models')['simple']['input_width']), + color_mode = 'rgb', + classes = None, + class_mode = 'input', + batch_size = self.config.get('models')['simple']['batch_size'], + ) + + def set_train_generator(self): + train_dir = self.config.get('directory')['autoencoder']['train'] + '/..' + self.train_generator = self.build_generator(directory = train_dir) + return self + + def get_train_generator(self): + return self.train_generator + + def set_test_generator(self): + test_dir = self.config.get('directory')['autoencoder']['test'] + '/..' + self.test_generator = self.build_generator(directory = test_dir) + return self + + def get_test_generator(self): + return self.train_generator + + + diff --git a/iss/models/ModelTrainer.py b/iss/models/ModelTrainer.py new file mode 100644 index 0000000..7af8a1a --- /dev/null +++ b/iss/models/ModelTrainer.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- + +from keras.callbacks import ModelCheckpoint, CSVLogger +from iss.models.Callbacks import DisplayPictureCallback +from iss.tools.tools import Tools + + +class ModelTrainer: + + def __init__(self, model, data_loader, config, callbacks=[]): + self.model = model + self.data_loader = data_loader + + self.epochs = config['epochs'] + self.verbose = config['verbose'] + self.initial_epoch = config['initial_epoch'] + self.workers = config['workers'] + self.use_multiprocessing = config['use_multiprocessing'] + self.steps_per_epoch = config['steps_per_epoch'] + self.validation_steps = config['validation_steps'] + self.validation_freq = config['validation_freq'] + self.callbacks = callbacks + + self.init_callbacks(config) + + def train(self): + + self.model.model.fit_generator( + generator = self.data_loader.get_train_generator(), + steps_per_epoch = self.steps_per_epoch, + epochs = self.epochs, + verbose = self.verbose, + initial_epoch = self.initial_epoch, + callbacks = self.callbacks, + workers = self.workers, + use_multiprocessing = self.use_multiprocessing, + validation_data = self.data_loader.get_test_generator(), + validation_steps = self.validation_steps, + validation_freq = self.validation_freq + ) + + self.model.save() + + def init_callbacks(self, config): + + if 'csv_logger' in config['callbacks']: + log_dir = config['callbacks']['csv_logger']['directory'] + Tools.create_dir_if_not_exists(log_dir) + + self.csv_logger = CSVLogger( + filename = '{}/training.log'.format(log_dir), + append = config['callbacks']['csv_logger']['append'] + ) + self.callbacks.extend([self.csv_logger]) + + if 'checkpoint' in config['callbacks']: + chekpt_dir = config['callbacks']['checkpoint']['directory'] + Tools.create_dir_if_not_exists(chekpt_dir) + self.checkpointer = ModelCheckpoint( + filepath = chekpt_dir + '/' + self.model.model_name + '-{epoch:02d}.hdf5', + verbose = config['callbacks']['checkpoint']['verbose'], + period = config['callbacks']['checkpoint']['period'] + ) + self.callbacks.extend([self.checkpointer]) + + if 'display_picture' in config['callbacks']: + self.picture_displayer = DisplayPictureCallback( + model = self.model, + data_loader = self.data_loader.get_train_generator(), + epoch_laps = config['callbacks']['display_picture']['epoch_laps'] + ) + self.callbacks.extend([self.picture_displayer]) + + return self diff --git a/iss/models/SimpleAutoEncoder.py b/iss/models/SimpleAutoEncoder.py new file mode 100644 index 0000000..0b011d0 --- /dev/null +++ b/iss/models/SimpleAutoEncoder.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- + +from iss.models.AbstractModel import AbstractModel +from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Reshape, Flatten +from keras.optimizers import Adadelta +from keras.models import Model +import numpy as np + +class SimpleAutoEncoder(AbstractModel): + + def __init__(self, config): + + save_directory = config['save_directory'] + model_name = config['model_name'] + + super().__init__(save_directory, model_name) + + self.input_shape = (config['input_height'], config['input_width'], config['input_channel']) + self.lr = config['learning_rate'] + self.build_model() + + def build_model(self): + input_shape = self.input_shape + + picture = Input(shape = input_shape) + + x = Flatten()(picture) + layer_1 = Dense(2000, activation = 'relu', name = 'enc_1')(x) + layer_2 = Dense(100, activation = 'relu', name = 'enc_2')(layer_1) + layer_3 = Dense(30, activation = 'relu', name = 'enc_3')(layer_2) + layer_4 = Dense(100, activation = 'relu', name = 'dec_1')(layer_3) + layer_5 = Dense(2000, activation = 'relu', name = 'dec_2')(layer_4) + + # encoded network + # x = Conv2D(1, (3, 3), activation = 'relu', padding = 'same', name = 'enc_conv_1')(picture) + # encoded = MaxPooling2D((2, 2))(x) + + # decoded network + # x = Conv2D(1, (3, 3), activation = 'relu', padding = 'same', name = 'dec_conv_1')(encoded) + # x = UpSampling2D((2, 2))(x) + # x = Flatten()(x) + x = Dense(np.prod(input_shape), activation = 'softmax')(layer_5) + decoded = Reshape((input_shape))(x) + + self.model = Model(picture, decoded) + + optimizer = Adadelta(lr = self.lr, rho = 0.95, epsilon = None, decay = 0.0) + + self.model.compile(optimizer = optimizer, loss = 'binary_crossentropy')