tests clustering + mosaic: failure + test facets

This commit is contained in:
Francois Vieille 2019-12-12 01:08:04 +01:00
vecāks ac1f75d28e
revīzija 4a95eb9d92
10 mainīti faili ar 172 papildinājumiem un 35 dzēšanām

Parādīt failu

@ -69,6 +69,15 @@ exec_clustering:
$(PYTHON_INTERPRETER) -m iss.exec.clustering
#################################################################################
# OUTSIDE CONTAINER #
#################################################################################
maximize_test:
cp $(PROJECT_DIR)/data/raw/collections/20180211-130001.jpg $(PROJECT_DIR)/data/isr/input/sample/
docker run -v "$(PROJECT_DIR)/data/isr:/home/isr/data" -v "$(PROJECT_DIR)/../image-super-resolution/weights:/home/isr/weights" -v "$(PROJECT_DIR)/config/config_isr.yml:/home/isr/config.yml" -it isr -d -p -c config.yml
#################################################################################
# FLOYDHUB #
#################################################################################

Parādīt failu

@ -31,6 +31,9 @@ class AbstractClustering:
def predict_embedding(self):
raise NotImplementedError
def predict_label(self):
raise NotImplementedError
def save(self):
raise NotImplementedError

Parādīt failu

@ -0,0 +1,66 @@
# -*- coding: utf-8 -*-
import os
import numpy as np
import umap
import hdbscan
from iss.tools import Tools
from iss.clustering import AbstractClustering
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_samples
from sklearn.externals import joblib
class DBScanClustering(AbstractClustering):
"""
Cf: https://umap-learn.readthedocs.io/en/latest/clustering.html
"""
def __init__(self, config, pictures_id = None, pictures_np = None):
super().__init__(config, pictures_id, pictures_np)
self.umap_args = self.config['umap']
self.umap_fit = None
self.umap_embedding = None
self.umap_save_name = 'UMAP_model.pkl'
self.dbscan_fit = None
self.dbscan_args = self.config['dbscan']
self.dbscan_labels = None
self.dbscan_centers = []
self.dbscan_save_name = "dbscan_model.pkl"
def compute_umap(self):
self.umap_fit = umap.UMAP(**self.umap_args)
self.umap_embedding = self.umap_fit.fit_transform(self.pictures_np)
return self
def compute_dbscan(self):
self.dbscan_fit = hdbscan.HDBSCAN(**self.dbscan_args)
self.dbscan_fit.fit(self.umap_embedding)
self.dbscan_labels = self.dbscan_fit.labels_
return self
def compute_final_labels(self):
self.final_labels = self.dbscan_labels
return self
def compute_silhouette_score(self):
self.silhouette_score = silhouette_samples(self.pictures_np, self.final_labels)
self.silhouette_score_labels = {cluster: np.mean(self.silhouette_score[self.final_labels == cluster]) for
cluster in np.unique(self.final_labels)}
return self.silhouette_score_labels
def predict_embedding(self, pictures_np):
return self.umap_fit.transform(pictures_np)
def save(self):
Tools.create_dir_if_not_exists(self.save_directory)
joblib.dump(self.umap_fit, os.path.join(self.save_directory, self.umap_save_name))
joblib.dump(self.dbscan_fit, os.path.join(self.save_directory, self.dbscan_save_name))
def load(self):
self.umap_fit = joblib.load(os.path.join(self.save_directory, self.umap_save_name))
self.dbscan_fit = joblib.load(os.path.join(self.save_directory, self.dbscan_save_name))

Parādīt failu

@ -53,6 +53,9 @@ class N2DClustering(AbstractClustering):
def predict_embedding(self, pictures_np):
return self.umap_fit.transform(pictures_np)
def predict_label(self, pictures_embedding):
return self.kmeans_fit.predict(pictures_embedding)
def save(self):
Tools.create_dir_if_not_exists(self.save_directory)

Parādīt failu

@ -1,4 +1,5 @@
from .AbstractClustering import AbstractClustering
from .ClassicalClustering import ClassicalClustering
from .AdvancedClustering import AdvancedClustering
from .N2DClustering import N2DClustering
from .N2DClustering import N2DClustering
from .DBScanClustering import DBScanClustering

Parādīt failu

@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
import numpy as np
class MysqlDataBaseManager:
@ -45,6 +46,7 @@ CREATE TABLE IF NOT EXISTS `iss`.`pictures_embedding` (
`pictures_id` VARCHAR( 15 ) ,
`pictures_x` FLOAT(8, 4),
`pictures_y` FLOAT(8, 4),
`label` INT NULL,
`clustering_type` VARCHAR(15),
`clustering_version` VARCHAR(5),
`clustering_model_type` VARCHAR(15),
@ -68,9 +70,17 @@ CREATE TABLE IF NOT EXISTS `iss`.`pictures_embedding` (
def insert_row_pictures_embedding(self, array):
sql_insert_template = "INSERT INTO `iss`.`pictures_embedding` (pictures_id, pictures_x, pictures_y, clustering_type, clustering_version, clustering_model_type, clustering_model_name) VALUES (%s, %s, %s, %s, %s, %s, %s);"
sql_insert_template = "INSERT INTO `iss`.`pictures_embedding` (pictures_id, pictures_x, pictures_y, label, clustering_type, clustering_version, clustering_model_type, clustering_model_name) VALUES (%s, %s, %s, %s, %s, %s, %s, %s);"
self.cursor.executemany(sql_insert_template, array)
self.conn.commit()
return self.cursor.rowcount
def select_close_embedding(self, x, y, limit):
sql_req = "SELECT pictures_id, SQRT(POWER(pictures_x - %s, 2) + POWER(pictures_y - %s, 2)) as distance FROM iss.pictures_embedding ORDER BY distance ASC LIMIT %s"
self.cursor.execute(sql_req, (float(np.round(x, 4)), float(np.round(y, 4)), limit))
return self.cursor.fetchall()

Parādīt failu

@ -6,24 +6,10 @@ import datetime as dt
import numpy as np
from iss.init_config import CONFIG
from iss.data.DataBaseManager import MysqlDataBaseManager
from iss.clustering import ClassicalClustering, AdvancedClustering, N2DClustering
from iss.clustering import ClassicalClustering, AdvancedClustering, N2DClustering, DBScanClustering
from iss.tools import Tools
def create_db_manager(config):
CON_MYSQL = mysql.connector.connect(
host = config.get('mysql')['database']['server'],
user = config.get('mysql')['database']['user'],
passwd = config.get('mysql')['database']['password'],
database = config.get('mysql')['database']['name'],
port = config.get('mysql')['database']['port']
)
return MysqlDataBaseManager(CON_MYSQL, config)
def populate_locations(config, db_manager):
history = pd.read_csv(os.path.join(CONFIG.get("directory")['data_dir'], "raw", "history", "history.txt"), sep=";", names=['latitude', 'longitude', 'id', 'location'])
@ -40,11 +26,7 @@ def populate_locations(config, db_manager):
def populate_embedding(config, db_manager, clustering_type, clustering_version, clustering_model_type, clustering_model_name, drop=False):
db_manager.create_pictures_embedding_table()
clustering_config = config.get('clustering')[clustering_type]
clustering_config['version'] = clustering_version
clustering_config['model']['type'] = clustering_model_type
clustering_config['model']['name'] = clustering_model_name
clustering, clustering_config = Tools.load_clustering(CONFIG, clustering_type, clustering_version, clustering_model_type, clustering_model_name)
if drop:
db_manager.drop_embedding_partition(clustering_type, clustering_version, clustering_model_type, clustering_model_name)
@ -53,6 +35,8 @@ def populate_embedding(config, db_manager, clustering_type, clustering_version,
clustering = N2DClustering(clustering_config)
elif clustering_type == 'classical':
clustering = ClassicalClustering(clustering_config)
elif clustering_type == 'dbscan':
clustering = DBScanClustering(clustering_config)
else:
raise Exception
@ -64,12 +48,14 @@ def populate_embedding(config, db_manager, clustering_type, clustering_version,
count = 0
for ids, latents in generator:
pictures_embedding = clustering.predict_embedding(latents)
pictures_label = clustering.predict_label(pictures_embedding)
rows = []
for i, id in enumerate(ids):
rows.append((
id,
float(np.round(pictures_embedding[i][0], 4)),
float(np.round(pictures_embedding[i][1], 4)),
int(pictures_label[i]),
clustering_type,
clustering_version,
clustering_model_type,
@ -84,12 +70,24 @@ def populate_embedding(config, db_manager, clustering_type, clustering_version,
def main(action = 'populate_embedding'):
db_manager = create_db_manager(CONFIG)
db_manager = Tools.create_db_manager(CONFIG)
if action == 'population_locations':
populate_locations(CONFIG, db_manager)
elif action == 'populate_embedding':
populate_embedding(CONFIG, db_manager, 'n2d', 1, 'simple_conv', 'model_colab')
db_manager.create_pictures_embedding_table(False)
to_load = [
{'clustering_type': 'n2d', 'clustering_version': 1, 'clustering_model_type': 'simple_conv', 'clustering_model_name': 'model_colab', 'drop': False},
{'clustering_type': 'n2d', 'clustering_version': 2, 'clustering_model_type': 'simple_conv', 'clustering_model_name': 'model_colab', 'drop': False},
{'clustering_type': 'n2d', 'clustering_version': 3, 'clustering_model_type': 'simple_conv', 'clustering_model_name': 'model_colab', 'drop': False},
]
for kwargs in to_load:
try:
populate_embedding(CONFIG, db_manager, **kwargs)
except Exception as err:
print(err)
pass
else:
pass

Parādīt failu

@ -11,7 +11,7 @@ from bokeh.models import HoverTool, ColumnDataSource, CategoricalColorMapper
from iss.init_config import CONFIG
from iss.tools import Tools
from iss.models import SimpleConvAutoEncoder, SimpleAutoEncoder
from iss.clustering import ClassicalClustering, AdvancedClustering, N2DClustering
from iss.clustering import ClassicalClustering, AdvancedClustering, N2DClustering, DBScanClustering
_DEBUG = True
@ -45,6 +45,14 @@ def run_clustering(config, clustering_type, pictures_id, intermediate_output):
clustering.compute_kmeans()
clustering.compute_final_labels()
clustering.compute_colors()
elif clustering_type == 'dbscan':
if _DEBUG:
print("HDBSCAN Clustering")
clustering = DBScanClustering(config.get('clustering')['dbscan'], pictures_id, intermediate_output)
clustering.compute_umap()
clustering.compute_dbscan()
clustering.compute_final_labels()
clustering.compute_colors()
return clustering
@ -71,7 +79,7 @@ def run_plots(config, clustering_type, clustering):
ax.add_artist(legend1)
plt.savefig(os.path.join(clustering.save_directory, 'tsne_clusters.png'))
if clustering_type in ['n2d']:
if clustering_type in ['n2d', 'dbscan']:
## Graphs of TSNE and final clusters
fig, ax = plt.subplots(figsize=(24, 14))
classes = clustering.final_labels
@ -80,12 +88,12 @@ def run_plots(config, clustering_type, clustering):
ax.add_artist(legend1)
plt.savefig(os.path.join(clustering.save_directory, 'umap_clusters.png'))
if clustering_type in ['n2d', 'classical']:
if clustering_type in ['n2d', 'classical', 'dbscan']:
filenames = [os.path.join(config.get('directory')['collections'], "%s.jpg" % one_res[0]) for one_res in clustering.get_results()]
images_array = [Tools.read_np_picture(img_filename, target_size = (54, 96)) for img_filename in filenames]
base64_images = [Tools.base64_image(img) for img in images_array]
if clustering_type == 'n2d':
if clustering_type in ['n2d', 'dbscan']:
x = clustering.umap_embedding[:, 0]
y = clustering.umap_embedding[:, 1]
html_file = 'umap_bokeh.html'
@ -181,7 +189,7 @@ def plot_mosaics(config, clustering_type, clustering, output_image_width, output
def main():
_CLUSTERING_TYPE = 'n2d'
_BATCH_SIZE = 496
_N_BATCH = 10
_N_BATCH = 5
_PLOTS = True
_MOSAICS = True
_SILHOUETTE = True

Parādīt failu

@ -4,12 +4,15 @@ import PIL
import os
import re
import numpy as np
import mysql.connector
from io import BytesIO
import base64
from scipy.cluster.hierarchy import dendrogram
from keras_preprocessing.image.utils import load_img
import matplotlib as plt
from iss.data.DataBaseManager import MysqlDataBaseManager
class Tools:
@ -172,16 +175,35 @@ class Tools:
model_config = config.get('models')[model_type]
return model, model_config
@staticmethod
def load_clustering(config, clustering_type, clustering_version, clustering_model_type, clustering_model_name):
from iss.clustering import ClassicalClustering, AdvancedClustering, N2DClustering
clustering_config = config.get('clustering')[clustering_type]
clustering_config['version'] = clustering_version
clustering_config['model']['type'] = clustering_model_type
clustering_config['model']['name'] = clustering_model_name
if clustering_type == 'n2d':
clustering = N2DClustering(clustering_config)
elif clustering_type == 'classical':
clustering = ClassicalClustering(clustering_config)
else:
raise Exception
clustering.load()
return clustering, clustering_config
@staticmethod
def load_latent_representation(config, model, model_config, filenames, batch_size, n_batch, by_step):
def load_latent_representation(config, model, model_config, filenames, batch_size, n_batch, by_step, scale=1./255):
"""
load images and predictions
"""
if by_step:
return Tools.load_latent_representation_by_step(config, model, model_config, filenames, batch_size, n_batch)
generator_imgs = Tools.generator_np_picture_from_filenames(filenames, target_size = (model_config['input_height'], model_config['input_width']), batch = batch_size, nb_batch = n_batch)
generator_imgs = Tools.generator_np_picture_from_filenames(filenames, target_size = (model_config['input_height'], model_config['input_width']), scale=scale, batch = batch_size, nb_batch = n_batch)
pictures_id, pictures_preds = Tools.encoded_pictures_from_generator(generator_imgs, model, by_step)
intermediate_output = pictures_preds.reshape((pictures_preds.shape[0], -1))
@ -189,9 +211,22 @@ class Tools:
return pictures_id, intermediate_output
@staticmethod
def load_latent_representation_by_step(config, model, model_config, filenames, batch_size, n_batch):
generator_imgs = Tools.generator_np_picture_from_filenames(filenames, target_size = (model_config['input_height'], model_config['input_width']), batch = batch_size, nb_batch = n_batch)
def load_latent_representation_by_step(config, model, model_config, filenames, batch_size, n_batch, scale=1./255):
generator_imgs = Tools.generator_np_picture_from_filenames(filenames, target_size = (model_config['input_height'], model_config['input_width']), scale=scale, batch = batch_size, nb_batch = n_batch)
for pictures_id, pictures_preds in Tools.encoded_pictures_from_generator(generator_imgs, model, True):
intermediate_output = pictures_preds.reshape((pictures_preds.shape[0], -1))
yield pictures_id, intermediate_output
@staticmethod
def create_db_manager(config):
CON_MYSQL = mysql.connector.connect(
host = config.get('mysql')['database']['server'],
user = config.get('mysql')['database']['user'],
passwd = config.get('mysql')['database']['password'],
database = config.get('mysql')['database']['name'],
port = config.get('mysql')['database']['port']
)
return MysqlDataBaseManager(CON_MYSQL, config)

Parādīt failu

@ -1,6 +1,7 @@
setuptools==40.8.0
Click==7.0
numpy==1.13.3
# numpy==1.13.3
numpy==1.17.4
pandas==0.23.4
tensorflow==1.12.0
Keras==2.2.4
@ -10,4 +11,7 @@ python-dotenv==0.10.1
PyYAML==3.13
matplotlib>=3.1.0
umap-learn==0.3.10
bokeh==0.13.0
bokeh==0.13.0
mysql-connector-python==8.0.18
hdbscan==0.8.24
facets_overview==1.0.0