import inspect
from immuneML.dsl.DefaultParamsLoader import DefaultParamsLoader
from immuneML.dsl.symbol_table.SymbolTable import SymbolTable
from immuneML.dsl.symbol_table.SymbolType import SymbolType
from immuneML.ml_methods.MLMethod import MLMethod
from immuneML.util.Logger import log
from immuneML.util.ParameterValidator import ParameterValidator
from immuneML.util.ReflectionHandler import ReflectionHandler
[docs]
class MLParser:
[docs]
@staticmethod
def parse(specification: dict, symbol_table: SymbolTable):
for ml_method_id in specification.keys():
ml_method, config = MLParser._parse_ml_method(ml_method_id, specification[ml_method_id])
specification[ml_method_id] = config
symbol_table.add(ml_method_id, SymbolType.ML_METHOD, ml_method, config)
return symbol_table, specification
@staticmethod
@log
def _parse_ml_method(ml_method_id: str, ml_specification) -> tuple:
valid_class_values = ReflectionHandler.all_nonabstract_subclass_basic_names(MLMethod, "", "ml_methods/")
if type(ml_specification) is str:
ml_specification = {ml_specification: {}}
ml_specification = {**DefaultParamsLoader.load("ml_methods/", "MLMethod"), **ml_specification}
ml_specification_keys = list(ml_specification.keys())
ParameterValidator.assert_all_in_valid_list(list(ml_specification_keys), ["model_selection_cv", "model_selection_n_folds"] +
valid_class_values, "MLParser", ml_method_id)
non_default_keys = [key for key in ml_specification.keys() if key not in ["model_selection_cv", "model_selection_n_folds"]]
assert len(ml_specification_keys) == 3, f"MLParser: ML method {ml_method_id} was not correctly specified. Expected at least 1 key " \
f"(ML method name), got {len(ml_specification_keys) - 2} instead: " \
f"{str([key for key in non_default_keys])[1:-1]}."
ml_method_class_name = non_default_keys[0]
ml_method_class = ReflectionHandler.get_class_by_name(ml_method_class_name, "ml_methods/")
ml_specification[ml_method_class_name] = {**DefaultParamsLoader.load("ml_methods/", ml_method_class_name, log_if_missing=False),
**ml_specification[ml_method_class_name]}
method, params = MLParser.create_method_instance(ml_specification, ml_method_class, ml_method_id)
ml_specification[ml_method_class_name] = params
method.name = ml_method_id
return method, ml_specification
[docs]
@staticmethod
def create_method_instance(ml_specification: dict, ml_method_class, key: str) -> tuple:
ml_params = {}
if ml_specification[ml_method_class.__name__] is None or len(ml_specification[ml_method_class.__name__].keys()) == 0:
ml_method = ml_method_class()
else:
ml_params = ml_specification[ml_method_class.__name__]
init_method_keys = inspect.signature(ml_method_class.__init__).parameters.keys()
if any([isinstance(ml_params[key], list) for key in ml_params.keys()]) and "parameter_grid" in init_method_keys:
ParameterValidator.assert_type_and_value(ml_specification['model_selection_cv'], bool, MLParser.__name__, f'{key}: model_selection_cv')
assert ml_specification['model_selection_cv'] is True, f"MLParser: when running ML method {key} with a list of inputs, model_selection_cv must be True! " \
f"Set the parameters for {key} to single values (not lists) or set model_selection_cv to True and model_selection_n_folds to >= 2"
ParameterValidator.assert_type_and_value(ml_specification['model_selection_n_folds'], int, MLParser.__name__, f'{key}: model_selection_n_folds', 2)
ml_method = ml_method_class(parameter_grid={key: [ml_params[key]] if not isinstance(ml_params[key], list) else ml_params[key]
for key in ml_params.keys()})
elif len(init_method_keys) == 3 and all(arg in init_method_keys for arg in ["parameters", "parameter_grid"]):
ml_method = ml_method_class(parameters=ml_params)
else:
ml_method = ml_method_class(**ml_params)
return ml_method, ml_params