mirror of
https://github.com/prise6/smart-iss-posts
synced 2024-05-05 23:23:11 +02:00
prepare clustering industrialisation
This commit is contained in:
parent
8592ee01ab
commit
45dbfd8db7
|
@ -1,4 +1,5 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
import os
|
||||||
from iss.tools import Tools
|
from iss.tools import Tools
|
||||||
|
|
||||||
class AbstractClustering:
|
class AbstractClustering:
|
||||||
|
@ -6,13 +7,13 @@ class AbstractClustering:
|
||||||
def __init__(self, config, pictures_id, pictures_np):
|
def __init__(self, config, pictures_id, pictures_np):
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.save_directory = os.path.join(self.config['save_directory'], '%s_%s_%s' % (self.config['model']['type'], self.config['model']['name'], self.config['version']))
|
||||||
self.pictures_id = pictures_id
|
self.pictures_id = pictures_id
|
||||||
self.pictures_np = pictures_np
|
self.pictures_np = pictures_np
|
||||||
self.final_labels = None
|
self.final_labels = None
|
||||||
self.colors = None
|
self.colors = None
|
||||||
|
|
||||||
if self.config['save_directory']:
|
Tools.create_dir_if_not_exists(self.save_directory)
|
||||||
Tools.create_dir_if_not_exists(self.config['save_directory'])
|
|
||||||
|
|
||||||
def compute_final_labels(self):
|
def compute_final_labels(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -20,19 +20,19 @@ class ClassicalClustering(AbstractClustering):
|
||||||
self.pca_fit = None
|
self.pca_fit = None
|
||||||
self.pca_args = self.config['PCA']
|
self.pca_args = self.config['PCA']
|
||||||
self.pca_reduction = None
|
self.pca_reduction = None
|
||||||
self.pca_save_name = "PCA_model_v%s.pkl" % (self.config['version'])
|
self.pca_save_name = "PCA_model.pkl"
|
||||||
|
|
||||||
self.kmeans_fit = None
|
self.kmeans_fit = None
|
||||||
self.kmeans_args = self.config['kmeans']
|
self.kmeans_args = self.config['kmeans']
|
||||||
self.kmeans_labels = None
|
self.kmeans_labels = None
|
||||||
self.kmeans_centers = []
|
self.kmeans_centers = []
|
||||||
self.kmeans_save_name = "kmeans_model_v%s.pkl" % (self.config['version'])
|
self.kmeans_save_name = "kmeans_model.pkl"
|
||||||
|
|
||||||
|
|
||||||
self.cah_fit = None
|
self.cah_fit = None
|
||||||
self.cah_args = self.config['CAH']
|
self.cah_args = self.config['CAH']
|
||||||
self.cah_labels = None
|
self.cah_labels = None
|
||||||
self.cah_save_name = "cah_model_v%s.pkl" % (self.config['version'])
|
self.cah_save_name = "cah_model.pkl"
|
||||||
|
|
||||||
self.tsne_fit = None
|
self.tsne_fit = None
|
||||||
self.tsne_args = self.config['TSNE']
|
self.tsne_args = self.config['TSNE']
|
||||||
|
@ -88,15 +88,15 @@ class ClassicalClustering(AbstractClustering):
|
||||||
|
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
Tools.create_dir_if_not_exists(self.config['save_directory'])
|
Tools.create_dir_if_not_exists(self.save_directory)
|
||||||
|
|
||||||
joblib.dump(self.pca_fit, os.path.join(self.config['save_directory'], self.pca_save_name))
|
joblib.dump(self.pca_fit, os.path.join(self.save_directory, self.pca_save_name))
|
||||||
joblib.dump(self.kmeans_fit, os.path.join(self.config['save_directory'], self.kmeans_save_name))
|
joblib.dump(self.kmeans_fit, os.path.join(self.save_directory, self.kmeans_save_name))
|
||||||
joblib.dump(self.cah_fit, os.path.join(self.config['save_directory'], self.cah_save_name))
|
joblib.dump(self.cah_fit, os.path.join(self.save_directory, self.cah_save_name))
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
self.pca_fit = joblib.load(os.path.join(self.config['save_directory'], self.pca_save_name))
|
self.pca_fit = joblib.load(os.path.join(self.save_directory, self.pca_save_name))
|
||||||
self.kmeans_fit = joblib.load(os.path.join(self.config['save_directory'], self.kmeans_save_name))
|
self.kmeans_fit = joblib.load(os.path.join(self.save_directory, self.kmeans_save_name))
|
||||||
self.cah_fit = joblib.load(os.path.join(self.config['save_directory'], self.cah_save_name))
|
self.cah_fit = joblib.load(os.path.join(self.save_directory, self.cah_save_name))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -21,12 +21,13 @@ class N2DClustering(AbstractClustering):
|
||||||
self.umap_args = self.config['umap']
|
self.umap_args = self.config['umap']
|
||||||
self.umap_fit = None
|
self.umap_fit = None
|
||||||
self.umap_embedding = None
|
self.umap_embedding = None
|
||||||
|
self.umap_save_name = 'UMAP_model.pkl'
|
||||||
|
|
||||||
self.kmeans_fit = None
|
self.kmeans_fit = None
|
||||||
self.kmeans_args = self.config['kmeans']
|
self.kmeans_args = self.config['kmeans']
|
||||||
self.kmeans_labels = None
|
self.kmeans_labels = None
|
||||||
self.kmeans_centers = []
|
self.kmeans_centers = []
|
||||||
self.kmeans_save_name = "kmeans_model_v%s.pkl" % (self.config['version'])
|
self.kmeans_save_name = "kmeans_model.pkl"
|
||||||
|
|
||||||
|
|
||||||
def compute_umap(self):
|
def compute_umap(self):
|
||||||
|
@ -50,3 +51,12 @@ class N2DClustering(AbstractClustering):
|
||||||
cluster in np.unique(self.final_labels)}
|
cluster in np.unique(self.final_labels)}
|
||||||
return self.silhouette_score_labels
|
return self.silhouette_score_labels
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
Tools.create_dir_if_not_exists(self.save_directory)
|
||||||
|
|
||||||
|
joblib.dump(self.umap_fit, os.path.join(self.save_directory, self.umap_save_name))
|
||||||
|
joblib.dump(self.kmeans_fit, os.path.join(self.save_directory, self.kmeans_save_name))
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
self.umap_fit = joblib.load(os.path.join(self.save_directory, self.pca_save_name))
|
||||||
|
self.kmeans_fit = joblib.load(os.path.join(self.save_directory, self.kmeans_save_name))
|
|
@ -10,177 +10,240 @@ from bokeh.models import HoverTool, ColumnDataSource, CategoricalColorMapper
|
||||||
|
|
||||||
from iss.init_config import CONFIG
|
from iss.init_config import CONFIG
|
||||||
from iss.tools import Tools
|
from iss.tools import Tools
|
||||||
from iss.models import SimpleConvAutoEncoder
|
from iss.models import SimpleConvAutoEncoder, SimpleAutoEncoder
|
||||||
from iss.clustering import ClassicalClustering, AdvancedClustering, N2DClustering
|
from iss.clustering import ClassicalClustering, AdvancedClustering, N2DClustering
|
||||||
|
|
||||||
## variable globales
|
|
||||||
|
|
||||||
_MODEL_TYPE = 'simple_conv'
|
|
||||||
_MODEL_NAME = 'model_colab'
|
|
||||||
_BATCH_SIZE = 496
|
|
||||||
_N_BATCH = 10
|
|
||||||
_DEBUG = True
|
_DEBUG = True
|
||||||
_CLUSTERING_TYPE = 'n2d'
|
|
||||||
_OUTPUT_IMAGE_WIDTH = 96
|
|
||||||
_OUTPUT_IMAGE_HEIGHT = 54
|
|
||||||
_MOSAIC_NROW = 10
|
|
||||||
_MOSAIC_NCOL_MAX = 10
|
|
||||||
|
|
||||||
|
|
||||||
## Charger le modèle
|
def load_model(config, clustering_type):
|
||||||
CONFIG.get('models')[_MODEL_TYPE]['model_name'] = _MODEL_NAME
|
"""
|
||||||
model = SimpleConvAutoEncoder(CONFIG.get('models')[_MODEL_TYPE])
|
Load model according to config
|
||||||
model_config = CONFIG.get('models')[_MODEL_TYPE]
|
"""
|
||||||
|
|
||||||
## Charger les images
|
model_type = config.get('clustering')[clustering_type]['model']['type']
|
||||||
filenames = Tools.list_directory_filenames(os.path.join(CONFIG.get('directory')['autoencoder']['train']))
|
model_name = config.get('clustering')[clustering_type]['model']['name']
|
||||||
generator_imgs = Tools.generator_np_picture_from_filenames(filenames, target_size = (model_config['input_height'], model_config['input_width']), batch = _BATCH_SIZE, nb_batch = _N_BATCH)
|
config.get('models')[model_type]['model_name'] = model_name
|
||||||
|
|
||||||
pictures_id, pictures_preds = Tools.encoded_pictures_from_generator(generator_imgs, model)
|
if model_type == 'simple_conv':
|
||||||
intermediate_output = pictures_preds.reshape((pictures_preds.shape[0], model_config['latent_width']*model_config['latent_height']*model_config['latent_channel']))
|
model = SimpleConvAutoEncoder(config.get('models')[model_type])
|
||||||
|
elif model_type == 'simple':
|
||||||
|
model = SimpleAutoEncoder(config.get('models')[model_type])
|
||||||
|
else:
|
||||||
|
raise Exception
|
||||||
|
|
||||||
|
model_config = config.get('models')[model_type]
|
||||||
|
|
||||||
|
return model, model_config
|
||||||
|
|
||||||
|
|
||||||
if _DEBUG:
|
def load_images(config, clustering_type, model, model_config, batch_size, n_batch):
|
||||||
for i, p_id in enumerate(pictures_id[:2]):
|
"""
|
||||||
print("%s: %s" % (p_id, pictures_preds[i]))
|
load images and predictions
|
||||||
print(len(pictures_id))
|
"""
|
||||||
print(len(intermediate_output))
|
model_type = config.get('clustering')[clustering_type]['model']['type']
|
||||||
|
filenames = Tools.list_directory_filenames(os.path.join(config.get('sampling')['autoencoder']['directory']['train']))
|
||||||
|
generator_imgs = Tools.generator_np_picture_from_filenames(filenames, target_size = (model_config['input_height'], model_config['input_width']), batch = batch_size, nb_batch = n_batch)
|
||||||
|
|
||||||
|
pictures_id, pictures_preds = Tools.encoded_pictures_from_generator(generator_imgs, model)
|
||||||
|
if model_type in ['simple_conv']:
|
||||||
|
intermediate_output = pictures_preds.reshape((pictures_preds.shape[0], model_config['latent_width']*model_config['latent_height']*model_config['latent_channel']))
|
||||||
|
else:
|
||||||
|
intermediate_output = pictures_preds
|
||||||
|
|
||||||
|
return pictures_id, intermediate_output
|
||||||
|
|
||||||
|
|
||||||
## Clustering
|
def run_clustering(config, clustering_type, pictures_id, intermediate_output):
|
||||||
if _CLUSTERING_TYPE == 'classical':
|
"""
|
||||||
if _DEBUG:
|
Apply clustering on images
|
||||||
print("Classical Clustering")
|
"""
|
||||||
clustering = ClassicalClustering(CONFIG.get('clustering')['classical'], pictures_id, intermediate_output)
|
|
||||||
clustering.compute_pca()
|
|
||||||
clustering.compute_kmeans()
|
|
||||||
clustering.compute_kmeans_centers()
|
|
||||||
clustering.compute_cah()
|
|
||||||
clustering.compute_final_labels()
|
|
||||||
clustering.compute_tsne()
|
|
||||||
clustering.compute_colors()
|
|
||||||
elif _CLUSTERING_TYPE == 'advanced':
|
|
||||||
if _DEBUG:
|
|
||||||
print("Advanced Clustering")
|
|
||||||
clustering = AdvancedClustering(CONFIG.get('clustering')['classical'], pictures_id, intermediate_output)
|
|
||||||
elif _CLUSTERING_TYPE == 'n2d':
|
|
||||||
if _DEBUG:
|
|
||||||
print("Not2Deep Clustering")
|
|
||||||
clustering = N2DClustering(CONFIG.get('clustering')['n2d'], pictures_id, intermediate_output)
|
|
||||||
clustering.compute_umap()
|
|
||||||
clustering.compute_kmeans()
|
|
||||||
clustering.compute_final_labels()
|
|
||||||
clustering.compute_colors()
|
|
||||||
|
|
||||||
silhouettes = clustering.compute_silhouette_score()
|
if clustering_type == 'classical':
|
||||||
clustering_res = clustering.get_results()
|
if _DEBUG:
|
||||||
|
print("Classical Clustering")
|
||||||
|
clustering = ClassicalClustering(config.get('clustering')['classical'], pictures_id, intermediate_output)
|
||||||
|
clustering.compute_pca()
|
||||||
|
clustering.compute_kmeans()
|
||||||
|
clustering.compute_kmeans_centers()
|
||||||
|
clustering.compute_cah()
|
||||||
|
clustering.compute_final_labels()
|
||||||
|
clustering.compute_tsne()
|
||||||
|
clustering.compute_colors()
|
||||||
|
elif clustering_type == 'advanced':
|
||||||
|
if _DEBUG:
|
||||||
|
print("Advanced Clustering")
|
||||||
|
clustering = AdvancedClustering(config.get('clustering')['classical'], pictures_id, intermediate_output)
|
||||||
|
elif clustering_type == 'n2d':
|
||||||
|
if _DEBUG:
|
||||||
|
print("Not2Deep Clustering")
|
||||||
|
clustering = N2DClustering(config.get('clustering')['n2d'], pictures_id, intermediate_output)
|
||||||
|
clustering.compute_umap()
|
||||||
|
clustering.compute_kmeans()
|
||||||
|
clustering.compute_final_labels()
|
||||||
|
clustering.compute_colors()
|
||||||
|
|
||||||
if _DEBUG:
|
return clustering
|
||||||
print(clustering_res[:2])
|
|
||||||
print(silhouettes)
|
|
||||||
|
|
||||||
|
|
||||||
if _CLUSTERING_TYPE in ['classical']:
|
def run_plots(config, clustering_type, clustering):
|
||||||
## Graphs of PCA and final clusters
|
"""
|
||||||
fig, ax = plt.subplots(figsize=(24, 14))
|
Plots specifics graphs
|
||||||
scatter = ax.scatter(clustering.pca_reduction[:, 0], clustering.pca_reduction[:, 1], c = clustering.colors)
|
"""
|
||||||
legend1 = ax.legend(*scatter.legend_elements(), loc="lower left", title="Classes")
|
|
||||||
ax.add_artist(legend1)
|
|
||||||
plt.savefig(os.path.join(CONFIG.get('clustering')[_CLUSTERING_TYPE]['save_directory'], 'pca_clusters.png'))
|
|
||||||
|
|
||||||
if _CLUSTERING_TYPE in ['classical']:
|
if clustering_type in ['classical']:
|
||||||
## Graphs of TSNE and final clusters
|
## Graphs of PCA and final clusters
|
||||||
fig, ax = plt.subplots(figsize=(24, 14))
|
fig, ax = plt.subplots(figsize=(24, 14))
|
||||||
classes = clustering.final_labels
|
scatter = ax.scatter(clustering.pca_reduction[:, 0], clustering.pca_reduction[:, 1], c = clustering.colors)
|
||||||
scatter = ax.scatter(clustering.tsne_embedding[:, 0], clustering.tsne_embedding[:, 1], c = clustering.colors)
|
legend1 = ax.legend(*scatter.legend_elements(), loc="lower left", title="Classes")
|
||||||
legend1 = ax.legend(*scatter.legend_elements(), loc="lower left", title="Classes")
|
ax.add_artist(legend1)
|
||||||
ax.add_artist(legend1)
|
plt.savefig(os.path.join(clustering.save_directory, 'pca_clusters.png'))
|
||||||
plt.savefig(os.path.join(CONFIG.get('clustering')[_CLUSTERING_TYPE]['save_directory'], 'tsne_clusters.png'))
|
|
||||||
|
|
||||||
if _CLUSTERING_TYPE in ['n2d']:
|
if clustering_type in ['classical']:
|
||||||
## Graphs of TSNE and final clusters
|
## Graphs of TSNE and final clusters
|
||||||
fig, ax = plt.subplots(figsize=(24, 14))
|
fig, ax = plt.subplots(figsize=(24, 14))
|
||||||
classes = clustering.final_labels
|
classes = clustering.final_labels
|
||||||
scatter = ax.scatter(clustering.umap_embedding[:, 0], clustering.umap_embedding[:, 1], c = clustering.colors)
|
scatter = ax.scatter(clustering.tsne_embedding[:, 0], clustering.tsne_embedding[:, 1], c = clustering.colors)
|
||||||
legend1 = ax.legend(*scatter.legend_elements(), loc="lower left", title="Classes")
|
legend1 = ax.legend(*scatter.legend_elements(), loc="lower left", title="Classes")
|
||||||
ax.add_artist(legend1)
|
ax.add_artist(legend1)
|
||||||
plt.savefig(os.path.join(CONFIG.get('clustering')[_CLUSTERING_TYPE]['save_directory'], 'umap_clusters.png'))
|
plt.savefig(os.path.join(clustering.save_directory, 'tsne_clusters.png'))
|
||||||
|
|
||||||
if _CLUSTERING_TYPE in ['n2d']:
|
if clustering_type in ['n2d']:
|
||||||
filenames = [os.path.join(CONFIG.get('directory')['collections'], "%s.jpg" % one_res[0]) for one_res in clustering_res]
|
## Graphs of TSNE and final clusters
|
||||||
images_array = [Tools.read_np_picture(img_filename, target_size = (54, 96)) for img_filename in filenames]
|
fig, ax = plt.subplots(figsize=(24, 14))
|
||||||
base64_images = [Tools.base64_image(img) for img in images_array]
|
classes = clustering.final_labels
|
||||||
|
scatter = ax.scatter(clustering.umap_embedding[:, 0], clustering.umap_embedding[:, 1], c = clustering.colors)
|
||||||
|
legend1 = ax.legend(*scatter.legend_elements(), loc="lower left", title="Classes")
|
||||||
|
ax.add_artist(legend1)
|
||||||
|
plt.savefig(os.path.join(clustering.save_directory, 'umap_clusters.png'))
|
||||||
|
|
||||||
print(clustering.umap_embedding)
|
if clustering_type in ['n2d', 'classical']:
|
||||||
print(clustering.umap_embedding.shape)
|
filenames = [os.path.join(config.get('directory')['collections'], "%s.jpg" % one_res[0]) for one_res in clustering.get_results()]
|
||||||
|
images_array = [Tools.read_np_picture(img_filename, target_size = (54, 96)) for img_filename in filenames]
|
||||||
|
base64_images = [Tools.base64_image(img) for img in images_array]
|
||||||
|
|
||||||
x = clustering.umap_embedding[:, 0]
|
if clustering_type == 'n2d':
|
||||||
y = clustering.umap_embedding[:, 1]
|
x = clustering.umap_embedding[:, 0]
|
||||||
|
y = clustering.umap_embedding[:, 1]
|
||||||
|
html_file = 'umap_bokeh.html'
|
||||||
|
title = 'UMAP projection of iss clusters'
|
||||||
|
elif clustering_type == 'classical':
|
||||||
|
x = clustering.tsne_embedding[:, 0]
|
||||||
|
y = clustering.tsne_embedding[:, 1]
|
||||||
|
html_file = 'tsne_bokeh.html'
|
||||||
|
title = 't-SNE projection of iss clusters'
|
||||||
|
|
||||||
df = pd.DataFrame({'x': x, 'y': y})
|
df = pd.DataFrame({'x': x, 'y': y})
|
||||||
df['image'] = base64_images
|
df['image'] = base64_images
|
||||||
df['label'] = clustering.final_labels.astype(str)
|
df['label'] = clustering.final_labels.astype(str)
|
||||||
df['color'] = df['label'].apply(Tools.get_color_from_label)
|
df['color'] = df['label'].apply(Tools.get_color_from_label)
|
||||||
|
|
||||||
datasource = ColumnDataSource(df)
|
datasource = ColumnDataSource(df)
|
||||||
|
|
||||||
output_file(os.path.join(CONFIG.get('clustering')[_CLUSTERING_TYPE]['save_directory'], 'umap_bokeh.html'))
|
output_file(os.path.join(clustering.save_directory, html_file))
|
||||||
|
|
||||||
plot_figure = figure(
|
plot_figure = figure(
|
||||||
title='UMAP projection of iss clusters',
|
title=title,
|
||||||
# plot_width=1200,
|
# plot_width=1200,
|
||||||
# plot_height=1200,
|
# plot_height=1200,
|
||||||
tools=('pan, wheel_zoom, reset')
|
tools=('pan, wheel_zoom, reset')
|
||||||
)
|
)
|
||||||
|
|
||||||
plot_figure.add_tools(HoverTool(tooltips="""
|
plot_figure.add_tools(HoverTool(tooltips="""
|
||||||
<div>
|
|
||||||
<div>
|
<div>
|
||||||
<img src='@image' style='float: left; margin: 5px 5px 5px 5px'/>
|
<div>
|
||||||
|
<img src='@image' style='float: left; margin: 5px 5px 5px 5px'/>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<span style='font-size: 16px'>Cluster:</span>
|
||||||
|
<span style='font-size: 18px'>@label</span>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div>
|
"""))
|
||||||
<span style='font-size: 16px'>Cluster:</span>
|
|
||||||
<span style='font-size: 18px'>@label</span>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
"""))
|
|
||||||
|
|
||||||
|
|
||||||
plot_figure.circle(
|
plot_figure.circle(
|
||||||
'x',
|
'x',
|
||||||
'y',
|
'y',
|
||||||
source=datasource,
|
source=datasource,
|
||||||
color=dict(field='color'),
|
color=dict(field='color'),
|
||||||
line_alpha=0.6,
|
line_alpha=0.6,
|
||||||
fill_alpha=0.6,
|
fill_alpha=0.6,
|
||||||
size=4
|
size=4
|
||||||
)
|
)
|
||||||
|
|
||||||
show(plot_figure)
|
show(plot_figure)
|
||||||
|
|
||||||
|
|
||||||
if _CLUSTERING_TYPE in ['classical']:
|
if clustering_type in ['classical']:
|
||||||
## Dendogram
|
## Dendogram
|
||||||
fig, ax = plt.subplots(figsize=(24, 14))
|
fig, ax = plt.subplots(figsize=(24, 14))
|
||||||
plt.title('Hierarchical Clustering Dendrogram')
|
plt.title('Hierarchical Clustering Dendrogram')
|
||||||
Tools.plot_dendrogram(clustering.cah_fit, labels=clustering.cah_labels)
|
Tools.plot_dendrogram(clustering.cah_fit, labels=clustering.cah_labels)
|
||||||
plt.savefig(os.path.join(CONFIG.get('clustering')[_CLUSTERING_TYPE]['save_directory'], 'dendograms.png'))
|
plt.savefig(os.path.join(clustering.save_directory, 'dendograms.png'))
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def plot_silhouette(config, clustering_type, clustering):
|
||||||
|
|
||||||
|
silhouettes = clustering.compute_silhouette_score()
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(12, 7))
|
||||||
|
ax.bar(silhouettes.keys(), silhouettes.values(), align='center')
|
||||||
|
ax.set_xticks(list(silhouettes.keys()))
|
||||||
|
ax.set_xticklabels(list(silhouettes.keys()))
|
||||||
|
plt.savefig(os.path.join(clustering.save_directory, 'silhouettes_score.png'))
|
||||||
|
|
||||||
|
return silhouettes
|
||||||
|
|
||||||
|
|
||||||
## Silhouette
|
def plot_mosaics(config, clustering_type, clustering, output_image_width, output_image_height, mosaic_nrow, mosaic_ncol_max):
|
||||||
fig, ax = plt.subplots(figsize=(12, 7))
|
"""
|
||||||
ax.bar(silhouettes.keys(), silhouettes.values(), align='center')
|
Mosaic of each cluster
|
||||||
ax.set_xticks(list(silhouettes.keys()))
|
"""
|
||||||
ax.set_xticklabels(list(silhouettes.keys()))
|
clusters_id = np.unique(clustering.final_labels)
|
||||||
plt.savefig(os.path.join(CONFIG.get('clustering')[_CLUSTERING_TYPE]['save_directory'], 'silhouettes_score.png'))
|
clustering_res = clustering.get_results()
|
||||||
|
|
||||||
|
for cluster_id in clusters_id:
|
||||||
|
cluster_image_filenames = [os.path.join(config.get('directory')['collections'], "%s.jpg" % one_res[0]) for one_res in clustering_res if one_res[1] == cluster_id]
|
||||||
|
|
||||||
|
images_array = [Tools.read_np_picture(img_filename, target_size = (output_image_height, output_image_width)) for img_filename in cluster_image_filenames]
|
||||||
|
|
||||||
|
img = Tools.display_mosaic(images_array, nrow = mosaic_nrow, ncol_max = mosaic_ncol_max)
|
||||||
|
img.save(os.path.join(clustering.save_directory, "cluster_%s.png" % str(cluster_id).zfill(2)), "PNG")
|
||||||
|
|
||||||
|
return clusters_id
|
||||||
|
|
||||||
|
|
||||||
## Mosaic of each cluster
|
def main():
|
||||||
clusters_id = np.unique(clustering.final_labels)
|
_CLUSTERING_TYPE = 'classical'
|
||||||
for cluster_id in clusters_id:
|
_BATCH_SIZE = 496
|
||||||
cluster_image_filenames = [os.path.join(CONFIG.get('directory')['collections'], "%s.jpg" % one_res[0]) for one_res in clustering_res if one_res[1] == cluster_id]
|
_N_BATCH = 1
|
||||||
|
_PLOTS = True
|
||||||
|
_MOSAICS = True
|
||||||
|
_SILHOUETTE = True
|
||||||
|
_OUTPUT_IMAGE_WIDTH = 96
|
||||||
|
_OUTPUT_IMAGE_HEIGHT = 54
|
||||||
|
_MOSAIC_NROW = 10
|
||||||
|
_MOSAIC_NCOL_MAX = 10
|
||||||
|
|
||||||
images_array = [Tools.read_np_picture(img_filename, target_size = (_OUTPUT_IMAGE_HEIGHT, _OUTPUT_IMAGE_WIDTH)) for img_filename in cluster_image_filenames]
|
model, model_config = load_model(CONFIG, _CLUSTERING_TYPE)
|
||||||
|
pictures_id, intermediate_output = load_images(CONFIG, _CLUSTERING_TYPE, model, model_config, _BATCH_SIZE, _N_BATCH)
|
||||||
|
|
||||||
|
clustering = run_clustering(CONFIG, _CLUSTERING_TYPE, pictures_id, intermediate_output)
|
||||||
|
|
||||||
|
clustering.save()
|
||||||
|
|
||||||
img = Tools.display_mosaic(images_array, nrow = _MOSAIC_NROW, ncol_max = _MOSAIC_NCOL_MAX)
|
if _PLOTS:
|
||||||
img.save(os.path.join(CONFIG.get('clustering')[_CLUSTERING_TYPE]['save_directory'], "cluster_%s.png" % str(cluster_id).zfill(2)), "PNG")
|
run_plots(CONFIG, _CLUSTERING_TYPE, clustering)
|
||||||
|
|
||||||
|
if _SILHOUETTE:
|
||||||
|
plot_silhouette(CONFIG, _CLUSTERING_TYPE, clustering)
|
||||||
|
|
||||||
|
if _MOSAICS:
|
||||||
|
plot_mosaics(CONFIG, _CLUSTERING_TYPE, clustering, _OUTPUT_IMAGE_WIDTH, _OUTPUT_IMAGE_HEIGHT, _MOSAIC_NROW, _MOSAIC_NCOL_MAX)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|
Loading…
Reference in a new issue