10 changed files with 172 additions and 35 deletions
@ -0,0 +1,66 @@
|
||||
# -*- coding: utf-8 -*- |
||||
|
||||
import os |
||||
import numpy as np |
||||
import umap |
||||
import hdbscan |
||||
from iss.tools import Tools |
||||
from iss.clustering import AbstractClustering |
||||
from sklearn.cluster import KMeans |
||||
from sklearn.metrics import silhouette_samples |
||||
from sklearn.externals import joblib |
||||
|
||||
class DBScanClustering(AbstractClustering): |
||||
""" |
||||
Cf: https://umap-learn.readthedocs.io/en/latest/clustering.html |
||||
""" |
||||
|
||||
def __init__(self, config, pictures_id = None, pictures_np = None): |
||||
|
||||
super().__init__(config, pictures_id, pictures_np) |
||||
|
||||
self.umap_args = self.config['umap'] |
||||
self.umap_fit = None |
||||
self.umap_embedding = None |
||||
self.umap_save_name = 'UMAP_model.pkl' |
||||
|
||||
self.dbscan_fit = None |
||||
self.dbscan_args = self.config['dbscan'] |
||||
self.dbscan_labels = None |
||||
self.dbscan_centers = [] |
||||
self.dbscan_save_name = "dbscan_model.pkl" |
||||
|
||||
|
||||
def compute_umap(self): |
||||
self.umap_fit = umap.UMAP(**self.umap_args) |
||||
self.umap_embedding = self.umap_fit.fit_transform(self.pictures_np) |
||||
return self |
||||
|
||||
def compute_dbscan(self): |
||||
self.dbscan_fit = hdbscan.HDBSCAN(**self.dbscan_args) |
||||
self.dbscan_fit.fit(self.umap_embedding) |
||||
self.dbscan_labels = self.dbscan_fit.labels_ |
||||
return self |
||||
|
||||
def compute_final_labels(self): |
||||
self.final_labels = self.dbscan_labels |
||||
return self |
||||
|
||||
def compute_silhouette_score(self): |
||||
self.silhouette_score = silhouette_samples(self.pictures_np, self.final_labels) |
||||
self.silhouette_score_labels = {cluster: np.mean(self.silhouette_score[self.final_labels == cluster]) for |
||||
cluster in np.unique(self.final_labels)} |
||||
return self.silhouette_score_labels |
||||
|
||||
def predict_embedding(self, pictures_np): |
||||
return self.umap_fit.transform(pictures_np) |
||||
|
||||
def save(self): |
||||
Tools.create_dir_if_not_exists(self.save_directory) |
||||
|
||||
joblib.dump(self.umap_fit, os.path.join(self.save_directory, self.umap_save_name)) |
||||
joblib.dump(self.dbscan_fit, os.path.join(self.save_directory, self.dbscan_save_name)) |
||||
|
||||
def load(self): |
||||
self.umap_fit = joblib.load(os.path.join(self.save_directory, self.umap_save_name)) |
||||
self.dbscan_fit = joblib.load(os.path.join(self.save_directory, self.dbscan_save_name)) |
@ -1,4 +1,5 @@
|
||||
from .AbstractClustering import AbstractClustering |
||||
from .ClassicalClustering import ClassicalClustering |
||||
from .AdvancedClustering import AdvancedClustering |
||||
from .N2DClustering import N2DClustering |
||||
from .N2DClustering import N2DClustering |
||||
from .DBScanClustering import DBScanClustering |
Loading…
Reference in new issue