1
0
Fork 0
mirror of https://github.com/prise6/smart-iss-posts synced 2024-05-18 13:26:33 +02:00

tensorboard callback

This commit is contained in:
Francois Vieille 2019-11-11 20:33:35 +01:00
parent 0811e3d3a5
commit 11cdb40b3f
2 changed files with 68 additions and 21 deletions

View file

@ -1,29 +1,77 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from keras.callbacks import Callback import os
import datetime
import numpy as np import numpy as np
from iss.tools.tools import Tools import tensorflow as tf
from keras.callbacks import Callback
from IPython.display import display from IPython.display import display
from iss.tools.tools import Tools
class DisplayPictureCallback(Callback): class DisplayPictureCallback(Callback):
def __init__(self, model, epoch_laps, data_loader): def __init__(self, model, epoch_laps, data_loader):
self.model_class = model self.model_class = model
self.epoch_laps = epoch_laps self.epoch_laps = epoch_laps
self.data_loader = data_loader self.data_loader = data_loader
super(DisplayPictureCallback, self).__init__() super(DisplayPictureCallback, self).__init__()
def on_epoch_end(self, epoch, logs): def on_epoch_end(self, epoch, logs):
if epoch % self.epoch_laps == 0: if epoch % self.epoch_laps == 0:
input_pict = self.data_loader.next()[0][1] input_pict = self.data_loader.next()[0][1]
output_pict = self.model_class.predict_one(input_pict) output_pict = self.model_class.predict_one(input_pict)
display(Tools.display_one_picture_scaled(input_pict)) display(Tools.display_one_picture_scaled(input_pict))
display(Tools.display_index_picture_scaled(output_pict)) display(Tools.display_index_picture_scaled(output_pict))
return self return self
class TensorboardCallback(Callback):
def __init__(self, log_dir, limit_image = 1, model = None, data_loader = None):
self.log_dir = os.path.join(log_dir, datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
self.limit_image = limit_image
self.model_class = model
self.data_loader = data_loader
self.writer = tf.summary.FileWriter(self.log_dir)
super(TensorboardCallback, self).__init__()
def on_epoch_end(self, epoch, logs=None):
print(logs)
image_summaries = []
for input_pict in self.data_loader.next()[0][:self.limit_image]:
output_pict = self.model_class.predict_one(input_pict)[0]
input_im_bytes = Tools.bytes_image(input_pict*255)
output_im_bytes = Tools.bytes_image(output_pict*255)
image_summaries.append(tf.Summary.Value(tag = 'input', image = tf.Summary.Image(encoded_image_string = input_im_bytes)))
image_summaries.append(tf.Summary.Value(tag = 'output', image = tf.Summary.Image(encoded_image_string = output_im_bytes)))
image_summary = tf.Summary(value = image_summaries)
self.writer.add_summary(image_summary, epoch)
self._write_logs(logs, epoch)
return self
def _write_logs(self, logs, index):
for name, value in logs.items():
if name in ['batch', 'size']:
continue
summary = tf.Summary()
summary_value = summary.value.add()
if isinstance(value, np.ndarray):
summary_value.simple_value = value.item()
else:
summary_value.simple_value = value
summary_value.tag = name
self.writer.add_summary(summary, index)
self.writer.flush()

View file

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from keras.callbacks import ModelCheckpoint, CSVLogger from keras.callbacks import ModelCheckpoint, CSVLogger
from iss.models.Callbacks import DisplayPictureCallback from iss.models.Callbacks import DisplayPictureCallback, TensorboardCallback
from iss.tools.tools import Tools from iss.tools.tools import Tools
import keras import keras
@ -91,12 +91,11 @@ class ModelTrainer:
if 'tensorboard' in config['callbacks']: if 'tensorboard' in config['callbacks']:
log_dir = config['callbacks']['tensorboard']['log_dir'] log_dir = config['callbacks']['tensorboard']['log_dir']
Tools.create_dir_if_not_exists(log_dir) Tools.create_dir_if_not_exists(log_dir)
self.callbacks.extend([keras.callbacks.TensorBoard( self.callbacks.extend([TensorboardCallback(
log_dir = log_dir, log_dir = log_dir,
histogram_freq=0, limit_image = config['callbacks']['tensorboard']['limit_image'],
batch_size=32, model = self.model,
write_graph=False, data_loader = self.data_loader.get_train_generator()
write_images = True
)]) )])