1
0
Fork 0
mirror of https://github.com/prise6/smart-iss-posts synced 2024-04-27 03:21:51 +02:00
smart-iss-posts/iss/exec/training.py
2019-12-08 02:25:27 +01:00

54 lines
1.4 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import click
from iss.init_config import CONFIG
from iss.models.DataLoader import ImageDataGeneratorWrapper
from iss.models.ModelTrainer import ModelTrainer
from iss.models import SimpleAutoEncoder
from iss.models import SimpleConvAutoEncoder
from iss.models import VarAutoEncoder
from iss.models import VarConvAutoEncoder
@click.command()
@click.option('--model-type', default='simple_conv', show_default=True, type=str)
@click.option('--load', default=False, is_flag=True)
@click.option('--load-name', default=None, show_default=True, type=str)
def main(model_type, load, load_name):
## Variables globales
_MODEL_TYPE = model_type
_LOAD_NAME = load_name
_LOAD = load
## Data loader
data_loader = ImageDataGeneratorWrapper(CONFIG, model = _MODEL_TYPE)
## Model
if _MODEL_TYPE in ['simple_conv']:
model = SimpleConvAutoEncoder(CONFIG.get('models')[_MODEL_TYPE])
elif _MODEL_TYPE in ['simple']:
model = SimpleAutoEncoder(CONFIG.get('models')[_MODEL_TYPE])
else:
raise Exception
if _LOAD:
model.load(which = _LOAD_NAME)
model.encoder_model.summary()
model.decoder_model.summary()
model.model.summary()
## Entraineur
trainer = ModelTrainer(model, data_loader, CONFIG.get('models')[_MODEL_TYPE], callbacks=[])
## Entrainement
try:
trainer.train()
except KeyboardInterrupt:
trainer.model.save()
if __name__ == '__main__':
main()