import shutil
from pathlib import Path
from zipfile import ZipFile, ZIP_STORED
import numpy as np
import pandas as pd
from immuneML.data_model.SequenceParams import RegionType
from immuneML.data_model.bnp_util import get_sequence_field_name, write_yaml, read_yaml
from immuneML.data_model.datasets.Dataset import Dataset
from immuneML.data_model.datasets.ElementDataset import SequenceDataset
from immuneML.environment.SequenceType import SequenceType
from immuneML.ml_methods.generative_models.GenerativeModel import GenerativeModel
from immuneML.ml_methods.generative_models.progen.ProGenConfig import ProGenConfig
from immuneML.ml_methods.generative_models.progen.ProGenForCausalLM import ProGenForCausalLM
from immuneML.util.Logger import print_log
from immuneML.util.PathBuilder import PathBuilder
[docs]
class ProGen(GenerativeModel):
"""
ProGen is a transformer-based language model for protein sequences. This class allows fine-tuning of a pre-trained
ProGen model on immune receptor sequences and generating new sequences. It is based on the ProGen2 implementation
available at https://github.com/salesforce/progen. It uses the sequences as given in "junction_aa" field in the
input dataset.
References:
Nijkamp, E., Ruffolo, J. A., Weinstein, E. N., Naik, N., & Madani, A. (2023).
Exploring the boundaries of protein language models. Cell Systems, 14(11), 968–978.e3.
https://doi.org/10.1016/j.cels.2023.10.002
**Specification arguments:**
- locus (str): which locus the sequence come from, e.g., TRB
- tokenizer_path (Path): path to the ProGen tokenizer file (tokenizer.json)
- trained_model_path (Path): path to the pre-trained ProGen model directory
- num_frozen_layers (int): number of transformer layers to freeze during fine-tuning
- num_epochs (int): number of epochs for fine-tuning
- learning_rate (float): learning rate for fine-tuning
- device (str): device to use for training and inference ("cpu" or "cuda")
- fp16 (bool): whether to use mixed precision training
- prefix_text (str): text to prepend to each sequence during fine-tuning
- suffix_text (str): text to append to each sequence during fine-tuning
- max_new_tokens (int): maximum number of new tokens to generate
- temperature (float): sampling temperature for sequence generation
- top_p (float): nucleus sampling parameter for sequence generation
- prompt (str): prompt text to start the generation
- num_gen_batches (int): number of batches to split generation into
- per_device_train_batch_size (int): batch size per device during fine-tuning
- remove_affixes (bool): whether to remove prefix and suffix from generated sequences
- seed (int): random seed for reproducibility
**YAML specification:**
.. indent with spaces
.. code-block:: yaml
definitions:
ml_methods:
progen_model:
ProGen:
locus: 'beta'
tokenizer_path: '/path/to/tokenizer.json'
trained_model_path: '/path/to/pretrained/progen/model'
num_frozen_layers: 27
num_epochs: 3
learning_rate: 0.00004
device: 'cuda'
fp16: False
prefix_text: '<|bos|>1'
suffix_text: '2<|eos|>'
max_new_tokens: 1024
temperature: 1.0
top_p: 0.9
prompt: '1'
num_gen_batches: 1
per_device_train_batch_size: 2
remove_affixes: True
name: 'progen_finetuned_model'
region_type: 'IMGT_JUNCTION'
seed: 42
"""
[docs]
@classmethod
def load_model(cls, path: Path):
import torch
assert path.exists(), f"{cls.__name__}: {path} does not exist."
model_overview_file = path / 'model_overview.yaml'
assert model_overview_file.exists(), f"{cls.__name__}: {model_overview_file} is not a file."
# Uses ProGen weights/tokenizer paths from training. Override in model_overview.yaml if needed.
model_overview = read_yaml(model_overview_file)
progen = ProGen(**{k: v for k, v in model_overview.items() if k != 'type'})
config = ProGenConfig.from_pretrained(path)
model = ProGenForCausalLM.from_pretrained(path,
config=config,
dtype=torch.float32 if not progen.fp16 else torch.float16)
progen.model = model.to(progen.device).eval()
return progen
def __init__(self, locus, tokenizer_path: Path, trained_model_path: Path, num_frozen_layers: int, num_epochs: int,
learning_rate: float, device: str, fp16: bool = False, prefix_text: str = "", suffix_text: str = "",
max_new_tokens: int = 1024, temperature: float = 1.0, top_p: float = 0.9, prompt: str = "1",
num_gen_batches: int = 1, per_device_train_batch_size: int = 2, remove_affixes: bool = True,
name: str = None, region_type: str = RegionType.IMGT_JUNCTION.name, seed: int = None, ):
super().__init__(locus, seed=seed, name=name, region_type=RegionType.get_object(region_type))
self.sequence_type = SequenceType.AMINO_ACID
self.tokenizer_path = tokenizer_path
self.trained_model_path = trained_model_path
self.num_frozen_layers = num_frozen_layers
self.num_epochs = num_epochs
self.learning_rate = learning_rate
self.device = device # "cpu" or "cuda"
self.fp16 = fp16
self.prefix_text = prefix_text
self.suffix_text = suffix_text
self.max_new_tokens = max_new_tokens
self.temperature = temperature
self.top_p = top_p
self.prompt = prompt
self.num_gen_batches = num_gen_batches
self.per_device_train_batch_size = per_device_train_batch_size
self.remove_affixes = remove_affixes
self.model = None
from tokenizers import Tokenizer
from transformers import PreTrainedTokenizerFast
tokenizer = Tokenizer.from_file(str(self.tokenizer_path))
self.hf_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer)
self.hf_tokenizer.pad_token = "<|pad|>"
self.hf_tokenizer.eos_token = "<|eos|>"
self.hf_tokenizer.bos_token = "<|bos|>"
[docs]
def fit(self, data, path: Path = None):
assert path is not None, "ProGen.fit requires a target directory path for training outputs."
logs_dir, output_dir = self._prepare_training_paths(path)
tokenized_dataset = self._preprocess_dataset(data)
from transformers import DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(tokenizer=self.hf_tokenizer, mlm=False)
config = ProGenConfig.from_pretrained(self.trained_model_path)
model = ProGenForCausalLM.from_pretrained(self.trained_model_path, config=config)
self._freeze_model_layers(model)
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir=str(output_dir),
per_device_train_batch_size=self.per_device_train_batch_size,
num_train_epochs=self.num_epochs,
learning_rate=self.learning_rate,
fp16=self.fp16,
use_cpu=True if self.device == "cpu" else False,
save_safetensors=False,
logging_dir=str(logs_dir),
save_total_limit=1,
save_strategy="no"
)
from transformers import Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
tokenizer=self.hf_tokenizer,
data_collator=data_collator
)
print_log(f"{self.name or ProGen.__name__}: starting ProGen fine-tuning.", True)
trainer.train()
print_log(f"{self.name or ProGen.__name__}: finished ProGen fine-tuning.", True)
self.model = trainer.model.to(self.device).eval()
def _freeze_model_layers(self, model):
for layer in model.transformer.h[:self.num_frozen_layers]:
for param in layer.parameters():
param.requires_grad = False
for param in model.transformer.ln_f.parameters():
param.requires_grad = True
for param in model.lm_head.parameters():
param.requires_grad = True
def _preprocess_dataset(self, data):
from datasets import Dataset as HFDataset
data_df = data.data.topandas()
data_df["junction_aa"] = self.prefix_text + data_df["junction_aa"].astype(str).fillna("") + self.suffix_text
hf_dataset = HFDataset.from_pandas(data_df[["junction_aa"]], preserve_index=False)
tokenized_dataset = hf_dataset.map(
self.hf_tokenizer,
batched=True,
input_columns="junction_aa",
fn_kwargs={"truncation": True},
remove_columns=["junction_aa"],
)
return tokenized_dataset
def _prepare_training_paths(self, path):
base_path = Path(path)
base_path.mkdir(parents=True, exist_ok=True)
output_dir = base_path / "model"
output_dir.mkdir(parents=True, exist_ok=True)
logs_dir = base_path / "logs"
logs_dir.mkdir(parents=True, exist_ok=True)
return logs_dir, output_dir
[docs]
def save_model(self, path: Path) -> Path:
model_path = PathBuilder.build(path / 'model')
self.model.save_pretrained(model_path, safe_serialization=False)
shutil.copy(self.tokenizer_path, model_path / Path(self.tokenizer_path).name)
skip_export_keys = {"model", "tokenizer", "hf_tokenizer", 'region_type', 'sequence_type'}
write_yaml(filename=model_path / 'model_overview.yaml',
yaml_dict={**{k: v for k, v in vars(self).items() if k not in skip_export_keys},
**{'type': self.__class__.__name__,
'locus': self.locus.name}})
archive_path = path / f"trained_model_{self.name}.zip"
with ZipFile(archive_path, "w", compression=ZIP_STORED) as archive:
for file_path in (fp for fp in model_path.rglob("*") if fp.is_file()):
archive.write(file_path, file_path.relative_to(model_path))
return archive_path.resolve()
[docs]
def generate_sequences(self, count: int, seed: int, path: Path, sequence_type: SequenceType,
compute_p_gen: bool) -> Dataset:
import torch
prompt_encoding = self.hf_tokenizer(self.prompt, return_tensors="pt")
prompt_input_ids = prompt_encoding.input_ids.to(self.device)
prompt_attention_mask = prompt_encoding.attention_mask.to(self.device)
gen_sequences = []
num_sequences_per_batch = count // self.num_gen_batches
for i in range(self.num_gen_batches):
num_current_sequences = num_sequences_per_batch if i < self.num_gen_batches - 1 else count - (
num_sequences_per_batch * (self.num_gen_batches - 1))
with torch.inference_mode():
output = self.model.generate(
input_ids=prompt_input_ids,
attention_mask=prompt_attention_mask,
max_new_tokens=self.max_new_tokens,
do_sample=True,
top_p=self.top_p,
temperature=self.temperature,
num_return_sequences=num_current_sequences,
pad_token_id=self.hf_tokenizer.pad_token_id,
return_dict_in_generate=False
)
gen_sequences.extend(self.hf_tokenizer.batch_decode(output, skip_special_tokens=True))
print_log(f"{self.name or ProGen.__name__}: {(i + 1) * num_current_sequences} sequences generated.", True)
if self.remove_affixes:
gen_sequences = self._remove_affixes(gen_sequences)
gen_sequences_df = pd.DataFrame({get_sequence_field_name(self.region_type, self.sequence_type): gen_sequences,
'locus': [self.locus.to_string() for _ in range(count)],
'gen_model_name': [self.name for _ in range(count)]})
return SequenceDataset.build_from_partial_df(gen_sequences_df, PathBuilder.build(path),
'synthetic_dataset', {'gen_model_name': [self.name]},
{'gen_model_name': str})
def _remove_affixes(self, gen_sequences):
prefix_text = self.hf_tokenizer.decode(self.hf_tokenizer(self.prefix_text).input_ids,
skip_special_tokens=True)
suffix_text = self.hf_tokenizer.decode(self.hf_tokenizer(self.suffix_text).input_ids,
skip_special_tokens=True)
gen_sequences = [seq.replace(prefix_text, '').replace(suffix_text, '') for seq in
gen_sequences]
return gen_sequences
[docs]
def is_same(self, model) -> bool:
raise NotImplementedError
[docs]
def compute_p_gens(self, sequences, sequence_type: SequenceType) -> np.ndarray:
raise RuntimeError
[docs]
def compute_p_gen(self, sequence: dict, sequence_type: SequenceType) -> float:
raise RuntimeError
[docs]
def can_compute_p_gens(self) -> bool:
return False
[docs]
def can_generate_from_skewed_gene_models(self) -> bool:
return False
[docs]
def generate_from_skewed_gene_models(self, v_genes: list, j_genes: list, seed: int, path: Path,
sequence_type: SequenceType, batch_size: int, compute_p_gen: bool):
return RuntimeError