mirror of
https://github.com/prise6/smart-iss-posts
synced 2024-05-18 13:26:33 +02:00
tensorboard callback
This commit is contained in:
parent
0811e3d3a5
commit
11cdb40b3f
|
@ -1,29 +1,77 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
from keras.callbacks import Callback
|
import os
|
||||||
|
import datetime
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from iss.tools.tools import Tools
|
import tensorflow as tf
|
||||||
|
from keras.callbacks import Callback
|
||||||
from IPython.display import display
|
from IPython.display import display
|
||||||
|
|
||||||
|
from iss.tools.tools import Tools
|
||||||
|
|
||||||
|
|
||||||
class DisplayPictureCallback(Callback):
|
class DisplayPictureCallback(Callback):
|
||||||
|
|
||||||
def __init__(self, model, epoch_laps, data_loader):
|
def __init__(self, model, epoch_laps, data_loader):
|
||||||
|
|
||||||
self.model_class = model
|
self.model_class = model
|
||||||
self.epoch_laps = epoch_laps
|
self.epoch_laps = epoch_laps
|
||||||
self.data_loader = data_loader
|
self.data_loader = data_loader
|
||||||
super(DisplayPictureCallback, self).__init__()
|
super(DisplayPictureCallback, self).__init__()
|
||||||
|
|
||||||
|
|
||||||
def on_epoch_end(self, epoch, logs):
|
def on_epoch_end(self, epoch, logs):
|
||||||
if epoch % self.epoch_laps == 0:
|
if epoch % self.epoch_laps == 0:
|
||||||
|
|
||||||
input_pict = self.data_loader.next()[0][1]
|
input_pict = self.data_loader.next()[0][1]
|
||||||
output_pict = self.model_class.predict_one(input_pict)
|
output_pict = self.model_class.predict_one(input_pict)
|
||||||
|
|
||||||
display(Tools.display_one_picture_scaled(input_pict))
|
display(Tools.display_one_picture_scaled(input_pict))
|
||||||
display(Tools.display_index_picture_scaled(output_pict))
|
display(Tools.display_index_picture_scaled(output_pict))
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
class TensorboardCallback(Callback):
|
||||||
|
|
||||||
|
def __init__(self, log_dir, limit_image = 1, model = None, data_loader = None):
|
||||||
|
self.log_dir = os.path.join(log_dir, datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
|
||||||
|
self.limit_image = limit_image
|
||||||
|
self.model_class = model
|
||||||
|
self.data_loader = data_loader
|
||||||
|
self.writer = tf.summary.FileWriter(self.log_dir)
|
||||||
|
super(TensorboardCallback, self).__init__()
|
||||||
|
|
||||||
|
def on_epoch_end(self, epoch, logs=None):
|
||||||
|
print(logs)
|
||||||
|
image_summaries = []
|
||||||
|
|
||||||
|
for input_pict in self.data_loader.next()[0][:self.limit_image]:
|
||||||
|
output_pict = self.model_class.predict_one(input_pict)[0]
|
||||||
|
input_im_bytes = Tools.bytes_image(input_pict*255)
|
||||||
|
output_im_bytes = Tools.bytes_image(output_pict*255)
|
||||||
|
|
||||||
|
image_summaries.append(tf.Summary.Value(tag = 'input', image = tf.Summary.Image(encoded_image_string = input_im_bytes)))
|
||||||
|
image_summaries.append(tf.Summary.Value(tag = 'output', image = tf.Summary.Image(encoded_image_string = output_im_bytes)))
|
||||||
|
|
||||||
|
|
||||||
|
image_summary = tf.Summary(value = image_summaries)
|
||||||
|
self.writer.add_summary(image_summary, epoch)
|
||||||
|
self._write_logs(logs, epoch)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _write_logs(self, logs, index):
|
||||||
|
for name, value in logs.items():
|
||||||
|
if name in ['batch', 'size']:
|
||||||
|
continue
|
||||||
|
summary = tf.Summary()
|
||||||
|
summary_value = summary.value.add()
|
||||||
|
if isinstance(value, np.ndarray):
|
||||||
|
summary_value.simple_value = value.item()
|
||||||
|
else:
|
||||||
|
summary_value.simple_value = value
|
||||||
|
summary_value.tag = name
|
||||||
|
self.writer.add_summary(summary, index)
|
||||||
|
self.writer.flush()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
from keras.callbacks import ModelCheckpoint, CSVLogger
|
from keras.callbacks import ModelCheckpoint, CSVLogger
|
||||||
from iss.models.Callbacks import DisplayPictureCallback
|
from iss.models.Callbacks import DisplayPictureCallback, TensorboardCallback
|
||||||
from iss.tools.tools import Tools
|
from iss.tools.tools import Tools
|
||||||
import keras
|
import keras
|
||||||
|
|
||||||
|
@ -91,12 +91,11 @@ class ModelTrainer:
|
||||||
if 'tensorboard' in config['callbacks']:
|
if 'tensorboard' in config['callbacks']:
|
||||||
log_dir = config['callbacks']['tensorboard']['log_dir']
|
log_dir = config['callbacks']['tensorboard']['log_dir']
|
||||||
Tools.create_dir_if_not_exists(log_dir)
|
Tools.create_dir_if_not_exists(log_dir)
|
||||||
self.callbacks.extend([keras.callbacks.TensorBoard(
|
self.callbacks.extend([TensorboardCallback(
|
||||||
log_dir = log_dir,
|
log_dir = log_dir,
|
||||||
histogram_freq=0,
|
limit_image = config['callbacks']['tensorboard']['limit_image'],
|
||||||
batch_size=32,
|
model = self.model,
|
||||||
write_graph=False,
|
data_loader = self.data_loader.get_train_generator()
|
||||||
write_images = True
|
|
||||||
)])
|
)])
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue