mirror of
https://github.com/prise6/smart-iss-posts
synced 2024-05-05 07:03:10 +02:00
predict picture embedding and save it
This commit is contained in:
parent
78ebb24f25
commit
ac1f75d28e
|
@ -29,8 +29,11 @@ class AbstractClustering:
|
||||||
self.colors = [Tools.get_color_from_label(label, n_classes) for label in self.final_labels]
|
self.colors = [Tools.get_color_from_label(label, n_classes) for label in self.final_labels]
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def predict_embedding(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -51,6 +51,9 @@ class N2DClustering(AbstractClustering):
|
||||||
cluster in np.unique(self.final_labels)}
|
cluster in np.unique(self.final_labels)}
|
||||||
return self.silhouette_score_labels
|
return self.silhouette_score_labels
|
||||||
|
|
||||||
|
def predict_embedding(self, pictures_np):
|
||||||
|
return self.umap_fit.transform(pictures_np)
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
Tools.create_dir_if_not_exists(self.save_directory)
|
Tools.create_dir_if_not_exists(self.save_directory)
|
||||||
|
|
||||||
|
@ -58,5 +61,5 @@ class N2DClustering(AbstractClustering):
|
||||||
joblib.dump(self.kmeans_fit, os.path.join(self.save_directory, self.kmeans_save_name))
|
joblib.dump(self.kmeans_fit, os.path.join(self.save_directory, self.kmeans_save_name))
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
self.umap_fit = joblib.load(os.path.join(self.save_directory, self.pca_save_name))
|
self.umap_fit = joblib.load(os.path.join(self.save_directory, self.umap_save_name))
|
||||||
self.kmeans_fit = joblib.load(os.path.join(self.save_directory, self.kmeans_save_name))
|
self.kmeans_fit = joblib.load(os.path.join(self.save_directory, self.kmeans_save_name))
|
|
@ -1,7 +1,7 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
|
||||||
class DataBaseManager:
|
class MysqlDataBaseManager:
|
||||||
|
|
||||||
def __init__(self, connexion, config):
|
def __init__(self, connexion, config):
|
||||||
self.conn = connexion
|
self.conn = connexion
|
||||||
|
@ -9,25 +9,66 @@ class DataBaseManager:
|
||||||
self.cursor = self.conn.cursor()
|
self.cursor = self.conn.cursor()
|
||||||
|
|
||||||
|
|
||||||
def createPicturesTable(self, force = False):
|
def create_pictures_location_table(self, force = False):
|
||||||
|
|
||||||
if force:
|
if force:
|
||||||
self.cursor.execute("DROP TABLE IF EXISTS `iss`.`pictures`;")
|
self.cursor.execute("DROP TABLE IF EXISTS `iss`.`pictures_location`;")
|
||||||
|
|
||||||
self.cursor.execute("""
|
self.cursor.execute("""
|
||||||
CREATE TABLE `iss`.`pictures` (
|
CREATE TABLE IF NOT EXISTS `iss`.`pictures_location` (
|
||||||
`pictures_latitude` FLOAT(10, 6) NULL,
|
`pictures_latitude` FLOAT(10, 6) NULL,
|
||||||
`pictures_longitude` FLOAT(10, 6 ) NULL ,
|
`pictures_longitude` FLOAT(10, 6 ) NULL ,
|
||||||
`pictures_id` VARCHAR( 15 ) PRIMARY KEY ,
|
`pictures_id` VARCHAR( 15 ) PRIMARY KEY ,
|
||||||
`pictures_timestamp` TIMESTAMP NULL ,
|
`pictures_timestamp` TIMESTAMP NULL ,
|
||||||
`pictures_location` TEXT NULL
|
`pictures_location_text` TEXT NULL
|
||||||
) ENGINE = MYISAM ;
|
) ENGINE = MYISAM ;
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
||||||
def insertRowPictures(self, array):
|
def insert_row_pictures_location(self, array):
|
||||||
|
|
||||||
sql_insert_template = "INSERT INTO `iss`.`pictures` (pictures_latitude, pictures_longitude, pictures_id, pictures_timestamp, pictures_location) VALUES (%s, %s, %s, %s, %s);"
|
sql_insert_template = "INSERT INTO `iss`.`pictures_location` (pictures_latitude, pictures_longitude, pictures_id, pictures_timestamp, pictures_location_text) VALUES (%s, %s, %s, %s, %s);"
|
||||||
|
|
||||||
|
self.cursor.executemany(sql_insert_template, array)
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
return self.cursor.rowcount
|
||||||
|
|
||||||
|
|
||||||
|
def create_pictures_embedding_table(self, force = False):
|
||||||
|
|
||||||
|
if force:
|
||||||
|
self.cursor.execute("DROP TABLE IF EXISTS `iss`.`pictures_embedding`;")
|
||||||
|
|
||||||
|
self.cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS `iss`.`pictures_embedding` (
|
||||||
|
`pictures_id` VARCHAR( 15 ) ,
|
||||||
|
`pictures_x` FLOAT(8, 4),
|
||||||
|
`pictures_y` FLOAT(8, 4),
|
||||||
|
`clustering_type` VARCHAR(15),
|
||||||
|
`clustering_version` VARCHAR(5),
|
||||||
|
`clustering_model_type` VARCHAR(15),
|
||||||
|
`clustering_model_name` VARCHAR(15),
|
||||||
|
UNIQUE KEY `unique_key` (`pictures_id`,`clustering_type`, `clustering_version`, `clustering_model_type`,`clustering_model_name`),
|
||||||
|
KEY `index_key_1` (`pictures_id`)
|
||||||
|
) ENGINE = MYISAM ;
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
def drop_embedding_partition(self, clustering_type, clustering_version, clustering_model_type, clustering_model_name):
|
||||||
|
|
||||||
|
req = "DELETE FROM `iss`.`pictures_embedding` WHERE clustering_type = %s AND clustering_version = %s AND clustering_model_type = %s AND clustering_model_name = %s"
|
||||||
|
|
||||||
|
self.cursor.execute(req, (clustering_type, clustering_version, clustering_model_type, clustering_model_name))
|
||||||
|
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
return self.cursor.rowcount
|
||||||
|
|
||||||
|
|
||||||
|
def insert_row_pictures_embedding(self, array):
|
||||||
|
|
||||||
|
sql_insert_template = "INSERT INTO `iss`.`pictures_embedding` (pictures_id, pictures_x, pictures_y, clustering_type, clustering_version, clustering_model_type, clustering_model_name) VALUES (%s, %s, %s, %s, %s, %s, %s);"
|
||||||
|
|
||||||
self.cursor.executemany(sql_insert_template, array)
|
self.cursor.executemany(sql_insert_template, array)
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
|
|
106
iss/exec/bdd.py
106
iss/exec/bdd.py
|
@ -3,28 +3,96 @@ import time
|
||||||
import mysql.connector
|
import mysql.connector
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import datetime as dt
|
import datetime as dt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from iss.init_config import CONFIG
|
from iss.init_config import CONFIG
|
||||||
from iss.data.DataBaseManager import DataBaseManager
|
from iss.data.DataBaseManager import MysqlDataBaseManager
|
||||||
|
from iss.clustering import ClassicalClustering, AdvancedClustering, N2DClustering
|
||||||
CON_MYSQL = mysql.connector.connect(
|
from iss.tools import Tools
|
||||||
host = CONFIG.get('mysql')['database']['server'],
|
|
||||||
user = CONFIG.get('mysql')['database']['user'],
|
|
||||||
passwd = CONFIG.get('mysql')['database']['password'],
|
|
||||||
database = CONFIG.get('mysql')['database']['name'],
|
|
||||||
port = CONFIG.get('mysql')['database']['port']
|
|
||||||
)
|
|
||||||
|
|
||||||
dbm = DataBaseManager(CON_MYSQL, CONFIG)
|
|
||||||
|
|
||||||
|
|
||||||
history = pd.read_csv(os.path.join(CONFIG.get("directory")['data_dir'], "raw", "history", "history.txt"), sep=";", names=['latitude', 'longitude', 'id', 'location'])
|
def create_db_manager(config):
|
||||||
history['timestamp'] = pd.to_datetime(history.id, format="%Y%m%d-%H%M%S").dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
history.fillna('NULL', inplace=True)
|
|
||||||
history = history[['latitude', 'longitude', 'id', 'timestamp', 'location']]
|
|
||||||
history_tuple = [tuple(x) for x in history.values]
|
|
||||||
|
|
||||||
dbm.createPicturesTable(force=True)
|
CON_MYSQL = mysql.connector.connect(
|
||||||
count = dbm.insertRowPictures(history_tuple)
|
host = config.get('mysql')['database']['server'],
|
||||||
|
user = config.get('mysql')['database']['user'],
|
||||||
|
passwd = config.get('mysql')['database']['password'],
|
||||||
|
database = config.get('mysql')['database']['name'],
|
||||||
|
port = config.get('mysql')['database']['port']
|
||||||
|
)
|
||||||
|
|
||||||
print(count)
|
return MysqlDataBaseManager(CON_MYSQL, config)
|
||||||
|
|
||||||
|
|
||||||
|
def populate_locations(config, db_manager):
|
||||||
|
|
||||||
|
history = pd.read_csv(os.path.join(CONFIG.get("directory")['data_dir'], "raw", "history", "history.txt"), sep=";", names=['latitude', 'longitude', 'id', 'location'])
|
||||||
|
history['timestamp'] = pd.to_datetime(history.id, format="%Y%m%d-%H%M%S").dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
history.fillna('NULL', inplace=True)
|
||||||
|
history = history[['latitude', 'longitude', 'id', 'timestamp', 'location']]
|
||||||
|
history_tuple = [tuple(x) for x in history.values]
|
||||||
|
|
||||||
|
db_manager.create_pictures_location_table(force=True)
|
||||||
|
count = db_manager.insert_row_pictures_location(history_tuple)
|
||||||
|
|
||||||
|
print("Nombre d'insertion: %s" % count)
|
||||||
|
|
||||||
|
|
||||||
|
def populate_embedding(config, db_manager, clustering_type, clustering_version, clustering_model_type, clustering_model_name, drop=False):
|
||||||
|
|
||||||
|
db_manager.create_pictures_embedding_table()
|
||||||
|
clustering_config = config.get('clustering')[clustering_type]
|
||||||
|
clustering_config['version'] = clustering_version
|
||||||
|
clustering_config['model']['type'] = clustering_model_type
|
||||||
|
clustering_config['model']['name'] = clustering_model_name
|
||||||
|
|
||||||
|
if drop:
|
||||||
|
db_manager.drop_embedding_partition(clustering_type, clustering_version, clustering_model_type, clustering_model_name)
|
||||||
|
|
||||||
|
if clustering_type == 'n2d':
|
||||||
|
clustering = N2DClustering(clustering_config)
|
||||||
|
elif clustering_type == 'classical':
|
||||||
|
clustering = ClassicalClustering(clustering_config)
|
||||||
|
else:
|
||||||
|
raise Exception
|
||||||
|
|
||||||
|
clustering.load()
|
||||||
|
model, model_config = Tools.load_model(CONFIG, clustering_model_type, clustering_model_name)
|
||||||
|
filenames = Tools.list_directory_filenames(CONFIG.get('directory')['collections'])
|
||||||
|
generator = Tools.load_latent_representation(CONFIG, model, model_config, filenames, 496, None, True)
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for ids, latents in generator:
|
||||||
|
pictures_embedding = clustering.predict_embedding(latents)
|
||||||
|
rows = []
|
||||||
|
for i, id in enumerate(ids):
|
||||||
|
rows.append((
|
||||||
|
id,
|
||||||
|
float(np.round(pictures_embedding[i][0], 4)),
|
||||||
|
float(np.round(pictures_embedding[i][1], 4)),
|
||||||
|
clustering_type,
|
||||||
|
clustering_version,
|
||||||
|
clustering_model_type,
|
||||||
|
clustering_model_name
|
||||||
|
))
|
||||||
|
count += db_manager.insert_row_pictures_embedding(rows)
|
||||||
|
print("Nombre d'insertion: %s / %s" % (count, len(filenames)))
|
||||||
|
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def main(action = 'populate_embedding'):
|
||||||
|
|
||||||
|
db_manager = create_db_manager(CONFIG)
|
||||||
|
|
||||||
|
if action == 'population_locations':
|
||||||
|
populate_locations(CONFIG, db_manager)
|
||||||
|
elif action == 'populate_embedding':
|
||||||
|
populate_embedding(CONFIG, db_manager, 'n2d', 1, 'simple_conv', 'model_colab')
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -17,44 +17,6 @@ from iss.clustering import ClassicalClustering, AdvancedClustering, N2DClusterin
|
||||||
_DEBUG = True
|
_DEBUG = True
|
||||||
|
|
||||||
|
|
||||||
def load_model(config, clustering_type):
|
|
||||||
"""
|
|
||||||
Load model according to config
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_type = config.get('clustering')[clustering_type]['model']['type']
|
|
||||||
model_name = config.get('clustering')[clustering_type]['model']['name']
|
|
||||||
config.get('models')[model_type]['model_name'] = model_name
|
|
||||||
|
|
||||||
if model_type == 'simple_conv':
|
|
||||||
model = SimpleConvAutoEncoder(config.get('models')[model_type])
|
|
||||||
elif model_type == 'simple':
|
|
||||||
model = SimpleAutoEncoder(config.get('models')[model_type])
|
|
||||||
else:
|
|
||||||
raise Exception
|
|
||||||
|
|
||||||
model_config = config.get('models')[model_type]
|
|
||||||
|
|
||||||
return model, model_config
|
|
||||||
|
|
||||||
|
|
||||||
def load_images(config, clustering_type, model, model_config, batch_size, n_batch):
|
|
||||||
"""
|
|
||||||
load images and predictions
|
|
||||||
"""
|
|
||||||
model_type = config.get('clustering')[clustering_type]['model']['type']
|
|
||||||
filenames = Tools.list_directory_filenames(os.path.join(config.get('sampling')['autoencoder']['directory']['train']))
|
|
||||||
generator_imgs = Tools.generator_np_picture_from_filenames(filenames, target_size = (model_config['input_height'], model_config['input_width']), batch = batch_size, nb_batch = n_batch)
|
|
||||||
|
|
||||||
pictures_id, pictures_preds = Tools.encoded_pictures_from_generator(generator_imgs, model)
|
|
||||||
if model_type in ['simple_conv']:
|
|
||||||
intermediate_output = pictures_preds.reshape((pictures_preds.shape[0], model_config['latent_width']*model_config['latent_height']*model_config['latent_channel']))
|
|
||||||
else:
|
|
||||||
intermediate_output = pictures_preds
|
|
||||||
|
|
||||||
return pictures_id, intermediate_output
|
|
||||||
|
|
||||||
|
|
||||||
def run_clustering(config, clustering_type, pictures_id, intermediate_output):
|
def run_clustering(config, clustering_type, pictures_id, intermediate_output):
|
||||||
"""
|
"""
|
||||||
Apply clustering on images
|
Apply clustering on images
|
||||||
|
@ -217,9 +179,9 @@ def plot_mosaics(config, clustering_type, clustering, output_image_width, output
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
_CLUSTERING_TYPE = 'classical'
|
_CLUSTERING_TYPE = 'n2d'
|
||||||
_BATCH_SIZE = 496
|
_BATCH_SIZE = 496
|
||||||
_N_BATCH = 1
|
_N_BATCH = 10
|
||||||
_PLOTS = True
|
_PLOTS = True
|
||||||
_MOSAICS = True
|
_MOSAICS = True
|
||||||
_SILHOUETTE = True
|
_SILHOUETTE = True
|
||||||
|
@ -228,8 +190,9 @@ def main():
|
||||||
_MOSAIC_NROW = 10
|
_MOSAIC_NROW = 10
|
||||||
_MOSAIC_NCOL_MAX = 10
|
_MOSAIC_NCOL_MAX = 10
|
||||||
|
|
||||||
model, model_config = load_model(CONFIG, _CLUSTERING_TYPE)
|
model, model_config = Tools.load_model(CONFIG, CONFIG.get('clustering')[_CLUSTERING_TYPE]['model']['type'], CONFIG.get('clustering')[_CLUSTERING_TYPE]['model']['name'])
|
||||||
pictures_id, intermediate_output = load_images(CONFIG, _CLUSTERING_TYPE, model, model_config, _BATCH_SIZE, _N_BATCH)
|
filenames = Tools.list_directory_filenames(CONFIG.get('sampling')['autoencoder']['directory']['train'])
|
||||||
|
pictures_id, intermediate_output = Tools.load_latent_representation(CONFIG, model, model_config, filenames, _BATCH_SIZE, _N_BATCH, False)
|
||||||
|
|
||||||
clustering = run_clustering(CONFIG, _CLUSTERING_TYPE, pictures_id, intermediate_output)
|
clustering = run_clustering(CONFIG, _CLUSTERING_TYPE, pictures_id, intermediate_output)
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import base64
|
import base64
|
||||||
|
@ -57,19 +58,30 @@ class Tools:
|
||||||
return path
|
return path
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def encoded_pictures_from_generator(generator, model):
|
def encoded_pictures_from_generator(generator, model, by_step=False):
|
||||||
|
if by_step:
|
||||||
|
return Tools.encoded_pictures_from_generator_by_step(generator, model)
|
||||||
|
|
||||||
predictions_list = []
|
predictions_list = []
|
||||||
predictions_id = []
|
predictions_id = []
|
||||||
for imgs in generator:
|
for imgs in generator:
|
||||||
predictions_id.append(imgs[0])
|
tmp_id = [os.path.splitext(os.path.basename(id))[0] for id in imgs[0]]
|
||||||
predictions_list.append(model.get_encoded_prediction(imgs[1]))
|
tmp_pred = model.get_encoded_prediction(imgs[1])
|
||||||
|
predictions_id += tmp_id
|
||||||
|
predictions_list.append(tmp_pred)
|
||||||
|
|
||||||
predictions = np.concatenate(tuple(predictions_list), axis = 0)
|
predictions = np.concatenate(tuple(predictions_list), axis = 0)
|
||||||
predictions_id = [os.path.splitext(os.path.basename(id))[0] for sub_id in predictions_id for id in sub_id]
|
|
||||||
|
|
||||||
return predictions_id, predictions
|
return predictions_id, predictions
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def encoded_pictures_from_generator_by_step(generator, model):
|
||||||
|
for imgs in generator:
|
||||||
|
# tmp_id = [os.path.splitext(os.path.basename(id))[0] for sub_id in imgs[0] for id in sub_id]
|
||||||
|
tmp_id = [os.path.splitext(os.path.basename(id))[0] for id in imgs[0]]
|
||||||
|
tmp_pred = model.get_encoded_prediction(imgs[1])
|
||||||
|
yield (tmp_id, tmp_pred)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def read_np_picture(path, target_size = None, scale = 1):
|
def read_np_picture(path, target_size = None, scale = 1):
|
||||||
# img = PIL.Image.open(filename)
|
# img = PIL.Image.open(filename)
|
||||||
|
@ -79,11 +91,12 @@ class Tools:
|
||||||
return img_np
|
return img_np
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def list_directory_filenames(path):
|
def list_directory_filenames(path, pattern = ".*jpg$"):
|
||||||
filenames = os.listdir(path)
|
filenames = os.listdir(path)
|
||||||
np.random.seed(33213)
|
np.random.seed(33213)
|
||||||
np.random.shuffle(filenames)
|
np.random.shuffle(filenames)
|
||||||
filenames = [os.path.join(path,f) for f in filenames]
|
pattern_regex = re.compile(pattern)
|
||||||
|
filenames = [os.path.join(path,f) for f in filenames if pattern_regex.match(f)]
|
||||||
|
|
||||||
return filenames
|
return filenames
|
||||||
|
|
||||||
|
@ -138,4 +151,47 @@ class Tools:
|
||||||
linkage_matrix = np.column_stack([children, distance, no_of_observations]).astype(float)
|
linkage_matrix = np.column_stack([children, distance, no_of_observations]).astype(float)
|
||||||
|
|
||||||
# Plot the corresponding dendrogram
|
# Plot the corresponding dendrogram
|
||||||
dendrogram(linkage_matrix, **kwargs)
|
dendrogram(linkage_matrix, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_model(config, model_type, model_name):
|
||||||
|
"""
|
||||||
|
Load model according to config
|
||||||
|
"""
|
||||||
|
from iss.models import SimpleConvAutoEncoder, SimpleAutoEncoder
|
||||||
|
|
||||||
|
config.get('models')[model_type]['model_name'] = model_name
|
||||||
|
|
||||||
|
if model_type == 'simple_conv':
|
||||||
|
model = SimpleConvAutoEncoder(config.get('models')[model_type])
|
||||||
|
elif model_type == 'simple':
|
||||||
|
model = SimpleAutoEncoder(config.get('models')[model_type])
|
||||||
|
else:
|
||||||
|
raise Exception
|
||||||
|
|
||||||
|
model_config = config.get('models')[model_type]
|
||||||
|
|
||||||
|
return model, model_config
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_latent_representation(config, model, model_config, filenames, batch_size, n_batch, by_step):
|
||||||
|
"""
|
||||||
|
load images and predictions
|
||||||
|
"""
|
||||||
|
if by_step:
|
||||||
|
return Tools.load_latent_representation_by_step(config, model, model_config, filenames, batch_size, n_batch)
|
||||||
|
|
||||||
|
generator_imgs = Tools.generator_np_picture_from_filenames(filenames, target_size = (model_config['input_height'], model_config['input_width']), batch = batch_size, nb_batch = n_batch)
|
||||||
|
|
||||||
|
pictures_id, pictures_preds = Tools.encoded_pictures_from_generator(generator_imgs, model, by_step)
|
||||||
|
intermediate_output = pictures_preds.reshape((pictures_preds.shape[0], -1))
|
||||||
|
|
||||||
|
return pictures_id, intermediate_output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_latent_representation_by_step(config, model, model_config, filenames, batch_size, n_batch):
|
||||||
|
generator_imgs = Tools.generator_np_picture_from_filenames(filenames, target_size = (model_config['input_height'], model_config['input_width']), batch = batch_size, nb_batch = n_batch)
|
||||||
|
|
||||||
|
for pictures_id, pictures_preds in Tools.encoded_pictures_from_generator(generator_imgs, model, True):
|
||||||
|
intermediate_output = pictures_preds.reshape((pictures_preds.shape[0], -1))
|
||||||
|
yield pictures_id, intermediate_output
|
||||||
|
|
Loading…
Reference in a new issue