Toggle Light / Dark / Auto color theme
Toggle table of contents sidebar
Source code for immuneML.workflows.steps.MLMethodTrainer
import copy
import pandas as pd
from immuneML.ml_methods.classifiers.MLMethod import MLMethod
from immuneML.util.Logger import print_log
from immuneML.workflows.steps.MLMethodTrainerParams import MLMethodTrainerParams
from immuneML.workflows.steps.Step import Step
[docs]
class MLMethodTrainer ( Step ):
[docs]
@staticmethod
def run ( input_params : MLMethodTrainerParams = None ):
print_log ( f "ML model training started..." , include_datetime = True )
method = MLMethodTrainer . _fit_method ( input_params )
MLMethodTrainer . store ( method , input_params )
print_log ( f "ML model training finished." , include_datetime = True )
return method
@staticmethod
def _fit_method ( input_params : MLMethodTrainerParams ):
method = copy . deepcopy ( input_params . method )
method . result_path = input_params . result_path
if input_params . model_selection_cv :
method . fit_by_cross_validation ( encoded_data = input_params . dataset . encoded_data ,
number_of_splits = input_params . model_selection_n_folds ,
label = input_params . label ,
cores_for_training = input_params . cores_for_training ,
optimization_metric = input_params . optimization_metric )
else :
method . fit ( encoded_data = input_params . dataset . encoded_data ,
label = input_params . label ,
cores_for_training = input_params . cores_for_training ,
optimization_metric = input_params . optimization_metric )
return method
[docs]
@staticmethod
def store ( method : MLMethod , input_params : MLMethodTrainerParams ):
method . store ( input_params . result_path )
train_predictions = method . predict ( input_params . dataset . encoded_data , input_params . label )
train_proba_predictions = method . predict_proba ( input_params . dataset . encoded_data , input_params . label )
df = pd . DataFrame ({ "example_ids" : input_params . dataset . encoded_data . example_ids ,
f " { input_params . label . name } _predicted_class" : train_predictions [ input_params . label . name ],
f " { input_params . label . name } _true_class" : input_params . dataset . encoded_data . labels [ input_params . label . name ]})
for cls in method . get_classes ():
tmp = train_proba_predictions [ input_params . label . name ][ cls ] if train_proba_predictions is not None and train_proba_predictions [ input_params . label . name ] is not None else None
df [ f ' { input_params . label . name } _ { cls } _proba' ] = tmp
df . to_csv ( input_params . train_predictions_path , index = False )