tensorboard callback
This commit is contained in:
parent
0811e3d3a5
commit
11cdb40b3f
|
@ -1,29 +1,77 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from keras.callbacks import Callback
|
||||
import os
|
||||
import datetime
|
||||
import numpy as np
|
||||
from iss.tools.tools import Tools
|
||||
import tensorflow as tf
|
||||
from keras.callbacks import Callback
|
||||
from IPython.display import display
|
||||
|
||||
from iss.tools.tools import Tools
|
||||
|
||||
|
||||
class DisplayPictureCallback(Callback):
|
||||
|
||||
def __init__(self, model, epoch_laps, data_loader):
|
||||
def __init__(self, model, epoch_laps, data_loader):
|
||||
|
||||
self.model_class = model
|
||||
self.epoch_laps = epoch_laps
|
||||
self.data_loader = data_loader
|
||||
super(DisplayPictureCallback, self).__init__()
|
||||
self.model_class = model
|
||||
self.epoch_laps = epoch_laps
|
||||
self.data_loader = data_loader
|
||||
super(DisplayPictureCallback, self).__init__()
|
||||
|
||||
|
||||
def on_epoch_end(self, epoch, logs):
|
||||
if epoch % self.epoch_laps == 0:
|
||||
def on_epoch_end(self, epoch, logs):
|
||||
if epoch % self.epoch_laps == 0:
|
||||
|
||||
input_pict = self.data_loader.next()[0][1]
|
||||
output_pict = self.model_class.predict_one(input_pict)
|
||||
input_pict = self.data_loader.next()[0][1]
|
||||
output_pict = self.model_class.predict_one(input_pict)
|
||||
|
||||
display(Tools.display_one_picture_scaled(input_pict))
|
||||
display(Tools.display_index_picture_scaled(output_pict))
|
||||
|
||||
return self
|
||||
display(Tools.display_one_picture_scaled(input_pict))
|
||||
display(Tools.display_index_picture_scaled(output_pict))
|
||||
|
||||
return self
|
||||
|
||||
class TensorboardCallback(Callback):
|
||||
|
||||
def __init__(self, log_dir, limit_image = 1, model = None, data_loader = None):
|
||||
self.log_dir = os.path.join(log_dir, datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
|
||||
self.limit_image = limit_image
|
||||
self.model_class = model
|
||||
self.data_loader = data_loader
|
||||
self.writer = tf.summary.FileWriter(self.log_dir)
|
||||
super(TensorboardCallback, self).__init__()
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
print(logs)
|
||||
image_summaries = []
|
||||
|
||||
for input_pict in self.data_loader.next()[0][:self.limit_image]:
|
||||
output_pict = self.model_class.predict_one(input_pict)[0]
|
||||
input_im_bytes = Tools.bytes_image(input_pict*255)
|
||||
output_im_bytes = Tools.bytes_image(output_pict*255)
|
||||
|
||||
image_summaries.append(tf.Summary.Value(tag = 'input', image = tf.Summary.Image(encoded_image_string = input_im_bytes)))
|
||||
image_summaries.append(tf.Summary.Value(tag = 'output', image = tf.Summary.Image(encoded_image_string = output_im_bytes)))
|
||||
|
||||
|
||||
image_summary = tf.Summary(value = image_summaries)
|
||||
self.writer.add_summary(image_summary, epoch)
|
||||
self._write_logs(logs, epoch)
|
||||
|
||||
return self
|
||||
|
||||
def _write_logs(self, logs, index):
|
||||
for name, value in logs.items():
|
||||
if name in ['batch', 'size']:
|
||||
continue
|
||||
summary = tf.Summary()
|
||||
summary_value = summary.value.add()
|
||||
if isinstance(value, np.ndarray):
|
||||
summary_value.simple_value = value.item()
|
||||
else:
|
||||
summary_value.simple_value = value
|
||||
summary_value.tag = name
|
||||
self.writer.add_summary(summary, index)
|
||||
self.writer.flush()
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from keras.callbacks import ModelCheckpoint, CSVLogger
|
||||
from iss.models.Callbacks import DisplayPictureCallback
|
||||
from iss.models.Callbacks import DisplayPictureCallback, TensorboardCallback
|
||||
from iss.tools.tools import Tools
|
||||
import keras
|
||||
|
||||
|
@ -91,12 +91,11 @@ class ModelTrainer:
|
|||
if 'tensorboard' in config['callbacks']:
|
||||
log_dir = config['callbacks']['tensorboard']['log_dir']
|
||||
Tools.create_dir_if_not_exists(log_dir)
|
||||
self.callbacks.extend([keras.callbacks.TensorBoard(
|
||||
self.callbacks.extend([TensorboardCallback(
|
||||
log_dir = log_dir,
|
||||
histogram_freq=0,
|
||||
batch_size=32,
|
||||
write_graph=False,
|
||||
write_images = True
|
||||
limit_image = config['callbacks']['tensorboard']['limit_image'],
|
||||
model = self.model,
|
||||
data_loader = self.data_loader.get_train_generator()
|
||||
)])
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue