import numpy as np
import torch
from torch import nn
from torch.nn.functional import relu
from immuneML.environment.EnvironmentSettings import EnvironmentSettings
from immuneML.environment.SequenceType import SequenceType
[docs]
class PyTorchReceptorCNN(nn.Module):
"""
This class implements the ReceptorCNN using PyTorch. This is one specific implementation of the architecture proposed in
:py:obj:`~immuneML.ml_methods.ReceptorCNN.ReceptorCNN`.
"""
def __init__(self, kernel_count: int, kernel_size, positional_channels: int, sequence_type: SequenceType, background_probabilities, chain_names):
super(PyTorchReceptorCNN, self).__init__()
self.background_probabilities = background_probabilities
self.threshold = 0.1
self.pseudocount = 0.05
self.in_channels = len(EnvironmentSettings.get_sequence_alphabet(sequence_type)) + positional_channels
self.positional_channels = positional_channels
self.max_information_gain = self.get_max_information_gain()
self.chain_names = chain_names
self.conv_chain_1 = [f"chain_1_kernel_{size}" for size in kernel_size]
self.conv_chain_2 = [f"chain_2_kernel_{size}" for size in kernel_size]
for size in kernel_size:
# chain 1
setattr(self, f"chain_1_kernel_{size}", nn.Conv1d(in_channels=self.in_channels, out_channels=kernel_count, kernel_size=size,
bias=True))
getattr(self, f"chain_1_kernel_{size}").weight.data. \
normal_(0.0, np.sqrt(1 / np.prod(getattr(self, f"chain_1_kernel_{size}").weight.shape)))
# chain 2
setattr(self, f"chain_2_kernel_{size}", nn.Conv1d(in_channels=self.in_channels, out_channels=kernel_count, kernel_size=size,
bias=True))
getattr(self, f"chain_2_kernel_{size}").weight.data. \
normal_(0.0, np.sqrt(1 / np.prod(getattr(self, f"chain_2_kernel_{size}").weight.shape)))
self.fully_connected = nn.Linear(in_features=kernel_count * len(kernel_size) * 2, out_features=1, bias=True)
self.fully_connected.weight.data.normal_(0.0, np.sqrt(1 / np.prod(self.fully_connected.weight.shape)))
[docs]
def forward(self, x):
"""
Implements the forward pass through the network by applying kernels to the one-hot encoded receptors, followed by ReLU activation and max
pooling. The obtained output is then concatenated to get the receptor representation. A fully-connected layer is then applied to the
representation to predict the class assignment.
Args:
x: input data consisting of one-hot encoded immune receptors with optional positional information
Returns:
predictions of class assignment
"""
# creates batch_size x kernel_count representation of chain 1 and chain 2 by applying kernels, followed by relu and global max pooling
chain_1 = torch.cat([torch.max(relu(conv_kernel(x[:, 0])), dim=2)[0] for conv_kernel in
[getattr(self, name) for name in self.conv_chain_1]], dim=1)
chain_2 = torch.cat([torch.max(relu(conv_kernel(x[:, 0])), dim=2)[0] for conv_kernel in
[getattr(self, name) for name in self.conv_chain_2]], dim=1)
# creates a single representation of the receptor from chain representations
receptor = torch.cat([chain_1, chain_2], dim=1).squeeze()
# predict the class of the receptor through the final fully-connected layer based on the inferred representations
predictions = self.fully_connected(receptor).squeeze()
return predictions
[docs]
def rescale_weights_for_IGM(self):
"""Rescales the weights in the kernels to represent information gain matrices."""
for name in self.conv_chain_1 + self.conv_chain_2:
value = getattr(self, name)
value.weight = self._rescale_chain(value.weight)
def _rescale_chain(self, weight_parameter):
# dimension for the content (without positional information):
dim = self.in_channels - self.positional_channels
# enforce non-negativity constraint
weight_chain = relu(weight_parameter[:, :dim, :])
# add pseudocount for positions where the total sum is below threshold
weight_chain[torch.sum(weight_chain, dim=1, keepdim=True).expand_as(weight_chain) < self.threshold] += self.pseudocount
# rescale chain weights to represent IGM
weight_chain = weight_chain / torch.sum(weight_chain, dim=1, keepdim=True).expand_as(weight_chain) * self.max_information_gain
# append positional kernel values to rescaled weights and convert back to parameter
weight_parameter = nn.Parameter(torch.cat([weight_chain, weight_parameter[:, dim:, :]], dim=1))
return weight_parameter