Source code for immuneML.ml_methods.pytorch_implementations.SimpleLSTMGenerator
import torch
from torch import nn
[docs]
class SimpleLSTMGenerator(nn.Module):
def __init__(self, input_size, embed_size, hidden_size, output_size, batch_size, num_layers=1):
super(SimpleLSTMGenerator, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.batch_size = batch_size
self.embed = nn.Embedding(num_embeddings=input_size, embedding_dim=embed_size)
self.lstm = nn.LSTM(input_size=embed_size, hidden_size=hidden_size, num_layers=self.num_layers)
self.fc = nn.Linear(hidden_size, output_size)
[docs]
def forward(self, features, hidden_and_cell_state):
features = features.view(1, -1)
embedded = self.embed(features)
output, hidden_and_cell_state = self.lstm(embedded, hidden_and_cell_state)
output = output.squeeze(0)
output = self.fc(output)
return output, hidden_and_cell_state
[docs]
def init_zero_state(self, batch_size=None):
init_hidden = torch.zeros(self.num_layers, batch_size if batch_size else self.batch_size, self.hidden_size)
init_cell = torch.zeros(self.num_layers, batch_size if batch_size else self.batch_size, self.hidden_size)
return init_hidden, init_cell