mirror of
https://github.com/prise6/smart-iss-posts
synced 2024-05-18 21:36:33 +02:00
class pour entrainer un autoencoder
This commit is contained in:
parent
bcaeb4cb1f
commit
88acd0d60d
28
iss/models/AbstractModel.py
Normal file
28
iss/models/AbstractModel.py
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from keras.models import load_model
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
|
class AbstractModel:
|
||||||
|
def __init__(self, save_directory, model_name):
|
||||||
|
self.save_directory = save_directory
|
||||||
|
self.model = None
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
if not os.path.exists(self.save_directory):
|
||||||
|
os.makedirs(self.save_directory)
|
||||||
|
|
||||||
|
self.model.save('{}/final_model.hdf5'.format(self.save_directory))
|
||||||
|
|
||||||
|
|
||||||
|
def load(self, which = 'final_model'):
|
||||||
|
self.model = load_model('{}/{}.hdf5'.format(self.save_directory, which))
|
||||||
|
|
||||||
|
def predict(self, x, batch_size = None, verbose = 0, steps = None, callbacks = None):
|
||||||
|
return self.model.predict(x, batch_size, verbose, steps)
|
||||||
|
|
||||||
|
def predict_one(self, x, batch_size = 1, verbose = 0, steps = None):
|
||||||
|
x = np.expand_dims(x, axis = 0)
|
||||||
|
return self.predict(x, batch_size, verbose, steps)
|
31
iss/models/Callbacks.py
Normal file
31
iss/models/Callbacks.py
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from keras.callbacks import Callback
|
||||||
|
import numpy as np
|
||||||
|
from iss.tools.tools import Tools
|
||||||
|
from IPython.display import display
|
||||||
|
|
||||||
|
class DisplayPictureCallback(Callback):
|
||||||
|
|
||||||
|
def __init__(self, model, epoch_laps, data_loader):
|
||||||
|
|
||||||
|
self.model_class = model
|
||||||
|
self.epoch_laps = epoch_laps
|
||||||
|
self.data_loader = data_loader
|
||||||
|
super(DisplayPictureCallback, self).__init__()
|
||||||
|
|
||||||
|
|
||||||
|
def on_epoch_end(self, epoch, logs):
|
||||||
|
if epoch % self.epoch_laps == 0:
|
||||||
|
|
||||||
|
print("ok")
|
||||||
|
|
||||||
|
input_pict = self.data_loader.next()[0][1]
|
||||||
|
output_pict = self.model_class.predict_one(input_pict)
|
||||||
|
|
||||||
|
display(Tools.display_one_picture_scaled(input_pict))
|
||||||
|
display(Tools.display_index_picture_scaled(output_pict))
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
53
iss/models/DataLoader.py
Normal file
53
iss/models/DataLoader.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from keras.preprocessing.image import ImageDataGenerator
|
||||||
|
|
||||||
|
class ImageDataGeneratorWrapper:
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.datagen = None
|
||||||
|
self.train_generator = None
|
||||||
|
self.test_generator = None
|
||||||
|
|
||||||
|
self.image_data_generator(config)
|
||||||
|
|
||||||
|
self.set_train_generator()
|
||||||
|
self.set_test_generator()
|
||||||
|
|
||||||
|
def image_data_generator(self, config):
|
||||||
|
self.datagen = ImageDataGenerator(
|
||||||
|
rescale = 1./255
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def build_generator(self, directory):
|
||||||
|
# voir plus tars si besoin de parametrer
|
||||||
|
return self.datagen.flow_from_directory(
|
||||||
|
directory,
|
||||||
|
target_size = (self.config.get('models')['simple']['input_height'], self.config.get('models')['simple']['input_width']),
|
||||||
|
color_mode = 'rgb',
|
||||||
|
classes = None,
|
||||||
|
class_mode = 'input',
|
||||||
|
batch_size = self.config.get('models')['simple']['batch_size'],
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_train_generator(self):
|
||||||
|
train_dir = self.config.get('directory')['autoencoder']['train'] + '/..'
|
||||||
|
self.train_generator = self.build_generator(directory = train_dir)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def get_train_generator(self):
|
||||||
|
return self.train_generator
|
||||||
|
|
||||||
|
def set_test_generator(self):
|
||||||
|
test_dir = self.config.get('directory')['autoencoder']['test'] + '/..'
|
||||||
|
self.test_generator = self.build_generator(directory = test_dir)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def get_test_generator(self):
|
||||||
|
return self.train_generator
|
||||||
|
|
||||||
|
|
||||||
|
|
74
iss/models/ModelTrainer.py
Normal file
74
iss/models/ModelTrainer.py
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from keras.callbacks import ModelCheckpoint, CSVLogger
|
||||||
|
from iss.models.Callbacks import DisplayPictureCallback
|
||||||
|
from iss.tools.tools import Tools
|
||||||
|
|
||||||
|
|
||||||
|
class ModelTrainer:
|
||||||
|
|
||||||
|
def __init__(self, model, data_loader, config, callbacks=[]):
|
||||||
|
self.model = model
|
||||||
|
self.data_loader = data_loader
|
||||||
|
|
||||||
|
self.epochs = config['epochs']
|
||||||
|
self.verbose = config['verbose']
|
||||||
|
self.initial_epoch = config['initial_epoch']
|
||||||
|
self.workers = config['workers']
|
||||||
|
self.use_multiprocessing = config['use_multiprocessing']
|
||||||
|
self.steps_per_epoch = config['steps_per_epoch']
|
||||||
|
self.validation_steps = config['validation_steps']
|
||||||
|
self.validation_freq = config['validation_freq']
|
||||||
|
self.callbacks = callbacks
|
||||||
|
|
||||||
|
self.init_callbacks(config)
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
|
||||||
|
self.model.model.fit_generator(
|
||||||
|
generator = self.data_loader.get_train_generator(),
|
||||||
|
steps_per_epoch = self.steps_per_epoch,
|
||||||
|
epochs = self.epochs,
|
||||||
|
verbose = self.verbose,
|
||||||
|
initial_epoch = self.initial_epoch,
|
||||||
|
callbacks = self.callbacks,
|
||||||
|
workers = self.workers,
|
||||||
|
use_multiprocessing = self.use_multiprocessing,
|
||||||
|
validation_data = self.data_loader.get_test_generator(),
|
||||||
|
validation_steps = self.validation_steps,
|
||||||
|
validation_freq = self.validation_freq
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model.save()
|
||||||
|
|
||||||
|
def init_callbacks(self, config):
|
||||||
|
|
||||||
|
if 'csv_logger' in config['callbacks']:
|
||||||
|
log_dir = config['callbacks']['csv_logger']['directory']
|
||||||
|
Tools.create_dir_if_not_exists(log_dir)
|
||||||
|
|
||||||
|
self.csv_logger = CSVLogger(
|
||||||
|
filename = '{}/training.log'.format(log_dir),
|
||||||
|
append = config['callbacks']['csv_logger']['append']
|
||||||
|
)
|
||||||
|
self.callbacks.extend([self.csv_logger])
|
||||||
|
|
||||||
|
if 'checkpoint' in config['callbacks']:
|
||||||
|
chekpt_dir = config['callbacks']['checkpoint']['directory']
|
||||||
|
Tools.create_dir_if_not_exists(chekpt_dir)
|
||||||
|
self.checkpointer = ModelCheckpoint(
|
||||||
|
filepath = chekpt_dir + '/' + self.model.model_name + '-{epoch:02d}.hdf5',
|
||||||
|
verbose = config['callbacks']['checkpoint']['verbose'],
|
||||||
|
period = config['callbacks']['checkpoint']['period']
|
||||||
|
)
|
||||||
|
self.callbacks.extend([self.checkpointer])
|
||||||
|
|
||||||
|
if 'display_picture' in config['callbacks']:
|
||||||
|
self.picture_displayer = DisplayPictureCallback(
|
||||||
|
model = self.model,
|
||||||
|
data_loader = self.data_loader.get_train_generator(),
|
||||||
|
epoch_laps = config['callbacks']['display_picture']['epoch_laps']
|
||||||
|
)
|
||||||
|
self.callbacks.extend([self.picture_displayer])
|
||||||
|
|
||||||
|
return self
|
49
iss/models/SimpleAutoEncoder.py
Normal file
49
iss/models/SimpleAutoEncoder.py
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from iss.models.AbstractModel import AbstractModel
|
||||||
|
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Reshape, Flatten
|
||||||
|
from keras.optimizers import Adadelta
|
||||||
|
from keras.models import Model
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class SimpleAutoEncoder(AbstractModel):
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
|
||||||
|
save_directory = config['save_directory']
|
||||||
|
model_name = config['model_name']
|
||||||
|
|
||||||
|
super().__init__(save_directory, model_name)
|
||||||
|
|
||||||
|
self.input_shape = (config['input_height'], config['input_width'], config['input_channel'])
|
||||||
|
self.lr = config['learning_rate']
|
||||||
|
self.build_model()
|
||||||
|
|
||||||
|
def build_model(self):
|
||||||
|
input_shape = self.input_shape
|
||||||
|
|
||||||
|
picture = Input(shape = input_shape)
|
||||||
|
|
||||||
|
x = Flatten()(picture)
|
||||||
|
layer_1 = Dense(2000, activation = 'relu', name = 'enc_1')(x)
|
||||||
|
layer_2 = Dense(100, activation = 'relu', name = 'enc_2')(layer_1)
|
||||||
|
layer_3 = Dense(30, activation = 'relu', name = 'enc_3')(layer_2)
|
||||||
|
layer_4 = Dense(100, activation = 'relu', name = 'dec_1')(layer_3)
|
||||||
|
layer_5 = Dense(2000, activation = 'relu', name = 'dec_2')(layer_4)
|
||||||
|
|
||||||
|
# encoded network
|
||||||
|
# x = Conv2D(1, (3, 3), activation = 'relu', padding = 'same', name = 'enc_conv_1')(picture)
|
||||||
|
# encoded = MaxPooling2D((2, 2))(x)
|
||||||
|
|
||||||
|
# decoded network
|
||||||
|
# x = Conv2D(1, (3, 3), activation = 'relu', padding = 'same', name = 'dec_conv_1')(encoded)
|
||||||
|
# x = UpSampling2D((2, 2))(x)
|
||||||
|
# x = Flatten()(x)
|
||||||
|
x = Dense(np.prod(input_shape), activation = 'softmax')(layer_5)
|
||||||
|
decoded = Reshape((input_shape))(x)
|
||||||
|
|
||||||
|
self.model = Model(picture, decoded)
|
||||||
|
|
||||||
|
optimizer = Adadelta(lr = self.lr, rho = 0.95, epsilon = None, decay = 0.0)
|
||||||
|
|
||||||
|
self.model.compile(optimizer = optimizer, loss = 'binary_crossentropy')
|
Loading…
Reference in a new issue