import logging
from pathlib import Path
from typing import Tuple
import pandas as pd
import plotly.express as px
from immuneML.data_model.datasets.RepertoireDataset import RepertoireDataset
from immuneML.encodings.EncoderParams import EncoderParams
from immuneML.encodings.diversity_encoding.ShannonDiversityEncoder import ShannonDiversityEncoder
from immuneML.reports.PlotlyUtil import PlotlyUtil
from immuneML.reports.ReportOutput import ReportOutput
from immuneML.reports.ReportResult import ReportResult
from immuneML.reports.data_reports.DataReport import DataReport
from immuneML.util.PathBuilder import PathBuilder
[docs]
class ShannonDiversityOverview(DataReport):
"""
Computes Shannon diversity for each repertoire using Shannon diversity encoder and plots the
results in a histogram, optionally stratified by labels.
**Dataset type:**
- Repertoire Dataset
**Specification arguments:**
- color_label (str): The label used to color the histogram bars. Default is None.
- facet_row_label (str): The label used to facet the histogram into multiple rows.
Default is None, meaning no row faceting.
- facet_col_label (str): The label used to facet the histogram into multiple columns.
Default is None, meaning no column faceting.
**YAML specification:**
.. indent with spaces
.. code-block:: yaml
definitions:
reports:
shannon_div_rep:
ShannonDiversityOverview:
color_label: disease_status
"""
def __init__(self, dataset: RepertoireDataset = None, result_path: Path = None, name: str = None,
number_of_processes: int = 1, color_label: str = None, facet_row_label: str = None,
facet_col_label: str = None):
super().__init__(dataset, result_path, name, number_of_processes)
self.color_label = color_label
self.facet_row_label = facet_row_label
self.facet_col_label = facet_col_label
[docs]
@classmethod
def build_object(cls, **kwargs):
return ShannonDiversityOverview(**kwargs)
[docs]
def check_prerequisites(self) -> bool:
valid = isinstance(self.dataset, RepertoireDataset)
if not valid:
logging.warning(f"ShannonDiversityOverview: Dataset must be of type RepertoireDataset, "
f"but got {type(self.dataset)}.")
return valid
def _generate(self) -> ReportResult:
encoded_dataset = (ShannonDiversityEncoder.build_object(self.dataset)
.encode(self.dataset, EncoderParams(self.result_path, encode_labels=False)))
PathBuilder.build(self.result_path)
df, table_output = self.prepare_data(encoded_dataset)
figure_output = self._safe_plot(encoded_df=df)
return ReportResult(name=self.name, info="Shannon diversity per repertoire",
output_figures=[figure_output],
output_tables=[table_output])
[docs]
def prepare_data(self, encoded_dataset) -> Tuple[pd.DataFrame, ReportOutput]:
labels = ['subject_id'] if 'subject_id' in self.dataset.labels.keys() else []
for label in [self.color_label, self.facet_row_label, self.facet_col_label]:
if label is not None:
labels.append(label)
df = pd.DataFrame({'shannon_diversity': encoded_dataset.encoded_data.examples.flatten(),
'repertoire_id': encoded_dataset.get_example_ids(),
**self.dataset.get_metadata(labels)})
df.sort_values(by='shannon_diversity', ascending=False, inplace=True)
df.to_csv(self.result_path / 'shannon_diversity.csv', index=False)
return df, ReportOutput(self.result_path / 'shannon_diversity.csv', name='Shannon diversity')
def _plot(self, encoded_df) -> ReportOutput:
facet_labels = []
if self.facet_row_label:
facet_labels.append(self.facet_row_label)
if self.facet_col_label:
facet_labels.append(self.facet_col_label)
encoded_df['repertoire_index'] = encoded_df.groupby(facet_labels).cumcount() \
if len(facet_labels) > 0 else list(range(encoded_df.shape[0]))
hover_data_cols = ['repertoire_id'] + (['subject_id'] if 'subject_id' in encoded_df.columns else [])
unique_color_labels = encoded_df[self.color_label].nunique() if self.color_label else 0
color_sequence = px.colors.qualitative.Vivid if unique_color_labels <= 10 else px.colors.qualitative.Dark24
if unique_color_labels > 24:
color_sequence = [px.colors.qualitative.Vivid[0]]
logging.warning(f"ShannonDiversityOverview: The color label '{self.color_label}' has "
f"{unique_color_labels} unique values, which may lead to indistinguishable colors in the "
f"plot. All bars will be colored with the first color in the Vivid palette.")
fig = px.bar(encoded_df, x='repertoire_index', y='shannon_diversity', facet_row=self.facet_row_label,
color=self.color_label, facet_col=self.facet_col_label,
color_discrete_sequence=color_sequence, hover_data=hover_data_cols)
fig.update_layout(template="plotly_white", xaxis_title='Repertoires sorted by Shannon diversity')
fig.add_annotation(
text="Shannon diversity",
xref="paper", yref="paper",
x=-0.02, y=0.5,
showarrow=False,
textangle=-90,
font=dict(size=14)
)
fig.update_layout(margin=dict(l=80))
fig.update_yaxes(title='')
fig.update_traces(
hovertemplate=(
"Repertoire id: %{customdata[0]}<br>" +
"Subject id: %{customdata[1]}<br>" if 'subject_id' in encoded_df.columns else "" +
"Shannon diversity: %{y}<extra></extra>"
)
)
figure_path = PlotlyUtil.write_image_to_file(fig, self.result_path / 'shannon_diversity.html')
return ReportOutput(figure_path, name='Shannon diversity per repertoire')