from collections import Counter
from pathlib import Path
import pandas as pd
import plotly.express as px
from pandas import DataFrame
from immuneML.data_model import bnp_util, AIRRSequenceSet
from immuneML.data_model.SequenceParams import RegionType
from immuneML.data_model.SequenceSet import Repertoire
from immuneML.data_model.datasets.Dataset import Dataset
from immuneML.data_model.datasets.ElementDataset import ReceptorDataset, SequenceDataset
from immuneML.data_model.datasets.RepertoireDataset import RepertoireDataset
from immuneML.environment.SequenceType import SequenceType
from immuneML.reports.ReportOutput import ReportOutput
from immuneML.reports.ReportResult import ReportResult
from immuneML.reports.data_reports.DataReport import DataReport
from immuneML.util.ParameterValidator import ParameterValidator
from immuneML.util.PathBuilder import PathBuilder
[docs]
class SequenceLengthDistribution(DataReport):
"""
Generates a histogram of the lengths of the sequences in a dataset.
**Specification arguments:**
- sequence_type (str): whether to check the length of amino acid or nucleotide sequences; default value is 'amino_acid'
- region_type (str): which part of the sequence to examine; e.g., IMGT_CDR3
- split_by_label (bool): Whether to split the plots by a label. If set to true, the Dataset must either contain a
single label, or alternatively the label of interest can be specified under 'label'. By default,
split_by_label is False.
- label (str): if split_by_label is set to True, a label can be specified here.
- plot_frequencies (bool): if set to True, the plot will show the frequencies of the sequence lengths instead of the
counts. By default, plot_frequencies is False.
**YAML specification:**
.. indent with spaces
.. code-block:: yaml
definitions:
reports:
my_sld_report:
SequenceLengthDistribution:
sequence_type: amino_acid
region_type: IMGT_CDR3
label: label_1
split_by_label: True
plot_frequencies: True
"""
[docs]
@classmethod
def build_object(cls, **kwargs):
ParameterValidator.assert_sequence_type(kwargs)
ParameterValidator.assert_region_type(kwargs)
return SequenceLengthDistribution(**{**kwargs, 'sequence_type': SequenceType[kwargs['sequence_type'].upper()],
'region_type': RegionType[kwargs['region_type'].upper()]})
def __init__(self, dataset: Dataset = None, batch_size: int = 1, result_path: Path = None,
number_of_processes: int = 1, region_type: RegionType = RegionType.IMGT_CDR3,
sequence_type: SequenceType = SequenceType.AMINO_ACID, name: str = None, label: str = None,
split_by_label: bool = False, plot_frequencies: bool = False):
super().__init__(dataset=dataset, result_path=result_path, number_of_processes=number_of_processes, name=name)
self.batch_size = batch_size
self.sequence_type = sequence_type
self.region_type = region_type
self.label_name = label
self.split_by_label = split_by_label
self.plot_frequencies = plot_frequencies
[docs]
def check_prerequisites(self):
return True
def _generate(self) -> ReportResult:
PathBuilder.build(self.result_path)
self.label_name = self._get_label_name()
df = self._get_sequence_lengths_df()
df.to_csv(self.result_path / 'sequence_length_distribution.csv', index=False)
report_output_fig = self._safe_plot(df=df, output_written=False)
output_figures = None if report_output_fig is None else [report_output_fig]
return ReportResult(name=self.name,
info="A histogram of the lengths of the sequences in a dataset.",
output_figures=output_figures,
output_tables=[ReportOutput(self.result_path / 'sequence_length_distribution.csv',
'lengths of sequences in the dataset')])
def _get_sequence_lengths_df(self) -> DataFrame:
if isinstance(self.dataset, RepertoireDataset):
sequence_lengths_df = self._get_sequence_lengths_df_repertoire_dataset()
elif isinstance(self.dataset, SequenceDataset):
sequence_lengths_df = self._get_sequence_lengths_df_sequence_dataset()
elif isinstance(self.dataset, ReceptorDataset):
sequence_lengths_df = self._get_sequence_lengths_df_receptor_dataset()
if not self.split_by_label and self.label_name in sequence_lengths_df.columns:
sequence_lengths_df.drop(columns=[self.label_name], inplace=True)
if self.plot_frequencies:
if self.split_by_label and 'chain' not in sequence_lengths_df.columns:
sequence_lengths_df['frequencies'] = sequence_lengths_df.groupby(self.label_name)['counts'].transform(
lambda x: x / x.sum())
elif self.split_by_label and 'chain' in sequence_lengths_df.columns:
sequence_lengths_df['frequencies'] = sequence_lengths_df.groupby([self.label_name, 'chain'])[
'counts'].transform(lambda x: x / x.sum())
else:
sequence_lengths_df['frequencies'] = sequence_lengths_df['counts'] / sequence_lengths_df['counts'].sum()
return sequence_lengths_df
def _get_sequence_lengths_df_repertoire_dataset(self):
raw_count_dict_per_class = {}
label_name = self._get_label_name()
if self.split_by_label:
class_names = self.dataset.get_metadata([label_name])[label_name]
else:
class_names = [0] * self.dataset.get_example_count()
for repertoire, class_name in zip(self.dataset.get_data(), class_names):
self._count_seq_lengths(repertoire.data, raw_count_dict_per_class, class_name)
return self._count_dict_per_class_to_df(raw_count_dict_per_class)
def _get_sequence_lengths_df_sequence_dataset(self):
return self._count_dict_per_class_to_df(self._count_seq_lengths(self.dataset.data, class_name=None))
def _count_dict_per_class_to_df(self, raw_count_dict_per_class):
result_dfs = []
for class_name, raw_count_dict in raw_count_dict_per_class.items():
result_df = self._count_dict_to_df(raw_count_dict)
result_df[self.label_name] = class_name
result_dfs.append(result_df)
return pd.concat(result_dfs)
def _count_seq_lengths(self, data: AIRRSequenceSet, raw_count_dict=None, class_name: str = None):
raw_count_dict = {} if raw_count_dict is None else raw_count_dict
for item in data.to_iter():
sequence = getattr(item, bnp_util.get_sequence_field_name(self.region_type, SequenceType.AMINO_ACID))
if self.split_by_label and self.label_name is not None:
if class_name is None:
cls_name = getattr(item, self.label_name)
else:
cls_name = class_name
else:
cls_name = 0
if cls_name not in raw_count_dict:
raw_count_dict[cls_name] = {}
if len(sequence) not in raw_count_dict[cls_name]:
raw_count_dict[cls_name][len(sequence)] = 1
else:
raw_count_dict[cls_name][len(sequence)] += 1
return raw_count_dict
def _count_dict_to_df(self, count_dict):
return pd.DataFrame({"counts": list(count_dict.values()),
'sequence_lengths': list(count_dict.keys())})
def _get_dataset_chains(self):
return next(self.dataset.get_data()).get_chains()
def _get_sequence_lengths_df_receptor_dataset(self):
data = self.dataset.data
chains = list(set(data.locus.tolist()))
dfs = []
label_name = self._get_label_name()
for chain in chains:
chain_data = data[[el == chain for el in data.locus.tolist()]]
if self.split_by_label and label_name:
chain_df = self._count_dict_per_class_to_df(self._count_seq_lengths(chain_data, class_name=None))
chain_df["chain"] = chain
dfs.append(chain_df)
else:
chain_counter = Counter(getattr(chain_data,
bnp_util.get_sequence_field_name(self.region_type,
self.sequence_type)).lengths.tolist())
dfs.append(pd.DataFrame({'counts': list(chain_counter.values()),
'sequence_lengths': list(chain_counter.keys()),
'chain': chain}))
return pd.concat(dfs)
def _count_in_repertoire(self, repertoire: Repertoire) -> Counter:
return Counter(getattr(repertoire.data,
bnp_util.get_sequence_field_name(self.region_type,
self.sequence_type)).lengths.tolist())
def _plot(self, df: pd.DataFrame) -> ReportOutput:
figure = px.bar(df, x="sequence_lengths", y="frequencies" if self.plot_frequencies else "counts",
facet_col=self.label_name if self.label_name in df.columns else None,
facet_row="chain" if isinstance(self.dataset, ReceptorDataset) else None)
figure.update_layout(template="plotly_white")
figure.update_traces(marker_color=px.colors.diverging.Tealrose[0])
for annotation in figure.layout.annotations:
annotation['font'] = {'size': 16}
PathBuilder.build(self.result_path)
file_path = self.result_path / "sequence_length_distribution.html"
figure.write_html(str(file_path))
return ReportOutput(path=file_path, name="Sequence length distribution plot")
def _get_label_name(self):
if self.split_by_label:
if self.label_name is None:
return list(self.dataset.get_label_names())[0]
else:
return self.label_name
else:
return None