1
0
Fork 0
mirror of https://github.com/prise6/smart-iss-posts synced 2024-05-05 07:03:10 +02:00

re write abstract class

This commit is contained in:
Francois Vieille 2019-03-13 22:55:32 +01:00
parent 0efb5bf975
commit 9c76c82920
2 changed files with 14 additions and 4 deletions

View file

@ -32,4 +32,5 @@ class AbstractAutoEncoderModel(AbstractModel):
def __init__(self, save_directory, model_name): def __init__(self, save_directory, model_name):
super().__init__(save_directory, model_name) super().__init__(save_directory, model_name)
self.encoded_layer = None self.encoder_model = None
self.decoder_model = None

View file

@ -4,6 +4,7 @@ from iss.models import AbstractAutoEncoderModel
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Reshape, Flatten, BatchNormalization, Activation from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Reshape, Flatten, BatchNormalization, Activation
from keras.optimizers import Adadelta, Adam from keras.optimizers import Adadelta, Adam
from keras.models import Model from keras.models import Model
from keras import backend as K
import numpy as np import numpy as np
class SimpleConvAutoEncoder(AbstractAutoEncoderModel): class SimpleConvAutoEncoder(AbstractAutoEncoderModel):
@ -17,11 +18,13 @@ class SimpleConvAutoEncoder(AbstractAutoEncoderModel):
self.activation = config['activation'] self.activation = config['activation']
self.input_shape = (config['input_height'], config['input_width'], config['input_channel']) self.input_shape = (config['input_height'], config['input_width'], config['input_channel'])
self.latent_shape = (config['latent_height'], config['latent_width'], config['latent_channel'])
self.lr = config['learning_rate'] self.lr = config['learning_rate']
self.build_model() self.build_model()
def build_model(self): def build_model(self):
input_shape = self.input_shape input_shape = self.input_shape
latent_shape = self.latent_shape
picture = Input(shape = input_shape) picture = Input(shape = input_shape)
@ -41,8 +44,12 @@ class SimpleConvAutoEncoder(AbstractAutoEncoderModel):
x = Activation('relu')(x) x = Activation('relu')(x)
encoded = MaxPooling2D((2, 2))(x) encoded = MaxPooling2D((2, 2))(x)
self.encoder_model = Model(picture, encoded, name = "encoder")
# decoded network # decoded network
x = Conv2D(16, (3, 3), padding = 'same', name = 'dec_conv_1')(encoded) latent_input = Input(shape = latent_shape)
x = Conv2D(16, (3, 3), padding = 'same', name = 'dec_conv_1')(latent_input)
x = BatchNormalization()(x) x = BatchNormalization()(x)
x = Activation('relu')(x) x = Activation('relu')(x)
x = UpSampling2D((2, 2))(x) x = UpSampling2D((2, 2))(x)
@ -63,9 +70,11 @@ class SimpleConvAutoEncoder(AbstractAutoEncoderModel):
x = Dense(np.prod(input_shape), activation = self.activation)(x) x = Dense(np.prod(input_shape), activation = self.activation)(x)
decoded = Reshape((input_shape))(x) decoded = Reshape((input_shape))(x)
self.model = Model(picture, decoded) self.decoder_model = Model(latent_input, decoded, name = "decoder")
picture_dec = self.decoder_model(self.encoder_model(picture))
self.model = Model(picture, picture_dec)
# optimizer = Adadelta(lr = self.lr, rho = 0.95, epsilon = None, decay = 0.0)
optimizer = Adam(lr = 0.001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False) optimizer = Adam(lr = 0.001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
self.model.compile(optimizer = optimizer, loss = 'binary_crossentropy') self.model.compile(optimizer = optimizer, loss = 'binary_crossentropy')