Source code for immuneML.reports.data_reports.CompAIRRClusteringReport

import copy
import logging
import os
import subprocess
from pathlib import Path
from typing import List

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.figure_factory as ff
import plotly.graph_objects as go
from scipy.cluster.hierarchy import linkage, fcluster

from immuneML.data_model.SequenceSet import Repertoire
from immuneML.data_model.datasets.RepertoireDataset import RepertoireDataset
from immuneML.reports.PlotlyUtil import PlotlyUtil
from immuneML.reports.ReportOutput import ReportOutput
from immuneML.reports.ReportResult import ReportResult
from immuneML.reports.data_reports.DataReport import DataReport
from immuneML.util.ParameterValidator import ParameterValidator
from immuneML.util.PathBuilder import PathBuilder


[docs] class CompAIRRClusteringReport(DataReport): """ A report that uses CompAIRR to compute repertoire distances based on sequence overlap and performs hierarchical clustering on the resulting distance matrix. The clustering results are visualized using a dendrogram, colored by a specified label. The report assumes that CompAIRR (https://github.com/uio-bmi/compairr) has been installed. .. note:: This is an experimental feature. **Specification arguments**: - labels (list): The list of labels to highlight below the dendrogram. The labels should be present in the dataset. - compairr_path (str): Path to the CompAIRR executable. - indels (bool): Whether to allow insertions/deletions when matching sequences (default: False) - ignore_counts (bool): Whether to ignore sequence counts when computing overlap (default: False) - ignore_genes (bool): Whether to ignore V/J gene assignments when matching sequences (default: False) - threads (int): Number of threads to use for CompAIRR computation (default: 4) - linkage_method (str): The linkage method to use for hierarchical clustering (default: 'single') - is_cdr3 (bool): Whether the sequences represent CDR3s (default: True) - clustering_criterion (str): The criterion to use for clustering (default: 'distance'), as defined in scipy.cluster.hierarchy.linkage; valid values are 'distance', 'maxclust', 'monocrit', 'maxclust_monocrit' - clustering_threshold (float): The threshold for the clustering algorithm (default: 0.5), mapped to 't' parameter in scipy.cluster.hierarchy.fcluster **YAML specification:** .. indent with spaces .. code-block:: yaml definitions: reports: my_compairr_clustering_report: CompAIRRClusteringReport: labels: [disease] compairr_path: /path/to/compairr indels: false ignore_counts: true ignore_genes: true threads: 4 linkage_method: single is_cdr3: true clustering_criterion: distance clustering_threshold: 0.5 """
[docs] @classmethod def build_object(cls, **kwargs): valid_keys = ['labels', 'compairr_path', 'indels', 'ignore_counts', 'ignore_genes', 'threads', 'linkage_method', 'is_cdr3', 'clustering_threshold', 'clustering_criterion', 'name'] ParameterValidator.assert_keys(kwargs, valid_keys, CompAIRRClusteringReport.__name__, CompAIRRClusteringReport.__name__, True) ParameterValidator.assert_type_and_value(kwargs['labels'], list, CompAIRRClusteringReport.__name__, 'labels') ParameterValidator.assert_in_valid_list(kwargs['clustering_criterion'], ['distance', 'maxclust', 'monocrit', 'maxclust_monocrit'], CompAIRRClusteringReport.__name__, 'clustering_criterion') assert isinstance(kwargs['clustering_threshold'], (int, float)), 'clustering_threshold must be a number' return CompAIRRClusteringReport(**kwargs)
def __init__(self, dataset: RepertoireDataset = None, result_path: Path = None, labels: List[str] = None, compairr_path: str = None, indels: bool = False, ignore_counts: bool = False, ignore_genes: bool = False, threads: int = 4, linkage_method: str = 'single', is_cdr3: bool = True, name: str = None, clustering_threshold: float = 0.5, clustering_criterion: str = 'distance'): super().__init__(dataset=dataset, result_path=result_path, name=name) self.labels = labels self.linkage_method = linkage_method self.compairr_path = compairr_path self.indels = indels self.ignore_counts = ignore_counts self.ignore_genes = ignore_genes self.threads = threads self.is_cdr3 = is_cdr3 self.clustering_threshold = clustering_threshold self.clustering_criterion = clustering_criterion
[docs] def check_prerequisites(self) -> bool: if not self.compairr_path: logging.warning("CompAIRR path not provided. CompAIRR must be installed and available in the system PATH.") return False if not isinstance(self.dataset, RepertoireDataset): logging.warning("CompAIRRClusteringReport requires a RepertoireDataset as input.") return False return True
def _generate(self) -> ReportResult: PathBuilder.build(self.result_path) self.check_label() similarity_matrix = self.compare_repertoires() distance_matrix = self.compute_distance_matrix(similarity_matrix) linkage_matrix = linkage(distance_matrix.values, method='single') predictions = fcluster(linkage_matrix, self.clustering_threshold, criterion=self.clustering_criterion) metadata = self.dataset.get_metadata(self.labels + ['subject_id'], return_df=True) # Save outputs distance_output = self._store_matrix(distance_matrix, 'distance_matrix') similarity_output = self._store_matrix(similarity_matrix, 'similarity_matrix') similarity_heatmap = self._make_similarity_heatmap(similarity_matrix, 'similarity') dendrogram_output = self._create_dendrogram(distance_matrix, metadata, linkage_matrix) cluster_output = self._store_cluster_assignments(self.dataset.get_example_ids(), predictions) return ReportResult( name=self.name, info="Hierarchical clustering of repertoires based on CompAIRR sequence overlap", output_figures=[dendrogram_output, similarity_heatmap], output_tables=[distance_output, similarity_output, cluster_output] )
[docs] def compute_distance_matrix(self, similarity_matrix) -> pd.DataFrame: return pd.DataFrame(data=1 - similarity_matrix.values, columns=similarity_matrix.columns, index=similarity_matrix.index)
[docs] def compare_repertoires(self) -> pd.DataFrame: # Create similarity matrix n_repertoires = len(self.dataset.repertoires) similarity_matrix = self._init_similarity_matrix(n_repertoires) # Compare repertoires pairwise for i in range(n_repertoires): rep1 = self.dataset.repertoires[i] tmp_rep_file1 = self.result_path / f"{rep1.identifier}_tmp.tsv" self._make_clean_rep_files(rep1, tmp_rep_file1) similarity_matrix.loc[rep1.identifier, rep1.identifier] = 1 for j in range(i + 1, n_repertoires): rep2 = self.dataset.repertoires[j] tmp_rep_file2 = self.result_path / f"{rep2.identifier}_tmp.tsv" self._make_clean_rep_files(rep2, tmp_rep_file2) similarity_matrix = self._run_compairr(rep1, rep2, tmp_rep_file1, tmp_rep_file2, similarity_matrix) os.remove(str(tmp_rep_file2)) os.remove(str(tmp_rep_file1)) logging.info(f"CompAIRRClusteringReport: Finished processing repertoire " f"{rep1.identifier} ({i + 1}/{n_repertoires})") return similarity_matrix
[docs] def check_label(self): if any(label not in self.dataset.get_label_names() for label in self.labels): raise ValueError( f"Labels '{self.labels}' not found in dataset metadata. Available labels: {self.dataset.get_label_names()}")
def _init_similarity_matrix(self, n_repertoires: int) -> pd.DataFrame: return pd.DataFrame( np.zeros((n_repertoires, n_repertoires)), index=[rep.identifier for rep in self.dataset.repertoires], columns=[rep.identifier for rep in self.dataset.repertoires] ) def _make_clean_rep_files(self, rep: Repertoire, tmp_file_path: Path): df = pd.read_csv(rep.data_filename, sep="\t")[['cdr3_aa', 'v_call', 'j_call', 'duplicate_count']] if (df['duplicate_count'] == -1).any() or df['duplicate_count'].isna().any(): logging.warning(f"CompAIRRClusteringReport: Duplicate count -1 or NA found in repertoire {rep.identifier}. " "Setting to 1, but it should be rechecked.") df['duplicate_count'] = df['duplicate_count'].replace(-1, 1) df['duplicate_count'] = df['duplicate_count'].replace(pd.NA, 1) df.to_csv(tmp_file_path, sep="\t", index=False) def _run_compairr(self, rep1: Repertoire, rep2: Repertoire, tmp_rep_file1: Path, tmp_rep_file2: Path, similarity_matrix: pd.DataFrame): output_file = self.result_path / f"overlap_{rep1.identifier}_{rep2.identifier}.tsv" cmd_args = [str(self.compairr_path), "--matrix", str(tmp_rep_file1), str(tmp_rep_file2), "-o", str(output_file)] if self.indels: cmd_args.append("-i") if self.ignore_counts: cmd_args.append("--ignore-counts") if self.ignore_genes: cmd_args.append("--ignore-genes") if self.is_cdr3: cmd_args.append("--cdr3") cmd_args.extend(["-t", str(self.threads)]) cmd_args.extend(["-s", 'Jaccard']) result = subprocess.run(cmd_args, capture_output=True, text=True) if result.returncode != 0: raise RuntimeError( f"CompAIRRClusteringReport: Error running CompAIRR for repertoires {rep1.identifier} and " f"{rep2.identifier}:\n{result.stderr}") elif output_file.is_file(): similarity = pd.read_csv(output_file, sep="\t").iloc[0, 1] similarity_matrix.loc[rep1.identifier, rep2.identifier] = similarity similarity_matrix.loc[rep2.identifier, rep1.identifier] = similarity os.remove(str(output_file)) return similarity_matrix def _store_matrix(self, matrix: pd.DataFrame, name: str) -> ReportOutput: output_path = self.result_path / f"{name}.tsv" matrix.to_csv(output_path, sep="\t") return ReportOutput(path=output_path, name=name) def _make_similarity_heatmap(self, matrix: pd.DataFrame, name: str) -> ReportOutput: metadata = self.dataset.get_metadata(field_names=None, return_df=True) subject_ids = metadata['subject_id'].tolist() if 'subject_id' in metadata.columns \ else [str(el) for el in range(matrix.shape[0])] mask = np.zeros_like(matrix, dtype=bool) mask[np.triu_indices_from(mask)] = True matrix_values = copy.deepcopy(matrix.values) matrix_values[mask] = np.nan matrix_text = matrix_values.round(3).astype(str) matrix_text = np.where(np.isnan(matrix_values), '', matrix_text) fig = go.Figure(data=go.Heatmap(z=matrix_values, x=subject_ids[:-1] + [""], y=[""] + subject_ids[1:], colorscale='Darkmint', hoverongaps=False, text=matrix_text, texttemplate="%{text}")) fig.update_layout(template="plotly_white", title=f"{name} matrix", xaxis_title="repertoire", yaxis_title='repertoire') fig.update_xaxes(showgrid=False) fig.update_yaxes(showgrid=False, autorange='reversed') file_path = PlotlyUtil.write_image_to_file(fig, self.result_path / f"{name}_heatmap.html", matrix_text.shape[0]) return ReportOutput(path=file_path, name=f"{name} heatmap") def _store_cluster_assignments(self, repertoire_ids, clusters) -> ReportOutput: output_path = self.result_path / "cluster_assignments.tsv" cluster_df_data = { 'repertoire_id': repertoire_ids, 'cluster': clusters } metadata = self.dataset.get_metadata(field_names=None, return_df=True) if 'subject_id' in metadata.columns: cluster_df_data['subject_id'] = metadata['subject_id'].tolist() cluster_df = pd.DataFrame(cluster_df_data) cluster_df.to_csv(output_path, sep="\t", index=False) return ReportOutput( path=output_path, name="Cluster Assignments", ) def _create_dendrogram(self, distance_matrix, metadata: pd.DataFrame, linkage_matrix) -> ReportOutput: """ Parameters: ----------- repertoire_ids : list List of repertoire IDs distance_matrix : DataFrame Distance matrix external_labels : dict Dictionary mapping repertoire_ids to their labels/categories """ output_path = self.result_path / "dendrogram.html" repertoire_ids = metadata['subject_id'].tolist() if 'subject_id' in metadata.columns else self.dataset.get_example_ids() metadata['subject_id'] = repertoire_ids # Create dummy data array for dendrogram n = distance_matrix.shape[0] dummy_data = np.zeros((n, 2)) # 2D dummy data # Initialize figure by creating upper dendrogram only fig = ff.create_dendrogram(dummy_data, orientation='bottom', labels=repertoire_ids, linkagefun=lambda x: linkage_matrix, distfun=lambda x: distance_matrix.values) # Remove trace numbers from dendrogram hover for i in range(len(fig['data'])): fig['data'][i]['yaxis'] = 'y2' fig['data'][i]['hovertemplate'] = None fig['data'][i]['showlegend'] = False # Create a mapping from repertoire ids to indices label_to_idx = {label: idx for idx, label in enumerate(repertoire_ids)} # Get the ordering of leaves from the top dendrogram and convert to indices dendro_leaves = [label_to_idx[label] for label in fig['layout']['xaxis']['ticktext']] ordered_ids = [repertoire_ids[idx] for idx in dendro_leaves] metadata = metadata.set_index('subject_id') annotation_start = 0.72 annotation_end = 0.82 annotation_gap = 0 annotation_height = (annotation_end - annotation_start - (len(self.labels) - 1) * annotation_gap) / len(self.labels) for i, label in enumerate(self.labels): yaxis_name = f'y{i + 3}' y_domain = [ annotation_start + i * (annotation_height + annotation_gap), annotation_start + i * (annotation_height + annotation_gap) + annotation_height ] annotation_values = [metadata.loc[ind, label] for ind in ordered_ids] unique_labels = sorted(metadata[label].unique().tolist()) label_to_num = {label: j for j, label in enumerate(unique_labels)} annotation_numeric = [label_to_num[v] for v in annotation_values] # Add annotation heatmap fig.add_trace(go.Heatmap( x=fig['layout']['xaxis']['tickvals'], y=[label], z=np.array(annotation_numeric).reshape(1, -1), showlegend=False, showscale=False, colorscale=px.colors.qualitative.Plotly[1:], hovertemplate=label + ': %{text}<br>Repertoire: %{x}<extra></extra>', text=[annotation_values], yaxis=yaxis_name )) # Update layout for this axis fig.update_layout(**{ f'yaxis{i + 3}': dict( domain=y_domain, mirror=False, showgrid=False, showline=False, zeroline=False ) }) # Reorder the distance matrix according to dendrogram leaves heat_data = distance_matrix.values[dendro_leaves, :] heat_data = heat_data[:, dendro_leaves] # Create Heatmap with custom hover template heatmap = [ go.Heatmap( x=fig['layout']['xaxis']['tickvals'], y=fig['layout']['xaxis']['tickvals'], z=heat_data, colorscale=px.colors.sequential.Blues, showlegend=False, showscale=False, hovertemplate='X: %{x}<br>Y: %{y}<br>Distance: %{z}<extra></extra>' ) ] # Add Heatmap Data to Figure for data in heatmap: fig.add_trace(data) # Edit Layout with adjusted domains for the new annotation row fig.update_layout({ 'height': 700 + len(self.labels) * 20, # Adjust height based on number of labels 'autosize': True, 'showlegend': False, 'hovermode': 'closest', 'template': 'plotly_white', 'xaxis': { 'domain': [0, 1], 'mirror': False, 'showgrid': False, 'showline': False, 'zeroline': False, 'ticks': "" }, 'yaxis': { 'domain': [0, 0.7], # Reduced height for main heatmap 'mirror': False, 'showgrid': False, 'showline': False, 'zeroline': False, 'showticklabels': False, 'ticks': "" }, 'yaxis2': { 'domain': [0.85, 1], # Top dendrogram 'mirror': False, 'showgrid': False, 'showline': False, 'zeroline': False, 'showticklabels': False, 'ticks': "" } }) output_path = PlotlyUtil.write_image_to_file(fig, output_path, metadata.shape[0]) return ReportOutput(path=output_path, name="hierarchical_clustering_of_repertoires")