Source code for immuneML.ml_methods.generative_models.SequenceTransitionDistribution

import dataclasses
from numbers import Number

import numpy as np
import scipy.special
from bionumpy.encoded_array import Encoding, as_encoded_array, EncodedRaggedArray, EncodedArray
from bionumpy.encodings import AlphabetEncoding, AminoAcidEncoding
from npstructures import RaggedArray, RaggedShape


[docs] class EncodedLookup(np.lib.mixins.NDArrayOperatorsMixin): def __init__(self, lookup: np.ndarray, encoding: Encoding): self._lookup = np.asanyarray(lookup) self._encoding = encoding def __repr__(self): return repr(self._lookup)[:20] @property def alphabet_size(self): return self._lookup.shape[-1] @property def encoding(self): return self._encoding
[docs] def raw(self): return self._lookup
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): assert method == "__call__" if isinstance(inputs, tuple): inputs = tuple(i._lookup if isinstance(i, EncodedLookup) else i for i in inputs) else: inputs = inputs._lookup if isinstance(inputs, EncodedLookup) else inputs return self.__class__(ufunc(*inputs, **kwargs), self._encoding) def __getitem__(self, key): assert not isinstance(key, Number), key key, shape = self._translate_key(key) value = self._lookup[key] if shape: value = RaggedArray(value, shape[-1]) return value def __setitem__(self, key, value): key, shape = self._translate_key(key) self._lookup[key] = value def _translate_key(self, key): shape = None if isinstance(key, tuple): key = tuple(as_encoded_array(i, self._encoding).raw() if (i is not Ellipsis and i is not None) else i for i in key) shape = tuple(i.shape if isinstance(i, RaggedArray) else (1,) for i in key) key = tuple(i.ravel() if isinstance(i, RaggedArray) else i for i in key) else: key = as_encoded_array(key, self._encoding).raw() if isinstance(key, RaggedArray): shape = key.shape key = key.ravel() return key, shape
[docs] @dataclasses.dataclass class SequenceTransitionDistribution: transition_matrix: EncodedLookup initial_distribution: EncodedLookup end_probs: EncodedLookup
[docs] @classmethod def from_probabilities(cls, *args, **kwargs): return cls(*(np.log(arg) for arg in args), **{k: np.log(v) for k, v in kwargs.items()})
[docs] def log_prob(self, sequence): sequence = as_encoded_array(sequence, self.transition_matrix.encoding) prev = sequence[..., :-1] next = sequence[..., 1:] transition_log_probs = self.transition_matrix[(prev, next)].sum(axis=-1) start_log_probs = self.initial_distribution[sequence[..., 0]] end_log_probs = self.end_probs[sequence[..., -1]] return transition_log_probs + start_log_probs + end_log_probs
[docs] def save(self, filename): np.savez(filename, transition_matrix=self.transition_matrix._lookup, initial_distribution=self.initial_distribution._lookup, end_probs=self.end_probs._lookup)
[docs] def sample(self, n_samples): encoding = self.transition_matrix.encoding n_letters = encoding.alphabet_size transitions = np.exp(self.transition_matrix.raw()) initial_probs = np.exp(self.initial_distribution.raw()) end_probs = np.exp(self.end_probs.raw()) matrix = np.hstack((transitions, end_probs[:, None])) matrix = np.concatenate([matrix, [np.append(initial_probs, 0)]]) state = -1 choices = np.arange(n_letters+1) data = [] lens = [] len_counter = 0 while len(lens) < n_samples: assert np.allclose(matrix[state].sum(), 1), matrix[state] state = np.random.choice(choices, p=matrix[state]) if state == n_letters: lens.append(len_counter) len_counter = 0 continue len_counter += 1 data.append(state) assert len(lens) == n_samples, (len(lens), n_samples) return EncodedRaggedArray( EncodedArray(np.array(data, dtype=int), encoding), lens)
[docs] @classmethod def load(cls, filename): encoding = AminoAcidEncoding data = np.load(filename) return cls(EncodedLookup(data['transition_matrix'], encoding), EncodedLookup(data['initial_distribution'], encoding), EncodedLookup(data['end_probs'], encoding))
[docs] def estimate_transition_model(sequences, weights=None): pseudo_count = 1 n = len(sequences) encoding: AlphabetEncoding = sequences.encoding n_letters = encoding.alphabet_size transition_pseudo = pseudo_count*np.mean(sequences.lengths)/(n_letters + 1)/n_letters initial_distribution = np.bincount(sequences[..., 0].raw(), weights=weights, minlength=n_letters) + pseudo_count/n_letters end_probs = np.bincount(sequences[..., -1].raw(), weights=weights, minlength=n_letters) + transition_pseudo # denominator = np.bincount(sequences.raw().ravel(), minlength=n_letters) + pseudo_count if weights is not None: weights = RaggedShape(sequences.lengths - 1).broadcast_values(weights[:, None]).ravel() transition_matrix = np.bincount( np.ravel_multi_index((sequences[..., :-1].raw().ravel(), sequences[..., 1:].raw().ravel()), (n_letters, n_letters)), weights=weights, minlength=n_letters ** 2).reshape((n_letters, n_letters)) + transition_pseudo row_sums = transition_matrix.sum(axis=1)+end_probs transition_matrix = transition_matrix / row_sums[..., None] initial_distribution = initial_distribution / initial_distribution.sum() end_probs = end_probs / row_sums prob_sum = np.sum(transition_matrix, axis=-1) + end_probs assert np.allclose(prob_sum, 1), (np.sum(transition_matrix, axis=-1), end_probs) return SequenceTransitionDistribution.from_probabilities( EncodedLookup(transition_matrix, encoding), EncodedLookup(initial_distribution, encoding), EncodedLookup(end_probs, encoding))