import json
import shutil
from pathlib import Path
import numpy as np
from olga import load_model
from immuneML.data_model.bnp_util import write_yaml, read_yaml
from immuneML.data_model.datasets.Dataset import Dataset
from immuneML.data_model.datasets.ElementDataset import SequenceDataset
from immuneML.data_model.SequenceParams import RegionType, Chain
from immuneML.data_model.SequenceSet import ReceptorSequence
from immuneML.environment.SequenceType import SequenceType
from immuneML.ml_methods.generative_models.GenerativeModel import GenerativeModel
from immuneML.ml_methods.generative_models.OLGA import OLGA
from immuneML.util.Logger import print_log
from immuneML.util.PathBuilder import PathBuilder
[docs]
class SoNNia(GenerativeModel):
"""
SoNNia models the selection process of T and B cell receptor repertoires. It is based on the SoNNia Python package.
It supports SequenceDataset as input, but not RepertoireDataset.
Original publication:
Isacchini, G., Walczak, A. M., Mora, T., & Nourmohammad, A. (2021). Deep generative selection models of T and B
cell receptor repertoires with soNNia. Proceedings of the National Academy of Sciences, 118(14), e2023141118.
https://doi.org/10.1073/pnas.2023141118
**Specification arguments:**
- locus (str)
- batch_size (int)
- epochs (int)
- deep (bool)
- include_joint_genes (bool)
- n_gen_seqs (int)
- custom_model_path (str)
- default_model_name (str)
**YAML specification:**
.. indent with spaces
.. code-block:: yaml
definitions:
ml_methods:
my_sonnia_model:
SoNNia:
...
"""
[docs]
@classmethod
def load_model(cls, path: Path):
from sonnia.sonnia import SoNNia as InternalSoNNia
assert path.exists(), f"{cls.__name__}: {path} does not exist."
model_overview_file = path / 'model_overview.yaml'
for file in [model_overview_file]:
assert file.exists(), f"{cls.__name__}: {file} is not a file."
model_overview = read_yaml(model_overview_file)
sonnia = SoNNia(**{k: v for k, v in model_overview.items() if k != 'type'})
with open(path / 'model.json', 'r') as json_file:
model_data = json.load(json_file)
sonnia._model = InternalSoNNia(custom_pgen_model=sonnia._model_path,
vj=sonnia.locus in [Chain.ALPHA, Chain.KAPPA, Chain.LIGHT],
include_joint_genes=sonnia.include_joint_genes,
include_indep_genes=not sonnia.include_joint_genes)
sonnia._model.model.set_weights([np.array(w) for w in model_data['model_weights']])
return sonnia
def __init__(self, locus=None, batch_size: int = None, epochs: int = None, deep: bool = False, name: str = None,
default_model_name: str = None, n_gen_seqs: int = None, include_joint_genes: bool = True,
custom_model_path: str = None, region_type: RegionType = RegionType.IMGT_CDR3):
if region_type is not None and isinstance(region_type, str):
region_type = RegionType[region_type]
if locus is not None:
super().__init__(locus, region_type=region_type)
elif default_model_name is not None:
super().__init__(locus=Chain.get_chain(default_model_name[-3:]), region_type=region_type)
self.epochs = epochs
self.batch_size = int(batch_size)
self.deep = deep
self.include_joint_genes = include_joint_genes
self.n_gen_seqs = n_gen_seqs
self._model = None
self.name = name
self.default_model_name = default_model_name
if custom_model_path is None or custom_model_path == '':
self._model_path = Path(
load_model.__file__).parent / f"default_models/{OLGA.DEFAULT_MODEL_FOLDER_MAP[self.default_model_name]}"
else:
self._model_path = custom_model_path
[docs]
def fit(self, dataset: Dataset, path: Path = None):
from sonnia.sonnia import SoNNia as InternalSoNNia
print_log(f"{SoNNia.__name__}: fitting a selection model...", True)
data = dataset.data.topandas()[['junction_aa', 'v_call', 'j_call']]
data_seqs = data.to_records(index=False).tolist()
self._model = InternalSoNNia(data_seqs=data_seqs,
gen_seqs=[],
custom_pgen_model=self._model_path,
vj=self.locus in [Chain.ALPHA, Chain.KAPPA, Chain.LIGHT],
include_joint_genes=self.include_joint_genes,
include_indep_genes=not self.include_joint_genes)
self._model.add_generated_seqs(num_gen_seqs=self.n_gen_seqs, custom_model_folder=self._model_path)
self._model.infer_selection(epochs=self.epochs, batch_size=self.batch_size, verbose=1)
print_log(f"{SoNNia.__name__}: selection model fitted.", True)
[docs]
def is_same(self, model) -> bool:
raise NotImplementedError
[docs]
def generate_sequences(self, count: int, seed: int, path: Path, sequence_type: SequenceType, compute_p_gen: bool):
from sonia.sequence_generation import SequenceGeneration
gen_model = SequenceGeneration(self._model)
sequences = gen_model.generate_sequences_post(count)
return SequenceDataset.build_from_objects(sequences=[ReceptorSequence(sequence_aa=seq[0], sequence=seq[3],
v_call=seq[1], j_call=seq[2],
metadata={'gen_model_name': self.name})
for seq in sequences],
path=PathBuilder.build(path), name='SoNNiaDataset',
labels={'gen_model_name': [self.name]})
[docs]
def compute_p_gens(self, sequences, sequence_type: SequenceType) -> np.ndarray:
raise NotImplementedError
[docs]
def compute_p_gen(self, sequence: dict, sequence_type: SequenceType) -> float:
raise NotImplementedError
[docs]
def can_compute_p_gens(self) -> bool:
return False
[docs]
def can_generate_from_skewed_gene_models(self) -> bool:
return False
[docs]
def generate_from_skewed_gene_models(self, v_genes: list, j_genes: list, seed: int, path: Path,
sequence_type: SequenceType, batch_size: int, compute_p_gen: bool):
raise NotImplementedError
[docs]
def save_model(self, path: Path) -> Path:
PathBuilder.build(path / 'model')
write_yaml(path / 'model/model_overview.yaml', {'type': 'SoNNia', 'locus': self.locus.name,
'region_type': self.region_type.name,
**{k: v for k, v in vars(self).items()
if
k not in ['_model', 'locus', '_model_path', 'region_type']}}) # todo add 'dataset_type': 'SequenceDataset',
attributes_to_save = ['data_seqs', 'gen_seqs', 'log']
self._model.save_model(path / 'model', attributes_to_save)
model_json = self._model.model.to_json()
model_weights = [w.tolist() for w in self._model.model.get_weights()]
model_data = {'model_config': model_json, 'model_weights': model_weights}
with open(path / 'model' / 'model.json', 'w') as json_file:
json.dump(model_data, json_file)
return Path(shutil.make_archive(str(path / 'trained_model'), "zip", str(path / 'model'))).absolute()