Source code for immuneML.encodings.protein_embedding.ProteinEmbeddingEncoder

from abc import ABC, abstractmethod

import numpy as np
from sklearn.preprocessing import StandardScaler

from immuneML.caching.CacheHandler import CacheHandler
from immuneML.data_model.AIRRSequenceSet import AIRRSequenceSet
from immuneML.data_model.EncodedData import EncodedData
from immuneML.data_model.SequenceParams import RegionType
from immuneML.data_model.bnp_util import get_sequence_field_name
from immuneML.data_model.datasets.Dataset import Dataset
from immuneML.data_model.datasets.ElementDataset import SequenceDataset, ReceptorDataset
from immuneML.data_model.datasets.RepertoireDataset import RepertoireDataset
from immuneML.encodings.DatasetEncoder import DatasetEncoder
from immuneML.encodings.EncoderParams import EncoderParams
from immuneML.encodings.preprocessing.FeatureScaler import FeatureScaler
from immuneML.environment.SequenceType import SequenceType
from immuneML.util.NumpyHelper import NumpyHelper


[docs] class ProteinEmbeddingEncoder(DatasetEncoder, ABC): """ Abstract base class for protein embedding encoders that handles dataset-type specific logic. Subclasses must implement the _embed_sequence_set method. """ def __init__(self, region_type: RegionType, name: str = None, num_processes: int = 1, device: str = 'cpu', batch_size: int = 4096, scale_to_zero_mean: bool = True, scale_to_unit_variance: bool = True): super().__init__(name) self.region_type = region_type self.num_processes = num_processes self.device = device self.batch_size = batch_size self.scale_to_zero_mean = scale_to_zero_mean self.scale_to_unit_variance = scale_to_unit_variance
[docs] @staticmethod @abstractmethod def build_object(dataset: Dataset, **params): pass
[docs] def encode(self, dataset: Dataset, params: EncoderParams) -> Dataset: cache_params = self._get_caching_params(dataset, params) if isinstance(dataset, SequenceDataset): return CacheHandler.memo_by_params(cache_params, lambda: self._encode_sequence_dataset(dataset, params)) elif isinstance(dataset, ReceptorDataset): return CacheHandler.memo_by_params(cache_params, lambda: self._encode_receptor_dataset(dataset, params)) elif isinstance(dataset, RepertoireDataset): return CacheHandler.memo_by_params(cache_params, lambda: self._encode_repertoire_dataset(dataset, params)) else: raise RuntimeError(f"{self.__class__.__name__}: {self.name}: invalid dataset type: {type(dataset)}.")
def _encode_sequence_dataset(self, dataset: SequenceDataset, params: EncoderParams): seq_field = get_sequence_field_name(self.region_type, SequenceType.AMINO_ACID) embeddings = self._embed_sequence_set(dataset.data, seq_field) embeddings = self._scale_examples(dataset, embeddings, params) encoded_dataset = dataset.clone() encoded_dataset.encoded_data = EncodedData(examples=embeddings, labels={label.name: getattr(dataset.data, label.name).tolist() for label in params.label_config.get_label_objects()}, example_ids=dataset.data.sequence_id.tolist(), encoding=self._get_encoding_name()) return encoded_dataset def _encode_receptor_dataset(self, dataset: ReceptorDataset, params: EncoderParams): seq_field = get_sequence_field_name(region_type=self.region_type, sequence_type=SequenceType.AMINO_ACID) data = dataset.data loci = sorted(list(set(data.locus.tolist()))) assert len(loci) == 2, ( f"{self.__class__.__name__}: {self.name}: to encode receptor dataset, it has to include " f"two different chains, but got: {loci} instead.") embeddings = self._embed_sequence_set(data, seq_field) cell_ids = data.cell_id.tolist() chain_types = data.locus.tolist() chain1_embeddings = {} chain2_embeddings = {} for i, (cell_id, chain_type) in enumerate(zip(cell_ids, chain_types)): if chain_type == loci[0]: chain1_embeddings[cell_id] = embeddings[i] else: chain2_embeddings[cell_id] = embeddings[i] assert set(chain1_embeddings.keys()) == set(chain2_embeddings.keys()), \ f"{self.__class__.__name__}: {self.name}: some receptors are missing one of the chains" receptor_ids = list(chain1_embeddings.keys()) concatenated_embeddings = np.array([ np.concatenate([chain1_embeddings[cell_id], chain2_embeddings[cell_id]]) for cell_id in receptor_ids ]) concatenated_embeddings = self._scale_examples(dataset, concatenated_embeddings, params) labels = (data.topandas().groupby('cell_id').first()[params.label_config.get_labels_by_name()] .to_dict(orient='list')) encoded_dataset = dataset.clone() encoded_dataset.encoded_data = EncodedData( examples=concatenated_embeddings, labels=labels, example_ids=receptor_ids, encoding=self._get_encoding_name() ) return encoded_dataset def _encode_repertoire_dataset(self, dataset: RepertoireDataset, params: EncoderParams) -> Dataset: seq_field = get_sequence_field_name(self.region_type, SequenceType.AMINO_ACID) examples = [] for repertoire in dataset.repertoires: examples.append(CacheHandler.memo_by_params((repertoire.identifier, self.__class__.__name__, self.region_type.name, self._get_model_link()), lambda: self._avg_sequence_set_embedding( embedding=self._embed_sequence_set(repertoire.data, seq_field)))) examples = self._scale_examples(dataset, np.array(examples), params) encoded_dataset = dataset.clone() labels = dataset.get_metadata(params.label_config.get_labels_by_name()) encoded_dataset.encoded_data = EncodedData(examples=examples, labels=labels, example_ids=dataset.get_example_ids(), encoding=self._get_encoding_name()) return encoded_dataset def _scale_examples(self, dataset: Dataset, examples: np.ndarray, params: EncoderParams) -> np.ndarray: if params.learn_model: self.scaler = StandardScaler(with_mean=self.scale_to_zero_mean, with_std=self.scale_to_unit_variance) examples = CacheHandler.memo_by_params( self._get_caching_params(dataset, params, step='scaled'), lambda: FeatureScaler.standard_scale_fit(self.scaler, examples, with_mean=self.scale_to_zero_mean)) else: examples = CacheHandler.memo_by_params( self._get_caching_params(dataset, params, step='scaled'), lambda: FeatureScaler.standard_scale(self.scaler, examples, with_mean=self.scale_to_zero_mean)) return NumpyHelper.create_memmap_array_in_cache(examples.shape, examples) def _avg_sequence_set_embedding(self, embedding: np.ndarray) -> np.ndarray: return embedding.mean(axis=0) @abstractmethod def _embed_sequence_set(self, sequence_set: AIRRSequenceSet, seq_field: str) -> np.ndarray: pass @abstractmethod def _get_encoding_name(self) -> str: pass @abstractmethod def _get_model_link(self) -> str: pass def _get_caching_params(self, dataset, params: EncoderParams, step: str = None) -> tuple: return (dataset.identifier, tuple(params.label_config.get_labels_by_name()), self.scale_to_zero_mean, self.scale_to_unit_variance, step, self.region_type.name, self._get_encoding_name(), params.learn_model)