1
0
Fork 0
mirror of https://github.com/prise6/smart-iss-posts synced 2024-05-18 21:36:33 +02:00
smart-iss-posts/iss/models/DataLoader.py

58 lines
1.5 KiB
Python
Raw Normal View History

2019-03-09 22:41:37 +01:00
# -*- coding: utf-8 -*-
2019-11-11 04:16:43 +01:00
import os
2019-03-09 22:41:37 +01:00
from keras.preprocessing.image import ImageDataGenerator
class ImageDataGeneratorWrapper:
2019-03-13 15:25:46 +01:00
def __init__(self, config, model):
2019-03-09 22:41:37 +01:00
self.config = config
2019-03-13 15:25:46 +01:00
self.model = model
2019-03-09 22:41:37 +01:00
self.datagen = None
self.train_generator = None
self.test_generator = None
self.image_data_generator(config)
2019-11-11 04:16:43 +01:00
sampling_type = self.config.get('models')[self.model]['sampling']
train_dir = os.path.join(self.config.get('sampling')[sampling_type]['directory']['train'], '..')
test_dir = os.path.join(self.config.get('sampling')[sampling_type]['directory']['test'], '..')
self.set_train_generator(train_dir)
self.set_test_generator(test_dir)
2019-03-09 22:41:37 +01:00
def image_data_generator(self, config):
self.datagen = ImageDataGenerator(
2019-03-11 22:33:09 +01:00
rescale = 1./255
2019-03-09 22:41:37 +01:00
)
return self
def build_generator(self, directory):
# voir plus tars si besoin de parametrer
return self.datagen.flow_from_directory(
directory,
2019-03-13 15:25:46 +01:00
target_size = (self.config.get('models')[self.model]['input_height'], self.config.get('models')[self.model]['input_width']),
2019-03-09 22:41:37 +01:00
color_mode = 'rgb',
classes = None,
class_mode = 'input',
2019-03-13 15:25:46 +01:00
batch_size = self.config.get('models')[self.model]['batch_size'],
2019-03-09 22:41:37 +01:00
)
2019-11-11 04:16:43 +01:00
def set_train_generator(self, train_dir):
2019-03-09 22:41:37 +01:00
self.train_generator = self.build_generator(directory = train_dir)
return self
def get_train_generator(self):
return self.train_generator
2019-11-11 04:16:43 +01:00
def set_test_generator(self, test_dir):
2019-03-09 22:41:37 +01:00
self.test_generator = self.build_generator(directory = test_dir)
return self
def get_test_generator(self):
return self.train_generator