1
0
Fork 0
mirror of https://github.com/prise6/smart-iss-posts synced 2024-05-03 06:03:10 +02:00
smart-iss-posts/iss/models/DataLoader.py
2019-11-11 04:16:43 +01:00

58 lines
1.5 KiB
Python

# -*- coding: utf-8 -*-
import os
from keras.preprocessing.image import ImageDataGenerator
class ImageDataGeneratorWrapper:
def __init__(self, config, model):
self.config = config
self.model = model
self.datagen = None
self.train_generator = None
self.test_generator = None
self.image_data_generator(config)
sampling_type = self.config.get('models')[self.model]['sampling']
train_dir = os.path.join(self.config.get('sampling')[sampling_type]['directory']['train'], '..')
test_dir = os.path.join(self.config.get('sampling')[sampling_type]['directory']['test'], '..')
self.set_train_generator(train_dir)
self.set_test_generator(test_dir)
def image_data_generator(self, config):
self.datagen = ImageDataGenerator(
rescale = 1./255
)
return self
def build_generator(self, directory):
# voir plus tars si besoin de parametrer
return self.datagen.flow_from_directory(
directory,
target_size = (self.config.get('models')[self.model]['input_height'], self.config.get('models')[self.model]['input_width']),
color_mode = 'rgb',
classes = None,
class_mode = 'input',
batch_size = self.config.get('models')[self.model]['batch_size'],
)
def set_train_generator(self, train_dir):
self.train_generator = self.build_generator(directory = train_dir)
return self
def get_train_generator(self):
return self.train_generator
def set_test_generator(self, test_dir):
self.test_generator = self.build_generator(directory = test_dir)
return self
def get_test_generator(self):
return self.train_generator