import warnings
from pathlib import Path
import pandas as pd
import plotly.express as px
from immuneML.data_model.dataset.ReceptorDataset import ReceptorDataset
from immuneML.data_model.dataset.RepertoireDataset import RepertoireDataset
from immuneML.data_model.dataset.SequenceDataset import SequenceDataset
from immuneML.data_model.receptor.receptor_sequence.ReceptorSequence import ReceptorSequence
from immuneML.environment.EnvironmentSettings import EnvironmentSettings
from immuneML.environment.SequenceType import SequenceType
from immuneML.reports.ReportOutput import ReportOutput
from immuneML.reports.ReportResult import ReportResult
from immuneML.reports.data_reports.DataReport import DataReport
from immuneML.util.ParameterValidator import ParameterValidator
from immuneML.util.PathBuilder import PathBuilder
from immuneML.util.PositionHelper import PositionHelper
[docs]
class AminoAcidFrequencyDistribution(DataReport):
"""
Generates a barplot showing the relative frequency of each amino acid at each position in the sequences of a dataset.
Arguments:
imgt_positions (bool): Whether to use IMGT positional numbering or sequence index numbering. When imgt_positions is True, IMGT positions are used, meaning sequences of unequal length are aligned according to their IMGT positions. By default imgt_positions is True.
relative_frequency (bool): Whether to plot relative frequencies (true) or absolute counts (false) of the positional amino acids. By default, relative_frequency is True.
split_by_label (bool): Whether to split the plots by a label. If set to true, the Dataset must either contain a single label, or alternatively the label of interest can be specified under 'label'. By default, split_by_label is False.
label (str): if split_by_label is set to True, a label can be specified here.
YAML specification:
.. indent with spaces
.. code-block:: yaml
my_aa_freq_report:
AminoAcidFrequencyDistribution:
relative_frequency: False
split_by_label: True
label: CMV
"""
[docs]
@classmethod
def build_object(cls, **kwargs):
location = AminoAcidFrequencyDistribution.__name__
ParameterValidator.assert_type_and_value(kwargs["imgt_positions"], bool, location, "imgt_positions")
ParameterValidator.assert_type_and_value(kwargs["relative_frequency"], bool, location, "relative_frequency")
ParameterValidator.assert_type_and_value(kwargs["split_by_label"], bool, location, "split_by_label")
if kwargs["label"] is not None:
ParameterValidator.assert_type_and_value(kwargs["label"], str, location, "label")
if kwargs["split_by_label"] is False:
warnings.warn(f"{location}: label is set but split_by_label was False, setting split_by_label to True")
kwargs["split_by_label"] = True
return AminoAcidFrequencyDistribution(**kwargs)
def __init__(self, dataset: SequenceDataset = None, imgt_positions: bool = None, relative_frequency: bool = None,
split_by_label: bool = None, label: str = None,
result_path: Path = None, number_of_processes: int = 1, name: str = None):
super().__init__(dataset=dataset, result_path=result_path, number_of_processes=number_of_processes, name=name)
self.imgt_positions = imgt_positions
self.relative_frequency = relative_frequency
self.split_by_label = split_by_label
self.label_name = label
def _generate(self) -> ReportResult:
PathBuilder.build(self.result_path)
freq_dist = self._get_plotting_data()
results_table = self._write_results_table(freq_dist)
report_output_fig = self._safe_plot(freq_dist=freq_dist)
return ReportResult(name=self.name,
info="A a barplot showing the relative frequency of each amino acid at each position in the sequences of a dataset.",
output_figures=None if report_output_fig is None else [report_output_fig],
output_tables=None if results_table is None else [results_table])
def _get_plotting_data(self):
if isinstance(self.dataset, SequenceDataset):
plotting_data = self._get_sequence_dataset_plotting_data()
elif isinstance(self.dataset, ReceptorDataset):
plotting_data = self._get_receptor_dataset_plotting_data()
elif isinstance(self.dataset, RepertoireDataset):
plotting_data = self._get_repertoire_dataset_plotting_data()
if not self.split_by_label:
plotting_data.drop(columns=["class"], inplace=True)
return plotting_data
def _get_sequence_dataset_plotting_data(self):
return self._count_dict_per_class_to_df(self._count_aa_frequencies(self._sequence_class_iterator()))
def _count_dict_per_class_to_df(self, raw_count_dict_per_class):
result_dfs = []
for class_name, raw_count_dict in raw_count_dict_per_class.items():
result_df = self._count_dict_to_df(raw_count_dict)
result_df["class"] = class_name
result_dfs.append(result_df)
return pd.concat(result_dfs)
def _sequence_class_iterator(self):
label_name = self._get_label_name()
for sequence in self.dataset.get_data():
if self.split_by_label:
yield (sequence, sequence.get_attribute(label_name))
else:
yield (sequence, 0)
def _get_receptor_dataset_plotting_data(self):
result_dfs = []
receptors = self.dataset.get_data()
chains = next(receptors).get_chains()
for chain in chains:
raw_count_dict_per_class = self._count_aa_frequencies(self._chain_class_iterator(chain))
for class_name, raw_count_dict in raw_count_dict_per_class.items():
result_df = self._count_dict_to_df(raw_count_dict)
result_df["chain"] = chain # todo facet rows/cols when chain is available
result_df["class"] = class_name
result_dfs.append(result_df)
return pd.concat(result_dfs)
def _chain_class_iterator(self, chain):
label_name = self._get_label_name()
for receptor in self.dataset.get_data():
assert chain in receptor.get_chains(), f"{AminoAcidFrequencyDistribution.__name__}: All receptors in the dataset must contain the same chains. Expected {chain} but found {receptor.get_chains()}"
if self.split_by_label:
yield (receptor.get_chain(chain), receptor.get_attribute(label_name))
else:
yield (receptor.get_chain(chain), 0)
def _get_repertoire_dataset_plotting_data(self):
raw_count_dict_per_class = {}
label_name = self._get_label_name()
if self.split_by_label:
class_names = self.dataset.get_metadata([label_name])[label_name]
else:
class_names = [0] * self.dataset.get_example_count()
for repertoire, class_name in zip(self.dataset.get_data(), class_names):
self._count_aa_frequencies(self._repertoire_class_iterator(repertoire, class_name), raw_count_dict_per_class)
return self._count_dict_per_class_to_df(raw_count_dict_per_class)
def _repertoire_class_iterator(self, repertoire, class_name):
for sequence in repertoire.get_sequence_objects():
yield (sequence, class_name)
def _count_aa_frequencies(self, sequence_class_iterator, raw_count_dict=None):
raw_count_dict = {} if raw_count_dict is None else raw_count_dict
for item in sequence_class_iterator:
sequence, class_name = item
seq_str = sequence.get_sequence(sequence_type=SequenceType.AMINO_ACID)
seq_pos = self._get_positions(sequence)
if class_name not in raw_count_dict:
raw_count_dict[class_name] = {}
for aa, pos in zip(seq_str, seq_pos):
if pos not in raw_count_dict[class_name]:
raw_count_dict[class_name][pos] = {legal_aa: 0 for legal_aa in
EnvironmentSettings.get_sequence_alphabet(SequenceType.AMINO_ACID)}
raw_count_dict[class_name][pos][aa] += 1
return raw_count_dict
def _count_dict_to_df(self, raw_count_dict):
df_dict = {"amino acid": [], "position": [], "count": [], "relative frequency": []}
for pos in raw_count_dict:
df_dict["position"].extend([pos] * 20)
df_dict["amino acid"].extend(list(raw_count_dict[pos].keys()))
counts = list(raw_count_dict[pos].values())
total_count_for_pos = sum(counts)
df_dict["count"].extend(counts)
df_dict["relative frequency"].extend([count / total_count_for_pos for count in counts])
return pd.DataFrame(df_dict)
def _get_positions(self, sequence: ReceptorSequence):
if self.imgt_positions:
positions = PositionHelper.gen_imgt_positions_from_length(len(sequence.get_sequence(SequenceType.AMINO_ACID)),
sequence.get_attribute("region_type"))
else:
positions = list(range(len(sequence.get_sequence(SequenceType.AMINO_ACID))))
return [str(pos) for pos in positions]
def _write_results_table(self, results_table):
file_path = self.result_path / "amino_acid_frequency_distribution.csv"
results_table.to_csv(file_path, index=False)
return ReportOutput(path=file_path, name="Table of amino acid frequencies")
def _get_colors(self):
return ['rgb(102, 197, 204)', 'rgb(179,222,105)', 'rgb(220, 176, 242)', 'rgb(217,217,217)',
'rgb(141,211,199)', 'rgb(251,128,114)', 'rgb(158, 185, 243)', 'rgb(248, 156, 116)',
'rgb(135, 197, 95)', 'rgb(254, 136, 177)', 'rgb(201, 219, 116)', 'rgb(255,237,111)',
'rgb(180, 151, 231)', 'rgb(246, 207, 113)', 'rgb(190,186,218)', 'rgb(128,177,211)',
'rgb(253,180,98)', 'rgb(252,205,229)', 'rgb(188,128,189)', 'rgb(204,235,197)', ]
def _plot(self, freq_dist):
freq_dist.sort_values(by=["amino acid"], ascending=False, inplace=True)
y = "relative frequency" if self.relative_frequency else "count"
figure = px.bar(freq_dist, x="position", y=y, color="amino acid", text="amino acid",
facet_col="class" if "class" in freq_dist.columns else None,
facet_row="chain" if "chain" in freq_dist.columns else None,
color_discrete_sequence=self._get_colors(),
labels={"position": "IMGT position" if self.imgt_positions else "Sequence index",
"count": "Count",
"relative frequency": "Relative frequency",
"amino acid": "Amino acid"}, template="plotly_white")
figure.update_xaxes(categoryorder='array', categoryarray=self._get_position_order(freq_dist["position"]))
figure.update_layout(showlegend=False, yaxis={'categoryorder': 'category ascending'})
if self.relative_frequency:
figure.update_yaxes(tickformat=",.0%", range=[0, 1])
file_path = self.result_path / "amino_acid_frequency_distribution.html"
figure.write_html(str(file_path))
return ReportOutput(path=file_path, name="Amino acid frequency distribution")
def _get_position_order(self, positions):
return [str(int(pos)) if pos.is_integer() else str(pos) for pos in sorted(set(positions.astype(float)))]
def _get_label_name(self):
if self.split_by_label:
if self.label_name is None:
return list(self.dataset.get_label_names())[0]
else:
return self.label_name
else:
return None
[docs]
def check_prerequisites(self):
if self.split_by_label:
if self.label_name is None:
if len(self.dataset.get_label_names()) != 1:
warnings.warn(
f"{AminoAcidFrequencyDistribution.__name__}: ambiguous label: split_by_label was set to True but no label name was specified, and the number of available labels is {len(self.dataset.get_label_names())}: {self.dataset.get_label_names()}. Skipping this report...")
return False
else:
if self.label_name not in self.dataset.get_label_names():
warnings.warn(
f"{AminoAcidFrequencyDistribution.__name__}: the specified label name ({self.label_name}) was not available among the dataset labels: {self.dataset.get_label_names()}. Skipping this report...")
return False
return True