import copy
import logging
from dataclasses import field, dataclass
from pathlib import Path
from typing import Dict, List
from uuid import uuid4
import numpy as np
from immuneML.IO.dataset_export.AIRRExporter import AIRRExporter
from immuneML.data_model.AIRRSequenceSet import AIRRSequenceSet
from immuneML.data_model.bnp_util import merge_dataclass_objects, bnp_write_to_file, get_type_dict_from_bnp_object, \
write_yaml
from immuneML.data_model.datasets.Dataset import Dataset
from immuneML.data_model.datasets.ElementDataset import SequenceDataset
from immuneML.hyperparameter_optimization.config.SplitType import SplitType
from immuneML.ml_methods.generative_models.GenerativeModel import GenerativeModel
from immuneML.reports.ReportResult import ReportResult
from immuneML.reports.data_reports.DataReport import DataReport
from immuneML.reports.train_gen_model_reports.TrainGenModelReport import TrainGenModelReport
from immuneML.util.Logger import print_log
from immuneML.util.PathBuilder import PathBuilder
from immuneML.workflows.instructions.GenModelInstruction import GenModelState, GenModelInstruction
from immuneML.workflows.steps.data_splitter.DataSplitter import DataSplitter
from immuneML.workflows.steps.data_splitter.DataSplitterParams import DataSplitterParams
[docs]
@dataclass
class TrainGenModelState:
result_path: Path = None
name: str = None
gen_examples_count: int = None
model_path: Path = None
generated_dataset: Dataset = None
exported_datasets: Dict[str, Path] = field(default_factory=dict)
report_results: Dict[str, List[ReportResult]] = field(
default_factory=lambda: {'data_reports': [], 'ml_reports': [], 'instruction_reports': []})
combined_dataset: Dataset = None
train_dataset: Dataset = None
test_dataset: Dataset = None
training_percentage: float = None
[docs]
class TrainGenModelInstruction(GenModelInstruction):
"""
TrainGenModel instruction implements training generative AIRR models on receptor level. Models that can be trained
for sequence generation are listed under Generative Models section.
This instruction takes a dataset as input which will be used to train a model, the model itself, and the number of
sequences to generate to illustrate the applicability of the model. It can also produce reports of the fitted model
and reports of original and generated sequences.
To use the generative model previously trained with immuneML, see :ref:`ApplyGenModel` instruction.
**Specification arguments:**
- dataset: dataset to use for fitting the generative model; it has to be defined under definitions/datasets
- method: which model to fit (defined previously under definitions/ml_methods)
- number_of_processes (int): how many processes to use for fitting the model
- gen_examples_count (int): how many examples (sequences, repertoires) to generate from the fitted model
- reports (list): list of report ids (defined under definitions/reports) to apply after fitting a generative model
and generating gen_examples_count examples; these can be data reports (to be run on generated examples), ML
reports (to be run on the fitted model)
**YAML specification:**
.. indent with spaces
.. code-block:: yaml
instructions:
my_train_gen_model_inst: # user-defined instruction name
type: TrainGenModel
dataset: d1 # defined previously under definitions/datasets
model: model1 # defined previously under definitions/ml_methods
gen_examples_count: 100
number_of_processes: 4
training_percentage: 0.7
export_generated_dataset: True
export_combined_dataset: False
reports: [data_rep1, ml_rep2]
"""
MAX_ELEMENT_COUNT_TO_SHOW = 10
def __init__(self, dataset: Dataset = None, method: GenerativeModel = None,
number_of_processes: int = 1, gen_examples_count: int = 100, result_path: Path = None,
name: str = None, reports: list = None, export_generated_dataset: bool = True,
export_combined_dataset: bool = False, training_percentage: float = None):
super().__init__(TrainGenModelState(result_path, name, gen_examples_count), method, reports)
self.dataset = dataset
self.number_of_processes = number_of_processes
self.export_generated_dataset = export_generated_dataset
self.export_combined_dataset = export_combined_dataset
self.state.training_percentage = training_percentage
[docs]
def run(self, result_path: Path) -> TrainGenModelState:
self._set_path(result_path)
self._split_dataset()
self._fit_model()
self._save_model()
self._gen_data()
if self.export_generated_dataset:
self._export_generated_dataset()
self._make_and_export_combined_dataset()
self._run_reports()
return self.state
def _split_dataset(self):
if self.state.training_percentage != 1:
split_params = DataSplitterParams(dataset=self.dataset, split_strategy=SplitType.RANDOM, split_count=1,
training_percentage=self.state.training_percentage,
paths=[self.state.result_path])
train_datasets, test_datasets = DataSplitter.run(split_params)
self.state.train_dataset = train_datasets[0]
self.state.test_dataset = test_datasets[0]
else:
logging.info(f"{TrainGenModelInstruction.__name__}: training_percentage was set to 1 meaning that the full "
f"dataset will be used for fitting the generative model. All resulting comparison reports "
f"will then use the full original dataset as opposed to independent test dataset if the "
f"training percentage was less than 1.")
self.state.train_dataset = self.dataset
self.state.test_dataset = self.dataset
def _fit_model(self):
print_log(f"{self.state.name}: starting to fit the model", True)
self.method.fit(self.state.train_dataset, self.state.result_path)
print_log(f"{self.state.name}: fitted the model", True)
def _make_combined_dataset(self):
path = PathBuilder.build(self.state.result_path / 'combined_dataset')
gen_data = self._get_dataclass_object_from_dataset(self.generated_dataset,
np.ones(self.state.gen_examples_count),
np.zeros(self.state.gen_examples_count))
if self.state.training_percentage < 1:
org_data = self._get_dataclass_object_from_dataset(self.state.train_dataset,
np.zeros(self.state.train_dataset.get_example_count()),
np.ones(self.state.train_dataset.get_example_count()))
test_data = self._get_dataclass_object_from_dataset(self.state.test_dataset,
np.zeros(self.state.test_dataset.get_example_count()),
np.zeros(self.state.test_dataset.get_example_count()))
combined_data = merge_dataclass_objects([org_data, test_data, gen_data], fill_unmatched=True)
else:
org_data = self._get_dataclass_object_from_dataset(self.dataset, np.zeros(self.dataset.get_example_count()),
np.ones(self.dataset.get_example_count()))
combined_data = merge_dataclass_objects([org_data, gen_data], fill_unmatched=True)
bnp_write_to_file(path / f'combined_{self.state.name}_dataset.tsv', combined_data)
metadata_yaml = SequenceDataset.create_metadata_dict(dataset_class=SequenceDataset.__name__,
filename=f'combined_{self.state.name}_dataset.tsv',
type_dict=type(combined_data).get_field_type_dict(all_fields=False),
identifier=uuid4().hex,
name=f'combined_{self.state.name}_dataset',
labels={'gen_model_name': [self.method.name, ''], "from_gen_model": [True, False]})
write_yaml(path / f'combined_{self.state.name}_dataset.yaml', metadata_yaml)
self.state.combined_dataset = SequenceDataset.build(
metadata_filename=path / f'combined_{self.state.name}_dataset.yaml',
filename=path / f'combined_{self.state.name}_dataset.tsv', name=f'combined_{self.state.name}_dataset')
def _get_dataclass_object_from_dataset(self, dataset: Dataset, from_gen_model_vals: np.ndarray,
used_for_training_vals: np.ndarray):
return dataset.data.add_fields(
{'from_gen_model': from_gen_model_vals, 'used_for_training': used_for_training_vals},
{'from_gen_model': bool, 'used_for_training': bool})
def _make_and_export_combined_dataset(self):
if self.export_combined_dataset and isinstance(self.dataset, SequenceDataset):
self._make_combined_dataset()
export_path = PathBuilder.build(self.state.result_path / 'exported_combined_dataset')
try:
AIRRExporter.export(self.state.combined_dataset, export_path)
self.state.exported_datasets['combined_dataset'] = export_path
except AssertionError as e:
logging.warning(f"{TrainGenModelInstruction.__name__}: {self.state.name}: combined dataset could not "
f"be exported due to the following error: {e}.")
elif self.export_combined_dataset:
logging.warning(f"{TrainGenModelInstruction.__name__}: {self.state.name}: export_combined_dataset is only "
f"supported for sequence datasets at this point.")
def _run_reports(self):
super()._run_reports()
report_path = self._get_reports_path()
for report in self.reports:
original_dataset = self.state.train_dataset if self.state.training_percentage != 1 else self.dataset
report.result_path = report_path
if isinstance(report, TrainGenModelReport):
report.generated_dataset = self.generated_dataset
report.original_dataset = original_dataset
report.model = self.method
self.state.report_results['instruction_reports'].append(report.generate_report())
elif isinstance(report, DataReport):
rep = copy.deepcopy(report)
rep.result_path = PathBuilder.build(rep.result_path.parent / f"{rep.result_path.name}_original_dataset")
rep.dataset = original_dataset
rep.name = rep.name + " (original dataset)"
self.state.report_results['data_reports'].append(rep.generate_report())
super()._print_report_summary_log()
def _save_model(self):
self.state.model_path = self.method.save_model(self.state.result_path / 'trained_model/')