Source code for immuneML.encodings.onehot.OneHotEncoder

import abc
import math

import numpy as np

from immuneML.caching.CacheHandler import CacheHandler
from immuneML.data_model import bnp_util
from immuneML.data_model.AIRRSequenceSet import AIRRSequenceSet
from immuneML.encodings.DatasetEncoder import DatasetEncoder
from immuneML.encodings.EncoderParams import EncoderParams
from immuneML.environment.EnvironmentSettings import EnvironmentSettings
from immuneML.environment.SequenceType import SequenceType
from immuneML.util.EncoderHelper import EncoderHelper
from immuneML.util.ParameterValidator import ParameterValidator
from immuneML.util.ReflectionHandler import ReflectionHandler


[docs] class OneHotEncoder(DatasetEncoder): """ One-hot encoding for repertoires, sequences or receptors. In one-hot encoding, each alphabet character (amino acid or nucleotide) is replaced by a sparse vector with one 1 and the rest zeroes. The position of the 1 represents the alphabet character. **Dataset type:** - SequenceDatasets - ReceptorDatasets - RepertoireDatasets **Specification arguments:** - use_positional_info (bool): whether to include features representing the positional information. If True, three additional feature vectors will be added, representing the sequence start, sequence middle and sequence end. The values in these features are scaled between 0 and 1. A graphical representation of the values of these vectors is given below. .. code-block:: console Value of sequence start: Value of sequence middle: Value of sequence end: 1 \ 1 /‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾\ 1 / \ / \ / \ / \ / 0 \_____________________ 0 / \ 0 _____________________/ <----sequence length----> <----sequence length----> <----sequence length----> - distance_to_seq_middle (int): only applies when use_positional_info is True. This is the distance from the edge of the CDR3 sequence (IMGT positions 105 and 117) to the portion of the sequence that is considered 'middle'. For example: if distance_to_seq_middle is 6 (default), all IMGT positions in the interval [111, 112) receive positional value 1. When using nucleotide sequences: note that the distance is measured in (amino acid) IMGT positions. If the complete sequence length is smaller than 2 * distance_to_seq_middle, the maximum value of the 'start' and 'end' vectors will not reach 0, and the maximum value of the 'middle' vector will not reach 1. A graphical representation of the positional vectors with a too short sequence is given below: .. code-block:: console Value of sequence start Value of sequence middle Value of sequence end: with very short sequence: with very short sequence: with very short sequence: 1 \ 1 1 / \ / \ /\ / 0 0 / \ 0 <-> <--> <-> - flatten (bool): whether to flatten the final onehot matrix to a 2-dimensional matrix [examples, other_dims_combined] This must be set to True when using onehot encoding in combination with scikit-learn ML methods (inheriting :py:obj:`~source.ml_methods.SklearnMethod.SklearnMethod`), such as :ref:`LogisticRegression`, :ref:`SVM`, :ref:`SVC`, :ref:`RandomForestClassifier` and :ref:`KNN`. - sequence_type: whether to use nucleotide or amino acid sequence for encoding. Valid values are 'nucleotide' and 'amino_acid'. **YAML specification:** .. indent with spaces .. code-block:: yaml definitions: encodings: one_hot_vanilla: OneHot: use_positional_info: False flatten: False sequence_type: amino_acid one_hot_positional: OneHot: use_positional_info: True distance_to_seq_middle: 3 flatten: False sequence_type: nucleotide """ dataset_mapping = { "RepertoireDataset": "OneHotRepertoireEncoder", "SequenceDataset": "OneHotSequenceEncoder", "ReceptorDataset": "OneHotReceptorEncoder" } def __init__(self, use_positional_info: bool, distance_to_seq_middle: int, flatten: bool, name: str = None, sequence_type: SequenceType = None): super().__init__(name=name) self.use_positional_info = use_positional_info self.distance_to_seq_middle = distance_to_seq_middle self.flatten = flatten self.sequence_type = sequence_type self.alphabet = EnvironmentSettings.get_sequence_alphabet(self.sequence_type) if distance_to_seq_middle: self.pos_increasing = [1 / self.distance_to_seq_middle * i for i in range(self.distance_to_seq_middle)] self.pos_decreasing = self.pos_increasing[::-1] else: self.pos_decreasing = None if self.sequence_type == SequenceType.NUCLEOTIDE and self.distance_to_seq_middle is not None: # todo check this / explain in docs self.distance_to_seq_middle = self.distance_to_seq_middle * 3 self.onehot_dimensions = self.alphabet + ["start", "mid", "end"] if self.use_positional_info else self.alphabet # todo test this @staticmethod def _prepare_parameters(use_positional_info: bool, distance_to_seq_middle: int, flatten: bool, sequence_type: str, name: str = None): location = OneHotEncoder.__name__ ParameterValidator.assert_type_and_value(use_positional_info, bool, location, "use_positional_info") if use_positional_info: ParameterValidator.assert_type_and_value(distance_to_seq_middle, int, location, "distance_to_seq_middle", min_inclusive=1) else: distance_to_seq_middle = None ParameterValidator.assert_type_and_value(flatten, bool, location, "flatten") ParameterValidator.assert_type_and_value(sequence_type, str, location, 'sequence_type') ParameterValidator.assert_in_valid_list(sequence_type.upper(), [item.name for item in SequenceType], location, 'sequence_type') return {"use_positional_info": use_positional_info, "distance_to_seq_middle": distance_to_seq_middle, "flatten": flatten, "sequence_type": SequenceType[sequence_type.upper()], "name": name}
[docs] @staticmethod def build_object(dataset=None, **params): EncoderHelper.check_dataset_type_available_in_mapping(dataset, OneHotEncoder) prepared_params = OneHotEncoder._prepare_parameters(**params) encoder = ReflectionHandler.get_class_by_name(OneHotEncoder.dataset_mapping[dataset.__class__.__name__], "onehot/")(**prepared_params) return encoder
[docs] def encode(self, dataset, params: EncoderParams): encoded_dataset = CacheHandler.memo_by_params(self._prepare_caching_params(dataset, params), lambda: self._encode_new_dataset(dataset, params)) return encoded_dataset
def _prepare_caching_params(self, dataset, params: EncoderParams): return (("dataset_identifier", dataset.identifier), ("example_identifiers", tuple(dataset.get_example_ids())), ("dataset_type", dataset.__class__.__name__), ("labels", tuple(params.label_config.get_labels_by_name())), ("encoding", OneHotEncoder.__name__), ("encoding_params", tuple(vars(self).items()))) @abc.abstractmethod def _encode_new_dataset(self, dataset, params: EncoderParams): pass def _encode_sequence_list(self, sequences: AIRRSequenceSet, params: EncoderParams, pad_n_sequences: int = None, pad_sequence_len: int = None): # Get sequence field based on sequence_type sequence_field = bnp_util.get_sequence_field_name(params.region_type, params.sequence_type) # Extract sequences from AIRRSequenceSet sequence_array = getattr(sequences, sequence_field) sequence_alphabet = "".join(AIRRSequenceSet.get_field_type_dict()[sequence_field].get_alphabet()).replace("*", "").replace( "X", "") # noinspection PyTypeChecker encoded_sequences = np.array([ np.pad(np.array((sequence_array[i][..., np.newaxis] == sequence_alphabet).tolist()), [[0, pad_sequence_len - sequence_array.lengths[i]], [0, 0]], mode='constant', constant_values=False) for i in range(len(sequence_array)) ]) if encoded_sequences.shape[0] != pad_n_sequences: encoded_sequences = np.concatenate([encoded_sequences, np.zeros( (pad_n_sequences - encoded_sequences.shape[0], pad_sequence_len, len(sequence_alphabet)))], axis=0) # Add positional encoding if needed if self.use_positional_info: pos_info = ([self._get_imgt_position_weights(seq_len, pad_length=pad_sequence_len).T for seq_len in sequence_array.lengths] + [[[0, 0, 0] for _ in range(pad_sequence_len)] for s in range(pad_n_sequences - len(sequence_array))]) pos_info = np.stack(pos_info) # Combine one-hot encoding with positional information encoded_sequences = np.concatenate([encoded_sequences, pos_info], axis=-1) return encoded_sequences def _get_imgt_position_weights(self, seq_length, pad_length=None): start_weights = self._get_imgt_start_weights(seq_length) mid_weights = self._get_imgt_mid_weights(seq_length) end_weights = start_weights[::-1] weights = np.array([start_weights, mid_weights, end_weights]) if pad_length is not None: weights = np.pad(weights, pad_width=((0, 0), (0, pad_length - seq_length))) return weights def _get_imgt_mid_weights(self, seq_length): mid_len = seq_length - (self.distance_to_seq_middle * 2) if mid_len >= 0: mid_weights = self.pos_increasing + [1] * mid_len + self.pos_decreasing else: left_idx = math.ceil(seq_length / 2) right_idx = math.floor(seq_length / 2) mid_weights = self.pos_increasing[:left_idx] + self.pos_decreasing[-right_idx:] return mid_weights def _get_imgt_start_weights(self, seq_length): diff = (seq_length - self.distance_to_seq_middle) - 1 if diff >= 0: start_weights = [1] + self.pos_decreasing + [0] * diff else: start_weights = [1] + self.pos_decreasing[:diff] return start_weights