Source code for immuneML.preprocessing.filters.DuplicateSequenceFilter

import copy
import os
from multiprocessing.pool import Pool
from pathlib import Path
from uuid import uuid4

import pandas as pd

from immuneML.data_model import bnp_util
from immuneML.data_model.AIRRSequenceSet import AIRRSequenceSet
from immuneML.data_model.SequenceParams import RegionType
from immuneML.data_model.bnp_util import write_yaml, read_yaml
from immuneML.data_model.datasets.ElementDataset import SequenceDataset
from immuneML.data_model.datasets.RepertoireDataset import RepertoireDataset
from immuneML.data_model.SequenceSet import Repertoire
from immuneML.environment.SequenceType import SequenceType
from immuneML.preprocessing.filters.CountAggregationFunction import CountAggregationFunction
from immuneML.preprocessing.filters.Filter import Filter
from immuneML.util.ParameterValidator import ParameterValidator
from scripts.specification_util import update_docs_per_mapping


[docs] class DuplicateSequenceFilter(Filter): """ Collapses duplicate nucleotide or amino acid sequences within each repertoire in the given RepertoireDataset or within a SequenceDataset. This filter can be applied to Repertoires, RepertoireDatasets, and SequenceDatasets. Sequences are considered duplicates if the following fields are identical: - amino acid or nucleotide sequence (whichever is specified) - v and j genes (note that the full field including subgroup + gene is used for matching, i.e. V1 and V1-1 are not considered duplicates) - chain - region type For all other fields (the non-specified sequence type, custom lists, sequence identifier) only the first occurring value is kept. Note that this means the count value of a sequence with a given sequence identifier might not be the same as before removing duplicates, unless count_agg = FIRST is used. **Specification arguments:** - filter_sequence_type (:py:obj:`~immuneML.environment.SequenceType.SequenceType`): Whether the sequences should be collapsed on the nucleotide or amino acid level. Valid options are defined by the SequenceType enum. - region_type (str): which part of the sequence to examine, by default, this is IMGT_CDR3 - count_agg (:py:obj:`~immuneML.preprocessing.filters.CountAggregationFunction.CountAggregationFunction`): determines how the sequence counts of duplicate sequences are aggregated. Valid options are defined by the CountAggregationFunction enum. **YAML specification:** .. indent with spaces .. code-block:: yaml preprocessing_sequences: my_preprocessing: - my_filter: DuplicateSequenceFilter: # required parameters: filter_sequence_type: AMINO_ACID # optional parameters (if not specified the values bellow will be used): batch_size: 4 count_agg: SUM region_type: IMGT_CDR3 """
[docs] @classmethod def build_object(cls, **kwargs): location = cls.__name__ ParameterValidator.assert_keys(kwargs.keys(), ["filter_sequence_type", "batch_size", "count_agg", "region_type"], location, "DuplicateSequenceFilter") ParameterValidator.assert_in_valid_list(kwargs["filter_sequence_type"].upper(), [item.name for item in SequenceType], location, "filter_sequence_type") ParameterValidator.assert_in_valid_list(kwargs["region_type"].upper(), [item.name for item in RegionType], location, "region_type") ParameterValidator.assert_in_valid_list(kwargs["count_agg"].upper(), [item.name for item in CountAggregationFunction], location, "count_agg") ParameterValidator.assert_type_and_value(kwargs["batch_size"], int, location, "batch_size", 1) return DuplicateSequenceFilter(filter_sequence_type=SequenceType[kwargs["filter_sequence_type"].upper()], region_type=RegionType[kwargs['region_type'].upper()], batch_size=kwargs["batch_size"], count_agg=CountAggregationFunction[kwargs["count_agg"].upper()])
def __init__(self, filter_sequence_type: SequenceType, batch_size: int, count_agg: CountAggregationFunction, result_path: Path = None, region_type: RegionType = RegionType.IMGT_CDR3): super().__init__(result_path) self.filter_sequence_type = filter_sequence_type self.region_type = region_type self.count_agg = count_agg self.batch_size = batch_size self.sequence_of_interest = bnp_util.get_sequence_field_name(self.region_type, self.filter_sequence_type) self.sequence_to_ignore = bnp_util.get_sequence_field_name(self.region_type, [t for t in SequenceType if t != self.filter_sequence_type][0])
[docs] def process_dataset(self, dataset, result_path: Path, number_of_processes=1): self.result_path = result_path if result_path is not None else self.result_path self.check_dataset_type(dataset, [RepertoireDataset, SequenceDataset], "DuplicateSequenceFilter") processed_dataset = dataset.clone() if isinstance(dataset, RepertoireDataset): with Pool(self.batch_size) as pool: repertoires = pool.map(self._process_repertoire, dataset.repertoires) processed_dataset.repertoires = repertoires return processed_dataset elif isinstance(dataset, SequenceDataset): processed_dataset = self._process_sequences(dataset) return processed_dataset
def _prepare_group_by_field(self, columns): rep_fields = list(AIRRSequenceSet.get_field_type_dict().keys()) groupby_fields = copy.deepcopy(rep_fields) if self.sequence_to_ignore in groupby_fields: groupby_fields.remove(self.sequence_to_ignore) groupby_fields.remove("duplicate_count") groupby_fields.remove("sequence_id") groupby_fields.remove("cell_id") for field in set(rep_fields).difference(set(columns)): if field in groupby_fields: groupby_fields.remove(field) return groupby_fields def _prepare_agg_dict(self, columns, custom_lists): agg_dict = {"sequence_id": "first"} if self.sequence_to_ignore in columns: agg_dict[self.sequence_to_ignore] = "first" if "duplicate_count" in columns: agg_dict["duplicate_count"] = self.count_agg.value if "cell_id" in columns: agg_dict["cell_id"] = "first" for key in custom_lists: agg_dict[key] = "first" return agg_dict def _process_repertoire(self, repertoire: Repertoire) -> Repertoire: no_duplicates = self._filter_duplicates(repertoire) no_duplicates.to_csv(f"{self.result_path}/{repertoire.data_filename.stem}_filtered.tsv", sep='\t', index=False) write_yaml(Path(f"{self.result_path}/{repertoire.metadata_filename.stem}_filtered.yaml"), repertoire.metadata) return Repertoire(data_filename=Path(f"{self.result_path}/{repertoire.data_filename.stem}_filtered.tsv"), metadata_filename=Path(f"{self.result_path}/{repertoire.metadata_filename.stem}_filtered.yaml"), identifier=str(uuid4().hex), dynamic_fields=repertoire.dynamic_fields) def _process_sequences(self, dataset: SequenceDataset) -> SequenceDataset: no_duplicates = self._filter_duplicates(dataset) metadata = read_yaml(dataset.dataset_file) os.makedirs(self.result_path, exist_ok=True) no_duplicates.to_csv(f"{self.result_path}/{dataset.filename.stem}_filtered.tsv", sep='\t', index=False) write_yaml(Path(f"{self.result_path}/{dataset.dataset_file.stem}_filtered.yaml"), metadata) return SequenceDataset(name=dataset.name, filename=Path(f"{self.result_path}/{dataset.filename.stem}_filtered.tsv"), dataset_file=Path(f"{self.result_path}/{dataset.dataset_file.stem}_filtered.yaml"), dynamic_fields=dataset.dynamic_fields) def _filter_duplicates(self, dataset): data = dataset.data.topandas() data['duplicate_count'].replace(-1, pd.NA, inplace=True) columns = data.columns groupby_fields = self._prepare_group_by_field(columns) custom_lists = list(dataset.dynamic_fields.keys()) agg_dict = self._prepare_agg_dict(columns, custom_lists) no_duplicates = data.groupby(groupby_fields, sort=False).agg(agg_dict).reset_index() return no_duplicates
[docs] @staticmethod def get_documentation(): doc = str(DuplicateSequenceFilter.__doc__) mapping = { "Valid options are defined by the CountAggregationFunction enum.": f"Valid values are: {[e.name.lower() for e in CountAggregationFunction]}.", "Valid options are defined by the SequenceType enum.": f"Valid values are: {[e.name.lower() for e in SequenceType]}." } doc = update_docs_per_mapping(doc, mapping) return doc