import logging
from pathlib import Path
import pandas as pd
import plotly.express as px
from immuneML.data_model.SequenceSet import Repertoire
from immuneML.data_model.datasets.Dataset import Dataset
from immuneML.data_model.datasets.RepertoireDataset import RepertoireDataset
from immuneML.reports.PlotlyUtil import PlotlyUtil
from immuneML.reports.ReportOutput import ReportOutput
from immuneML.reports.ReportResult import ReportResult
from immuneML.reports.data_reports.DataReport import DataReport
from immuneML.util.PathBuilder import PathBuilder
[docs]
class RepertoireClonotypeSummary(DataReport):
"""
Shows the number of distinct clonotypes per repertoire in a given dataset as a bar plot.
**Specification arguments:**
- color_label (str): the label to color the bar plot by (optional, default: None)
- facet_label (str): the label to facet the bar plot by (optional, default: None)
**YAML specification:**
.. indent with spaces
.. code-block:: yaml
definitions:
reports:
my_clonotype_summary_rep:
RepertoireClonotypeSummary:
color_label: celiac
facet_label: hla
"""
def __init__(self, dataset: Dataset = None, result_path: Path = None, name: str = None, number_of_processes: int = 1,
color_label: str = None, facet_label: str = None):
super().__init__(dataset, result_path, name, number_of_processes)
self.color_label = color_label
self.facet_label = facet_label
[docs]
@classmethod
def build_object(cls, **kwargs):
return RepertoireClonotypeSummary(**kwargs)
def _generate(self) -> ReportResult:
PathBuilder.build(self.result_path)
return self._safe_plot()
[docs]
def add_labels(self, df: pd.DataFrame) -> pd.DataFrame:
column_names = []
if self.color_label:
column_names.append(self.color_label)
if self.facet_label:
column_names.append(self.facet_label)
if len(column_names) > 0:
metadata = self.dataset.get_metadata(column_names, return_df=True)
return pd.concat([df, metadata], axis=1)
else:
return df
def _plot(self) -> ReportResult:
clonotypes = pd.DataFrame({'clonotype_count': [self._get_clonotype_count(repertoire)
for repertoire in self.dataset.get_data()]})
clonotypes['repertoire_id'] = self.dataset.get_example_ids()
clonotypes = self.add_labels(clonotypes)
clonotypes.sort_values(by='clonotype_count', ascending=False, inplace=True)
clonotypes['repertoire_index'] = clonotypes.groupby(self.facet_label).cumcount() if self.facet_label else list(range(clonotypes.shape[0]))
fig = px.bar(clonotypes, x='repertoire_index', y='clonotype_count', facet_row=self.facet_label,
color=self.color_label, title='Clonotype count per repertoire',
color_discrete_sequence=px.colors.qualitative.Vivid)
fig.update_layout(template="plotly_white", yaxis_title='clonotype count',
xaxis_title='repertoires')
if self.facet_label:
facet_label_counts = {str(k): v for k, v in clonotypes[self.facet_label].value_counts().to_dict().items()}
for annotation in fig.layout.annotations:
group_label = annotation.text
if '=' in group_label:
group = group_label.split('=')[1]
count = facet_label_counts.get(group, 0)
annotation.text = f"{group_label}<br>(n={count})"
fig.add_annotation(
text="clonotype counts", # Your y-axis label
xref="paper", yref="paper", # Use paper coordinates (0–1)
x=-0.07, y=0.5, # Position to the left and centered vertically
showarrow=False,
textangle=-90, # Vertical orientation
font=dict(size=14)
)
fig.update_layout(margin=dict(l=80))
fig.update_yaxes(title='')
clonotypes.to_csv(self.result_path / 'clonotype_count_per_repertoire.csv', index=False)
fig_path = PlotlyUtil.write_image_to_file(fig, self.result_path / 'clonotype_count_per_repertoire.html',
self.dataset.get_example_count())
return ReportResult(name=self.name, info="Clonotype count per repertoire",
output_figures=[ReportOutput(fig_path, name='Clonotype count per repertoire')],
output_tables=[ReportOutput(self.result_path / 'clonotype_count_per_repertoire.csv',
name='Clonotype count per repertoire')])
def _get_clonotype_count(self, repertoire: Repertoire) -> int:
sequences = repertoire.data.topandas()
sequence_count = sequences.shape[0]
unique_sequence_count = len(sequences.groupby(['cdr3_aa', 'v_call', 'j_call']).size().reset_index(name='count'))
if sequence_count != unique_sequence_count:
logging.warning(f"{RepertoireClonotypeSummary.__name__}: {self.name}: for repertoire {repertoire.identifier}, "
f"there are {sequence_count} sequences, but {unique_sequence_count} unique (CDR3 amino acid"
f" sequence, V call, J call) combinations.")
return unique_sequence_count
[docs]
def check_prerequisites(self) -> bool:
if isinstance(self.dataset, RepertoireDataset):
return True
else:
logging.warning(f"{RepertoireClonotypeSummary.__name__}: report can be generated only from "
f"RepertoireDataset. Skipping this report...")
return False