Source code for immuneML.workflows.instructions.clustering.ClusteringInstruction

import logging
import os
import random
from collections import defaultdict
from pathlib import Path
from typing import List, Dict

import numpy as np
import pandas as pd
import plotly.express as px

from immuneML.IO.ml_method.ClusteringExporter import ClusteringExporter
from immuneML.data_model.SequenceParams import RegionType
from immuneML.data_model.datasets.Dataset import Dataset
from immuneML.environment.LabelConfiguration import LabelConfiguration
from immuneML.environment.SequenceType import SequenceType
from immuneML.hyperparameter_optimization.clustering.StabilityLange import StabilityLange
from immuneML.hyperparameter_optimization.config.SampleConfig import SampleConfig
from immuneML.hyperparameter_optimization.config.SplitConfig import SplitConfig
from immuneML.hyperparameter_optimization.config.SplitType import SplitType
from immuneML.hyperparameter_optimization.core.HPUtil import HPUtil
from immuneML.ml_metrics.ClusteringMetric import is_internal, is_external, get_search_criterion
from immuneML.reports.PlotlyUtil import PlotlyUtil
from immuneML.reports.Report import Report
from immuneML.reports.ReportOutput import ReportOutput
from immuneML.reports.ReportResult import ReportResult
from immuneML.util.Logger import print_log, log_memory_usage
from immuneML.util.PathBuilder import PathBuilder
from immuneML.workflows.instructions.Instruction import Instruction
from immuneML.workflows.instructions.clustering import clustering_runner
from immuneML.workflows.instructions.clustering.ClusteringReportHandler import ClusteringReportHandler
from immuneML.workflows.instructions.clustering.ClusteringState import (
    ClusteringConfig, ClusteringState, ClusteringResultPerRun, StabilityConfig
)
from immuneML.workflows.instructions.clustering.clustering_run_model import ClusteringSetting
from immuneML.workflows.steps.DataSampler import DataSamplerParams, DataSampler


[docs] class ClusteringInstruction(Instruction): """ Clustering instruction fits clustering methods to the provided encoded dataset and compares the combinations of clustering method with its hyperparameters, and encodings across a pre-defined set of metrics. It provides results either for the full discovery dataset or for multiple subsets of discovery data as way to assess the stability of different metrics (Liu et al., 2022; Dangl and Leisch, 2020; Lange et al. 2004). Finally, it provides options to include a set of reports to visualize the results. See also: :ref:`How to perform clustering analysis` for more details on the clustering procedure. References: Lange, T., Roth, V., Braun, M. L., & Buhmann, J. M. (2004). Stability-Based Validation of Clustering Solutions. Neural Computation, 16(6), 1299–1323. https://doi.org/10.1162/089976604773717621 Dangl, R., & Leisch, F. (2020). Effects of Resampling in Determining the Number of Clusters in a Data Set. Journal of Classification, 37(3), 558–583. https://doi.org/10.1007/s00357-019-09328-2 Liu, T., Yu, H., & Blair, R. H. (2022). Stability estimation for unsupervised clustering: A review. WIREs Computational Statistics, 14(6), e1575. https://doi.org/10.1002/wics.1575 **Specification arguments:** - dataset (str): name of the dataset to be clustered - metrics (list): a list of metrics to use for comparison of clustering algorithms and encodings (it can include metrics for either internal evaluation if no labels are provided or metrics for external evaluation so that the clusters can be compared against a list of predefined labels); some of the supported metrics include adjusted_rand_score, completeness_score, homogeneity_score, silhouette_score; for the full list, see scikit-learn's documentation of clustering metrics at https://scikit-learn.org/stable/api/sklearn.metrics.html#module-sklearn.metrics.cluster. - labels (list): an optional list of labels to use for external evaluation of clustering - sample_config (SampleConfig): configuration describing how to construct the data subsets to estimate different clustering settings' performance with different internal and external validation indices; with parameters `percentage`, `split_count`, `random_seed`: .. indent with spaces .. code-block:: yaml sample_config: # make 5 subsets with 80% of the data each split_count: 5 percentage: 0.8 random_seed: 42 - stability_config (StabilityConfig): configuration describing how to compute clustering stability; currently, clustering stability is computed following approach by Lange et al. (2004) and only takes the number of repetitions as a parameter. Other strategies to compute clustering stability will be added in the future. .. indent with spaces .. code-block:: yaml stability_config: split_count: 5 # number of times to repeat clustering for stability estimation random_seed: 12 - clustering_settings (list): a list where each element represents a :py:obj:`~immuneML.workflows.clustering.clustering_run_model.ClusteringSetting`; a combinations of encoding, optional dimensionality reduction algorithm, and the clustering algorithm that will be evaluated - reports (list): a list of reports to be run on the clustering results or the encoded data - number_of_processes (int): how many processes to use for parallelization - sequence_type (str): whether to do analysis on the amino_acid or nucleotide level; this value is used only if nothing is specified on the encoder level - region_type (str): which part of the receptor sequence to analyze (e.g., IMGT_CDR3); this value is used only if nothing is specified on the encoder level **YAML specification:** .. indent with spaces .. code-block:: yaml instructions: my_clustering_instruction: type: Clustering dataset: d1 metrics: [adjusted_rand_score, adjusted_mutual_info_score] labels: [epitope, v_call] sequence_type: amino_acid region_type: imgt_cdr3 sample_config: split_count: 5 percentage: 0.8 random_seed: 42 stability_config: split_count: 5 random_seed: 12 clustering_settings: - encoding: e1 dim_reduction: pca method: k_means1 - encoding: e2 method: dbscan reports: [rep1, rep2] """ def __init__(self, dataset: Dataset, metrics: List[str], clustering_settings: List[ClusteringSetting], name: str, label_config: LabelConfiguration = None, reports: List[Report] = None, number_of_processes: int = None, sample_config: SampleConfig = None, stability_config: StabilityConfig = None, sequence_type: SequenceType = None, region_type: RegionType = None): config = ClusteringConfig(name=name, dataset=dataset, metrics=metrics, clustering_settings=clustering_settings, label_config=label_config, sample_config=sample_config, sequence_type=sequence_type, region_type=region_type, stability_config=stability_config) self.number_of_processes = number_of_processes self.state = ClusteringState(config=config, name=name) self.report_handler = ClusteringReportHandler(reports)
[docs] def run(self, result_path: Path): """Main entry point: computes validation indices and estimates stability.""" self.state.result_path = PathBuilder.build(result_path / self.state.config.name) self._fix_max_processes() # Step 1: Compute validation indices self._compute_validation_indices() # Step 2: Estimate clustering stability self._compute_stability() # Step 3: Refit the best settings on full dataset self._refit_best_settings_on_full_dataset() return self.state
def _compute_validation_indices(self): """Compute internal and external validation indices across all datasets and settings.""" self._setup_paths() print_log(f"{self.__class__.__name__} ({self.state.name}): computing validation indices.") # 1. Construct datasets (subsampling) datasets = self._construct_datasets() # 2. Run clustering settings on each dataset and collect predictions all_results = self._run_clustering_on_all_datasets(datasets) # 3. Aggregate and export index results self._aggregate_internal_indices(all_results) self._aggregate_external_indices(all_results) # 4. Run any additional clustering reports self.report_handler.run_clustering_reports(self.state) print_log(f"{self.__class__.__name__} ({self.state.name}): computed validation indices.") def _refit_best_settings_on_full_dataset(self): print_log(f"{self.__class__.__name__} ({self.state.name}): refitting best settings.") path = PathBuilder.build(self.state.result_path / "refitted_best_settings") best_settings = defaultdict(list) for metric_name, metric_path in self.state.metrics_performance_paths.items(): best_setting = pd.read_csv(metric_path).drop(columns=['split_id']).mean(axis=0) if metric_name == 'stability_lange' or get_search_criterion(metric_name.split("__")[0]) == max: best_setting_key = best_setting.idxmax() else: best_setting_key = best_setting.idxmin() best_settings[best_setting_key].append(metric_path.stem) print_log(f"Best setting for metric {metric_path.stem} is {best_setting_key}.") predictions_df = self._init_predictions_df(self.state.config.dataset) for best_setting_key, per_metrics in best_settings.items(): cl_setting = self.state.config.get_cl_setting_by_key(best_setting_key) cl_item_res, predictions_df = clustering_runner.run_setting(dataset=self.state.config.dataset, cl_setting=cl_setting, path=path, predictions_df=predictions_df, metrics=[], label_config=self.state.config.label_config, number_of_processes=self.number_of_processes, sequence_type=self.state.config.sequence_type, evaluate=False, region_type=self.state.config.region_type, state=self.state) cl_item_res.item.classifier = clustering_runner.train_cluster_classifier(cl_item_res.item) self.state.optimal_settings_on_discovery[best_setting_key] = cl_item_res # Export the best setting as a zip file setting_path = path / best_setting_key zip_path = ClusteringExporter.export_zip(cl_item_res.item, setting_path, best_setting_key) self.state.best_settings_zip_paths[best_setting_key] = {'path': zip_path, 'metrics': per_metrics} logging.info(f"ClusteringInstruction: exported best setting {best_setting_key} to: {zip_path}") predictions_df.to_csv(path / "best_settings_predictions_full_dataset.csv", index=False) self.state.final_predictions_path = path / "best_settings_predictions_full_dataset.csv" print_log(f"{self.__class__.__name__} ({self.state.name}): refitted best settings.") def _construct_datasets(self) -> List[Dataset]: """Construct subsampled datasets for validation.""" paths = [PathBuilder.build(f"{self.state.result_path}/split_{run_id + 1}/") for run_id in range(self.state.config.sample_config.split_count)] self.state.subsampled_datasets = DataSampler.run( DataSamplerParams(self.state.config.dataset, self.state.config.sample_config, paths) ) return self.state.subsampled_datasets def _run_clustering_on_all_datasets(self, datasets: List[Dataset]) -> List[Dict]: """Run every clustering setting on each dataset and collect results.""" all_results = [] for run_id, dataset in enumerate(datasets): print_log(f"Running clustering for split {run_id + 1}.") path = self.state.result_path / f"validation_indices/split_{run_id + 1}" predictions_df = self._init_predictions_df(dataset) clustering_items, predictions_df = clustering_runner.run_all_settings( dataset=dataset, path=path, run_id=run_id, predictions_df=predictions_df, state=self.state, report_handler=self.report_handler, number_of_processes=self.number_of_processes, sequence_type=self.state.config.sequence_type, region_type=self.state.config.region_type, clustering_settings=self.state.config.clustering_settings, metrics=self.state.config.metrics, label_config=self.state.config.label_config ) predictions_df.to_csv(self.state.predictions_paths[run_id], index=False) cl_result = ClusteringResultPerRun(run_id, clustering_items) self.state.add_cl_result_per_run(run_id, cl_result) all_results.append({ 'run_id': run_id, 'clustering_items': clustering_items, 'predictions_df': predictions_df }) log_memory_usage(f"discovery in split {run_id + 1}", f"Clustering instruction {self.state.name}") return all_results def _aggregate_internal_indices(self, all_results: List[Dict]): """ Aggregate internal indices across all datasets and clustering settings. Produces one CSV per internal index, then generates boxplots. Creates ReportResult objects and adds them to state. """ internal_metrics = [m for m in self.state.config.metrics if is_internal(m)] if not internal_metrics: return indices_path = PathBuilder.build(self.state.result_path / 'validation_indices' / 'internal') figures, tables = [], [] for metric in internal_metrics: metric_data = self._collect_internal_metric_data(all_results, metric) csv_path = indices_path / f'{metric}.csv' metric_data.to_csv(csv_path, index=False) self.state.metrics_performance_paths[metric] = csv_path tables.append(ReportOutput(path=csv_path, name=f'{metric} values per split')) figures.append(self._create_internal_index_boxplot(metric_data, metric, indices_path)) # Create ReportResult and add to state report_result = ReportResult( name=f'Internal Validation Indices', info=f'Internal validation indices ({", ".join(internal_metrics).replace(" ", "")}) computed ' f'across all clustering settings and data splits describe the quality of the clustering with respect to' f' different criteria relying solely on the data used for clustering.', output_figures=figures, output_tables=tables ) self.state.clustering_report_results.append(report_result) def _collect_internal_metric_data(self, all_results: List[Dict], metric: str) -> pd.DataFrame: """Collect internal metric values for all datasets and clustering settings.""" rows = [] for result in all_results: row = {'split_id': result['run_id'] + 1} for setting_key, cl_item_result in result['clustering_items'].items(): internal_perf = cl_item_result.item.internal_performance if internal_perf is not None: perf_df = internal_perf.get_df() row[setting_key] = perf_df.loc[0, metric] if metric in perf_df.columns else None else: row[setting_key] = None rows.append(row) return pd.DataFrame(rows) def _create_internal_index_boxplot(self, metric_data: pd.DataFrame, metric: str, output_path: Path) -> ReportOutput: """Create boxplot for internal index across clustering settings.""" melted = metric_data.melt(id_vars=['split_id'], var_name='clustering_setting', value_name=metric) fig = px.box(melted, x='clustering_setting', y=metric, labels={'clustering_setting': 'clustering setting', metric: metric}, color='clustering_setting', points='all', color_discrete_sequence=px.colors.qualitative.Vivid) fig.update_layout(template='plotly_white', showlegend=False) plot_path = PlotlyUtil.write_image_to_file(fig, output_path / f'{metric}_boxplot.html', melted.shape[0]) return ReportOutput(path=plot_path, name=f'{metric} across clustering settings') def _aggregate_external_indices(self, all_results: List[Dict]): """ Aggregate external indices across all datasets, labels, and clustering settings. Produces one CSV per (label, external_index) combination, then generates boxplots. Creates ReportResult objects and adds them to state. """ external_metrics = [m for m in self.state.config.metrics if is_external(m)] labels = self.state.config.label_config.get_labels_by_name() if self.state.config.label_config else [] if not external_metrics or not labels: return indices_path = PathBuilder.build(self.state.result_path / 'validation_indices/external') figures, tables = [], [] for label in labels: for metric in external_metrics: metric_data = self._collect_external_metric_data(all_results, label, metric) csv_path = indices_path / f'{metric}__{label}.csv' metric_data.to_csv(csv_path, index=False) self.state.metrics_performance_paths[f"{metric}__{label}"] = csv_path tables.append(ReportOutput(path=csv_path, name=f'{metric} values for label {label}')) figures.append(self._create_external_index_boxplot(metric_data, label, metric, indices_path)) report_result = ReportResult( name=f'External Validation Indices', info=f'External validation indices ({", ".join(external_metrics)}) are computed ' f'with respect to labels {", ".join(labels)} across all clustering ' f'settings and data splits measuring the agreement between the clustering results and ' f'the provided labels.', output_figures=figures, output_tables=tables ) self.state.clustering_report_results.append(report_result) def _collect_external_metric_data(self, all_results: List[Dict], label: str, metric: str) -> pd.DataFrame: """Collect external metric values for a specific label across all datasets and settings.""" rows = [] for result in all_results: row = {'split_id': result['run_id'] + 1} for setting_key, cl_item_result in result['clustering_items'].items(): external_perf = cl_item_result.item.external_performance if external_perf is not None: perf_df = external_perf.get_df() # External performance is stored with metrics as rows, labels as columns metric_row = perf_df[perf_df['metric'] == metric] if not metric_row.empty and label in metric_row.columns: row[setting_key] = metric_row[label].values[0] else: row[setting_key] = None else: row[setting_key] = None rows.append(row) return pd.DataFrame(rows) def _create_external_index_boxplot(self, metric_data: pd.DataFrame, label: str, metric: str, output_path: Path) -> ReportOutput: """Create boxplot for external index per label, grouped by clustering setting.""" melted = metric_data.melt(id_vars=['split_id'], var_name='clustering_setting', value_name=metric) fig = px.box(melted, x='clustering_setting', y=metric, labels={'clustering_setting': 'clustering setting', metric: metric}, color='clustering_setting', points='all', color_discrete_sequence=px.colors.qualitative.Vivid) fig.update_layout(template='plotly_white', showlegend=False) plot_path = PlotlyUtil.write_image_to_file(fig, output_path / f'{label}_{metric}_boxplot.html', melted.shape[0]) return ReportOutput(path=plot_path, name=f'{metric} for label "{label}" across clustering settings') def _compute_stability(self): """Compute clustering stability using the Lange et al. approach.""" print_log(f"{self.__class__.__name__} ({self.state.name}): computing stability.") random.seed(self.state.config.stability_config.random_seed) np.random.seed(self.state.config.stability_config.random_seed) discovery_datasets, tuning_datasets = HPUtil.split_data( self.state.config.dataset, SplitConfig(SplitType.RANDOM, self.state.config.stability_config.split_count, 0.5), PathBuilder.build(self.state.result_path / 'stability'), LabelConfiguration() ) stab_lange = StabilityLange( discovery_datasets, tuning_datasets, self.state.config.clustering_settings, self.state.result_path / 'stability', self.number_of_processes, self.state.config.sequence_type, self.state.config.region_type ) report_result, stability_path = stab_lange.run() self.state.clustering_report_results.append(report_result) self.state.metrics_performance_paths['stability_lange'] = stability_path print_log(f"{self.__class__.__name__} ({self.state.name}): computed stability.") def _setup_paths(self): """Initialize result paths.""" self.state.predictions_paths = [ PathBuilder.build(self.state.result_path / f"validation_indices/split_{run_id + 1}") / 'predictions.csv' for run_id in range(self.state.config.sample_config.split_count) ] def _init_predictions_df(self, dataset: Dataset) -> pd.DataFrame: """Initialize predictions DataFrame with labels and example IDs.""" if len(self.state.config.label_config.get_labels_by_name()) > 0: predictions_df = dataset.get_metadata(self.state.config.label_config.get_labels_by_name(), return_df=True) else: predictions_df = pd.DataFrame(index=range(dataset.get_example_count())) predictions_df['example_id'] = dataset.get_example_ids() return predictions_df def _fix_max_processes(self): """Configure thread limits for parallelization.""" if self.number_of_processes: try: import torch torch.set_num_threads(self.number_of_processes) except ImportError: pass os.environ["OMP_NUM_THREADS"] = str(self.number_of_processes) os.environ["OPENBLAS_NUM_THREADS"] = str(self.number_of_processes) os.environ["MKL_NUM_THREADS"] = str(self.number_of_processes)