class pour entrainer un autoencoder

This commit is contained in:
Francois 2019-03-09 22:41:37 +01:00
parent bcaeb4cb1f
commit 88acd0d60d
5 changed files with 235 additions and 0 deletions

View File

@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
from keras.models import load_model
import numpy as np
import os
class AbstractModel:
def __init__(self, save_directory, model_name):
self.save_directory = save_directory
self.model = None
self.model_name = model_name
def save(self):
if not os.path.exists(self.save_directory):
os.makedirs(self.save_directory)
self.model.save('{}/final_model.hdf5'.format(self.save_directory))
def load(self, which = 'final_model'):
self.model = load_model('{}/{}.hdf5'.format(self.save_directory, which))
def predict(self, x, batch_size = None, verbose = 0, steps = None, callbacks = None):
return self.model.predict(x, batch_size, verbose, steps)
def predict_one(self, x, batch_size = 1, verbose = 0, steps = None):
x = np.expand_dims(x, axis = 0)
return self.predict(x, batch_size, verbose, steps)

31
iss/models/Callbacks.py Normal file
View File

@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-
from keras.callbacks import Callback
import numpy as np
from iss.tools.tools import Tools
from IPython.display import display
class DisplayPictureCallback(Callback):
def __init__(self, model, epoch_laps, data_loader):
self.model_class = model
self.epoch_laps = epoch_laps
self.data_loader = data_loader
super(DisplayPictureCallback, self).__init__()
def on_epoch_end(self, epoch, logs):
if epoch % self.epoch_laps == 0:
print("ok")
input_pict = self.data_loader.next()[0][1]
output_pict = self.model_class.predict_one(input_pict)
display(Tools.display_one_picture_scaled(input_pict))
display(Tools.display_index_picture_scaled(output_pict))
return self

53
iss/models/DataLoader.py Normal file
View File

@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-
from keras.preprocessing.image import ImageDataGenerator
class ImageDataGeneratorWrapper:
def __init__(self, config):
self.config = config
self.datagen = None
self.train_generator = None
self.test_generator = None
self.image_data_generator(config)
self.set_train_generator()
self.set_test_generator()
def image_data_generator(self, config):
self.datagen = ImageDataGenerator(
rescale = 1./255
)
return self
def build_generator(self, directory):
# voir plus tars si besoin de parametrer
return self.datagen.flow_from_directory(
directory,
target_size = (self.config.get('models')['simple']['input_height'], self.config.get('models')['simple']['input_width']),
color_mode = 'rgb',
classes = None,
class_mode = 'input',
batch_size = self.config.get('models')['simple']['batch_size'],
)
def set_train_generator(self):
train_dir = self.config.get('directory')['autoencoder']['train'] + '/..'
self.train_generator = self.build_generator(directory = train_dir)
return self
def get_train_generator(self):
return self.train_generator
def set_test_generator(self):
test_dir = self.config.get('directory')['autoencoder']['test'] + '/..'
self.test_generator = self.build_generator(directory = test_dir)
return self
def get_test_generator(self):
return self.train_generator

View File

@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
from keras.callbacks import ModelCheckpoint, CSVLogger
from iss.models.Callbacks import DisplayPictureCallback
from iss.tools.tools import Tools
class ModelTrainer:
def __init__(self, model, data_loader, config, callbacks=[]):
self.model = model
self.data_loader = data_loader
self.epochs = config['epochs']
self.verbose = config['verbose']
self.initial_epoch = config['initial_epoch']
self.workers = config['workers']
self.use_multiprocessing = config['use_multiprocessing']
self.steps_per_epoch = config['steps_per_epoch']
self.validation_steps = config['validation_steps']
self.validation_freq = config['validation_freq']
self.callbacks = callbacks
self.init_callbacks(config)
def train(self):
self.model.model.fit_generator(
generator = self.data_loader.get_train_generator(),
steps_per_epoch = self.steps_per_epoch,
epochs = self.epochs,
verbose = self.verbose,
initial_epoch = self.initial_epoch,
callbacks = self.callbacks,
workers = self.workers,
use_multiprocessing = self.use_multiprocessing,
validation_data = self.data_loader.get_test_generator(),
validation_steps = self.validation_steps,
validation_freq = self.validation_freq
)
self.model.save()
def init_callbacks(self, config):
if 'csv_logger' in config['callbacks']:
log_dir = config['callbacks']['csv_logger']['directory']
Tools.create_dir_if_not_exists(log_dir)
self.csv_logger = CSVLogger(
filename = '{}/training.log'.format(log_dir),
append = config['callbacks']['csv_logger']['append']
)
self.callbacks.extend([self.csv_logger])
if 'checkpoint' in config['callbacks']:
chekpt_dir = config['callbacks']['checkpoint']['directory']
Tools.create_dir_if_not_exists(chekpt_dir)
self.checkpointer = ModelCheckpoint(
filepath = chekpt_dir + '/' + self.model.model_name + '-{epoch:02d}.hdf5',
verbose = config['callbacks']['checkpoint']['verbose'],
period = config['callbacks']['checkpoint']['period']
)
self.callbacks.extend([self.checkpointer])
if 'display_picture' in config['callbacks']:
self.picture_displayer = DisplayPictureCallback(
model = self.model,
data_loader = self.data_loader.get_train_generator(),
epoch_laps = config['callbacks']['display_picture']['epoch_laps']
)
self.callbacks.extend([self.picture_displayer])
return self

View File

@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
from iss.models.AbstractModel import AbstractModel
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Reshape, Flatten
from keras.optimizers import Adadelta
from keras.models import Model
import numpy as np
class SimpleAutoEncoder(AbstractModel):
def __init__(self, config):
save_directory = config['save_directory']
model_name = config['model_name']
super().__init__(save_directory, model_name)
self.input_shape = (config['input_height'], config['input_width'], config['input_channel'])
self.lr = config['learning_rate']
self.build_model()
def build_model(self):
input_shape = self.input_shape
picture = Input(shape = input_shape)
x = Flatten()(picture)
layer_1 = Dense(2000, activation = 'relu', name = 'enc_1')(x)
layer_2 = Dense(100, activation = 'relu', name = 'enc_2')(layer_1)
layer_3 = Dense(30, activation = 'relu', name = 'enc_3')(layer_2)
layer_4 = Dense(100, activation = 'relu', name = 'dec_1')(layer_3)
layer_5 = Dense(2000, activation = 'relu', name = 'dec_2')(layer_4)
# encoded network
# x = Conv2D(1, (3, 3), activation = 'relu', padding = 'same', name = 'enc_conv_1')(picture)
# encoded = MaxPooling2D((2, 2))(x)
# decoded network
# x = Conv2D(1, (3, 3), activation = 'relu', padding = 'same', name = 'dec_conv_1')(encoded)
# x = UpSampling2D((2, 2))(x)
# x = Flatten()(x)
x = Dense(np.prod(input_shape), activation = 'softmax')(layer_5)
decoded = Reshape((input_shape))(x)
self.model = Model(picture, decoded)
optimizer = Adadelta(lr = self.lr, rho = 0.95, epsilon = None, decay = 0.0)
self.model.compile(optimizer = optimizer, loss = 'binary_crossentropy')