1
0
Fork 0
mirror of https://github.com/prise6/smart-iss-posts synced 2024-04-27 11:31:51 +02:00

fix loss function

This commit is contained in:
Francois Vieille 2019-03-24 18:00:28 +01:00
parent b5c1fb710f
commit 70eafeba17

View file

@ -3,7 +3,7 @@
from iss.models import AbstractAutoEncoderModel from iss.models import AbstractAutoEncoderModel
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Reshape, Flatten, BatchNormalization, Activation from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Reshape, Flatten, BatchNormalization, Activation
from keras.optimizers import Adadelta, Adam from keras.optimizers import Adadelta, Adam
from keras.models import Model from keras.models import Model, load_model
from keras import backend as K from keras import backend as K
import numpy as np import numpy as np
@ -22,6 +22,9 @@ class SimpleConvAutoEncoder(AbstractAutoEncoderModel):
self.lr = config['learning_rate'] self.lr = config['learning_rate']
self.build_model() self.build_model()
def load(self, which = 'final_model'):
self.model = load_model('{}/{}.hdf5'.format(self.save_directory, which), custom_objects= {'my_loss':self.my_loss})
def build_model(self): def build_model(self):
input_shape = self.input_shape input_shape = self.input_shape
latent_shape = self.latent_shape latent_shape = self.latent_shape
@ -77,9 +80,11 @@ class SimpleConvAutoEncoder(AbstractAutoEncoderModel):
optimizer = Adam(lr = self.lr, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False) optimizer = Adam(lr = self.lr, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
def my_loss(picture, picture_dec):
loss = K.mean(K.binary_crossentropy(picture, picture_dec))
return loss
# self.model.compile(optimizer = optimizer, loss = 'binary_crossentropy') # self.model.compile(optimizer = optimizer, loss = 'binary_crossentropy')
self.model.compile(optimizer = optimizer, loss = my_loss) self.model.compile(optimizer = optimizer, loss = self.my_loss)
def my_loss(self, picture, picture_dec):
loss = K.mean(K.binary_crossentropy(picture, picture_dec))
return loss