1
0
Fork 0
mirror of https://github.com/prise6/smart-iss-posts synced 2024-04-27 11:31:51 +02:00

variational model

This commit is contained in:
Francois Vieille 2019-03-14 10:46:59 +01:00
parent 624150da81
commit 04713d43d1
2 changed files with 73 additions and 0 deletions

View file

@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-
from iss.models import AbstractAutoEncoderModel
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Reshape, Flatten, BatchNormalization, Activation, Lambda
from keras.optimizers import Adadelta, Adam
from keras.models import Model
from keras.losses import binary_crossentropy
from keras import backend as K
import numpy as np
class VarAutoEncoder(AbstractAutoEncoderModel):
def __init__(self, config):
save_directory = config['save_directory']
model_name = config['model_name']
super().__init__(save_directory, model_name)
self.activation = config['activation']
self.input_shape = (config['input_height'], config['input_width'], config['input_channel'])
self.latent_shape = config['latent_shape']
self.lr = config['learning_rate']
self.build_model()
def sampling(self, args):
z_mean, z_log_var = args
batch = K.shape(z_mean)[0]
dim = K.int_shape(z_mean)[1]
epsilon = K.random_normal(shape=(batch, dim))
return z_mean + K.exp(0.5 * z_log_var) * epsilon
def build_model(self):
input_shape = self.input_shape
latent_shape = self.latent_shape
picture = Input(shape = input_shape)
x = Flatten()(picture)
x = Dense(1000, activation = 'relu', name = 'enc_1')(x)
x = Dense(100, activation = 'relu', name = 'enc_2')(x)
z_mean = Dense(latent_shape, name = 'enc_z_mean')(x)
z_log_var = Dense(latent_shape, name = 'enc_z_log_var')(x)
z = Lambda(self.sampling, name='enc_z')([z_mean, z_log_var])
self.encoder_model = Model(picture, [z_mean, z_log_var, z], name = "encoder")
latent_input = Input(shape = (latent_shape,), name = "enc_z_sampling")
x = Dense(100, activation = 'relu', name = 'dec_1')(latent_input)
x = Dense(1000, activation = 'relu', name = 'dec_2')(x)
x = Dense(np.prod(input_shape), activation = self.activation)(x)
decoded = Reshape((input_shape))(x)
self.decoder_model = Model(latent_input, decoded, name = "decoder")
picture_dec = self.decoder_model(self.encoder_model(picture)[2])
self.model = Model(picture, picture_dec, name = "autoencoder")
def my_loss(picture, picture_dec):
xent_loss = K.mean(K.binary_crossentropy(picture, picture_dec), axis = (-1, -2, -3))
kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
loss = K.mean(xent_loss + kl_loss)
return loss
optimizer = Adam(lr = self.lr, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
self.model.compile(optimizer = optimizer, loss = my_loss)

View file

@ -3,3 +3,4 @@ from .AbstractModel import AbstractModel
from .AbstractModel import AbstractAutoEncoderModel from .AbstractModel import AbstractAutoEncoderModel
from .SimpleConvAutoEncoder import SimpleConvAutoEncoder from .SimpleConvAutoEncoder import SimpleConvAutoEncoder
from .SimpleAutoEncoder import SimpleAutoEncoder from .SimpleAutoEncoder import SimpleAutoEncoder
from .VariationalAutoEncoder import VarAutoEncoder