Source code for immuneML.reports.clustering_reports.ExternalLabelMetricHeatmap

from pathlib import Path
from typing import List

import numpy as np
import pandas as pd
import plotly.graph_objects as go

from immuneML.ml_metrics import ClusteringMetric
from immuneML.reports.PlotlyUtil import PlotlyUtil
from immuneML.reports.ReportOutput import ReportOutput
from immuneML.reports.ReportResult import ReportResult
from immuneML.reports.clustering_reports.ClusteringReport import ClusteringReport
from immuneML.util.ParameterValidator import ParameterValidator
from immuneML.util.PathBuilder import PathBuilder
from immuneML.workflows.instructions.clustering.ClusteringState import ClusteringState


[docs] class ExternalLabelMetricHeatmap(ClusteringReport): """ This report creates heatmaps comparing clustering methods against external labels for each metric. For each external label and metric combination, it creates: 1. A table showing the mean and standard deviation of metric values across splits for each combination of clustering method and external label 2. A heatmap visualization where the color represents the mean value and the text shows mean±std The external labels and metrics are automatically determined from the clustering instruction specification. **YAML specification:** .. indent with spaces .. code-block:: yaml reports: my_external_label_metric_heatmap: ExternalLabelMetricHeatmap """
[docs] @classmethod def build_object(cls, **kwargs): ParameterValidator.assert_keys(list(kwargs.keys()), ['name'], ExternalLabelMetricHeatmap.__name__, ExternalLabelMetricHeatmap.__name__) return ExternalLabelMetricHeatmap(**kwargs)
def __init__(self, name: str = None, state: ClusteringState = None, result_path: Path = None, number_of_processes: int = 1): super().__init__(name, result_path, number_of_processes, state) self.desc = "External Label - Clustering Heatmap" def _generate(self) -> ReportResult: self.result_path = PathBuilder.build(self.result_path / self.name) report_outputs = [] external_labels = self.state.config.label_config.get_labels_by_name() if not external_labels or len(external_labels) == 0: return ReportResult( name=f"{self.desc} ({self.name})", info="No external labels were found in the clustering state's label configuration." ) metrics = [metric for metric in self.state.config.metrics if ClusteringMetric.is_external(metric)] for metric in metrics: report_outputs.extend(self._process_metric_across_splits(metric, external_labels)) if not report_outputs: return ReportResult( name=f"{self.desc} ({self.name})", info="No results were generated. This could be because no metrics were computed." ) return ReportResult( name=f"{self.desc} ({self.name})", info="Heatmaps of metric values (mean±std across splits) for clustering methods versus external labels", output_tables=[output for output in report_outputs if output.path.suffix == '.csv'], output_figures=[output for output in report_outputs if output.path.suffix in ['.html', '.png']], ) def _process_metric_across_splits(self, metric: str, external_labels: List[str]) -> List[ReportOutput]: outputs = [] # Get setting keys from the first split setting_keys = list(self.state.clustering_items[0].items.keys()) # Collect metric values across all splits # Shape: (n_splits, n_settings, n_labels) all_values = [] for clustering_results in self.state.clustering_items: split_values = [] for setting_key in setting_keys: item_result = clustering_results.items[setting_key] performance_df = item_result.item.external_performance.get_df() label_values = [] for label in external_labels: value = performance_df[performance_df['metric'] == metric][label].values[0] label_values.append(value) split_values.append(label_values) all_values.append(split_values) all_values = np.array(all_values) # (n_splits, n_settings, n_labels) # Calculate mean and std across splits (axis=0) mean_values = np.mean(all_values, axis=0) # (n_settings, n_labels) std_values = np.std(all_values, axis=0) # (n_settings, n_labels) # Create text annotations with mean±std format text_annotations = np.empty(mean_values.shape, dtype=object) for i in range(mean_values.shape[0]): for j in range(mean_values.shape[1]): text_annotations[i, j] = f"{mean_values[i, j]:.3f}±{std_values[i, j]:.3f}" # Wrap long y-axis labels (clustering setting names) for better display wrapped_setting_keys = [_wrap_label(key) for key in setting_keys] # Calculate dynamic figure dimensions based on data size n_settings = len(setting_keys) n_labels = len(external_labels) # Base dimensions with scaling factors row_height = 40 # pixels per row col_width = 100 # pixels per column min_height, max_height = 400, 1200 min_width, max_width = 600, 1600 # Calculate margins based on label lengths max_y_label_lines = max(key.count('<br>') + 1 for key in wrapped_setting_keys) left_margin = 50 + (max_y_label_lines - 1) * 30 fig_height = max(min_height, min(max_height, n_settings * row_height + 150)) fig_width = max(min_width, min(max_width, n_labels * col_width + left_margin + 100)) # Create heatmap with mean as color and mean±std as text # Note on z, x, y mapping: z[i][j] corresponds to (y[i], x[j]) # mean_values shape is (n_settings, n_labels), so this is correct fig = go.Figure(data=go.Heatmap( z=mean_values, x=external_labels, y=wrapped_setting_keys, colorscale='Blues', text=text_annotations, texttemplate='%{text}', hovertemplate=(metric + ' (mean): %{z:.3f}<br>external label: %{x}<br>' 'clustering setting: %{y}<extra></extra>'), hoverongaps=False )) # Adjust text font size based on number of cells total_cells = n_settings * n_labels if total_cells > 50: text_font_size = 9 elif total_cells > 30: text_font_size = 10 else: text_font_size = 11 fig.update_traces(textfont=dict(size=text_font_size)) fig.update_layout( template='plotly_white', width=fig_width, height=fig_height, margin=dict(l=left_margin, r=50, t=50, b=80), yaxis=dict( tickfont=dict(size=10), automargin=True ), xaxis=dict( tickfont=dict(size=10), tickangle=-45 if n_labels > 5 else 0, automargin=True ) ) # Save heatmap heatmap_path = self.result_path / f"{metric}_heatmap.html" plot_path = PlotlyUtil.write_image_to_file(fig, heatmap_path, len(setting_keys)) outputs.append(ReportOutput( path=plot_path, name=f"Heatmap for {metric.replace('_', ' ')} (mean±std across splits)" )) # Create a combined table with mean±std format (use original keys for CSV) df_combined = pd.DataFrame(text_annotations, index=setting_keys, columns=external_labels) df_combined = df_combined.reset_index().rename(columns={'index': 'clustering_setting'}) table_path = self.result_path / f"{metric}_mean_std.csv" df_combined.to_csv(table_path, index=False) outputs.append(ReportOutput( path=table_path, name=f"Metric values (mean±std) for {metric}" )) return outputs
[docs] def check_prerequisites(self): if not self.state: return False if not self.state.clustering_items: return False if not self.state.config.label_config: return False return True
def _wrap_label(label: str, max_chars_per_line: int = 15) -> str: """ Wrap a label by splitting on underscores to create multi-line text. Attempts to keep each line under max_chars_per_line characters. """ if len(label) <= max_chars_per_line: return label parts = label.split('_') lines = [] current_line = "" for part in parts: if not current_line: current_line = part elif len(current_line) + 1 + len(part) <= max_chars_per_line: current_line += "_" + part else: lines.append(current_line) current_line = part if current_line: lines.append(current_line) return "<br>".join(lines)