class pour entrainer un autoencoder
This commit is contained in:
parent
bcaeb4cb1f
commit
88acd0d60d
|
@ -0,0 +1,28 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from keras.models import load_model
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
class AbstractModel:
|
||||
def __init__(self, save_directory, model_name):
|
||||
self.save_directory = save_directory
|
||||
self.model = None
|
||||
self.model_name = model_name
|
||||
|
||||
def save(self):
|
||||
if not os.path.exists(self.save_directory):
|
||||
os.makedirs(self.save_directory)
|
||||
|
||||
self.model.save('{}/final_model.hdf5'.format(self.save_directory))
|
||||
|
||||
|
||||
def load(self, which = 'final_model'):
|
||||
self.model = load_model('{}/{}.hdf5'.format(self.save_directory, which))
|
||||
|
||||
def predict(self, x, batch_size = None, verbose = 0, steps = None, callbacks = None):
|
||||
return self.model.predict(x, batch_size, verbose, steps)
|
||||
|
||||
def predict_one(self, x, batch_size = 1, verbose = 0, steps = None):
|
||||
x = np.expand_dims(x, axis = 0)
|
||||
return self.predict(x, batch_size, verbose, steps)
|
|
@ -0,0 +1,31 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from keras.callbacks import Callback
|
||||
import numpy as np
|
||||
from iss.tools.tools import Tools
|
||||
from IPython.display import display
|
||||
|
||||
class DisplayPictureCallback(Callback):
|
||||
|
||||
def __init__(self, model, epoch_laps, data_loader):
|
||||
|
||||
self.model_class = model
|
||||
self.epoch_laps = epoch_laps
|
||||
self.data_loader = data_loader
|
||||
super(DisplayPictureCallback, self).__init__()
|
||||
|
||||
|
||||
def on_epoch_end(self, epoch, logs):
|
||||
if epoch % self.epoch_laps == 0:
|
||||
|
||||
print("ok")
|
||||
|
||||
input_pict = self.data_loader.next()[0][1]
|
||||
output_pict = self.model_class.predict_one(input_pict)
|
||||
|
||||
display(Tools.display_one_picture_scaled(input_pict))
|
||||
display(Tools.display_index_picture_scaled(output_pict))
|
||||
|
||||
return self
|
||||
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from keras.preprocessing.image import ImageDataGenerator
|
||||
|
||||
class ImageDataGeneratorWrapper:
|
||||
|
||||
def __init__(self, config):
|
||||
|
||||
self.config = config
|
||||
self.datagen = None
|
||||
self.train_generator = None
|
||||
self.test_generator = None
|
||||
|
||||
self.image_data_generator(config)
|
||||
|
||||
self.set_train_generator()
|
||||
self.set_test_generator()
|
||||
|
||||
def image_data_generator(self, config):
|
||||
self.datagen = ImageDataGenerator(
|
||||
rescale = 1./255
|
||||
)
|
||||
return self
|
||||
|
||||
def build_generator(self, directory):
|
||||
# voir plus tars si besoin de parametrer
|
||||
return self.datagen.flow_from_directory(
|
||||
directory,
|
||||
target_size = (self.config.get('models')['simple']['input_height'], self.config.get('models')['simple']['input_width']),
|
||||
color_mode = 'rgb',
|
||||
classes = None,
|
||||
class_mode = 'input',
|
||||
batch_size = self.config.get('models')['simple']['batch_size'],
|
||||
)
|
||||
|
||||
def set_train_generator(self):
|
||||
train_dir = self.config.get('directory')['autoencoder']['train'] + '/..'
|
||||
self.train_generator = self.build_generator(directory = train_dir)
|
||||
return self
|
||||
|
||||
def get_train_generator(self):
|
||||
return self.train_generator
|
||||
|
||||
def set_test_generator(self):
|
||||
test_dir = self.config.get('directory')['autoencoder']['test'] + '/..'
|
||||
self.test_generator = self.build_generator(directory = test_dir)
|
||||
return self
|
||||
|
||||
def get_test_generator(self):
|
||||
return self.train_generator
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from keras.callbacks import ModelCheckpoint, CSVLogger
|
||||
from iss.models.Callbacks import DisplayPictureCallback
|
||||
from iss.tools.tools import Tools
|
||||
|
||||
|
||||
class ModelTrainer:
|
||||
|
||||
def __init__(self, model, data_loader, config, callbacks=[]):
|
||||
self.model = model
|
||||
self.data_loader = data_loader
|
||||
|
||||
self.epochs = config['epochs']
|
||||
self.verbose = config['verbose']
|
||||
self.initial_epoch = config['initial_epoch']
|
||||
self.workers = config['workers']
|
||||
self.use_multiprocessing = config['use_multiprocessing']
|
||||
self.steps_per_epoch = config['steps_per_epoch']
|
||||
self.validation_steps = config['validation_steps']
|
||||
self.validation_freq = config['validation_freq']
|
||||
self.callbacks = callbacks
|
||||
|
||||
self.init_callbacks(config)
|
||||
|
||||
def train(self):
|
||||
|
||||
self.model.model.fit_generator(
|
||||
generator = self.data_loader.get_train_generator(),
|
||||
steps_per_epoch = self.steps_per_epoch,
|
||||
epochs = self.epochs,
|
||||
verbose = self.verbose,
|
||||
initial_epoch = self.initial_epoch,
|
||||
callbacks = self.callbacks,
|
||||
workers = self.workers,
|
||||
use_multiprocessing = self.use_multiprocessing,
|
||||
validation_data = self.data_loader.get_test_generator(),
|
||||
validation_steps = self.validation_steps,
|
||||
validation_freq = self.validation_freq
|
||||
)
|
||||
|
||||
self.model.save()
|
||||
|
||||
def init_callbacks(self, config):
|
||||
|
||||
if 'csv_logger' in config['callbacks']:
|
||||
log_dir = config['callbacks']['csv_logger']['directory']
|
||||
Tools.create_dir_if_not_exists(log_dir)
|
||||
|
||||
self.csv_logger = CSVLogger(
|
||||
filename = '{}/training.log'.format(log_dir),
|
||||
append = config['callbacks']['csv_logger']['append']
|
||||
)
|
||||
self.callbacks.extend([self.csv_logger])
|
||||
|
||||
if 'checkpoint' in config['callbacks']:
|
||||
chekpt_dir = config['callbacks']['checkpoint']['directory']
|
||||
Tools.create_dir_if_not_exists(chekpt_dir)
|
||||
self.checkpointer = ModelCheckpoint(
|
||||
filepath = chekpt_dir + '/' + self.model.model_name + '-{epoch:02d}.hdf5',
|
||||
verbose = config['callbacks']['checkpoint']['verbose'],
|
||||
period = config['callbacks']['checkpoint']['period']
|
||||
)
|
||||
self.callbacks.extend([self.checkpointer])
|
||||
|
||||
if 'display_picture' in config['callbacks']:
|
||||
self.picture_displayer = DisplayPictureCallback(
|
||||
model = self.model,
|
||||
data_loader = self.data_loader.get_train_generator(),
|
||||
epoch_laps = config['callbacks']['display_picture']['epoch_laps']
|
||||
)
|
||||
self.callbacks.extend([self.picture_displayer])
|
||||
|
||||
return self
|
|
@ -0,0 +1,49 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from iss.models.AbstractModel import AbstractModel
|
||||
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Reshape, Flatten
|
||||
from keras.optimizers import Adadelta
|
||||
from keras.models import Model
|
||||
import numpy as np
|
||||
|
||||
class SimpleAutoEncoder(AbstractModel):
|
||||
|
||||
def __init__(self, config):
|
||||
|
||||
save_directory = config['save_directory']
|
||||
model_name = config['model_name']
|
||||
|
||||
super().__init__(save_directory, model_name)
|
||||
|
||||
self.input_shape = (config['input_height'], config['input_width'], config['input_channel'])
|
||||
self.lr = config['learning_rate']
|
||||
self.build_model()
|
||||
|
||||
def build_model(self):
|
||||
input_shape = self.input_shape
|
||||
|
||||
picture = Input(shape = input_shape)
|
||||
|
||||
x = Flatten()(picture)
|
||||
layer_1 = Dense(2000, activation = 'relu', name = 'enc_1')(x)
|
||||
layer_2 = Dense(100, activation = 'relu', name = 'enc_2')(layer_1)
|
||||
layer_3 = Dense(30, activation = 'relu', name = 'enc_3')(layer_2)
|
||||
layer_4 = Dense(100, activation = 'relu', name = 'dec_1')(layer_3)
|
||||
layer_5 = Dense(2000, activation = 'relu', name = 'dec_2')(layer_4)
|
||||
|
||||
# encoded network
|
||||
# x = Conv2D(1, (3, 3), activation = 'relu', padding = 'same', name = 'enc_conv_1')(picture)
|
||||
# encoded = MaxPooling2D((2, 2))(x)
|
||||
|
||||
# decoded network
|
||||
# x = Conv2D(1, (3, 3), activation = 'relu', padding = 'same', name = 'dec_conv_1')(encoded)
|
||||
# x = UpSampling2D((2, 2))(x)
|
||||
# x = Flatten()(x)
|
||||
x = Dense(np.prod(input_shape), activation = 'softmax')(layer_5)
|
||||
decoded = Reshape((input_shape))(x)
|
||||
|
||||
self.model = Model(picture, decoded)
|
||||
|
||||
optimizer = Adadelta(lr = self.lr, rho = 0.95, epsilon = None, decay = 0.0)
|
||||
|
||||
self.model.compile(optimizer = optimizer, loss = 'binary_crossentropy')
|
Loading…
Reference in New Issue