mirror of
https://github.com/prise6/smart-iss-posts
synced 2024-05-03 06:03:10 +02:00
30 lines
704 B
Python
30 lines
704 B
Python
# -*- coding: utf-8 -*-
|
|
|
|
from keras.callbacks import Callback
|
|
import numpy as np
|
|
from iss.tools.tools import Tools
|
|
from IPython.display import display
|
|
|
|
class DisplayPictureCallback(Callback):
|
|
|
|
def __init__(self, model, epoch_laps, data_loader):
|
|
|
|
self.model_class = model
|
|
self.epoch_laps = epoch_laps
|
|
self.data_loader = data_loader
|
|
super(DisplayPictureCallback, self).__init__()
|
|
|
|
|
|
def on_epoch_end(self, epoch, logs):
|
|
if epoch % self.epoch_laps == 0:
|
|
|
|
input_pict = self.data_loader.next()[0][1]
|
|
output_pict = self.model_class.predict_one(input_pict)
|
|
|
|
display(Tools.display_one_picture_scaled(input_pict))
|
|
display(Tools.display_index_picture_scaled(output_pict))
|
|
|
|
return self
|
|
|
|
|