import logging
import shutil
from pathlib import Path
import numpy as np
import pandas as pd
import scipy
import torch.optim
from torch.distributions import Categorical
from torch.nn.functional import cross_entropy, one_hot
from torch.utils.data import DataLoader
from immuneML import Constants
from immuneML.data_model.AIRRSequenceSet import AIRRSequenceSet
from immuneML.data_model.bnp_util import write_yaml, read_yaml, get_sequence_field_name
from immuneML.data_model.datasets.ElementDataset import SequenceDataset
from immuneML.data_model.SequenceParams import RegionType
from immuneML.data_model.SequenceSet import ReceptorSequence
from immuneML.environment.EnvironmentSettings import EnvironmentSettings
from immuneML.environment.SequenceType import SequenceType
from immuneML.ml_methods.generative_models.GenerativeModel import GenerativeModel
from immuneML.ml_methods.pytorch_implementations.SimpleVAEGenerator import Encoder, Decoder, SimpleVAEGenerator, \
vae_cdr3_loss
from immuneML.ml_methods.util.pytorch_util import store_weights
from immuneML.util.Logger import print_log
from immuneML.util.PathBuilder import PathBuilder
from immuneML.util.StringHelper import StringHelper
[docs]
class SimpleVAE(GenerativeModel):
"""
SimpleVAE is a generative model on sequence level that relies on variational autoencoder. This type of model was
proposed by Davidsen et al. 2019, and this implementation is inspired by their original implementation available
at https://github.com/matsengrp/vampire.
References:
Davidsen, K., Olson, B. J., DeWitt, W. S., III, Feng, J., Harkins, E., Bradley, P., & Matsen, F. A., IV. (2019).
Deep generative models for T cell receptor protein sequences. eLife, 8, e46935. https://doi.org/10.7554/eLife.46935
**Specification arguments:**
- locus (str): which locus the sequence come from, e.g., TRB
- beta (float): VAE hyperparameter that balanced the reconstruction loss and latent dimension regularization
- latent_dim (int): latent dimension of the VAE
- linear_nodes_count (int): in linear layers, how many nodes to use
- num_epochs (int): how many epochs to use for training
- batch_size (int): how many examples to consider at the same time
- j_gene_embed_dim (int): dimension of J gene embedding
- v_gene_embed_dim (int): dimension of V gene embedding
- cdr3_embed_dim (int): dimension of the cdr3 embedding
- pretrains (int): how many times to attempt pretraining to initialize the weights and use warm-up for the beta hyperparameter before the main training process
- warmup_epochs (int): how many epochs to use for training where beta hyperparameter is linearly increased from 0 up to its max value; this is in addition to num_epochs set above
- patience (int): number of epochs to wait before the training is stopped when the loss is not improving
- iter_count_prob_estimation (int): how many iterations to use to estimate the log probability of the generated sequence (the more iterations, the better the estimated log probability)
- vocab (list): which letters (amino acids) are allowed - this is automatically filled for new models (no need to set)
- max_cdr3_len (int): what is the maximum cdr3 length - this is automatically filled for new models (no need to set)
- unique_v_genes (list): list of allowed V genes (this will be automatically filled from the dataset if not provided here manually)
- unique_j_genes (list): list of allowed J genes (this will be automatically filled from the dataset if not provided here manually)
- device (str): name of the device where to train the model (e.g., cpu)
**YAML specification:**
.. indent with spaces
.. code-block:: yaml
definitions:
ml_methods:
my_vae:
SimpleVAE:
locus: beta
beta: 0.75
latent_dim: 20
linear_nodes_count: 75
num_epochs: 5000
batch_size: 10000
j_gene_embed_dim: 13
v_gene_embed_dim: 30
cdr3_embed_dim: 21
pretrains: 10
warmup_epochs: 20
patience: 20
device: cpu
"""
[docs]
@classmethod
def load_model(cls, path: Path):
assert path.exists(), f"{cls.__name__}: {path} does not exist."
model_overview_file = path / 'model_overview.yaml'
state_dict_file = path / 'state_dict.yaml'
for file in [model_overview_file, state_dict_file]:
assert file.exists(), f"{cls.__name__}: {file} is not a file."
model_overview = read_yaml(model_overview_file)
vae = SimpleVAE(**{k: v for k, v in model_overview.items() if k != 'type'})
vae.model = vae.make_new_model(state_dict_file)
return vae
def __init__(self, locus, beta, latent_dim, linear_nodes_count, num_epochs, batch_size, j_gene_embed_dim, pretrains,
v_gene_embed_dim, cdr3_embed_dim, warmup_epochs, patience, iter_count_prob_estimation, device,
vocab=None, max_cdr3_len=None, unique_v_genes=None, unique_j_genes=None, name: str = None):
super().__init__(locus)
self.sequence_type = SequenceType.AMINO_ACID
self.iter_count_prob_estimation = iter_count_prob_estimation
self.num_epochs = num_epochs
self.pretrains = pretrains
self.region_type = RegionType.IMGT_CDR3 # TODO: check if they use cdr3 or junction in the original paper
self.vocab = vocab if vocab is not None else (
sorted((EnvironmentSettings.get_sequence_alphabet(self.sequence_type) + [Constants.GAP_LETTER])))
self.vocab_size = len(self.vocab)
self.beta = beta
self.warmup_epochs = warmup_epochs
self.patience = patience
self.cdr3_embed_dim = cdr3_embed_dim
self.latent_dim = latent_dim
self.j_gene_embed_dim, self.v_gene_embed_dim = j_gene_embed_dim, v_gene_embed_dim
self.linear_nodes_count = linear_nodes_count
self.batch_size = batch_size
self.device = device
self.max_cdr3_len, self.unique_v_genes, self.unique_j_genes = max_cdr3_len, unique_v_genes, unique_j_genes
self.name = name
self.model = None
# hard-coded in the original implementation
self.v_gene_loss_weight = 0.8138
self.j_gene_loss_weight = 0.1305
self.loss_path = None
[docs]
def make_new_model(self, initial_values_path: Path = None):
assert self.unique_v_genes is not None and self.unique_j_genes is not None, \
f'{SimpleVAE.__name__}: cannot generate empty model since unique V and J genes are not set.'
encoder = Encoder(self.vocab_size, self.cdr3_embed_dim, len(self.unique_v_genes), self.v_gene_embed_dim,
len(self.unique_j_genes), self.j_gene_embed_dim, self.latent_dim, self.max_cdr3_len,
self.linear_nodes_count)
decoder = Decoder(self.latent_dim, self.linear_nodes_count, self.max_cdr3_len, self.vocab_size,
len(self.unique_v_genes), len(self.unique_j_genes))
vae = SimpleVAEGenerator(encoder, decoder)
if initial_values_path and initial_values_path.is_file():
state_dict = read_yaml(filename=initial_values_path)
state_dict = {k: torch.as_tensor(v, device=self.device) for k, v in state_dict.items()}
vae.load_state_dict(state_dict)
vae.to(self.device)
return vae
[docs]
def fit(self, data, path: Path = None):
data_loader = self.encode_dataset(data)
# TODO: split the data to train and validation? or have external validation/test dataset
pretrained_weights_path = self._pretrain(data_loader=data_loader, path=path)
model = self.make_new_model(pretrained_weights_path)
model.train()
optimizer = torch.optim.Adam(model.parameters())
losses = []
epoch = 1
loss_decreasing = True
while epoch <= self.num_epochs and loss_decreasing:
loss = self._train_for_epoch(data_loader, model, self.beta, optimizer)
losses.append(loss.item())
print_log(f"{SimpleVAE.__name__}: epoch: {epoch}, loss: {loss}.")
if min(losses) == loss.item():
store_weights(model, path / 'state_dict.yaml')
if epoch > self.patience and all(
x <= y for x, y in zip(losses[-self.patience:], losses[-self.patience:][1:])):
loss_decreasing = False
epoch += 1
pd.DataFrame({'epoch': list(range(1, epoch)), 'loss': losses}).to_csv(str(path / 'training_losses.csv'),
index=False)
self.loss_path = path / 'training_losses.csv'
self.model = self.make_new_model(path / 'state_dict.yaml')
def _pretrain(self, data_loader, path: Path):
pretrained_weights_path = PathBuilder.build(path) / 'pretrained_warmup_weights.yaml'
for pretrain_index in range(self.pretrains):
model = self.make_new_model()
optimizer = torch.optim.Adam(model.parameters())
beta = 0 if self.warmup_epochs > 0 else self.beta
best_val_loss = np.inf
for epoch in range(self.warmup_epochs + 1):
loss = self._train_for_epoch(data_loader, model, beta, optimizer)
val_loss = loss.item()
beta = self._update_beta_on_epoch_end(epoch)
if val_loss < best_val_loss:
best_val_loss = val_loss
store_weights(model, pretrained_weights_path)
return pretrained_weights_path
def _train_for_epoch(self, data_loader, model, beta, optimizer):
loss = None
for batch in data_loader:
cdr3_input, v_gene_input, j_gene_input = batch
cdr3_output, v_gene_output, j_gene_output, z = model(cdr3_input, v_gene_input, j_gene_input)
loss = (vae_cdr3_loss(cdr3_output, cdr3_input, self.max_cdr3_len, z[0], z[1], beta)
+ cross_entropy(v_gene_output, v_gene_input.float()) * self.v_gene_loss_weight
+ cross_entropy(j_gene_output, j_gene_input.float())) * self.j_gene_loss_weight
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss
[docs]
def encode_dataset(self, dataset, batch_size=None, shuffle=True):
seq_col = get_sequence_field_name(self.region_type, self.sequence_type)
data = dataset.data.topandas()[[seq_col, 'v_call', 'j_call']]
if self.unique_v_genes is None:
self.unique_v_genes = sorted(list(set([el.split("*")[0] for el in data['v_call']])))
if self.unique_j_genes is None:
self.unique_j_genes = sorted(list(set([el.split("*")[0] for el in data['j_call']])))
if self.max_cdr3_len is None:
self.max_cdr3_len = max(len(el) for el in data[seq_col])
encoded_v_genes = one_hot(
torch.as_tensor([self.unique_v_genes.index(v_gene.split("*")[0]) for v_gene in data['v_call']],
device=self.device),
num_classes=len(self.unique_v_genes))
encoded_j_genes = one_hot(
torch.as_tensor([self.unique_j_genes.index(j_gene.split("*")[0]) for j_gene in data['j_call']],
device=self.device),
num_classes=len(self.unique_j_genes))
padded_encoded_cdr3s = one_hot(torch.as_tensor([
[self.vocab.index(letter) for letter in
StringHelper.pad_sequence_in_the_middle(seq, self.max_cdr3_len, Constants.GAP_LETTER)]
for seq in data[seq_col]], device=self.device), num_classes=self.vocab_size)
pytorch_dataset = PyTorchSequenceDataset({'v_gene': encoded_v_genes, 'j_gene': encoded_j_genes,
'cdr3': padded_encoded_cdr3s})
return DataLoader(pytorch_dataset, shuffle=shuffle, batch_size=batch_size if batch_size else self.batch_size)
[docs]
def is_same(self, model) -> bool:
raise RuntimeError
def _update_beta_on_epoch_end(self, epoch):
new_beta = self.beta
if self.warmup_epochs > 0 and epoch < self.warmup_epochs:
new_beta *= epoch / self.warmup_epochs
logging.info(f'{SimpleVAE.__name__}: epoch {epoch}: beta updated to {new_beta}.')
return new_beta
[docs]
def generate_sequences(self, count: int, seed: int, path: Path, sequence_type: SequenceType, compute_p_gen: bool):
torch.manual_seed(seed)
self.model.eval()
with torch.no_grad():
z_sample = torch.as_tensor(np.random.normal(0, 1, size=(count, self.latent_dim)),
device=self.device).float()
sequences, v_genes, j_genes = self.model.decode(z_sample)
seq_objs = []
for i in range(count):
seq_content = [self.vocab[Categorical(letter).sample()] for letter in sequences[i]]
sequence = ReceptorSequence(**{
self.sequence_type.value: ''.join(seq_content).replace(Constants.GAP_LETTER, ''),
'v_call': self.unique_v_genes[Categorical(v_genes[i]).sample()],
'j_call': self.unique_j_genes[Categorical(j_genes[i]).sample()],
'locus': self.locus,
'metadata': {'gen_model_name': self.name if self.name else "SimpleVAE"}
})
seq_objs.append(sequence)
# for obj in seq_objs:
# log_prob = self.compute_p_gen({self.sequence_type.value: obj.get_attribute(self.sequence_type.value),
# 'v_call': obj.metadata.v_call, 'j_call': obj.metadata.j_call},
# self.sequence_type)
# obj.metadata.custom_params = {'log_prob': log_prob}
dataset = SequenceDataset.build_from_objects(seq_objs, PathBuilder.build(path),
f'synthetic_{self.name}_dataset',
{'gen_model_name': [self.name]},
self.region_type)
return dataset
[docs]
def compute_p_gens(self, sequences, sequence_type: SequenceType) -> np.ndarray:
pass
[docs]
def compute_p_gen(self, sequence: dict, sequence_type: SequenceType) -> float:
with torch.no_grad():
encoded_v_genes = one_hot(
torch.as_tensor([self.unique_v_genes.index(sequence['v_call'].split("*")[0])]),
num_classes=len(self.unique_v_genes))
encoded_j_genes = one_hot(
torch.as_tensor([self.unique_j_genes.index(sequence['j_call'].split("*")[0])]),
num_classes=len(self.unique_j_genes))
padded_encoded_cdr3s = one_hot(torch.as_tensor([
[self.vocab.index(letter) for letter in
StringHelper.pad_sequence_in_the_middle(sequence[self.sequence_type.value], self.max_cdr3_len,
Constants.GAP_LETTER)]]), num_classes=self.vocab_size)
log_prob_estimates = []
for _ in range(self.iter_count_prob_estimation):
z_mean, z_log_var = self.model.encode(padded_encoded_cdr3s, encoded_v_genes, encoded_j_genes)
z_sd = (z_log_var / 2).exp()
z_sample = torch.as_tensor(np.array([scipy.stats.norm.rvs(loc=z_mean, scale=z_sd)])).float()
aa_probs, v_gene_probs, j_gene_probs = self.model.decode(z_sample)
aa_probs, v_gene_probs, j_gene_probs = aa_probs.numpy(), v_gene_probs.numpy(), j_gene_probs.numpy()
log_p_x_given_z = (np.sum(np.log(np.sum(aa_probs[0] * padded_encoded_cdr3s[0].numpy(), axis=1))) +
np.log(np.sum(v_gene_probs[0] * encoded_v_genes[0].numpy())) +
np.log(np.sum(j_gene_probs[0] * encoded_j_genes[0].numpy())))
log_p_z = np.sum(scipy.stats.norm.logpdf(z_sample[0], 0, 1))
log_q_z_given_x = np.sum(scipy.stats.norm.logpdf(z_sample[0], z_mean[0], z_sd[0]))
log_imp_weight = log_p_z - log_q_z_given_x
log_prob_estimates.append(float(log_p_x_given_z + log_imp_weight))
return sum(log_prob_estimates) / self.iter_count_prob_estimation
[docs]
def can_compute_p_gens(self) -> bool:
return True
[docs]
def can_generate_from_skewed_gene_models(self) -> bool:
return False
[docs]
def generate_from_skewed_gene_models(self, v_genes: list, j_genes: list, seed: int, path: Path,
sequence_type: SequenceType, batch_size: int, compute_p_gen: bool):
raise RuntimeError
[docs]
def save_model(self, path: Path) -> Path:
model_path = PathBuilder.build(path / 'model')
skip_export_keys = ['model', 'loss_path', 'j_gene_loss_weight', 'v_gene_loss_weight', 'region_type',
'sequence_type', 'vocab_size']
write_yaml(filename=model_path / 'model_overview.yaml',
yaml_dict={**{k: v for k, v in vars(self).items() if k not in skip_export_keys},
**{'type': self.__class__.__name__}}) # todo add 'dataset_type': 'SequenceDataset',
store_weights(self.model, model_path / 'state_dict.yaml')
return Path(shutil.make_archive(str(path / 'trained_model'), 'zip', str(model_path))).absolute()
[docs]
class PyTorchSequenceDataset(torch.utils.data.Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data['cdr3'])
def __getitem__(self, index):
return self.data['cdr3'][index], self.data['v_gene'][index], self.data['j_gene'][index]
[docs]
def get_v_genes(self):
return self.data['v_gene']
[docs]
def get_j_genes(self):
return self.data['j_gene']