Source code for immuneML.reports.data_reports.ShannonDiversityOverview

import logging
from pathlib import Path
from typing import Tuple

import pandas as pd
import plotly.express as px

from immuneML.data_model.datasets.RepertoireDataset import RepertoireDataset
from immuneML.encodings.EncoderParams import EncoderParams
from immuneML.encodings.diversity_encoding.ShannonDiversityEncoder import ShannonDiversityEncoder
from immuneML.environment.LabelConfiguration import LabelConfiguration
from immuneML.reports.PlotlyUtil import PlotlyUtil
from immuneML.reports.ReportOutput import ReportOutput
from immuneML.reports.ReportResult import ReportResult
from immuneML.reports.data_reports.DataReport import DataReport
from immuneML.util.PathBuilder import PathBuilder


[docs] class ShannonDiversityOverview(DataReport): """ Computes Shannon diversity for each repertoire using Shannon diversity encoder and plots the results in a histogram, optionally stratified by labels. **Dataset type:** - Repertoire Dataset **Specification arguments:** - color_label (str): The label used to color the histogram bars. Default is None. - facet_row_label (str): The label used to facet the histogram into multiple rows. Default is None, meaning no row faceting. - facet_col_label (str): The label used to facet the histogram into multiple columns. Default is None, meaning no column faceting. **YAML specification:** .. indent with spaces .. code-block:: yaml definitions: reports: shannon_div_rep: ShannonDiversityOverview: color_label: disease_status """ def __init__(self, dataset: RepertoireDataset = None, result_path: Path = None, name: str = None, number_of_processes: int = 1, color_label: str = None, facet_row_label: str = None, facet_col_label: str = None): super().__init__(dataset, result_path, name, number_of_processes) self.color_label = color_label self.facet_row_label = facet_row_label self.facet_col_label = facet_col_label
[docs] @classmethod def build_object(cls, **kwargs): return ShannonDiversityOverview(**kwargs)
[docs] def check_prerequisites(self) -> bool: valid = isinstance(self.dataset, RepertoireDataset) if not valid: logging.warning(f"ShannonDiversityOverview: Dataset must be of type RepertoireDataset, " f"but got {type(self.dataset)}.") return valid
def _generate(self) -> ReportResult: encoded_dataset = (ShannonDiversityEncoder.build_object(self.dataset) .encode(self.dataset, EncoderParams(self.result_path, encode_labels=False))) PathBuilder.build(self.result_path) df, table_output = self.prepare_data(encoded_dataset) figure_output = self._safe_plot(encoded_df=df) return ReportResult(name=self.name, info="Shannon diversity per repertoire", output_figures=[figure_output], output_tables=[table_output])
[docs] def prepare_data(self, encoded_dataset) -> Tuple[pd.DataFrame, ReportOutput]: labels = ['subject_id'] if 'subject_id' in self.dataset.labels.keys() else [] for label in [self.color_label, self.facet_row_label, self.facet_col_label]: if label is not None: labels.append(label) df = pd.DataFrame({'shannon_diversity': encoded_dataset.encoded_data.examples.flatten(), 'repertoire_id': encoded_dataset.get_example_ids(), **self.dataset.get_metadata(labels)}) df.sort_values(by='shannon_diversity', ascending=False, inplace=True) df.to_csv(self.result_path / 'shannon_diversity.csv', index=False) return df, ReportOutput(self.result_path / 'shannon_diversity.csv', name='Shannon diversity')
def _plot(self, encoded_df) -> ReportOutput: facet_labels = [] if self.facet_row_label: facet_labels.append(self.facet_row_label) if self.facet_col_label: facet_labels.append(self.facet_col_label) encoded_df['repertoire_index'] = encoded_df.groupby(facet_labels).cumcount() \ if len(facet_labels) > 0 else list(range(encoded_df.shape[0])) hover_data_cols = ['repertoire_id'] + (['subject_id'] if 'subject_id' in encoded_df.columns else []) fig = px.bar(encoded_df, x='repertoire_index', y='shannon_diversity', facet_row=self.facet_row_label, color=self.color_label, title='Shannon diversity per repertoire', facet_col=self.facet_col_label, color_discrete_sequence=px.colors.qualitative.Vivid, hover_data=hover_data_cols) fig.update_layout(template="plotly_white", yaxis_title='Shannon diversity', xaxis_title='Repertoires sorted by Shannon diversity') fig.update_traces( hovertemplate=( "Repertoire id: %{customdata[0]}<br>" + "Subject id: %{customdata[1]}<br>" if 'subject_id' in encoded_df.columns else "" + "Shannon diversity: %{y}<extra></extra>" ) ) figure_path = PlotlyUtil.write_image_to_file(fig, self.result_path / 'shannon_diversity.html') return ReportOutput(figure_path, name='Shannon diversity')