import shutil
from itertools import chain
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, TensorDataset
from immuneML.data_model.bnp_util import write_yaml, read_yaml, get_sequence_field_name
from immuneML.data_model.datasets.ElementDataset import SequenceDataset
from immuneML.data_model.SequenceParams import RegionType, Chain
from immuneML.data_model.SequenceSet import ReceptorSequence
from immuneML.environment.EnvironmentSettings import EnvironmentSettings
from immuneML.environment.SequenceType import SequenceType
from immuneML.ml_methods.generative_models.GenerativeModel import GenerativeModel
from immuneML.ml_methods.pytorch_implementations.SimpleLSTMGenerator import SimpleLSTMGenerator
from immuneML.ml_methods.util.pytorch_util import store_weights
from immuneML.util.Logger import print_log
from immuneML.util.PathBuilder import PathBuilder
[docs]
class SimpleLSTM(GenerativeModel):
"""
This is a simple generative model for receptor sequences based on LSTM.
Similar models have been proposed in:
Akbar, R. et al. (2022). In silico proof of principle of machine learning-based antibody design at unconstrained scale. mAbs, 14(1), 2031482. https://doi.org/10.1080/19420862.2022.2031482
Saka, K. et al. (2021). Antibody design using LSTM based deep generative model from phage display library for affinity maturation. Scientific Reports, 11(1), Article 1. https://doi.org/10.1038/s41598-021-85274-7
**Specification arguments:**
- sequence_type (str): whether the model should work on amino_acid or nucleotide level
- hidden_size (int): how many LSTM cells should exist per layer
- num_layers (int): how many hidden LSTM layers should there be
- num_epochs (int): for how many epochs to train the model
- learning_rate (float): what learning rate to use for optimization
- batch_size (int): how many examples (sequences) to use for training for one batch
- embed_size (int): the dimension of the sequence embedding
- temperature (float): a higher temperature leads to faster yet more unstable learning
**YAML specification:**
.. indent with spaces
.. code-block:: yaml
definitions:
ml_methods:
my_simple_lstm:
sequence_type: amino_acid
hidden_size: 50
num_layers: 1
num_epochs: 5000
learning_rate: 0.001
batch_size: 100
embed_size: 100
"""
[docs]
@classmethod
def load_model(cls, path: Path):
assert path.exists(), f"{cls.__name__}: {path} does not exist."
model_overview_file = path / 'model_overview.yaml'
state_dict_file = path / 'state_dict.yaml'
for file in [model_overview_file, state_dict_file]:
assert file.exists(), f"{cls.__name__}: {file} is not a file."
model_overview = read_yaml(model_overview_file)
lstm = SimpleLSTM(**{k: v for k, v in model_overview.items() if k != 'type'})
lstm._model = lstm.make_new_model(state_dict_file)
return lstm
ITER_TO_REPORT = 100
def __init__(self, locus: str, sequence_type: str, hidden_size: int, learning_rate: float, num_epochs: int,
batch_size: int, num_layers: int, embed_size: int, temperature, device: str, name=None,
region_type: str = RegionType.IMGT_CDR3):
super().__init__(Chain.get_chain(locus))
self._model = None
self.region_type = RegionType[region_type.upper()] if region_type else None
self.sequence_type = SequenceType[sequence_type.upper()]
self.hidden_size = hidden_size
self.learning_rate = learning_rate
self.num_layers = num_layers
self.num_epochs = num_epochs
self.batch_size = batch_size
self.embed_size = embed_size
self.temperature = temperature
self.name = name
self.device = device
self.unique_letters = EnvironmentSettings.get_sequence_alphabet(self.sequence_type) + ["*"]
self.num_letters = len(self.unique_letters)
self.letter_to_index = {letter: i for i, letter in enumerate(self.unique_letters)}
self.index_to_letter = {i: letter for letter, i in self.letter_to_index.items()}
self.loss_summary_path = None
[docs]
def make_new_model(self, state_dict_file: Path = None):
model = SimpleLSTMGenerator(input_size=self.num_letters, hidden_size=self.hidden_size,
embed_size=self.embed_size, output_size=self.num_letters,
batch_size=self.batch_size)
if isinstance(state_dict_file, Path) and state_dict_file.is_file():
state_dict = read_yaml(state_dict_file)
state_dict = {k: torch.as_tensor(v, device=self.device) for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.to(self.device)
return model
[docs]
def fit(self, data, path: Path = None):
data_loader = self._encode_dataset(data)
model = self.make_new_model()
model.train()
criterion = nn.CrossEntropyLoss(reduction='sum')
optimizer = optim.Adam(model.parameters(), self.learning_rate)
loss_summary = {"loss": [], "epoch": []}
with torch.autograd.set_detect_anomaly(True):
for epoch in range(self.num_epochs):
loss = 0.
state = model.init_zero_state()
optimizer.zero_grad()
for x_batch, y_batch in data_loader:
state = state[0][:, :x_batch.size(0), :], state[1][:, :x_batch.size(0), :]
outputs, state = model(x_batch, state)
loss = loss + criterion(outputs, y_batch)
loss = loss / len(data_loader.dataset)
loss.backward()
optimizer.step()
loss_summary = self._log_training_progress(loss_summary, epoch, loss)
self._log_training_summary(loss_summary, path)
self._model = model
def _log_training_summary(self, loss_summary, path):
if path is not None:
PathBuilder.build(path)
self.loss_summary_path = path / 'loss_summary.csv'
pd.DataFrame(loss_summary).to_csv(str(self.loss_summary_path), index=False)
def _log_training_progress(self, loss_summary, epoch, loss):
if (epoch + 1) % SimpleLSTM.ITER_TO_REPORT == 0:
print_log(f"{SimpleLSTM.__name__}: Epoch [{epoch + 1}/{self.num_epochs}]: loss: {loss.item():.4f}",
True)
loss_summary['loss'].append(loss.item())
loss_summary['epoch'].append(epoch + 1)
return loss_summary
def _encode_dataset(self, dataset: SequenceDataset):
seq_col = get_sequence_field_name(self.region_type, self.sequence_type)
sequences = dataset.get_attribute(seq_col).tolist()
sequences = list(chain.from_iterable(
[[self.letter_to_index[letter] for letter in seq] + [self.letter_to_index['*']] for seq in sequences]))
sequences = torch.as_tensor(sequences, device=self.device).long()
return DataLoader(TensorDataset(sequences[:-1], sequences[1:]), shuffle=True, batch_size=self.batch_size)
[docs]
def is_same(self, model) -> bool:
raise NotImplementedError
[docs]
def generate_sequences(self, count: int, seed: int, path: Path, sequence_type: SequenceType, compute_p_gen: bool):
torch.manual_seed(seed)
self._model.eval()
prime_str = "CAS"
input_vector = torch.as_tensor([self.letter_to_index[letter] for letter in prime_str], device=self.device).long()
predicted = prime_str
with torch.no_grad():
state = self._model.init_zero_state(batch_size=1)
for p in range(len(prime_str) - 1):
_, state = self._model(input_vector[p], state)
inp = input_vector[-1]
gen_seq_count = 0
while gen_seq_count <= count:
output, state = self._model(inp, state)
output_dist = output.data.view(-1).div(self.temperature).exp()
top_i = torch.multinomial(output_dist, 1)[0].item()
predicted_char = self.index_to_letter[top_i]
predicted += predicted_char
inp = torch.as_tensor(self.letter_to_index[predicted_char], device=self.device).long()
if predicted_char == "*":
gen_seq_count += 1
print_log(f"{SimpleLSTM.__name__} {self.name}: generated {count} sequences.", True)
sequences = predicted.split('*')[1:-1]
return self._export_dataset(sequences, count, path)
def _export_dataset(self, sequences, count, path):
sequence_objs = [ReceptorSequence(**{
self.sequence_type.value: sequence,
'locus': self.locus.name, 'metadata': {'gen_model_name': self.name}
}) for i, sequence in enumerate(sequences)]
return SequenceDataset.build_from_objects(sequences=sequence_objs, path=PathBuilder.build(path),
name='synthetic_lstm_dataset', labels={'gen_model_name': [self.name]})
[docs]
def compute_p_gens(self, sequences, sequence_type: SequenceType) -> np.ndarray:
raise RuntimeError
[docs]
def compute_p_gen(self, sequence: dict, sequence_type: SequenceType) -> float:
raise RuntimeError
[docs]
def can_compute_p_gens(self) -> bool:
return False
[docs]
def can_generate_from_skewed_gene_models(self) -> bool:
return False
[docs]
def generate_from_skewed_gene_models(self, v_genes: list, j_genes: list, seed: int, path: Path,
sequence_type: SequenceType, batch_size: int, compute_p_gen: bool):
raise RuntimeError
[docs]
def save_model(self, path: Path) -> Path:
model_path = PathBuilder.build(path / 'model')
skip_keys_for_export = ['_model', 'loss_summary_path', 'index_to_letter', 'letter_to_index',
'unique_letters', 'num_letters']
write_yaml(filename=model_path / 'model_overview.yaml',
yaml_dict={**{k: v for k, v in vars(self).items() if k not in skip_keys_for_export},
**{'type': self.__class__.__name__, 'region_type': self.region_type.name,
'sequence_type': self.sequence_type.name, 'locus': self.locus.name}}) # todo add 'dataset_type': 'SequenceDataset',
store_weights(self._model, model_path / 'state_dict.yaml')
return Path(shutil.make_archive(str(path / 'trained_model'), 'zip', str(model_path))).absolute()