mirror of
https://github.com/prise6/smart-iss-posts
synced 2024-05-05 07:03:10 +02:00
re write abstract class
This commit is contained in:
parent
0efb5bf975
commit
9c76c82920
|
@ -32,4 +32,5 @@ class AbstractAutoEncoderModel(AbstractModel):
|
||||||
|
|
||||||
def __init__(self, save_directory, model_name):
|
def __init__(self, save_directory, model_name):
|
||||||
super().__init__(save_directory, model_name)
|
super().__init__(save_directory, model_name)
|
||||||
self.encoded_layer = None
|
self.encoder_model = None
|
||||||
|
self.decoder_model = None
|
||||||
|
|
|
@ -4,6 +4,7 @@ from iss.models import AbstractAutoEncoderModel
|
||||||
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Reshape, Flatten, BatchNormalization, Activation
|
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Reshape, Flatten, BatchNormalization, Activation
|
||||||
from keras.optimizers import Adadelta, Adam
|
from keras.optimizers import Adadelta, Adam
|
||||||
from keras.models import Model
|
from keras.models import Model
|
||||||
|
from keras import backend as K
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
class SimpleConvAutoEncoder(AbstractAutoEncoderModel):
|
class SimpleConvAutoEncoder(AbstractAutoEncoderModel):
|
||||||
|
@ -17,11 +18,13 @@ class SimpleConvAutoEncoder(AbstractAutoEncoderModel):
|
||||||
|
|
||||||
self.activation = config['activation']
|
self.activation = config['activation']
|
||||||
self.input_shape = (config['input_height'], config['input_width'], config['input_channel'])
|
self.input_shape = (config['input_height'], config['input_width'], config['input_channel'])
|
||||||
|
self.latent_shape = (config['latent_height'], config['latent_width'], config['latent_channel'])
|
||||||
self.lr = config['learning_rate']
|
self.lr = config['learning_rate']
|
||||||
self.build_model()
|
self.build_model()
|
||||||
|
|
||||||
def build_model(self):
|
def build_model(self):
|
||||||
input_shape = self.input_shape
|
input_shape = self.input_shape
|
||||||
|
latent_shape = self.latent_shape
|
||||||
|
|
||||||
picture = Input(shape = input_shape)
|
picture = Input(shape = input_shape)
|
||||||
|
|
||||||
|
@ -41,8 +44,12 @@ class SimpleConvAutoEncoder(AbstractAutoEncoderModel):
|
||||||
x = Activation('relu')(x)
|
x = Activation('relu')(x)
|
||||||
encoded = MaxPooling2D((2, 2))(x)
|
encoded = MaxPooling2D((2, 2))(x)
|
||||||
|
|
||||||
|
self.encoder_model = Model(picture, encoded, name = "encoder")
|
||||||
|
|
||||||
# decoded network
|
# decoded network
|
||||||
x = Conv2D(16, (3, 3), padding = 'same', name = 'dec_conv_1')(encoded)
|
latent_input = Input(shape = latent_shape)
|
||||||
|
|
||||||
|
x = Conv2D(16, (3, 3), padding = 'same', name = 'dec_conv_1')(latent_input)
|
||||||
x = BatchNormalization()(x)
|
x = BatchNormalization()(x)
|
||||||
x = Activation('relu')(x)
|
x = Activation('relu')(x)
|
||||||
x = UpSampling2D((2, 2))(x)
|
x = UpSampling2D((2, 2))(x)
|
||||||
|
@ -63,9 +70,11 @@ class SimpleConvAutoEncoder(AbstractAutoEncoderModel):
|
||||||
x = Dense(np.prod(input_shape), activation = self.activation)(x)
|
x = Dense(np.prod(input_shape), activation = self.activation)(x)
|
||||||
decoded = Reshape((input_shape))(x)
|
decoded = Reshape((input_shape))(x)
|
||||||
|
|
||||||
self.model = Model(picture, decoded)
|
self.decoder_model = Model(latent_input, decoded, name = "decoder")
|
||||||
|
|
||||||
|
picture_dec = self.decoder_model(self.encoder_model(picture))
|
||||||
|
self.model = Model(picture, picture_dec)
|
||||||
|
|
||||||
# optimizer = Adadelta(lr = self.lr, rho = 0.95, epsilon = None, decay = 0.0)
|
|
||||||
optimizer = Adam(lr = 0.001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
|
optimizer = Adam(lr = 0.001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
|
||||||
|
|
||||||
self.model.compile(optimizer = optimizer, loss = 'binary_crossentropy')
|
self.model.compile(optimizer = optimizer, loss = 'binary_crossentropy')
|
||||||
|
|
Loading…
Reference in a new issue