Source code for immuneML.workflows.instructions.GenModelInstruction
import copy
from abc import ABC
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List
from immuneML.IO.dataset_export.AIRRExporter import AIRRExporter
from immuneML.data_model.datasets.Dataset import Dataset
from immuneML.environment.SequenceType import SequenceType
from immuneML.ml_methods.generative_models import GenerativeModel
from immuneML.reports.ReportResult import ReportResult
from immuneML.reports.data_reports.DataReport import DataReport
from immuneML.reports.ml_reports.MLReport import MLReport
from immuneML.util.Logger import print_log
from immuneML.util.PathBuilder import PathBuilder
from immuneML.workflows.instructions.Instruction import Instruction
[docs]
@dataclass
class GenModelState:
result_path: Path
name: str
gen_examples_count: int
model_path: Path = None
generated_dataset: Dataset = None
exported_datasets: Dict[str, Path] = field(default_factory=dict)
report_results: Dict[str, List[ReportResult]] = field(
default_factory=lambda: {'data_reports': [], 'ml_reports': [], 'instruction_reports': []})
[docs]
class GenModelInstruction(Instruction, ABC):
def __init__(self, state=None, method: GenerativeModel = None, reports: list = None):
self.generated_dataset = None
self.method = method
self.state = state
self.reports = reports
def _gen_data(self):
dataset = self.method.generate_sequences(self.state.gen_examples_count, 1,
self.state.result_path / 'generated_sequences',
SequenceType.AMINO_ACID, False)
print_log(f"{self.state.name}: generated {self.state.gen_examples_count} examples from the fitted model", True)
self.generated_dataset = dataset
self.state.generated_dataset = self.generated_dataset
def _export_generated_dataset(self):
AIRRExporter.export(self.generated_dataset, self.state.result_path / 'exported_gen_dataset')
self.state.exported_datasets['generated_dataset'] = self.state.result_path / 'exported_gen_dataset'
def _run_reports(self):
report_path = self._get_reports_path()
for report in self.reports:
report.result_path = report_path
if isinstance(report, DataReport):
rep = copy.deepcopy(report)
rep.dataset = self.generated_dataset
rep.name = rep.name + " (generated dataset)"
self.state.report_results['data_reports'].append(rep.generate_report())
elif isinstance(report, MLReport):
report.method = self.method
self.state.report_results['ml_reports'].append(report.generate_report())
def _print_report_summary_log(self):
if len(self.reports) > 0:
gen_rep_count = len(self.state.report_results['ml_reports']) + len(
self.state.report_results['data_reports'])
print_log(f"{self.state.name}: generated {gen_rep_count} reports.", True)
def _get_reports_path(self) -> Path:
return PathBuilder.build(self.state.result_path / 'reports')
def _set_path(self, result_path):
self.state.result_path = PathBuilder.build(result_path / self.state.name)