1
0
Fork 0
mirror of https://github.com/prise6/smart-iss-posts synced 2024-05-03 14:13:10 +02:00
smart-iss-posts/iss/exec/training.py
2019-11-11 04:16:43 +01:00

37 lines
949 B
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from iss.init_config import CONFIG
from iss.models.DataLoader import ImageDataGeneratorWrapper
from iss.models.ModelTrainer import ModelTrainer
from iss.models import SimpleAutoEncoder
from iss.models import SimpleConvAutoEncoder
from iss.models import VarAutoEncoder
from iss.models import VarConvAutoEncoder
## Variables globales
_MODEL_TYPE = 'simple_conv'
_LOAD_NAME = None
_LOAD = False
## Data loader
data_loader = ImageDataGeneratorWrapper(CONFIG, model = _MODEL_TYPE)
## Model
if _MODEL_TYPE in ['simple_conv']:
model = SimpleConvAutoEncoder(CONFIG.get('models')[_MODEL_TYPE])
if _LOAD:
model.load(which = _LOAD_NAME)
model.encoder_model.summary()
model.decoder_model.summary()
model.model.summary()
## Entraineur
trainer = ModelTrainer(model, data_loader, CONFIG.get('models')[_MODEL_TYPE], callbacks=[])
## Entrainement
try:
trainer.train()
except KeyboardInterrupt:
trainer.model.save()