1
0
Fork 0
mirror of https://github.com/prise6/smart-iss-posts synced 2024-05-17 21:06:33 +02:00
smart-iss-posts/iss/exec/training.py

37 lines
949 B
Python
Raw Normal View History

2019-11-11 04:16:43 +01:00
import os
from iss.init_config import CONFIG
from iss.models.DataLoader import ImageDataGeneratorWrapper
from iss.models.ModelTrainer import ModelTrainer
from iss.models import SimpleAutoEncoder
from iss.models import SimpleConvAutoEncoder
from iss.models import VarAutoEncoder
from iss.models import VarConvAutoEncoder
## Variables globales
_MODEL_TYPE = 'simple_conv'
_LOAD_NAME = None
_LOAD = False
## Data loader
data_loader = ImageDataGeneratorWrapper(CONFIG, model = _MODEL_TYPE)
## Model
if _MODEL_TYPE in ['simple_conv']:
model = SimpleConvAutoEncoder(CONFIG.get('models')[_MODEL_TYPE])
if _LOAD:
model.load(which = _LOAD_NAME)
model.encoder_model.summary()
model.decoder_model.summary()
model.model.summary()
## Entraineur
trainer = ModelTrainer(model, data_loader, CONFIG.get('models')[_MODEL_TYPE], callbacks=[])
## Entrainement
try:
trainer.train()
except KeyboardInterrupt:
trainer.model.save()