from multiprocessing.pool import Pool
from typing import List
import dill
import numpy as np
import pandas as pd
from immuneML.analysis.SequenceMatcher import SequenceMatcher
from immuneML.caching.CacheHandler import CacheHandler
from immuneML.data_model.datasets.RepertoireDataset import RepertoireDataset
from immuneML.data_model.EncodedData import EncodedData
from immuneML.data_model.SequenceSet import ReceptorSequence
from immuneML.data_model.SequenceSet import Repertoire
from immuneML.encodings.DatasetEncoder import DatasetEncoder
from immuneML.encodings.EncoderParams import EncoderParams
from immuneML.encodings.reference_encoding.MatchedReferenceUtil import MatchedReferenceUtil
from immuneML.util.ParameterValidator import ParameterValidator
from immuneML.util.ReadsType import ReadsType
[docs]
class MatchedSequencesEncoder(DatasetEncoder):
"""
Encodes the dataset based on the matches between a RepertoireDataset and a reference sequence dataset.
This encoding can be used in combination with the :ref:`Matches` report.
When sum_matches and normalize are set to True, this encoder behaves as described in: Yao, Y. et al. ‘T cell receptor repertoire as a potential diagnostic marker for celiac disease’.
Clinical Immunology Volume 222 (January 2021): 108621. `doi.org/10.1016/j.clim.2020.108621 <https://doi.org/10.1016/j.clim.2020.108621>`_
**Dataset type:**
- RepertoireDatasets
**Specification arguments:**
- reference (dict): A dictionary describing the reference dataset file. Import should be specified the same way as
regular dataset import. It is only allowed to import a sequence dataset here (i.e., is_repertoire and paired are
False by default, and are not allowed to be set to True).
- max_edit_distance (int): The maximum edit distance between a target sequence (from the repertoire) and the
reference sequence.
- reads (:py:mod:`~immuneML.util.ReadsType`): Reads type signify whether the counts of the sequences in the
repertoire will be taken into account. If :py:mod:`~immuneML.util.ReadsType.UNIQUE`, only unique sequences
(clonotypes) are counted, and if :py:mod:`~immuneML.util.ReadsType.ALL`, the sequence 'count' value is summed when
determining the number of matches. The default value for reads is all.
- sum_matches (bool): When sum_matches is False, the resulting encoded data matrix contains multiple columns with
the number of matches per reference sequence. When sum_matches is true, all columns are summed together, meaning
that there is only one aggregated sum of matches per repertoire in the encoded data.
To use this encoder in combination with the :ref:`Matches` report, sum_matches must be set to False. When
sum_matches is set to True, this encoder behaves as described by Yao, Y. et al. By default, sum_matches is False.
- normalize (bool): If True, the sequence matches are divided by the total number of unique sequences in the
repertoire (when reads = unique) or the total number of reads in the repertoire (when reads = all).
**YAML specification:**
.. indent with spaces
.. code-block:: yaml
definitions:
encodings:
my_ms_encoding:
MatchedSequences:
reference:
format: VDJDB
params:
path: path/to/file.txt
max_edit_distance: 1
"""
def __init__(self, max_edit_distance: int, reference: List[ReceptorSequence], reads: ReadsType, sum_matches: bool, normalize: bool,
name: str = None):
super().__init__(name=name)
self.max_edit_distance = max_edit_distance
self.reference_sequences = reference
self.reads = reads
self.sum_matches = sum_matches
self.normalize = normalize
self.feature_count = 1 if self.sum_matches else len(self.reference_sequences)
@staticmethod
def _prepare_parameters(max_edit_distance: int, reference: dict, reads: str, sum_matches: bool, normalize: bool,
name: str = None):
location = "MatchedSequencesEncoder"
ParameterValidator.assert_type_and_value(max_edit_distance, int, location, "max_edit_distance", min_inclusive=0)
ParameterValidator.assert_type_and_value(sum_matches, bool, location, "sum_matches")
ParameterValidator.assert_type_and_value(normalize, bool, location, "normalize")
ParameterValidator.assert_in_valid_list(reads.upper(), [item.name for item in ReadsType], location, "reads")
reference_sequences = MatchedReferenceUtil.prepare_reference(reference_params=reference, location=location, paired=False)
return {
"max_edit_distance": max_edit_distance,
"reference": reference_sequences,
"reads": ReadsType[reads.upper()],
"sum_matches": sum_matches,
"normalize": normalize,
"name": name
}
[docs]
@staticmethod
def build_object(dataset=None, **params):
if isinstance(dataset, RepertoireDataset):
prepared_parameters = MatchedSequencesEncoder._prepare_parameters(**params)
return MatchedSequencesEncoder(**prepared_parameters)
else:
raise ValueError("MatchedSequencesEncoder is not defined for dataset types which are not RepertoireDataset.")
[docs]
def encode(self, dataset, params: EncoderParams):
cache_key = CacheHandler.generate_cache_key(self._prepare_caching_params(dataset, params))
encoded_dataset = CacheHandler.memo(cache_key,
lambda: self._encode_new_dataset(dataset, params))
return encoded_dataset
def _prepare_caching_params(self, dataset, params: EncoderParams):
encoding_params_desc = {"max_edit_distance": self.max_edit_distance,
"reference_sequences": sorted([seq.get_sequence() + seq.v_call + seq.j_call
for seq in self.reference_sequences]),
"reads": self.reads.name,
"sum_matches": self.sum_matches,
"normalize": self.normalize}
return (("dataset_identifiers", tuple(dataset.get_example_ids())),
("dataset_metadata", dataset.metadata_file),
("dataset_type", dataset.__class__.__name__),
("labels", tuple(params.label_config.get_labels_by_name())),
("encoding", MatchedSequencesEncoder.__name__),
("learn_model", params.learn_model),
("encoding_params", encoding_params_desc),)
def _encode_new_dataset(self, dataset, params: EncoderParams):
encoded_repertoires, labels = self._encode_repertoires(dataset, params)
encoded_repertoires = self._normalize(dataset, encoded_repertoires) if self.normalize else encoded_repertoires
feature_annotations = None if self.sum_matches else self._get_feature_info()
feature_names = [f"sum_of_{self.reads.value}_reads"] if self.sum_matches else list(feature_annotations["sequence_desc"])
encoded_dataset = dataset.clone()
encoded_dataset.encoded_data = EncodedData(
examples=encoded_repertoires,
labels=labels,
feature_names=feature_names,
feature_annotations=feature_annotations,
example_ids=[repertoire.identifier for repertoire in dataset.get_data()],
encoding=MatchedSequencesEncoder.__name__,
info={'sequence_type': params.sequence_type,
'region_type': params.region_type}
)
return encoded_dataset
def _normalize(self, dataset, encoded_repertoires):
if self.reads == ReadsType.UNIQUE:
repertoire_totals = np.asarray([[repertoire.get_element_count() for repertoire in dataset.get_data()]]).T
else:
repertoire_totals = np.asarray([[sum(repertoire.data.duplicate_count) for repertoire in dataset.get_data()]]).T
return encoded_repertoires / repertoire_totals
def _get_feature_info(self):
"""
returns a pandas dataframe containing:
- sequence id
- chain
- amino acid sequence
- v call
- j call
"""
features = [[] for i in range(0, self.feature_count)]
for i, sequence in enumerate(self.reference_sequences):
features[i] = [sequence.sequence_id,
sequence.locus,
sequence.sequence_aa,
sequence.v_call,
sequence.j_call,
self._get_sequence_desc(sequence)]
features = pd.DataFrame(features, columns=["sequence_id", "locus", "sequence", "v_call", "j_call", "sequence_desc"])
if features['sequence_desc'].unique().shape[0] < features.shape[0]:
features.loc[:, 'sequence_desc'] = [row['sequence_desc'] + "_" + row['sequence_id'] for ind, row in features.iterrows()]
return features
def _get_sequence_desc(self, sequence: ReceptorSequence) -> str:
desc = ""
if sequence.v_call not in [None, ""]:
desc += f"{sequence.v_call}_"
desc += sequence.sequence_aa if sequence.sequence_aa != "" else sequence.sequence
if sequence.j_call not in ["", None]:
desc += f"_{sequence.j_call}"
return desc
def _encode_repertoires(self, dataset: RepertoireDataset, params):
labels = {label: [] for label in params.label_config.get_labels_by_name()} if params.encode_labels else None
with Pool(params.pool_size) as pool:
encoded_repertories = np.array(pool.map(self._get_repertoire_matches_to_reference,
[dill.dumps(rep) for rep in dataset.repertoires]))
for repertoire in dataset.repertoires:
for label_name in params.label_config.get_labels_by_name():
labels[label_name].append(repertoire.metadata[label_name])
return encoded_repertories, labels
def _get_repertoire_matches_to_reference(self, repertoire):
if isinstance(repertoire, bytes):
repertoire = dill.loads(repertoire)
return CacheHandler.memo_by_params(
(("repertoire_identifier", repertoire.identifier),
("encoding", MatchedSequencesEncoder.__name__),
("readstype", self.reads.name),
("sum_matches", self.sum_matches),
("max_edit_distance", self.max_edit_distance),
("reference_sequences", tuple(
[(seq.locus, seq.sequence_aa, seq.v_call, seq.j_call) for seq in self.reference_sequences]))),
lambda: self._compute_matches_to_reference(repertoire))
def _compute_matches_to_reference(self, repertoire: Repertoire):
matcher = SequenceMatcher()
matches = np.zeros(self.feature_count, dtype=int)
rep_seqs = repertoire.sequences()
for i, reference_seq in enumerate(self.reference_sequences):
for repertoire_seq in rep_seqs:
if matcher.matches_sequence(reference_seq, repertoire_seq, max_distance=self.max_edit_distance):
matches_idx = 0 if self.sum_matches else i
match_count = 1 if self.reads == ReadsType.UNIQUE else repertoire_seq.duplicate_count
matches[matches_idx] += match_count
return matches