Source code for immuneML.preprocessing.filters.DuplicateSequenceFilter

import copy
from multiprocessing.pool import Pool
from pathlib import Path
from uuid import uuid4

import pandas as pd

from immuneML.data_model import bnp_util
from immuneML.data_model.AIRRSequenceSet import AIRRSequenceSet
from immuneML.data_model.SequenceParams import RegionType
from immuneML.data_model.bnp_util import write_yaml
from immuneML.data_model.datasets.RepertoireDataset import RepertoireDataset
from immuneML.data_model.SequenceSet import Repertoire
from immuneML.environment.SequenceType import SequenceType
from immuneML.preprocessing.filters.CountAggregationFunction import CountAggregationFunction
from immuneML.preprocessing.filters.Filter import Filter
from immuneML.util.ParameterValidator import ParameterValidator
from scripts.specification_util import update_docs_per_mapping


[docs] class DuplicateSequenceFilter(Filter): """ Collapses duplicate nucleotide or amino acid sequences within each repertoire in the given RepertoireDataset. This filter can be applied to Repertoires and RepertoireDatasets. Sequences are considered duplicates if the following fields are identical: - amino acid or nucleotide sequence (whichever is specified) - v and j genes (note that the full field including subgroup + gene is used for matching, i.e. V1 and V1-1 are not considered duplicates) - chain - region type For all other fields (the non-specified sequence type, custom lists, sequence identifier) only the first occurring value is kept. Note that this means the count value of a sequence with a given sequence identifier might not be the same as before removing duplicates, unless count_agg = FIRST is used. **Specification arguments:** - filter_sequence_type (:py:obj:`~immuneML.environment.SequenceType.SequenceType`): Whether the sequences should be collapsed on the nucleotide or amino acid level. Valid options are defined by the SequenceType enum. - region_type (str): which part of the sequence to examine, by default, this is IMGT_CDR3 - count_agg (:py:obj:`~immuneML.preprocessing.filters.CountAggregationFunction.CountAggregationFunction`): determines how the sequence counts of duplicate sequences are aggregated. Valid options are defined by the CountAggregationFunction enum. **YAML specification:** .. indent with spaces .. code-block:: yaml preprocessing_sequences: my_preprocessing: - my_filter: DuplicateSequenceFilter: # required parameters: filter_sequence_type: AMINO_ACID # optional parameters (if not specified the values bellow will be used): batch_size: 4 count_agg: SUM region_type: IMGT_CDR3 """
[docs] @classmethod def build_object(cls, **kwargs): location = cls.__name__ ParameterValidator.assert_keys(kwargs.keys(), ["filter_sequence_type", "batch_size", "count_agg", "region_type"], location, "DuplicateSequenceFilter") ParameterValidator.assert_in_valid_list(kwargs["filter_sequence_type"].upper(), [item.name for item in SequenceType], location, "filter_sequence_type") ParameterValidator.assert_in_valid_list(kwargs["region_type"].upper(), [item.name for item in RegionType], location, "region_type") ParameterValidator.assert_in_valid_list(kwargs["count_agg"].upper(), [item.name for item in CountAggregationFunction], location, "count_agg") ParameterValidator.assert_type_and_value(kwargs["batch_size"], int, location, "batch_size", 1) return DuplicateSequenceFilter(filter_sequence_type=SequenceType[kwargs["filter_sequence_type"].upper()], region_type=RegionType[kwargs['region_type'].upper()], batch_size=kwargs["batch_size"], count_agg=CountAggregationFunction[kwargs["count_agg"].upper()])
def __init__(self, filter_sequence_type: SequenceType, batch_size: int, count_agg: CountAggregationFunction, result_path: Path = None, region_type: RegionType = RegionType.IMGT_CDR3): super().__init__(result_path) self.filter_sequence_type = filter_sequence_type self.region_type = region_type self.count_agg = count_agg self.batch_size = batch_size self.sequence_of_interest = bnp_util.get_sequence_field_name(self.region_type, self.filter_sequence_type) self.sequence_to_ignore = bnp_util.get_sequence_field_name(self.region_type, [t for t in SequenceType if t != self.filter_sequence_type][0])
[docs] def process_dataset(self, dataset: RepertoireDataset, result_path: Path, number_of_processes=1) -> RepertoireDataset: self.result_path = result_path if result_path is not None else self.result_path self.check_dataset_type(dataset, [RepertoireDataset], "DuplicateSequenceFilter") processed_dataset = dataset.clone() with Pool(self.batch_size) as pool: repertoires = pool.map(self._process_repertoire, dataset.repertoires) processed_dataset.repertoires = repertoires return processed_dataset
def _prepare_group_by_field(self, columns): rep_fields = list(AIRRSequenceSet.get_field_type_dict().keys()) groupby_fields = copy.deepcopy(rep_fields) if self.sequence_to_ignore in groupby_fields: groupby_fields.remove(self.sequence_to_ignore) groupby_fields.remove("duplicate_count") groupby_fields.remove("sequence_id") groupby_fields.remove("cell_id") for field in set(rep_fields).difference(set(columns)): if field in groupby_fields: groupby_fields.remove(field) return groupby_fields def _prepare_agg_dict(self, columns, custom_lists): agg_dict = {"sequence_id": "first"} if self.sequence_to_ignore in columns: agg_dict[self.sequence_to_ignore] = "first" if "duplicate_count" in columns: agg_dict["duplicate_count"] = self.count_agg.value if "cell_id" in columns: agg_dict["cell_id"] = "first" for key in custom_lists: agg_dict[key] = "first" return agg_dict def _process_repertoire(self, repertoire: Repertoire) -> Repertoire: data = repertoire.data.topandas() data['duplicate_count'].replace(-1, pd.NA, inplace=True) columns = data.columns groupby_fields = self._prepare_group_by_field(columns) custom_lists = list(repertoire.dynamic_fields.keys()) agg_dict = self._prepare_agg_dict(columns, custom_lists) no_duplicates = data.groupby(groupby_fields, sort=False).agg(agg_dict).reset_index() no_duplicates.to_csv(f"{self.result_path}/{repertoire.data_filename.stem}_filtered.tsv", sep='\t', index=False) write_yaml(Path(f"{self.result_path}/{repertoire.metadata_filename.stem}_filtered.yaml"), repertoire.metadata) return Repertoire(data_filename=Path(f"{self.result_path}/{repertoire.data_filename.stem}_filtered.tsv"), metadata_filename=Path(f"{self.result_path}/{repertoire.metadata_filename.stem}_filtered.yaml"), identifier=str(uuid4().hex), dynamic_fields=repertoire.dynamic_fields)
[docs] @staticmethod def get_documentation(): doc = str(DuplicateSequenceFilter.__doc__) mapping = { "Valid options are defined by the CountAggregationFunction enum.": f"Valid values are: {[e.name.lower() for e in CountAggregationFunction]}.", "Valid options are defined by the SequenceType enum.": f"Valid values are: {[e.name.lower() for e in SequenceType]}." } doc = update_docs_per_mapping(doc, mapping) return doc