mirror of
https://github.com/prise6/smart-iss-posts
synced 2024-05-03 14:13:10 +02:00
55 lines
1.4 KiB
Python
55 lines
1.4 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
from keras.preprocessing.image import ImageDataGenerator
|
|
|
|
class ImageDataGeneratorWrapper:
|
|
|
|
def __init__(self, config, model):
|
|
|
|
self.config = config
|
|
self.model = model
|
|
self.datagen = None
|
|
self.train_generator = None
|
|
self.test_generator = None
|
|
|
|
self.image_data_generator(config)
|
|
|
|
self.set_train_generator()
|
|
self.set_test_generator()
|
|
|
|
def image_data_generator(self, config):
|
|
self.datagen = ImageDataGenerator(
|
|
rescale = 1./255
|
|
)
|
|
return self
|
|
|
|
def build_generator(self, directory):
|
|
# voir plus tars si besoin de parametrer
|
|
return self.datagen.flow_from_directory(
|
|
directory,
|
|
target_size = (self.config.get('models')[self.model]['input_height'], self.config.get('models')[self.model]['input_width']),
|
|
color_mode = 'rgb',
|
|
classes = None,
|
|
class_mode = 'input',
|
|
batch_size = self.config.get('models')[self.model]['batch_size'],
|
|
)
|
|
|
|
def set_train_generator(self):
|
|
train_dir = self.config.get('directory')['autoencoder']['train'] + '/..'
|
|
self.train_generator = self.build_generator(directory = train_dir)
|
|
return self
|
|
|
|
def get_train_generator(self):
|
|
return self.train_generator
|
|
|
|
def set_test_generator(self):
|
|
test_dir = self.config.get('directory')['autoencoder']['test'] + '/..'
|
|
self.test_generator = self.build_generator(directory = test_dir)
|
|
return self
|
|
|
|
def get_test_generator(self):
|
|
return self.train_generator
|
|
|
|
|
|
|