1
0
Fork 0
mirror of https://github.com/prise6/smart-iss-posts synced 2024-06-02 22:02:12 +02:00
smart-iss-posts/iss/models/DataLoader.py

55 lines
1.4 KiB
Python
Raw Normal View History

2019-03-09 22:41:37 +01:00
# -*- coding: utf-8 -*-
from keras.preprocessing.image import ImageDataGenerator
class ImageDataGeneratorWrapper:
2019-03-13 15:25:46 +01:00
def __init__(self, config, model):
2019-03-09 22:41:37 +01:00
self.config = config
2019-03-13 15:25:46 +01:00
self.model = model
2019-03-09 22:41:37 +01:00
self.datagen = None
self.train_generator = None
self.test_generator = None
self.image_data_generator(config)
self.set_train_generator()
self.set_test_generator()
def image_data_generator(self, config):
self.datagen = ImageDataGenerator(
2019-03-11 22:33:09 +01:00
rescale = 1./255
2019-03-09 22:41:37 +01:00
)
return self
def build_generator(self, directory):
# voir plus tars si besoin de parametrer
return self.datagen.flow_from_directory(
directory,
2019-03-13 15:25:46 +01:00
target_size = (self.config.get('models')[self.model]['input_height'], self.config.get('models')[self.model]['input_width']),
2019-03-09 22:41:37 +01:00
color_mode = 'rgb',
classes = None,
class_mode = 'input',
2019-03-13 15:25:46 +01:00
batch_size = self.config.get('models')[self.model]['batch_size'],
2019-03-09 22:41:37 +01:00
)
def set_train_generator(self):
train_dir = self.config.get('directory')['autoencoder']['train'] + '/..'
self.train_generator = self.build_generator(directory = train_dir)
return self
def get_train_generator(self):
return self.train_generator
def set_test_generator(self):
test_dir = self.config.get('directory')['autoencoder']['test'] + '/..'
self.test_generator = self.build_generator(directory = test_dir)
return self
def get_test_generator(self):
return self.train_generator