Source code for immuneML.dsl.instruction_parsers.TrainMLModelParser

import hashlib
from inspect import signature
from pathlib import Path
from typing import Tuple

from immuneML.data_model.dataset.Dataset import Dataset
from immuneML.dsl.DefaultParamsLoader import DefaultParamsLoader
from immuneML.dsl.definition_parsers.PreprocessingParser import PreprocessingParser
from immuneML.dsl.instruction_parsers.LabelHelper import LabelHelper
from immuneML.dsl.symbol_table.SymbolTable import SymbolTable
from immuneML.environment.EnvironmentSettings import EnvironmentSettings
from immuneML.environment.LabelConfiguration import LabelConfiguration
from immuneML.hyperparameter_optimization.HPSetting import HPSetting
from immuneML.hyperparameter_optimization.config.LeaveOneOutConfig import LeaveOneOutConfig
from immuneML.hyperparameter_optimization.config.ManualSplitConfig import ManualSplitConfig
from immuneML.hyperparameter_optimization.config.ReportConfig import ReportConfig
from immuneML.hyperparameter_optimization.config.SplitConfig import SplitConfig
from immuneML.hyperparameter_optimization.config.SplitType import SplitType
from immuneML.ml_metrics.Metric import Metric
from immuneML.reports.train_ml_model_reports.TrainMLModelReport import TrainMLModelReport
from immuneML.util.ParameterValidator import ParameterValidator
from immuneML.util.ReflectionHandler import ReflectionHandler
from immuneML.workflows.instructions.TrainMLModelInstruction import TrainMLModelInstruction


[docs] class TrainMLModelParser:
[docs] def parse(self, key: str, instruction: dict, symbol_table: SymbolTable, path: Path = None) -> TrainMLModelInstruction: valid_keys = ["assessment", "selection", "dataset", "strategy", "labels", "metrics", "settings", "number_of_processes", "type", "reports", "optimization_metric", 'refit_optimal_model'] ParameterValidator.assert_type_and_value(instruction['settings'], list, TrainMLModelParser.__name__, 'settings') ParameterValidator.assert_keys(list(instruction.keys()), valid_keys, TrainMLModelParser.__name__, "TrainMLModel") ParameterValidator.assert_type_and_value(instruction['refit_optimal_model'], bool, TrainMLModelParser.__name__, 'refit_optimal_model') ParameterValidator.assert_type_and_value(instruction['metrics'], list, TrainMLModelParser.__name__, 'metrics') ParameterValidator.assert_type_and_value(instruction['optimization_metric'], str, TrainMLModelParser.__name__, 'optimization_metric') ParameterValidator.assert_type_and_value(instruction['number_of_processes'], int, TrainMLModelParser.__name__, 'number_of_processes') ParameterValidator.assert_type_and_value(instruction['strategy'], str, TrainMLModelParser.__name__, 'strategy') if instruction["reports"] is not None: ParameterValidator.assert_type_and_value(instruction['reports'], list, TrainMLModelParser.__name__, 'reports') settings = self._parse_settings(instruction, symbol_table) dataset = symbol_table.get(instruction["dataset"]) label_config = LabelHelper.create_label_config(instruction["labels"], dataset, TrainMLModelParser.__name__, key) assessment = self._parse_split_config(key, instruction, "assessment", symbol_table, len(settings), label_config) selection = self._parse_split_config(key, instruction, "selection", symbol_table, len(settings), label_config) assessment, selection = self._update_split_configs(assessment, selection, dataset) strategy = ReflectionHandler.get_class_by_name(instruction["strategy"], "hyperparameter_optimization/") metrics = {Metric.get_metric(metric) for metric in instruction["metrics"]} optimization_metric = Metric.get_metric(instruction["optimization_metric"].upper()) metric_search_criterion = Metric.get_search_criterion(optimization_metric) path = self._prepare_path(instruction) context = self._prepare_context(instruction, symbol_table) reports = self._prepare_reports(instruction["reports"], symbol_table) hp_instruction = TrainMLModelInstruction(dataset=dataset, hp_strategy=strategy(settings, metric_search_criterion), hp_settings=settings, assessment=assessment, selection=selection, metrics=metrics, optimization_metric=optimization_metric, refit_optimal_model=instruction['refit_optimal_model'], label_configuration=label_config, path=path, context=context, number_of_processes=instruction["number_of_processes"], reports=reports, name=key) return hp_instruction
def _update_split_configs(self, assessment: SplitConfig, selection: SplitConfig, dataset: Dataset) -> Tuple[SplitConfig, SplitConfig]: if assessment.split_strategy == SplitType.LOOCV: assessment.split_count = dataset.get_example_count() train_val_example_count = assessment.split_count - 1 elif assessment.split_strategy == SplitType.K_FOLD or assessment.split_strategy.STRATIFIED_K_FOLD: train_val_example_count = int(dataset.get_example_count() * (assessment.split_count - 1) / assessment.split_count) else: train_val_example_count = int(dataset.get_example_count() * assessment.training_percentage) if selection.split_strategy == SplitType.LOOCV: selection.split_count = train_val_example_count return assessment, selection def _prepare_reports(self, reports: list, symbol_table: SymbolTable) -> dict: if reports is not None: ParameterValidator.assert_type_and_value(reports, list, TrainMLModelParser.__name__, "reports") report_objects = {report_id: symbol_table.get(report_id) for report_id in reports} ParameterValidator.assert_all_type_and_value(report_objects.values(), TrainMLModelReport, TrainMLModelParser.__name__, 'reports') return report_objects else: return {} def _prepare_context(self, instruction: dict, symbol_table: SymbolTable): return {"dataset": symbol_table.get(instruction["dataset"])} def _parse_settings(self, instruction: dict, symbol_table: SymbolTable) -> list: try: settings = [] for index, setting in enumerate(instruction["settings"]): if "preprocessing" in setting and setting["preprocessing"] is not None: ParameterValidator.assert_type_and_value(setting["preprocessing"], str, TrainMLModelParser.__name__, f'settings: {index+1}. ' f'element: preprocessing') if symbol_table.contains(setting["preprocessing"]): preprocessing_sequence = symbol_table.get(setting["preprocessing"]) preproc_name = setting["preprocessing"] if not all(preproc.keeps_example_count() for preproc in preprocessing_sequence): raise ValueError(f"{TrainMLModelParser.__name__}: preprocessing sequence {preproc_name} includes preprocessing that " f"change the number of examples at runtime and as such cannot be used with this instruction. See the " f"documentation for the preprocessing or alternatively use them with other instructions.") else: raise KeyError(f"{TrainMLModelParser.__name__}: preprocessing was set in the TrainMLModel instruction to value " f"{setting['preprocessing']}, but no such preprocessing was defined in the specification under " f"definitions: {PreprocessingParser.keyword}.") else: setting["preprocessing"] = None preprocessing_sequence = [] preproc_name = None ParameterValidator.assert_keys(setting.keys(), ["preprocessing", "ml_method", "encoding"], TrainMLModelParser.__name__, f"settings, {index + 1}. entry") encoder = symbol_table.get(setting["encoding"]).build_object(symbol_table.get(instruction["dataset"]), **symbol_table.get_config(setting["encoding"])["encoder_params"])\ .set_context({"dataset": symbol_table.get(instruction['dataset'])}) ml_method = symbol_table.get(setting["ml_method"]) ml_method.check_encoder_compatibility(encoder) s = HPSetting(encoder=encoder, encoder_name=setting["encoding"], encoder_params=symbol_table.get_config(setting["encoding"])["encoder_params"], ml_method=ml_method, ml_method_name=setting["ml_method"], ml_params=symbol_table.get_config(setting["ml_method"]), preproc_sequence=preprocessing_sequence, preproc_sequence_name=preproc_name) settings.append(s) return settings except KeyError as key_error: raise KeyError(f"{TrainMLModelParser.__name__}: parameter {key_error.args[0]} was not defined under settings in TrainMLModel instruction.") def _prepare_path(self, instruction: dict) -> Path: if "path" in instruction: path = Path(instruction["path"]).absolute() else: path = EnvironmentSettings.default_analysis_path / hashlib.md5(str(instruction).encode()).hexdigest() return path def _parse_split_config(self, instruction_key, instruction: dict, split_key: str, symbol_table: SymbolTable, settings_count: int, label_config: LabelConfiguration) -> SplitConfig: try: default_params = DefaultParamsLoader.load("instructions/", SplitConfig.__name__) report_config_input = self._prepare_report_config(instruction_key, instruction, split_key, symbol_table) instruction[split_key] = {**default_params, **instruction[split_key]} split_strategy = SplitType[instruction[split_key]["split_strategy"].upper()] training_percentage = float(instruction[split_key]["training_percentage"]) if split_strategy == SplitType.RANDOM else -1 if split_strategy == SplitType.RANDOM and training_percentage == 1 and settings_count > 1: raise ValueError(f"{TrainMLModelParser.__name__}: all data under {instruction_key}/{split_key} was specified to be used for " f"training, but {settings_count} settings were specified for evaluation. Please define a test/validation set by " f"reducing the training percentage (e.g., to 0.7) or use only one hyperparameter setting to run the analysis.") if split_strategy == SplitType.STRATIFIED_K_FOLD and len(label_config.get_labels_by_name()) != 1: raise ValueError(f"{TrainMLModelParser.__name__}: Stratified k-fold cross-validation cannot be used when " f"{len(label_config.get_labels_by_name())} labels are specified. It support only one label (and multiple classes).") return SplitConfig(split_strategy=split_strategy, split_count=int(instruction[split_key]["split_count"]), training_percentage=training_percentage, reports=ReportConfig(**report_config_input), manual_config=ManualSplitConfig(**instruction[split_key]["manual_config"]) if "manual_config" in instruction[split_key] else None, leave_one_out_config=LeaveOneOutConfig(**instruction[split_key]["leave_one_out_config"]) if "leave_one_out_config" in instruction[split_key] else None) except KeyError as key_error: raise KeyError(f"{TrainMLModelParser.__name__}: parameter {key_error.args[0]} was not defined under {split_key}.") def _prepare_report_config(self, instruction_key, instruction, split_key, symbol_table): if "reports" in instruction[split_key] and len(instruction[split_key]["reports"]) > 0: location = f"{instruction_key}/{split_key}/reports" report_types = list(signature(ReportConfig).parameters.keys()) ParameterValidator.assert_all_in_valid_list(instruction[split_key]["reports"].keys(), report_types, location, "reports") for report_type in instruction[split_key]["reports"]: ParameterValidator.assert_type_and_value(instruction[split_key]["reports"][report_type], list, f"{location}/{report_type}", report_type) report_config_input = {report_type: {report_id: symbol_table.get(report_id) for report_id in instruction[split_key]["reports"][report_type]} for report_type in instruction[split_key]["reports"]} else: report_config_input = {} return report_config_input