mirror of
https://github.com/prise6/smart-iss-posts
synced 2024-04-27 11:31:51 +02:00
fix loss function
This commit is contained in:
parent
b5c1fb710f
commit
70eafeba17
|
@ -3,7 +3,7 @@
|
||||||
from iss.models import AbstractAutoEncoderModel
|
from iss.models import AbstractAutoEncoderModel
|
||||||
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Reshape, Flatten, BatchNormalization, Activation
|
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Reshape, Flatten, BatchNormalization, Activation
|
||||||
from keras.optimizers import Adadelta, Adam
|
from keras.optimizers import Adadelta, Adam
|
||||||
from keras.models import Model
|
from keras.models import Model, load_model
|
||||||
from keras import backend as K
|
from keras import backend as K
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -22,6 +22,9 @@ class SimpleConvAutoEncoder(AbstractAutoEncoderModel):
|
||||||
self.lr = config['learning_rate']
|
self.lr = config['learning_rate']
|
||||||
self.build_model()
|
self.build_model()
|
||||||
|
|
||||||
|
def load(self, which = 'final_model'):
|
||||||
|
self.model = load_model('{}/{}.hdf5'.format(self.save_directory, which), custom_objects= {'my_loss':self.my_loss})
|
||||||
|
|
||||||
def build_model(self):
|
def build_model(self):
|
||||||
input_shape = self.input_shape
|
input_shape = self.input_shape
|
||||||
latent_shape = self.latent_shape
|
latent_shape = self.latent_shape
|
||||||
|
@ -77,9 +80,11 @@ class SimpleConvAutoEncoder(AbstractAutoEncoderModel):
|
||||||
|
|
||||||
optimizer = Adam(lr = self.lr, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
|
optimizer = Adam(lr = self.lr, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
|
||||||
|
|
||||||
def my_loss(picture, picture_dec):
|
|
||||||
loss = K.mean(K.binary_crossentropy(picture, picture_dec))
|
|
||||||
return loss
|
|
||||||
|
|
||||||
# self.model.compile(optimizer = optimizer, loss = 'binary_crossentropy')
|
# self.model.compile(optimizer = optimizer, loss = 'binary_crossentropy')
|
||||||
self.model.compile(optimizer = optimizer, loss = my_loss)
|
self.model.compile(optimizer = optimizer, loss = self.my_loss)
|
||||||
|
|
||||||
|
def my_loss(self, picture, picture_dec):
|
||||||
|
loss = K.mean(K.binary_crossentropy(picture, picture_dec))
|
||||||
|
return loss
|
Loading…
Reference in a new issue