mirror of
https://github.com/prise6/smart-iss-posts
synced 2024-05-05 23:23:11 +02:00
66 lines
2.3 KiB
Python
66 lines
2.3 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
import os
|
|
import numpy as np
|
|
import umap
|
|
import hdbscan
|
|
from iss.tools import Tools
|
|
from iss.clustering import AbstractClustering
|
|
from sklearn.cluster import KMeans
|
|
from sklearn.metrics import silhouette_samples
|
|
from sklearn.externals import joblib
|
|
|
|
class DBScanClustering(AbstractClustering):
|
|
"""
|
|
Cf: https://umap-learn.readthedocs.io/en/latest/clustering.html
|
|
"""
|
|
|
|
def __init__(self, config, pictures_id = None, pictures_np = None):
|
|
|
|
super().__init__(config, pictures_id, pictures_np)
|
|
|
|
self.umap_args = self.config['umap']
|
|
self.umap_fit = None
|
|
self.umap_embedding = None
|
|
self.umap_save_name = 'UMAP_model.pkl'
|
|
|
|
self.dbscan_fit = None
|
|
self.dbscan_args = self.config['dbscan']
|
|
self.dbscan_labels = None
|
|
self.dbscan_centers = []
|
|
self.dbscan_save_name = "dbscan_model.pkl"
|
|
|
|
|
|
def compute_umap(self):
|
|
self.umap_fit = umap.UMAP(**self.umap_args)
|
|
self.umap_embedding = self.umap_fit.fit_transform(self.pictures_np)
|
|
return self
|
|
|
|
def compute_dbscan(self):
|
|
self.dbscan_fit = hdbscan.HDBSCAN(**self.dbscan_args)
|
|
self.dbscan_fit.fit(self.umap_embedding)
|
|
self.dbscan_labels = self.dbscan_fit.labels_
|
|
return self
|
|
|
|
def compute_final_labels(self):
|
|
self.final_labels = self.dbscan_labels
|
|
return self
|
|
|
|
def compute_silhouette_score(self):
|
|
self.silhouette_score = silhouette_samples(self.pictures_np, self.final_labels)
|
|
self.silhouette_score_labels = {cluster: np.mean(self.silhouette_score[self.final_labels == cluster]) for
|
|
cluster in np.unique(self.final_labels)}
|
|
return self.silhouette_score_labels
|
|
|
|
def predict_embedding(self, pictures_np):
|
|
return self.umap_fit.transform(pictures_np)
|
|
|
|
def save(self):
|
|
Tools.create_dir_if_not_exists(self.save_directory)
|
|
|
|
joblib.dump(self.umap_fit, os.path.join(self.save_directory, self.umap_save_name))
|
|
joblib.dump(self.dbscan_fit, os.path.join(self.save_directory, self.dbscan_save_name))
|
|
|
|
def load(self):
|
|
self.umap_fit = joblib.load(os.path.join(self.save_directory, self.umap_save_name))
|
|
self.dbscan_fit = joblib.load(os.path.join(self.save_directory, self.dbscan_save_name)) |