Source code for immuneML.ml_methods.pytorch_implementations.SimpleLSTMGenerator
import torch
from torch import nn
[docs]
class SimpleLSTMGenerator(nn.Module):
def __init__(self, input_size, embed_size, hidden_size, output_size, batch_size, num_layers=1, device: str = 'cpu'):
super(SimpleLSTMGenerator, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.batch_size = batch_size
self.device = device
self.embed = nn.Embedding(num_embeddings=input_size, embedding_dim=embed_size)
nn.init.normal_(self.embed.weight)
self.lstm = nn.LSTM(input_size=embed_size, hidden_size=hidden_size, num_layers=num_layers)
for name, param in self.lstm.named_parameters():
if 'weight' in name:
nn.init.xavier_uniform_(param)
elif 'bias' in name:
nn.init.zeros_(param)
self.fc = nn.Linear(hidden_size, output_size)
nn.init.xavier_uniform_(self.fc.weight)
nn.init.zeros_(self.fc.bias)
[docs]
def forward(self, features, hidden_and_cell_state):
# features shape: (batch_size, seq_len)
features = features.transpose(0, 1) # Convert to (seq_len, batch_size)
embedded = self.embed(features) # (seq_len, batch_size, embed_size)
output, hidden_and_cell_state = self.lstm(embedded, hidden_and_cell_state)
# output shape: (seq_len, batch_size, hidden_size)
output = self.fc(output) # (seq_len, batch_size, output_size)
output = output.transpose(0, 1) # Convert back to (batch_size, seq_len, output_size)
return output, hidden_and_cell_state
[docs]
def init_zero_state(self, batch_size=None):
if batch_size is None:
batch_size = self.batch_size
init_hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=self.device)
init_cell = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=self.device)
return init_hidden, init_cell