mirror of
https://github.com/prise6/smart-iss-posts
synced 2024-06-14 11:35:00 +02:00
nouveau modele de convolution
This commit is contained in:
parent
9be8896b22
commit
ebd7f74e98
|
@ -26,3 +26,10 @@ class AbstractModel:
|
|||
def predict_one(self, x, batch_size = 1, verbose = 0, steps = None):
|
||||
x = np.expand_dims(x, axis = 0)
|
||||
return self.predict(x, batch_size, verbose, steps)
|
||||
|
||||
|
||||
class AbstractAutoEncoderModel(AbstractModel):
|
||||
|
||||
def __init__(self, save_directory, model_name):
|
||||
super().__init__(save_directory, model_name)
|
||||
self.encoded_layer = None
|
||||
|
|
|
@ -4,9 +4,10 @@ from keras.preprocessing.image import ImageDataGenerator
|
|||
|
||||
class ImageDataGeneratorWrapper:
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, model):
|
||||
|
||||
self.config = config
|
||||
self.model = model
|
||||
self.datagen = None
|
||||
self.train_generator = None
|
||||
self.test_generator = None
|
||||
|
@ -26,11 +27,11 @@ class ImageDataGeneratorWrapper:
|
|||
# voir plus tars si besoin de parametrer
|
||||
return self.datagen.flow_from_directory(
|
||||
directory,
|
||||
target_size = (self.config.get('models')['simple']['input_height'], self.config.get('models')['simple']['input_width']),
|
||||
target_size = (self.config.get('models')[self.model]['input_height'], self.config.get('models')[self.model]['input_width']),
|
||||
color_mode = 'rgb',
|
||||
classes = None,
|
||||
class_mode = 'input',
|
||||
batch_size = self.config.get('models')['simple']['batch_size'],
|
||||
batch_size = self.config.get('models')[self.model]['batch_size'],
|
||||
)
|
||||
|
||||
def set_train_generator(self):
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from iss.models.AbstractModel import AbstractModel
|
||||
from iss.models import AbstractAutoEncoderModel
|
||||
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Reshape, Flatten
|
||||
from keras.optimizers import Adadelta, Adam
|
||||
from keras.models import Model
|
||||
import numpy as np
|
||||
|
||||
class SimpleAutoEncoder(AbstractModel):
|
||||
class SimpleAutoEncoder(AbstractAutoEncoderModel):
|
||||
|
||||
def __init__(self, config):
|
||||
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from iss.models.AbstractModel import AbstractModel
|
||||
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Reshape, Flatten
|
||||
from iss.models import AbstractAutoEncoderModel
|
||||
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Reshape, Flatten, BatchNormalization, Activation
|
||||
from keras.optimizers import Adadelta, Adam
|
||||
from keras.models import Model
|
||||
import numpy as np
|
||||
|
||||
class SimpleConvAutoEncoder(AbstractModel):
|
||||
class SimpleConvAutoEncoder(AbstractAutoEncoderModel):
|
||||
|
||||
def __init__(self, config):
|
||||
|
||||
|
@ -26,16 +26,39 @@ class SimpleConvAutoEncoder(AbstractModel):
|
|||
picture = Input(shape = input_shape)
|
||||
|
||||
# encoded network
|
||||
x = Conv2D(4, (3, 3), activation = 'relu', padding = 'same', name = 'enc_conv_1')(picture)
|
||||
x = Conv2D(64, (3, 3), padding = 'same', name = 'enc_conv_1')(picture)
|
||||
x = BatchNormalization()(x)
|
||||
x = Activation('relu')(x)
|
||||
x = MaxPooling2D((2, 2))(x)
|
||||
x = Conv2D(8, (3, 3), activation = 'relu', padding = 'same', name = 'enc_conv_2')(x)
|
||||
|
||||
x = Conv2D(32, (3, 3), padding = 'same', name = 'enc_conv_2')(x)
|
||||
x = BatchNormalization()(x)
|
||||
x = Activation('relu')(x)
|
||||
x = MaxPooling2D((2, 2))(x)
|
||||
|
||||
x = Conv2D(16, (3, 3), padding = 'same', name = 'enc_conv_3')(x)
|
||||
x = BatchNormalization()(x)
|
||||
x = Activation('relu')(x)
|
||||
encoded = MaxPooling2D((2, 2))(x)
|
||||
|
||||
# decoded network
|
||||
x = Conv2D(8, (3, 3), activation = 'relu', padding = 'same', name = 'dec_conv_1')(encoded)
|
||||
x = Conv2D(16, (3, 3), padding = 'same', name = 'dec_conv_1')(encoded)
|
||||
x = BatchNormalization()(x)
|
||||
x = Activation('relu')(x)
|
||||
x = UpSampling2D((2, 2))(x)
|
||||
x = Conv2D(4, (3, 3), activation = 'relu', padding = 'same', name = 'dec_conv_2')(x)
|
||||
|
||||
x = Conv2D(32, (3, 3), padding = 'same', name = 'dec_conv_2')(x)
|
||||
x = BatchNormalization()(x)
|
||||
x = Activation('relu')(x)
|
||||
x = UpSampling2D((2, 2))(x)
|
||||
|
||||
x = Conv2D(64, (3, 3), padding = 'same', name = 'dec_conv_3')(x)
|
||||
x = BatchNormalization()(x)
|
||||
x = Activation('relu')(x)
|
||||
x = UpSampling2D((2, 2))(x)
|
||||
|
||||
x = Conv2D(3, (3, 3), padding = 'same', name = 'dec_conv_4')(x)
|
||||
x = BatchNormalization()(x)
|
||||
x = Flatten()(x)
|
||||
x = Dense(np.prod(input_shape), activation = self.activation)(x)
|
||||
decoded = Reshape((input_shape))(x)
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
|
||||
from .AbstractModel import AbstractModel
|
||||
from .AbstractModel import AbstractAutoEncoderModel
|
||||
from .SimpleConvAutoEncoder import SimpleConvAutoEncoder
|
||||
from .SimpleAutoEncoder import SimpleAutoEncoder
|
Loading…
Reference in a new issue