2019-03-09 22:41:37 +01:00
|
|
|
# -*- coding: utf-8 -*-
|
2019-11-11 04:16:43 +01:00
|
|
|
import os
|
2019-03-09 22:41:37 +01:00
|
|
|
|
|
|
|
from keras.preprocessing.image import ImageDataGenerator
|
|
|
|
|
|
|
|
class ImageDataGeneratorWrapper:
|
|
|
|
|
2019-03-13 15:25:46 +01:00
|
|
|
def __init__(self, config, model):
|
2019-03-09 22:41:37 +01:00
|
|
|
|
|
|
|
self.config = config
|
2019-03-13 15:25:46 +01:00
|
|
|
self.model = model
|
2019-03-09 22:41:37 +01:00
|
|
|
self.datagen = None
|
|
|
|
self.train_generator = None
|
|
|
|
self.test_generator = None
|
|
|
|
|
|
|
|
self.image_data_generator(config)
|
|
|
|
|
2019-11-11 04:16:43 +01:00
|
|
|
sampling_type = self.config.get('models')[self.model]['sampling']
|
|
|
|
train_dir = os.path.join(self.config.get('sampling')[sampling_type]['directory']['train'], '..')
|
|
|
|
test_dir = os.path.join(self.config.get('sampling')[sampling_type]['directory']['test'], '..')
|
|
|
|
|
|
|
|
self.set_train_generator(train_dir)
|
|
|
|
self.set_test_generator(test_dir)
|
2019-03-09 22:41:37 +01:00
|
|
|
|
|
|
|
def image_data_generator(self, config):
|
|
|
|
self.datagen = ImageDataGenerator(
|
2019-03-11 22:33:09 +01:00
|
|
|
rescale = 1./255
|
2019-03-09 22:41:37 +01:00
|
|
|
)
|
|
|
|
return self
|
|
|
|
|
|
|
|
def build_generator(self, directory):
|
|
|
|
# voir plus tars si besoin de parametrer
|
|
|
|
return self.datagen.flow_from_directory(
|
|
|
|
directory,
|
2019-03-13 15:25:46 +01:00
|
|
|
target_size = (self.config.get('models')[self.model]['input_height'], self.config.get('models')[self.model]['input_width']),
|
2019-03-09 22:41:37 +01:00
|
|
|
color_mode = 'rgb',
|
|
|
|
classes = None,
|
|
|
|
class_mode = 'input',
|
2019-03-13 15:25:46 +01:00
|
|
|
batch_size = self.config.get('models')[self.model]['batch_size'],
|
2019-03-09 22:41:37 +01:00
|
|
|
)
|
|
|
|
|
2019-11-11 04:16:43 +01:00
|
|
|
def set_train_generator(self, train_dir):
|
2019-03-09 22:41:37 +01:00
|
|
|
self.train_generator = self.build_generator(directory = train_dir)
|
|
|
|
return self
|
|
|
|
|
|
|
|
def get_train_generator(self):
|
|
|
|
return self.train_generator
|
|
|
|
|
2019-11-11 04:16:43 +01:00
|
|
|
def set_test_generator(self, test_dir):
|
2019-03-09 22:41:37 +01:00
|
|
|
self.test_generator = self.build_generator(directory = test_dir)
|
|
|
|
return self
|
|
|
|
|
|
|
|
def get_test_generator(self):
|
|
|
|
return self.train_generator
|
|
|
|
|
|
|
|
|
|
|
|
|