Source code for immuneML.encodings.kmer_frequency.KmerFreqRepertoireEncoder

from multiprocessing.pool import Pool

import dill
import numpy as np

from immuneML.caching.CacheHandler import CacheHandler
from immuneML.caching.CacheObjectType import CacheObjectType
from immuneML.data_model.SequenceSet import Repertoire
from immuneML.encodings.EncoderParams import EncoderParams
from immuneML.encodings.kmer_frequency.BNPSequenceEncodingStrategies import (
    dispatch_encoding, get_v_genes, kmer_weights, seq_field, V_GENE_ENCODING_TYPES,
)
from immuneML.encodings.kmer_frequency.KmerFrequencyEncoder import KmerFrequencyEncoder


[docs] class KmerFreqRepertoireEncoder(KmerFrequencyEncoder): def _encode_locus(self, dataset): loci = [set(rep.data.locus.tolist()) for rep in dataset.repertoires] return len(set.union(*loci)) > 1 def _encode_new_dataset(self, dataset, params: EncoderParams): encoded_data = self._encode_data(dataset, params) encoded_dataset = dataset.clone() encoded_dataset.encoded_data = encoded_data return encoded_dataset def _encode_examples(self, dataset, params: EncoderParams): encode_locus = self._encode_locus(dataset) arguments = [(dill.dumps(rep), params, encode_locus) for rep in dataset.repertoires] with Pool(params.pool_size) as pool: repertoires = pool.starmap(self.get_encoded_repertoire, arguments) all_flat_kmers, all_weights, rep_ids_list, labels_list = zip(*repertoires) n_reps = len(all_flat_kmers) rep_kmer_counts = np.array([len(fk) for fk in all_flat_kmers], dtype=np.intp) flat_kmers = np.concatenate(all_flat_kmers) if any(len(fk) for fk in all_flat_kmers) else np.empty(0, dtype=str) row_ids = np.repeat(np.arange(n_reps), rep_kmer_counts) has_weights = any(w is not None for w in all_weights) if has_weights: weights_parts = [w if w is not None else np.ones(len(fk), dtype=float) for fk, w in zip(all_flat_kmers, all_weights)] weights = np.concatenate(weights_parts) if any(len(p) for p in weights_parts) else None else: weights = None encoded_labels = {k: [d[k] for d in labels_list] for k in labels_list[0]} if params.encode_labels else None return flat_kmers, row_ids, weights, list(rep_ids_list), encoded_labels
[docs] def get_encoded_repertoire(self, repertoire, params: EncoderParams, encode_locus: bool): rep = dill.loads(repertoire) if not isinstance(repertoire, Repertoire) else repertoire params.model = vars(self) return CacheHandler.memo_by_params( (("encoding_model", params.model), ("type", "kmer_encoding"), ("labels", params.label_config.get_labels_by_name()), ("repertoire_id", rep.identifier)), lambda: self._encode_single_repertoire(rep, params, encode_locus), CacheObjectType.ENCODING_STEP)
def _encode_single_repertoire(self, rep: Repertoire, params: EncoderParams, encode_locus: bool): data = rep.data seq_array = getattr(data, seq_field(self.region_type, self.sequence_type)) locus_labels = data.locus.tolist() if encode_locus else None v_genes = get_v_genes(data) if self.sequence_encoding in V_GENE_ENCODING_TYPES else None flat_kmers, row_ids = dispatch_encoding( seq_array, self.sequence_encoding, self.k, self.k_left, self.k_right, self.min_gap, self.max_gap, self.region_type, v_genes, locus_labels, ) labels = None if params.encode_labels: labels = {name: rep.metadata[name] for name in params.label_config.get_labels_by_name()} return flat_kmers, kmer_weights(data, self.reads, row_ids), rep.identifier, labels