2019-03-09 22:41:37 +01:00
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
|
|
from keras.preprocessing.image import ImageDataGenerator
|
|
|
|
|
|
|
|
class ImageDataGeneratorWrapper:
|
|
|
|
|
2019-03-13 15:25:46 +01:00
|
|
|
def __init__(self, config, model):
|
2019-03-09 22:41:37 +01:00
|
|
|
|
|
|
|
self.config = config
|
2019-03-13 15:25:46 +01:00
|
|
|
self.model = model
|
2019-03-09 22:41:37 +01:00
|
|
|
self.datagen = None
|
|
|
|
self.train_generator = None
|
|
|
|
self.test_generator = None
|
|
|
|
|
|
|
|
self.image_data_generator(config)
|
|
|
|
|
|
|
|
self.set_train_generator()
|
|
|
|
self.set_test_generator()
|
|
|
|
|
|
|
|
def image_data_generator(self, config):
|
|
|
|
self.datagen = ImageDataGenerator(
|
2019-03-11 22:33:09 +01:00
|
|
|
rescale = 1./255
|
2019-03-09 22:41:37 +01:00
|
|
|
)
|
|
|
|
return self
|
|
|
|
|
|
|
|
def build_generator(self, directory):
|
|
|
|
# voir plus tars si besoin de parametrer
|
|
|
|
return self.datagen.flow_from_directory(
|
|
|
|
directory,
|
2019-03-13 15:25:46 +01:00
|
|
|
target_size = (self.config.get('models')[self.model]['input_height'], self.config.get('models')[self.model]['input_width']),
|
2019-03-09 22:41:37 +01:00
|
|
|
color_mode = 'rgb',
|
|
|
|
classes = None,
|
|
|
|
class_mode = 'input',
|
2019-03-13 15:25:46 +01:00
|
|
|
batch_size = self.config.get('models')[self.model]['batch_size'],
|
2019-03-09 22:41:37 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
def set_train_generator(self):
|
|
|
|
train_dir = self.config.get('directory')['autoencoder']['train'] + '/..'
|
|
|
|
self.train_generator = self.build_generator(directory = train_dir)
|
|
|
|
return self
|
|
|
|
|
|
|
|
def get_train_generator(self):
|
|
|
|
return self.train_generator
|
|
|
|
|
|
|
|
def set_test_generator(self):
|
|
|
|
test_dir = self.config.get('directory')['autoencoder']['test'] + '/..'
|
|
|
|
self.test_generator = self.build_generator(directory = test_dir)
|
|
|
|
return self
|
|
|
|
|
|
|
|
def get_test_generator(self):
|
|
|
|
return self.train_generator
|
|
|
|
|
|
|
|
|
|
|
|
|