1
0
Fork 0
mirror of https://github.com/prise6/smart-iss-posts synced 2024-06-29 02:20:01 +02:00
smart-iss-posts/iss/models/SimpleAutoEncoder.py

53 lines
1.9 KiB
Python
Raw Normal View History

2019-03-09 22:41:37 +01:00
# -*- coding: utf-8 -*-
2019-03-13 15:25:46 +01:00
from iss.models import AbstractAutoEncoderModel
2019-03-09 22:41:37 +01:00
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Reshape, Flatten
2019-03-10 19:36:42 +01:00
from keras.optimizers import Adadelta, Adam
2019-03-09 22:41:37 +01:00
from keras.models import Model
import numpy as np
2019-03-13 15:25:46 +01:00
class SimpleAutoEncoder(AbstractAutoEncoderModel):
2019-03-09 22:41:37 +01:00
def __init__(self, config):
save_directory = config['save_directory']
model_name = config['model_name']
super().__init__(save_directory, model_name)
2019-03-10 18:49:10 +01:00
self.activation = config['activation']
2019-03-09 22:41:37 +01:00
self.input_shape = (config['input_height'], config['input_width'], config['input_channel'])
self.latent_shape = config['latent_shape']
2019-03-09 22:41:37 +01:00
self.lr = config['learning_rate']
self.build_model()
def build_model(self):
input_shape = self.input_shape
picture = Input(shape = input_shape)
# encoded network
2019-03-09 22:41:37 +01:00
x = Flatten()(picture)
2019-03-10 19:36:42 +01:00
layer_1 = Dense(1000, activation = 'relu', name = 'enc_1')(x)
2019-03-09 22:41:37 +01:00
layer_2 = Dense(100, activation = 'relu', name = 'enc_2')(layer_1)
encoded = Dense(self.latent_shape, activation = 'relu', name = 'enc_3')(layer_2)
self.encoder_model = Model(picture, encoded, name = "encoder")
2019-03-09 22:41:37 +01:00
# decoded netword
latent_input = Input(shape = (self.latent_shape,))
layer_4 = Dense(100, activation = 'relu', name = 'dec_1')(latent_input)
layer_5 = Dense(1000, activation = 'relu', name = 'dec_2')(layer_4)
2019-03-09 22:41:37 +01:00
2019-03-10 18:49:10 +01:00
x = Dense(np.prod(input_shape), activation = self.activation)(layer_5)
2019-03-09 22:41:37 +01:00
decoded = Reshape((input_shape))(x)
self.decoder_model = Model(latent_input, decoded, name = "decoder")
picture_dec = self.decoder_model(self.encoder_model(picture))
self.model = Model(picture, picture_dec)
2019-03-09 22:41:37 +01:00
2019-03-10 19:36:42 +01:00
# optimizer = Adadelta(lr = self.lr, rho = 0.95, epsilon = None, decay = 0.0)
optimizer = Adam(lr = 0.001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
2019-03-09 22:41:37 +01:00
self.model.compile(optimizer = optimizer, loss = 'binary_crossentropy')