import copy
import pickle
from immuneML.caching.CacheHandler import CacheHandler
from immuneML.data_model.datasets.Dataset import Dataset
from immuneML.data_model.datasets.ElementDataset import ElementDataset
from immuneML.data_model.datasets.RepertoireDataset import RepertoireDataset
from immuneML.encodings.EncoderParams import EncoderParams
from immuneML.environment.Label import Label
from immuneML.environment.LabelConfiguration import LabelConfiguration
from immuneML.pairwise_repertoire_comparison.ComparisonData import ComparisonData
from immuneML.util.PathBuilder import PathBuilder
[docs]
class EncoderHelper:
[docs]
@staticmethod
def prepare_training_ids(dataset: Dataset, params: EncoderParams):
PathBuilder.build(params.result_path)
if params.learn_model:
training_ids = dataset.get_example_ids()
training_ids_path = params.result_path / "training_ids.pickle"
with training_ids_path.open("wb") as file:
pickle.dump(training_ids, file)
else:
training_ids_path = params.result_path / "training_ids.pickle"
with training_ids_path.open("rb") as file:
training_ids = pickle.load(file)
return training_ids
[docs]
@staticmethod
def get_current_dataset(dataset, context):
"""Retrieves the full dataset (training+validation+test) if present in context, otherwise return the given
dataset"""
return dataset if context is None or "dataset" not in context else context["dataset"]
[docs]
@staticmethod
def build_comparison_params(dataset, comparison_attributes) -> tuple:
return (("dataset_identifier", dataset.identifier),
("comparison_attributes", tuple(comparison_attributes)),
("repertoire_ids", tuple(dataset.get_repertoire_ids())))
[docs]
@staticmethod
def build_comparison_data(dataset: RepertoireDataset, params: EncoderParams,
comparison_attributes, sequence_batch_size):
comp_data = ComparisonData(dataset.get_repertoire_ids(), comparison_attributes,
sequence_batch_size, params.result_path)
comp_data.process_dataset(dataset)
return comp_data
[docs]
@staticmethod
def sync_encoder_with_cache(cache_params: tuple, encoder_memo_func, encoder, param_names):
encoder_cache_params = tuple((key, val) for key, val in dict(cache_params).items() if key != 'learn_model')
encoder_cache_params = (encoder_cache_params, "encoder")
encoder_from_cache = CacheHandler.memo_by_params(encoder_cache_params, encoder_memo_func)
for param in param_names:
setattr(encoder, param, copy.deepcopy(encoder_from_cache[param]))
return encoder
[docs]
@staticmethod
def check_dataset_type_available_in_mapping(dataset, class_name):
if dataset.__class__.__name__ not in class_name.dataset_mapping.keys():
raise ValueError(
f"{class_name.__name__}: this encoder is not defined for dataset of type {dataset.__class__.__name__}. "
f"Valid dataset types for this encoder are: {', '.join(list(class_name.dataset_mapping.keys()))}")
[docs]
@staticmethod
def encode_element_dataset_labels(dataset: ElementDataset, label_config: LabelConfiguration):
"""Automatically generates the encoded labels for an ElementDataset (= SequenceDataset or ReceptorDataset)"""
labels = {name: [] for name in label_config.get_labels_by_name()}
for sequence in dataset.get_data():
for label_name in label_config.get_labels_by_name():
label = sequence.metadata[label_name]
labels[label_name].append(label)
return labels
[docs]
@staticmethod
def encode_repertoire_dataset_labels(dataset: RepertoireDataset, label_config: LabelConfiguration):
'''Automatically generates the encoded labels for a RepertoireDataset'''
label_names = label_config.get_labels_by_name()
return dataset.get_metadata(label_names)
[docs]
@staticmethod
def encode_dataset_labels(dataset: Dataset, label_config: LabelConfiguration, encode_labels: bool = True):
'''Automatically generates the encoded labels for a Dataset.
This contains labels in the following format: {'label_name': ['label_class1', 'label_class2', 'label_class2']}
where the inner list(s) contain the class label for each example in the dataset'''
if not encode_labels:
return None
if isinstance(dataset, RepertoireDataset):
return EncoderHelper.encode_repertoire_dataset_labels(dataset, label_config)
else:
return EncoderHelper.encode_element_dataset_labels(dataset, label_config)
[docs]
@staticmethod
def check_positive_class_labels(label_config: LabelConfiguration, location: str):
"""
Performs checks for Encoders that explicitly predict a positive class. These Encoders can only be trained for a
single binary label at a time.
"""
labels = label_config.get_label_objects()
assert len(labels) == 1, (f"{location}: this encoding works only for single label, there are now "
f"{len(labels)} labels specified.")
label = labels[0]
assert isinstance(label, Label) and label.positive_class is not None and label.positive_class != "", \
f"{location}: positive_class parameter was not set for label {label}. It has to be set to determine the " \
f"receptor sequences associated with the positive class. " \
f"To use this encoder, in the label definition in the specification of the instruction, define " \
f"the positive class for the label. See documentation for this encoder for more details."
assert len(
label.values) == 2, f"{location}: only binary classification (2 classes) is possible when extracting " \
f"relevant sequences for the label, but got these classes for label {label.name} instead: {label.values}."
[docs]
@staticmethod
def get_example_weights_by_identifiers(dataset, example_identifiers):
weights = None
if weights is not None:
weights_dict = dict(zip(dataset.get_example_ids(), weights))
return [weights_dict[identifier] for identifier in example_identifiers]
[docs]
@staticmethod
def get_single_label_name_from_config(label_config: LabelConfiguration, location="EncoderHelper"):
assert label_config.get_label_count() != 0, f"{location}: the dataset does not contain labels, please specify a label under 'instructions'."
assert label_config.get_label_count() == 1, f"{location}: multiple labels were found: {', '.join(label_config.get_labels_by_name())}, expected a single label."
return label_config.get_labels_by_name()[0]