import logging
import shutil
from itertools import chain
from pathlib import Path
import numpy as np
import pandas as pd
from immuneML.data_model.SequenceParams import RegionType, Chain
from immuneML.data_model.bnp_util import write_yaml, read_yaml, get_sequence_field_name
from immuneML.data_model.datasets.ElementDataset import SequenceDataset
from immuneML.environment.EnvironmentSettings import EnvironmentSettings
from immuneML.environment.SequenceType import SequenceType
from immuneML.ml_methods.generative_models.GenerativeModel import GenerativeModel
from immuneML.ml_methods.pytorch_implementations.SimpleLSTMGenerator import SimpleLSTMGenerator
from immuneML.ml_methods.util.pytorch_util import store_weights
from immuneML.util.Logger import print_log
from immuneML.util.PathBuilder import PathBuilder
[docs]
class SimpleLSTM(GenerativeModel):
"""
This is a simple generative model for receptor sequences based on LSTM.
Similar models have been proposed in:
Akbar, R. et al. (2022). In silico proof of principle of machine learning-based antibody design at unconstrained scale. mAbs, 14(1), 2031482. https://doi.org/10.1080/19420862.2022.2031482
Saka, K. et al. (2021). Antibody design using LSTM based deep generative model from phage display library for affinity maturation. Scientific Reports, 11(1), Article 1. https://doi.org/10.1038/s41598-021-85274-7
**Specification arguments:**
- sequence_type (str): whether the model should work on amino_acid or nucleotide level
- hidden_size (int): how many LSTM cells should exist per layer
- num_layers (int): how many hidden LSTM layers should there be
- num_epochs (int): for how many epochs to train the model
- learning_rate (float): what learning rate to use for optimization
- batch_size (int): how many examples (sequences) to use for training for one batch
- embed_size (int): the dimension of the sequence embedding
- temperature (float): a higher temperature leads to faster yet more unstable learning
- prime_str (str): the initial sequence to start generating from
- seed (int): random seed for the model or None
- iter_to_report (int): number of epochs between training progress reports
**YAML specification:**
.. indent with spaces
.. code-block:: yaml
definitions:
ml_methods:
my_simple_lstm:
sequence_type: amino_acid
hidden_size: 50
num_layers: 1
num_epochs: 5000
learning_rate: 0.001
batch_size: 100
embed_size: 100
"""
[docs]
@classmethod
def load_model(cls, path: Path):
assert path.exists(), f"{cls.__name__}: {path} does not exist."
model_overview_file = path / 'model_overview.yaml'
state_dict_file = path / 'state_dict.yaml'
for file in [model_overview_file, state_dict_file]:
assert file.exists(), f"{cls.__name__}: {file} is not a file."
model_overview = read_yaml(model_overview_file)
lstm = SimpleLSTM(**{k: v for k, v in model_overview.items() if k != 'type'})
lstm._model = lstm.make_new_model(state_dict_file)
return lstm
def __init__(self, locus: str, sequence_type: str, hidden_size: int, learning_rate: float, num_epochs: int,
batch_size: int, num_layers: int, embed_size: int, temperature, device: str, name=None,
region_type: str = RegionType.IMGT_CDR3.name, prime_str: str = "C", window_size: int = 64,
seed: int = None, iter_to_report: int = 1):
super().__init__(Chain.get_chain(locus), region_type=RegionType.get_object(region_type), name=name, seed=seed)
self._model = None
self.sequence_type = SequenceType[sequence_type.upper()] if sequence_type else SequenceType.AMINO_ACID
self.hidden_size = hidden_size
self.learning_rate = learning_rate
self.num_layers = num_layers
self.num_epochs = num_epochs
self.batch_size = batch_size
self.embed_size = embed_size
self.temperature = temperature
self.prime_str = prime_str
self.window_size = window_size
self.device = device
self.iter_to_report = iter_to_report
self.unique_letters = EnvironmentSettings.get_sequence_alphabet(self.sequence_type) + ["*"]
self.num_letters = len(self.unique_letters)
self.letter_to_index = {letter: i for i, letter in enumerate(self.unique_letters)}
self.index_to_letter = {i: letter for letter, i in self.letter_to_index.items()}
self.loss_summary_path = None
[docs]
def make_new_model(self, state_dict_file: Path = None):
from torch import as_tensor
model = SimpleLSTMGenerator(input_size=self.num_letters, hidden_size=self.hidden_size,
embed_size=self.embed_size, output_size=self.num_letters,
batch_size=self.batch_size, device=self.device)
if isinstance(state_dict_file, Path) and state_dict_file.is_file():
state_dict = read_yaml(state_dict_file)
state_dict = {k: as_tensor(v, device=self.device) for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.to(self.device)
return model
def _log_training_progress(self, loss_summary, epoch, loss):
if (epoch + 1) % self.iter_to_report == 0:
message = f"{SimpleLSTM.__name__}: Epoch [{epoch + 1}/{self.num_epochs}]: loss: {loss:.4f}"
print_log(message, True)
loss_summary['loss'].append(loss)
loss_summary['epoch'].append(epoch + 1)
return loss_summary
[docs]
def fit(self, data, path: Path = None):
import torch
from torch import nn, optim
if self.seed is not None:
torch.manual_seed(self.seed)
data_loader = self._encode_dataset(data)
model = self.make_new_model()
model.train()
criterion = nn.CrossEntropyLoss(reduction='mean')
optimizer = optim.Adam(model.parameters(), self.learning_rate)
loss_summary = {"loss": [], "epoch": []}
for epoch in range(self.num_epochs):
epoch_loss = 0.
num_batches = 0
for x_batch, y_batch in data_loader:
x_batch, y_batch = x_batch.to(self.device), y_batch.to(self.device)
# Initialize state for each batch
state = model.init_zero_state(x_batch.size(0))
optimizer.zero_grad()
outputs, _ = model(x_batch, state)
outputs = outputs.reshape(-1, outputs.size(-1))
y_batch = y_batch.reshape(-1)
loss = criterion(outputs, y_batch)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)
optimizer.step()
epoch_loss += loss.item()
num_batches += 1
avg_epoch_loss = epoch_loss / num_batches
loss_summary = self._log_training_progress(loss_summary, epoch, avg_epoch_loss)
self._log_training_summary(loss_summary, path)
self._model = model
def _log_training_summary(self, loss_summary, path):
if path is not None:
try:
PathBuilder.build(path)
self.loss_summary_path = path / 'loss_summary.csv'
pd.DataFrame(loss_summary).to_csv(str(self.loss_summary_path), index=False)
except Exception as e:
logging.error(f"{SimpleLSTM.__name__}: failed to save loss summary: {e};\n{loss_summary}")
def _encode_dataset(self, dataset: SequenceDataset):
from torch import as_tensor
from torch.utils.data import DataLoader, TensorDataset
seq_col = get_sequence_field_name(self.region_type, self.sequence_type)
sequences = dataset.get_attribute(seq_col).tolist()
# Flatten all sequences into one long sequence with end tokens
sequences = list(chain.from_iterable(
[[self.letter_to_index[letter] for letter in seq] + [self.letter_to_index['*']] for seq in sequences]))
sequences = as_tensor(sequences, device=self.device).long()
# Create overlapping windows of size window_size
# stride=1 means each window overlaps with previous window by window_size-1 elements
windows = sequences.unfold(0, self.window_size, 1)
# Create input-target pairs from windows
x = windows[:, :-1] # All but last character of each window
y = windows[:, 1:] # All but first character of each window
return DataLoader(TensorDataset(x, y), shuffle=True, batch_size=self.batch_size)
[docs]
def is_same(self, model) -> bool:
raise NotImplementedError
[docs]
def generate_sequences(self, count: int, seed: int, path: Path, sequence_type: SequenceType, compute_p_gen: bool,
max_failed_trials: int = 1000):
import torch
torch.manual_seed(seed)
self._model.eval()
input_vector = torch.as_tensor([self.letter_to_index[letter] for letter in self.prime_str],
device=self.device).long()
predicted = self.prime_str
failed_trials = 0
total_trials = 0
with torch.no_grad():
state = self._model.init_zero_state(batch_size=1)
# Add sequence and batch dimensions for LSTM input
for p in range(len(self.prime_str) - 1):
inp = input_vector[p].unsqueeze(0).unsqueeze(0)
_, state = self._model(inp, state)
inp = input_vector[-1].unsqueeze(0).unsqueeze(0)
gen_seq_count = 0
while gen_seq_count < count and failed_trials < max_failed_trials:
output, state = self._model(inp, state)
scaled_logits = output[0, 0] / self.temperature
scaled_logits = torch.clamp(scaled_logits, min=-100, max=100)
output_dist = torch.nn.functional.softmax(scaled_logits, dim=0)
try:
output_dist = output_dist + 1e-10
output_dist = output_dist / output_dist.sum()
top_i = torch.multinomial(output_dist, 1)[0].item()
except RuntimeError as e:
logging.warning(f"{SimpleLSTM.__name__}: Error sampling from distribution: {e}; "
f"Using argmax instead.")
logging.debug(f"{SimpleLSTM.__name__}: Distribution: {output_dist}\n"
f"Sum: {output_dist.sum()}, Min: {output_dist.min()}, Max: {output_dist.max()}")
top_i = scaled_logits.argmax().item()
predicted_char = self.index_to_letter[top_i]
predicted += predicted_char
inp = torch.as_tensor(self.letter_to_index[predicted_char], device=self.device).long()
inp = inp.unsqueeze(0).unsqueeze(0)
if predicted_char == "*":
last_seq = predicted.split('*')[-2]
if len(last_seq) > 0:
gen_seq_count += 1
print_log(f"Generated valid sequence {gen_seq_count}/{count}", True)
else:
failed_trials += 1
if failed_trials % 10 == 0:
print_log(f"Warning: LSTM model generated {failed_trials} empty sequences", True)
total_trials += 1
print_log(
f"{SimpleLSTM.__name__} {self.name}: generated {gen_seq_count} sequences with {failed_trials} failed attempts.",
True)
dataset = self._export_dataset(predicted, path)
return dataset
def _export_dataset(self, predicted, path):
sequences = [seq for seq in predicted.split('*') if len(seq) > 0]
count = len(sequences)
df = pd.DataFrame({get_sequence_field_name(self.region_type, self.sequence_type): sequences,
'locus': [self.locus.to_string() for _ in range(count)],
'gen_model_name': [self.name for _ in range(count)]})
return SequenceDataset.build_from_partial_df(df, PathBuilder.build(path), 'synthetic_lstm_dataset',
{'gen_model_name': [self.name]}, {'gen_model_name': str})
[docs]
def compute_p_gens(self, sequences, sequence_type: SequenceType) -> np.ndarray:
raise RuntimeError
[docs]
def compute_p_gen(self, sequence: dict, sequence_type: SequenceType) -> float:
raise RuntimeError
[docs]
def can_compute_p_gens(self) -> bool:
return False
[docs]
def can_generate_from_skewed_gene_models(self) -> bool:
return False
[docs]
def generate_from_skewed_gene_models(self, v_genes: list, j_genes: list, seed: int, path: Path,
sequence_type: SequenceType, batch_size: int, compute_p_gen: bool):
raise RuntimeError
[docs]
def save_model(self, path: Path) -> Path:
model_path = PathBuilder.build(path / 'model')
skip_keys_for_export = ['_model', 'loss_summary_path', 'index_to_letter', 'letter_to_index',
'unique_letters', 'num_letters']
write_yaml(filename=model_path / 'model_overview.yaml',
yaml_dict={**{k: v for k, v in vars(self).items() if k not in skip_keys_for_export},
**{'type': self.__class__.__name__, 'region_type': self.region_type.name,
'sequence_type': self.sequence_type.name, 'locus': self.locus.name}})
store_weights(self._model, model_path / 'state_dict.yaml')
return Path(shutil.make_archive(str(path / 'trained_model'), 'zip', str(model_path))).absolute()