Source code for immuneML.encodings.protein_embedding.ProteinEmbeddingEncoder

from abc import ABC, abstractmethod

import numpy as np

from immuneML.caching.CacheHandler import CacheHandler
from immuneML.data_model.AIRRSequenceSet import AIRRSequenceSet
from immuneML.data_model.EncodedData import EncodedData
from immuneML.data_model.SequenceParams import RegionType
from immuneML.data_model.bnp_util import get_sequence_field_name
from immuneML.data_model.datasets.Dataset import Dataset
from immuneML.data_model.datasets.ElementDataset import SequenceDataset, ReceptorDataset
from immuneML.data_model.datasets.RepertoireDataset import RepertoireDataset
from immuneML.encodings.DatasetEncoder import DatasetEncoder
from immuneML.encodings.EncoderParams import EncoderParams
from immuneML.environment.SequenceType import SequenceType


[docs] class ProteinEmbeddingEncoder(DatasetEncoder, ABC): """ Abstract base class for protein embedding encoders that handles dataset-type specific logic. Subclasses must implement the _embed_sequence_set method. """ def __init__(self, region_type: RegionType, name: str = None, num_processes: int = 1, device: str = 'cpu'): super().__init__(name) self.region_type = region_type self.num_processes = num_processes self.device = device
[docs] @staticmethod @abstractmethod def build_object(dataset: Dataset, **params): pass
[docs] def encode(self, dataset: Dataset, params: EncoderParams) -> Dataset: cache_params = self._get_caching_params(dataset, params) if isinstance(dataset, SequenceDataset): return CacheHandler.memo_by_params(cache_params, lambda: self._encode_sequence_dataset(dataset, params)) elif isinstance(dataset, ReceptorDataset): return CacheHandler.memo_by_params(cache_params, lambda: self._encode_receptor_dataset(dataset, params)) elif isinstance(dataset, RepertoireDataset): return CacheHandler.memo_by_params(cache_params, lambda: self._encode_repertoire_dataset(dataset, params)) else: raise RuntimeError(f"{self.__class__.__name__}: {self.name}: invalid dataset type: {type(dataset)}.")
def _encode_sequence_dataset(self, dataset: SequenceDataset, params: EncoderParams): seq_field = get_sequence_field_name(self.region_type, SequenceType.AMINO_ACID) embeddings = self._embed_sequence_set(dataset.data, seq_field) encoded_dataset = dataset.clone() encoded_dataset.encoded_data = EncodedData(examples=np.array(embeddings), labels={label.name: getattr(dataset.data, label.name).tolist() for label in params.label_config.get_label_objects()}, example_ids=dataset.data.sequence_id.tolist(), encoding=self._get_encoding_name()) return encoded_dataset def _encode_receptor_dataset(self, dataset: ReceptorDataset, params: EncoderParams): seq_field = get_sequence_field_name(region_type=self.region_type, sequence_type=SequenceType.AMINO_ACID) data = dataset.data loci = sorted(list(set(data.locus.tolist()))) assert len(loci) == 2, ( f"{self.__class__.__name__}: {self.name}: to encode receptor dataset, it has to include " f"two different chains, but got: {loci} instead.") embeddings = self._embed_sequence_set(data, seq_field) cell_ids = data.cell_id.tolist() chain_types = data.locus.tolist() chain1_embeddings = {} chain2_embeddings = {} for i, (cell_id, chain_type) in enumerate(zip(cell_ids, chain_types)): if chain_type == loci[0]: chain1_embeddings[cell_id] = embeddings[i] else: chain2_embeddings[cell_id] = embeddings[i] assert set(chain1_embeddings.keys()) == set(chain2_embeddings.keys()), \ f"{self.__class__.__name__}: {self.name}: some receptors are missing one of the chains" receptor_ids = list(chain1_embeddings.keys()) concatenated_embeddings = np.array([ np.concatenate([chain1_embeddings[cell_id], chain2_embeddings[cell_id]]) for cell_id in receptor_ids ]) labels = (data.topandas().groupby('cell_id').first()[params.label_config.get_labels_by_name()] .to_dict(orient='list')) encoded_dataset = dataset.clone() encoded_dataset.encoded_data = EncodedData( examples=concatenated_embeddings, labels=labels, example_ids=receptor_ids, encoding=self._get_encoding_name() ) return encoded_dataset def _encode_repertoire_dataset(self, dataset: RepertoireDataset, params: EncoderParams) -> Dataset: seq_field = get_sequence_field_name(self.region_type, SequenceType.AMINO_ACID) examples = [] for repertoire in dataset.repertoires: examples.append(CacheHandler.memo_by_params((repertoire.identifier, self.__class__.__name__, self.region_type.name, self._get_model_link()), lambda: self._avg_sequence_set_embedding( embedding=self._embed_sequence_set(repertoire.data, seq_field)))) encoded_dataset = dataset.clone() labels = dataset.get_metadata(params.label_config.get_labels_by_name()) encoded_dataset.encoded_data = EncodedData(examples=np.array(examples), labels=labels, example_ids=dataset.get_example_ids(), encoding=self._get_encoding_name()) return encoded_dataset def _avg_sequence_set_embedding(self, embedding: np.ndarray) -> np.ndarray: return embedding.mean(axis=0) @abstractmethod def _embed_sequence_set(self, sequence_set: AIRRSequenceSet, seq_field: str) -> np.ndarray: pass @abstractmethod def _get_encoding_name(self) -> str: pass @abstractmethod def _get_model_link(self) -> str: pass @abstractmethod def _get_caching_params(self, dataset, params: EncoderParams): pass