2019-03-31 18:05:09 +02:00
|
|
|
# -*- coding: utf-8 -*-
|
2019-12-08 02:24:20 +01:00
|
|
|
import os
|
2019-11-16 18:30:08 +01:00
|
|
|
from iss.tools import Tools
|
2019-03-31 18:05:09 +02:00
|
|
|
|
|
|
|
class AbstractClustering:
|
|
|
|
|
2019-11-16 18:30:08 +01:00
|
|
|
def __init__(self, config, pictures_id, pictures_np):
|
2019-03-31 18:05:09 +02:00
|
|
|
|
2019-11-16 18:30:08 +01:00
|
|
|
self.config = config
|
2019-12-08 02:24:20 +01:00
|
|
|
self.save_directory = os.path.join(self.config['save_directory'], '%s_%s_%s' % (self.config['model']['type'], self.config['model']['name'], self.config['version']))
|
2019-11-16 18:30:08 +01:00
|
|
|
self.pictures_id = pictures_id
|
|
|
|
self.pictures_np = pictures_np
|
|
|
|
self.final_labels = None
|
|
|
|
self.colors = None
|
2019-03-31 18:05:09 +02:00
|
|
|
|
2019-12-08 02:24:20 +01:00
|
|
|
Tools.create_dir_if_not_exists(self.save_directory)
|
2019-03-31 18:05:09 +02:00
|
|
|
|
2019-11-16 18:30:08 +01:00
|
|
|
def compute_final_labels(self):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def get_results(self):
|
|
|
|
return list(zip(self.pictures_id, self.final_labels, self.pictures_np))
|
|
|
|
|
|
|
|
def compute_silhouette_score(self):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def compute_colors(self):
|
|
|
|
n_classes = len(list(set(self.final_labels)))
|
|
|
|
self.colors = [Tools.get_color_from_label(label, n_classes) for label in self.final_labels]
|
|
|
|
return self
|
|
|
|
|
2019-12-11 03:04:58 +01:00
|
|
|
def predict_embedding(self):
|
|
|
|
raise NotImplementedError
|
2019-12-12 01:08:04 +01:00
|
|
|
|
|
|
|
def predict_label(self):
|
|
|
|
raise NotImplementedError
|
2019-12-11 03:04:58 +01:00
|
|
|
|
2019-11-16 18:30:08 +01:00
|
|
|
def save(self):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def load(self):
|
2019-12-11 03:04:58 +01:00
|
|
|
raise NotImplementedError
|