Source code for immuneML.reports.clustering_reports.ExternalLabelMetricHeatmap

from pathlib import Path
from typing import List

import pandas as pd
import plotly.graph_objects as go

from immuneML.ml_metrics import ClusteringMetric
from immuneML.reports.PlotlyUtil import PlotlyUtil
from immuneML.reports.ReportOutput import ReportOutput
from immuneML.reports.ReportResult import ReportResult
from immuneML.reports.clustering_reports.ClusteringReport import ClusteringReport
from immuneML.util.ParameterValidator import ParameterValidator
from immuneML.util.PathBuilder import PathBuilder
from immuneML.workflows.instructions.clustering.ClusteringState import ClusteringState, ClusteringResultPerRun


[docs] class ExternalLabelMetricHeatmap(ClusteringReport): """ This report creates heatmaps comparing clustering methods against external labels for each metric. For each external label and metric combination, it creates: 1. A table showing the metric values for each combination of clustering method and external label 2. A heatmap visualization of these values The external labels and metrics are automatically determined from the clustering instruction specification. **YAML specification:** .. indent with spaces .. code-block:: yaml reports: my_external_label_metric_heatmap: ExternalLabelMetricHeatmap """
[docs] @classmethod def build_object(cls, **kwargs): ParameterValidator.assert_keys(list(kwargs.keys()), ['name'], ExternalLabelMetricHeatmap.__name__, ExternalLabelMetricHeatmap.__name__) return ExternalLabelMetricHeatmap(**kwargs)
def __init__(self, name: str = None, state: ClusteringState = None, result_path: Path = None, number_of_processes: int = 1): super().__init__(name, result_path, number_of_processes, state) self.desc = "External Label - Clustering Heatmap" def _generate(self) -> ReportResult: self.result_path = PathBuilder.build(self.result_path / self.name) report_outputs = [] external_labels = self.state.config.label_config.get_labels_by_name() if not external_labels or len(external_labels) == 0: return ReportResult( name=f"{self.desc} ({self.name})", info="No external labels were found in the clustering state's label configuration." ) metrics = [metric for metric in self.state.config.metrics if ClusteringMetric.is_external(metric)] for metric in metrics: # For each split in the clustering results for split_idx, clustering_results in enumerate(self.state.clustering_items): # Process discovery results if clustering_results.discovery: report_outputs.extend(self._process_analysis_results( clustering_results.discovery, f"discovery_split_{split_idx + 1}", metric )) # Process method-based validation results if available if clustering_results.method_based_validation: report_outputs.extend(self._process_analysis_results( clustering_results.method_based_validation, f"method_based_validation_split_{split_idx + 1}", metric )) # Process result-based validation results if available if clustering_results.result_based_validation: report_outputs.extend(self._process_analysis_results( clustering_results.result_based_validation, f"result_based_validation_split_{split_idx + 1}", metric )) if not report_outputs: return ReportResult( name=f"{self.desc} ({self.name})", info="No results were generated. This could be because no metrics were computed." ) return ReportResult( name=f"{self.desc} ({self.name})", info="Heatmaps of metric values for clustering methods versus external labels", output_tables=[output for output in report_outputs if output.path.suffix == '.csv'], output_figures=[output for output in report_outputs if output.path.suffix in ['.html', '.png']], ) def _process_analysis_results(self, analysis_results: ClusteringResultPerRun, analysis_name: str, metric: ClusteringMetric) -> List[ReportOutput]: outputs = [] external_labels = self.state.config.label_config.get_labels_by_name() # Create a DataFrame with metric values df = pd.DataFrame(0, index=list(analysis_results.items.keys()), columns=external_labels) for setting_key, item_result in analysis_results.items.items(): performance_df = item_result.item.external_performance.get_df() for label in external_labels: df.loc[setting_key, label] = performance_df[performance_df['metric'] == metric][label].values[0] # Create heatmap fig = go.Figure(data=go.Heatmap( z=df.values, x=df.columns, # label values y=df.index, # clustering methodsa colorscale='Darkmint', text=df.values, texttemplate='%{text:.3f}', hovertemplate=metric + ': %{z:.3f}<br>external label: %{x}<br>clustering setting: %{y}<extra></extra>', textfont={"size": 15}, hoverongaps=False )) fig.update_layout( template='plotly_white' ) # Save heatmap heatmap_path = self.result_path / f"{analysis_name}_{metric}_heatmap.html" plot_path = PlotlyUtil.write_image_to_file(fig, heatmap_path, df.shape[0]) outputs.append(ReportOutput( path=plot_path, name=f"Heatmap for {metric.replace('_', ' ')} ({analysis_name.replace('_', ' ')})" )) # Save metric table table_path = self.result_path / f"{analysis_name}_{metric}.csv" df.reset_index().rename(columns={'index': 'clustering_setting'}).to_csv(table_path, index=False) outputs.append(ReportOutput( path=table_path, name=f"Metric values for {metric} ({analysis_name.replace('_', ' ')})" )) return outputs
[docs] def check_prerequisites(self): if not self.state: return False if not self.state.clustering_items: return False if not self.state.config.label_config: return False return True