import numpy as np
from immuneML.data_model.datasets.ElementDataset import ReceptorDataset
from immuneML.encodings.EncoderParams import EncoderParams
from immuneML.encodings.kmer_frequency.BNPSequenceEncodingStrategies import (
dispatch_encoding, get_v_genes, kmer_weights, seq_field, V_GENE_ENCODING_TYPES,
)
from immuneML.encodings.kmer_frequency.KmerFrequencyEncoder import KmerFrequencyEncoder
from immuneML.util.EncoderHelper import EncoderHelper
def _encode_chain(data, chain_mask: np.ndarray, locus_name: str,
encoder: KmerFrequencyEncoder) -> tuple:
seq_array = getattr(data, seq_field(encoder.region_type, encoder.sequence_type))[chain_mask]
locus_labels = [locus_name] * int(chain_mask.sum())
v_genes = get_v_genes(data, chain_mask) if encoder.sequence_encoding in V_GENE_ENCODING_TYPES else None
flat_kmers, row_ids = dispatch_encoding(
seq_array, encoder.sequence_encoding,
encoder.k, encoder.k_left, encoder.k_right, encoder.min_gap, encoder.max_gap,
encoder.region_type, v_genes, locus_labels,
)
return flat_kmers, row_ids, kmer_weights(data, encoder.reads, row_ids, chain_mask)
[docs]
class KmerFreqReceptorEncoder(KmerFrequencyEncoder):
def _encode_locus(self, dataset):
return True
def _encode_new_dataset(self, dataset, params: EncoderParams):
encoded_data = self._encode_data(dataset, params)
encoded_dataset = dataset.clone()
encoded_dataset.encoded_data = encoded_data
return encoded_dataset
def _encode_examples(self, dataset: ReceptorDataset, params: EncoderParams):
data = dataset.data
receptor_ids, loci, mask1, mask2 = EncoderHelper.get_receptor_chain_masks(dataset)
all_flat, all_row_ids, all_weights = [], [], []
for mask, locus in ((mask1, loci[0]), (mask2, loci[1])):
flat_kmers, row_ids, weights = _encode_chain(data, mask, locus, self)
if len(flat_kmers) > 0:
all_flat.append(flat_kmers)
all_row_ids.append(row_ids)
if weights is not None:
all_weights.append(weights)
flat_kmers = np.concatenate(all_flat) if all_flat else np.empty(0, dtype=str)
row_ids = np.concatenate(all_row_ids) if all_row_ids else np.empty(0, dtype=np.intp)
weights = np.concatenate(all_weights) if all_weights else None
labels = (EncoderHelper.encode_element_dataset_labels(dataset, params.label_config)
if params.encode_labels else None)
return flat_kmers, row_ids, weights, receptor_ids, labels