Source code for immuneML.workflows.steps.data_splitter.ManualSplitter

import os

import numpy as np
import pandas as pd

from immuneML.data_model.dataset.Dataset import Dataset
from immuneML.data_model.dataset.ElementDataset import ElementDataset
from immuneML.data_model.dataset.RepertoireDataset import RepertoireDataset
from immuneML.util.ReflectionHandler import ReflectionHandler
from immuneML.workflows.steps.data_splitter.DataSplitterParams import DataSplitterParams
from immuneML.workflows.steps.data_splitter.Util import Util


[docs]class ManualSplitter:
[docs] @staticmethod def split_dataset(input_params: DataSplitterParams): if isinstance(input_params.dataset, RepertoireDataset): return ManualSplitter._split_repertoire_dataset(input_params) elif isinstance(input_params.dataset, ElementDataset): return ManualSplitter._split_element_dataset(input_params) else: raise ValueError(f"DataSplitter: dataset is unexpected class: {type(input_params.dataset).__name__}, " f"expected one of {str(ReflectionHandler.all_nonabstract_subclass_basic_names(Dataset, '', 'dataset/'))[1:-1]}")
@staticmethod def _split_repertoire_dataset(input_params): train_metadata_path = input_params.split_config.manual_config.train_metadata_path test_metadata_path = input_params.split_config.manual_config.test_metadata_path train_dataset = ManualSplitter._make_manual_dataset(input_params, train_metadata_path, Dataset.TRAIN) test_dataset = ManualSplitter._make_manual_dataset(input_params, test_metadata_path, Dataset.TEST) return [train_dataset], [test_dataset] @staticmethod def _make_manual_dataset(input_params, metadata_path, dataset_type): dataset = input_params.dataset metadata = dataset.get_metadata(["subject_id"])["subject_id"] unique_metadata_count = np.unique(metadata).shape[0] assert len(metadata) == unique_metadata_count, f"DataSplitter: there are {len(metadata)} repertoires, but {unique_metadata_count} " \ f"unique identifiers. Check the metadata for the original dataset {dataset.name}." metadata_df = pd.read_csv(metadata_path) assert "subject_id" in metadata_df, f"DataSplitter: {dataset_type} metadata {os.path.basename(metadata_path)} is missing column " \ f"'subject_id' which should be used for matching repertoires when splitting to train and test data." indices = [i for i in range(len(metadata)) if metadata[i] in metadata_df["subject_id"].values.tolist()] new_dataset = Util.make_dataset(dataset, indices, input_params, 0, dataset_type) return new_dataset @staticmethod def _split_element_dataset(input_params): raise NotImplementedError("DataSplitter: manually specifying receptors or receptor sequences for training and test set is not yet " "implemented.")