Source code for immuneML.workflows.steps.data_splitter.LeaveOneOutSplitter

import numpy as np

from immuneML.data_model.dataset.ReceptorDataset import ReceptorDataset
from immuneML.workflows.steps.data_splitter.DataSplitterParams import DataSplitterParams
from immuneML.workflows.steps.data_splitter.Util import Util


[docs]class LeaveOneOutSplitter:
[docs] @staticmethod def split_dataset(input_params: DataSplitterParams): if isinstance(input_params.dataset, ReceptorDataset): return LeaveOneOutSplitter._split_receptor_dataset(input_params) else: raise NotImplementedError("LeaveOneOutSplitter: leave-one-out stratification is currently implemented only for receptor dataset, " f"got {type(input_params.dataset).__name__} instead.")
@staticmethod def _split_receptor_dataset(input_params: DataSplitterParams): dataset = input_params.dataset param, min_count = input_params.split_config.leave_one_out_config.parameter, input_params.split_config.leave_one_out_config.min_count unique_values = LeaveOneOutSplitter._get_unique_param_values(dataset, param, min_count) input_params = LeaveOneOutSplitter._update_split_count(input_params, unique_values) train_indices, test_indices = LeaveOneOutSplitter._get_train_test_indices(dataset, unique_values, param) train_datasets, test_datasets = LeaveOneOutSplitter._make_datasets_from_indices(unique_values, dataset, train_indices, test_indices, input_params) return train_datasets, test_datasets @staticmethod def _update_split_count(input_params: DataSplitterParams, unique_values): input_params.split_config.split_count = unique_values.shape[0] input_params.split_count = input_params.split_config.split_count return input_params @staticmethod def _make_datasets_from_indices(unique_values, dataset, train_indices, test_indices, input_params): train_datasets, test_datasets = [], [] for index, value in enumerate(unique_values): train_datasets.append(Util.make_dataset(dataset, train_indices[value], input_params, index, ReceptorDataset.TRAIN)) test_datasets.append(Util.make_dataset(dataset, test_indices[value], input_params, index, ReceptorDataset.TEST)) return train_datasets, test_datasets @staticmethod def _get_unique_param_values(dataset, param, min_count): parameter_values = [receptor.metadata[param] for receptor in dataset.get_data()] unique_values, count = np.unique(parameter_values, return_counts=True) assert all(el > min_count for el in count), f"DataSplitter: there are not enough examples with different values of the parameter {param} " \ f"to split the dataset." return unique_values @staticmethod def _get_train_test_indices(dataset, unique_values, param): train_indices, test_indices = {value: [] for value in unique_values}, {value: [] for value in unique_values} for index, receptor in enumerate(dataset.get_data()): for value in unique_values: if receptor.metadata[param] == value: test_indices[value].append(index) else: train_indices[value].append(index) return train_indices, test_indices