2019-11-11 04:16:43 +01:00
|
|
|
|
import os
|
2019-12-08 02:25:27 +01:00
|
|
|
|
import click
|
2019-11-11 04:16:43 +01:00
|
|
|
|
|
|
|
|
|
from iss.init_config import CONFIG
|
|
|
|
|
from iss.models.DataLoader import ImageDataGeneratorWrapper
|
|
|
|
|
from iss.models.ModelTrainer import ModelTrainer
|
|
|
|
|
from iss.models import SimpleAutoEncoder
|
|
|
|
|
from iss.models import SimpleConvAutoEncoder
|
|
|
|
|
from iss.models import VarAutoEncoder
|
|
|
|
|
from iss.models import VarConvAutoEncoder
|
|
|
|
|
|
|
|
|
|
|
2019-12-08 02:25:27 +01:00
|
|
|
|
@click.command()
|
|
|
|
|
@click.option('--model-type', default='simple_conv', show_default=True, type=str)
|
|
|
|
|
@click.option('--load', default=False, is_flag=True)
|
|
|
|
|
@click.option('--load-name', default=None, show_default=True, type=str)
|
|
|
|
|
def main(model_type, load, load_name):
|
|
|
|
|
|
|
|
|
|
## Variables globales
|
|
|
|
|
_MODEL_TYPE = model_type
|
|
|
|
|
_LOAD_NAME = load_name
|
|
|
|
|
_LOAD = load
|
|
|
|
|
|
|
|
|
|
## Data loader
|
|
|
|
|
data_loader = ImageDataGeneratorWrapper(CONFIG, model = _MODEL_TYPE)
|
|
|
|
|
|
|
|
|
|
## Model
|
|
|
|
|
if _MODEL_TYPE in ['simple_conv']:
|
|
|
|
|
model = SimpleConvAutoEncoder(CONFIG.get('models')[_MODEL_TYPE])
|
|
|
|
|
elif _MODEL_TYPE in ['simple']:
|
|
|
|
|
model = SimpleAutoEncoder(CONFIG.get('models')[_MODEL_TYPE])
|
|
|
|
|
else:
|
|
|
|
|
raise Exception
|
2019-11-11 04:16:43 +01:00
|
|
|
|
|
|
|
|
|
if _LOAD:
|
|
|
|
|
model.load(which = _LOAD_NAME)
|
2019-12-08 02:25:27 +01:00
|
|
|
|
|
2019-11-11 04:16:43 +01:00
|
|
|
|
model.encoder_model.summary()
|
|
|
|
|
model.decoder_model.summary()
|
2019-12-08 02:25:27 +01:00
|
|
|
|
model.model.summary()
|
|
|
|
|
|
|
|
|
|
## Entraineur
|
|
|
|
|
trainer = ModelTrainer(model, data_loader, CONFIG.get('models')[_MODEL_TYPE], callbacks=[])
|
2019-11-11 04:16:43 +01:00
|
|
|
|
|
2019-12-08 02:25:27 +01:00
|
|
|
|
## Entrainement
|
|
|
|
|
try:
|
|
|
|
|
trainer.train()
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
trainer.model.save()
|
2019-11-11 04:16:43 +01:00
|
|
|
|
|
|
|
|
|
|
2019-12-08 02:25:27 +01:00
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
main()
|