tensorboard callback

This commit is contained in:
Francois Vieille 2019-11-11 20:33:35 +01:00
parent 0811e3d3a5
commit 11cdb40b3f
2 changed files with 68 additions and 21 deletions

View File

@ -1,29 +1,77 @@
# -*- coding: utf-8 -*-
from keras.callbacks import Callback
import os
import datetime
import numpy as np
from iss.tools.tools import Tools
import tensorflow as tf
from keras.callbacks import Callback
from IPython.display import display
from iss.tools.tools import Tools
class DisplayPictureCallback(Callback):
def __init__(self, model, epoch_laps, data_loader):
def __init__(self, model, epoch_laps, data_loader):
self.model_class = model
self.epoch_laps = epoch_laps
self.data_loader = data_loader
super(DisplayPictureCallback, self).__init__()
self.model_class = model
self.epoch_laps = epoch_laps
self.data_loader = data_loader
super(DisplayPictureCallback, self).__init__()
def on_epoch_end(self, epoch, logs):
if epoch % self.epoch_laps == 0:
def on_epoch_end(self, epoch, logs):
if epoch % self.epoch_laps == 0:
input_pict = self.data_loader.next()[0][1]
output_pict = self.model_class.predict_one(input_pict)
input_pict = self.data_loader.next()[0][1]
output_pict = self.model_class.predict_one(input_pict)
display(Tools.display_one_picture_scaled(input_pict))
display(Tools.display_index_picture_scaled(output_pict))
return self
display(Tools.display_one_picture_scaled(input_pict))
display(Tools.display_index_picture_scaled(output_pict))
return self
class TensorboardCallback(Callback):
def __init__(self, log_dir, limit_image = 1, model = None, data_loader = None):
self.log_dir = os.path.join(log_dir, datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
self.limit_image = limit_image
self.model_class = model
self.data_loader = data_loader
self.writer = tf.summary.FileWriter(self.log_dir)
super(TensorboardCallback, self).__init__()
def on_epoch_end(self, epoch, logs=None):
print(logs)
image_summaries = []
for input_pict in self.data_loader.next()[0][:self.limit_image]:
output_pict = self.model_class.predict_one(input_pict)[0]
input_im_bytes = Tools.bytes_image(input_pict*255)
output_im_bytes = Tools.bytes_image(output_pict*255)
image_summaries.append(tf.Summary.Value(tag = 'input', image = tf.Summary.Image(encoded_image_string = input_im_bytes)))
image_summaries.append(tf.Summary.Value(tag = 'output', image = tf.Summary.Image(encoded_image_string = output_im_bytes)))
image_summary = tf.Summary(value = image_summaries)
self.writer.add_summary(image_summary, epoch)
self._write_logs(logs, epoch)
return self
def _write_logs(self, logs, index):
for name, value in logs.items():
if name in ['batch', 'size']:
continue
summary = tf.Summary()
summary_value = summary.value.add()
if isinstance(value, np.ndarray):
summary_value.simple_value = value.item()
else:
summary_value.simple_value = value
summary_value.tag = name
self.writer.add_summary(summary, index)
self.writer.flush()

View File

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
from keras.callbacks import ModelCheckpoint, CSVLogger
from iss.models.Callbacks import DisplayPictureCallback
from iss.models.Callbacks import DisplayPictureCallback, TensorboardCallback
from iss.tools.tools import Tools
import keras
@ -91,12 +91,11 @@ class ModelTrainer:
if 'tensorboard' in config['callbacks']:
log_dir = config['callbacks']['tensorboard']['log_dir']
Tools.create_dir_if_not_exists(log_dir)
self.callbacks.extend([keras.callbacks.TensorBoard(
self.callbacks.extend([TensorboardCallback(
log_dir = log_dir,
histogram_freq=0,
batch_size=32,
write_graph=False,
write_images = True
limit_image = config['callbacks']['tensorboard']['limit_image'],
model = self.model,
data_loader = self.data_loader.get_train_generator()
)])