diff --git a/iss/models/AbstractModel.py b/iss/models/AbstractModel.py index 553a031..35c8325 100644 --- a/iss/models/AbstractModel.py +++ b/iss/models/AbstractModel.py @@ -26,3 +26,10 @@ class AbstractModel: def predict_one(self, x, batch_size = 1, verbose = 0, steps = None): x = np.expand_dims(x, axis = 0) return self.predict(x, batch_size, verbose, steps) + + +class AbstractAutoEncoderModel(AbstractModel): + + def __init__(self, save_directory, model_name): + super().__init__(save_directory, model_name) + self.encoded_layer = None diff --git a/iss/models/DataLoader.py b/iss/models/DataLoader.py index 8ba8e4f..6a60e25 100644 --- a/iss/models/DataLoader.py +++ b/iss/models/DataLoader.py @@ -4,9 +4,10 @@ from keras.preprocessing.image import ImageDataGenerator class ImageDataGeneratorWrapper: - def __init__(self, config): + def __init__(self, config, model): self.config = config + self.model = model self.datagen = None self.train_generator = None self.test_generator = None @@ -26,11 +27,11 @@ class ImageDataGeneratorWrapper: # voir plus tars si besoin de parametrer return self.datagen.flow_from_directory( directory, - target_size = (self.config.get('models')['simple']['input_height'], self.config.get('models')['simple']['input_width']), + target_size = (self.config.get('models')[self.model]['input_height'], self.config.get('models')[self.model]['input_width']), color_mode = 'rgb', classes = None, class_mode = 'input', - batch_size = self.config.get('models')['simple']['batch_size'], + batch_size = self.config.get('models')[self.model]['batch_size'], ) def set_train_generator(self): diff --git a/iss/models/SimpleAutoEncoder.py b/iss/models/SimpleAutoEncoder.py index 3309084..73b23c1 100644 --- a/iss/models/SimpleAutoEncoder.py +++ b/iss/models/SimpleAutoEncoder.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- -from iss.models.AbstractModel import AbstractModel +from iss.models import AbstractAutoEncoderModel from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Reshape, Flatten from keras.optimizers import Adadelta, Adam from keras.models import Model import numpy as np -class SimpleAutoEncoder(AbstractModel): +class SimpleAutoEncoder(AbstractAutoEncoderModel): def __init__(self, config): diff --git a/iss/models/SimpleConvAutoEncoder.py b/iss/models/SimpleConvAutoEncoder.py index a40d5ba..ea25771 100644 --- a/iss/models/SimpleConvAutoEncoder.py +++ b/iss/models/SimpleConvAutoEncoder.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- -from iss.models.AbstractModel import AbstractModel -from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Reshape, Flatten +from iss.models import AbstractAutoEncoderModel +from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Reshape, Flatten, BatchNormalization, Activation from keras.optimizers import Adadelta, Adam from keras.models import Model import numpy as np -class SimpleConvAutoEncoder(AbstractModel): +class SimpleConvAutoEncoder(AbstractAutoEncoderModel): def __init__(self, config): @@ -26,16 +26,39 @@ class SimpleConvAutoEncoder(AbstractModel): picture = Input(shape = input_shape) # encoded network - x = Conv2D(4, (3, 3), activation = 'relu', padding = 'same', name = 'enc_conv_1')(picture) + x = Conv2D(64, (3, 3), padding = 'same', name = 'enc_conv_1')(picture) + x = BatchNormalization()(x) + x = Activation('relu')(x) x = MaxPooling2D((2, 2))(x) - x = Conv2D(8, (3, 3), activation = 'relu', padding = 'same', name = 'enc_conv_2')(x) + + x = Conv2D(32, (3, 3), padding = 'same', name = 'enc_conv_2')(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + x = MaxPooling2D((2, 2))(x) + + x = Conv2D(16, (3, 3), padding = 'same', name = 'enc_conv_3')(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) encoded = MaxPooling2D((2, 2))(x) # decoded network - x = Conv2D(8, (3, 3), activation = 'relu', padding = 'same', name = 'dec_conv_1')(encoded) + x = Conv2D(16, (3, 3), padding = 'same', name = 'dec_conv_1')(encoded) + x = BatchNormalization()(x) + x = Activation('relu')(x) x = UpSampling2D((2, 2))(x) - x = Conv2D(4, (3, 3), activation = 'relu', padding = 'same', name = 'dec_conv_2')(x) + + x = Conv2D(32, (3, 3), padding = 'same', name = 'dec_conv_2')(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) x = UpSampling2D((2, 2))(x) + + x = Conv2D(64, (3, 3), padding = 'same', name = 'dec_conv_3')(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + x = UpSampling2D((2, 2))(x) + + x = Conv2D(3, (3, 3), padding = 'same', name = 'dec_conv_4')(x) + x = BatchNormalization()(x) x = Flatten()(x) x = Dense(np.prod(input_shape), activation = self.activation)(x) decoded = Reshape((input_shape))(x) diff --git a/iss/models/__init__.py b/iss/models/__init__.py index cbc4cfa..027e674 100644 --- a/iss/models/__init__.py +++ b/iss/models/__init__.py @@ -1,3 +1,5 @@ +from .AbstractModel import AbstractModel +from .AbstractModel import AbstractAutoEncoderModel from .SimpleConvAutoEncoder import SimpleConvAutoEncoder -from .SimpleAutoEncoder import SimpleAutoEncoder \ No newline at end of file +from .SimpleAutoEncoder import SimpleAutoEncoder