From 11cdb40b3fdd5e7d542093a1c7a3837f6c337654 Mon Sep 17 00:00:00 2001 From: Francois Vieille Date: Mon, 11 Nov 2019 20:33:35 +0100 Subject: [PATCH] tensorboard callback --- iss/models/Callbacks.py | 78 ++++++++++++++++++++++++++++++-------- iss/models/ModelTrainer.py | 11 +++--- 2 files changed, 68 insertions(+), 21 deletions(-) diff --git a/iss/models/Callbacks.py b/iss/models/Callbacks.py index 653b384..1102365 100644 --- a/iss/models/Callbacks.py +++ b/iss/models/Callbacks.py @@ -1,29 +1,77 @@ # -*- coding: utf-8 -*- -from keras.callbacks import Callback +import os +import datetime import numpy as np -from iss.tools.tools import Tools +import tensorflow as tf +from keras.callbacks import Callback from IPython.display import display +from iss.tools.tools import Tools + + class DisplayPictureCallback(Callback): - def __init__(self, model, epoch_laps, data_loader): + def __init__(self, model, epoch_laps, data_loader): - self.model_class = model - self.epoch_laps = epoch_laps - self.data_loader = data_loader - super(DisplayPictureCallback, self).__init__() + self.model_class = model + self.epoch_laps = epoch_laps + self.data_loader = data_loader + super(DisplayPictureCallback, self).__init__() - def on_epoch_end(self, epoch, logs): - if epoch % self.epoch_laps == 0: + def on_epoch_end(self, epoch, logs): + if epoch % self.epoch_laps == 0: - input_pict = self.data_loader.next()[0][1] - output_pict = self.model_class.predict_one(input_pict) + input_pict = self.data_loader.next()[0][1] + output_pict = self.model_class.predict_one(input_pict) - display(Tools.display_one_picture_scaled(input_pict)) - display(Tools.display_index_picture_scaled(output_pict)) - - return self + display(Tools.display_one_picture_scaled(input_pict)) + display(Tools.display_index_picture_scaled(output_pict)) + + return self + +class TensorboardCallback(Callback): + + def __init__(self, log_dir, limit_image = 1, model = None, data_loader = None): + self.log_dir = os.path.join(log_dir, datetime.datetime.now().strftime("%Y%m%d-%H%M%S")) + self.limit_image = limit_image + self.model_class = model + self.data_loader = data_loader + self.writer = tf.summary.FileWriter(self.log_dir) + super(TensorboardCallback, self).__init__() + + def on_epoch_end(self, epoch, logs=None): + print(logs) + image_summaries = [] + + for input_pict in self.data_loader.next()[0][:self.limit_image]: + output_pict = self.model_class.predict_one(input_pict)[0] + input_im_bytes = Tools.bytes_image(input_pict*255) + output_im_bytes = Tools.bytes_image(output_pict*255) + + image_summaries.append(tf.Summary.Value(tag = 'input', image = tf.Summary.Image(encoded_image_string = input_im_bytes))) + image_summaries.append(tf.Summary.Value(tag = 'output', image = tf.Summary.Image(encoded_image_string = output_im_bytes))) + + + image_summary = tf.Summary(value = image_summaries) + self.writer.add_summary(image_summary, epoch) + self._write_logs(logs, epoch) + + return self + + def _write_logs(self, logs, index): + for name, value in logs.items(): + if name in ['batch', 'size']: + continue + summary = tf.Summary() + summary_value = summary.value.add() + if isinstance(value, np.ndarray): + summary_value.simple_value = value.item() + else: + summary_value.simple_value = value + summary_value.tag = name + self.writer.add_summary(summary, index) + self.writer.flush() diff --git a/iss/models/ModelTrainer.py b/iss/models/ModelTrainer.py index fa40f91..621c542 100644 --- a/iss/models/ModelTrainer.py +++ b/iss/models/ModelTrainer.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- from keras.callbacks import ModelCheckpoint, CSVLogger -from iss.models.Callbacks import DisplayPictureCallback +from iss.models.Callbacks import DisplayPictureCallback, TensorboardCallback from iss.tools.tools import Tools import keras @@ -91,12 +91,11 @@ class ModelTrainer: if 'tensorboard' in config['callbacks']: log_dir = config['callbacks']['tensorboard']['log_dir'] Tools.create_dir_if_not_exists(log_dir) - self.callbacks.extend([keras.callbacks.TensorBoard( + self.callbacks.extend([TensorboardCallback( log_dir = log_dir, - histogram_freq=0, - batch_size=32, - write_graph=False, - write_images = True + limit_image = config['callbacks']['tensorboard']['limit_image'], + model = self.model, + data_loader = self.data_loader.get_train_generator() )])