import hashlib
import warnings
from pathlib import Path
import h5py
import numpy as np
import pkg_resources
import torch
import yaml
from sklearn.exceptions import NotFittedError
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from immuneML.caching.CacheHandler import CacheHandler
from immuneML.data_model.encoded_data.EncodedData import EncodedData
from immuneML.encodings.deeprc.DeepRCEncoder import DeepRCEncoder
from immuneML.ml_methods.MLMethod import MLMethod
from immuneML.ml_methods.util.Util import Util
from immuneML.util.FilenameHandler import FilenameHandler
from immuneML.util.PathBuilder import PathBuilder
from immuneML.util.ReflectionHandler import ReflectionHandler
[docs]class DeepRC(MLMethod):
"""
This classifier uses the DeepRC method for repertoire classification. The DeepRC ML method should be used in combination
with the DeepRC encoder. Also consider using the :ref:`DeepRCMotifDiscovery` report for interpretability.
Notes:
- DeepRC uses PyTorch functionalities that depend on GPU. Therefore, DeepRC does not work on a CPU.
- This wrapper around DeepRC currently only supports binary classification.
Reference:
Michael Widrich, Bernhard Schäfl, Milena Pavlović, Geir Kjetil Sandve, Sepp Hochreiter, Victor Greiff, Günter Klambauer
‘DeepRC: Immune repertoire classification with attention-based deep massive multiple instance learning’.
bioRxiv preprint doi: `https://doi.org/10.1101/2020.04.12.038158 <https://doi.org/10.1101/2020.04.12.038158>`_
Arguments:
validation_part (float): the part of the data that will be used for validation, the rest will be used for training.
add_positional_information (bool): whether positional information should be included in the input features.
kernel_size (int): the size of the 1D-CNN kernels.
n_kernels (int): the number of 1D-CNN kernels in each layer.
n_additional_convs (int): Number of additional 1D-CNN layers after first layer
n_attention_network_layers (int): Number of attention layers to compute keys
n_attention_network_units (int): Number of units in each attention layer
n_output_network_units (int): Number of units in the output layer
consider_seq_counts (bool): whether the input data should be scaled by the receptor sequence counts.
sequence_reduction_fraction (float): Fraction of number of sequences to which to reduce the number of sequences per bag based on attention weights. Has to be in range [0,1].
reduction_mb_size (int): Reduction of sequences per bag is performed using minibatches of reduction_mb_size` sequences to compute the attention weights.
n_updates (int): Number of updates to train for
n_torch_threads (int): Number of parallel threads to allow PyTorch
learning_rate (float): Learning rate for adam optimizer
l1_weight_decay (float): l1 weight decay factor. l1 weight penalty will be added to loss, scaled by `l1_weight_decay`
l2_weight_decay (float): l2 weight decay factor. l2 weight penalty will be added to loss, scaled by `l2_weight_decay`
evaluate_at (int): Evaluate model on training and validation set every `evaluate_at` updates. This will also check for a new best model for early stopping.
sample_n_sequences (int): Optional random sub-sampling of `sample_n_sequences` sequences per repertoire. Number of sequences per repertoire might be smaller than `sample_n_sequences` if repertoire is smaller or random indices have been drawn multiple times. If None, all sequences will be loaded for each repertoire.
training_batch_size (int): Number of repertoires per minibatch during training.
n_workers (int): Number of background processes to use for converting dataset to hdf5 container and training set data loader.
pytorch_device_name (str): The name of the pytorch device to use. This name will be passed to torch.device(self.pytorch_device_name). The default value is cuda:0
YAML specification:
.. indent with spaces
.. code-block:: yaml
my_deeprc_method:
DeepRC:
validation_part: 0.2
add_positional_information: True
kernel_size: 9
"""
def __init__(self, validation_part, add_positional_information, kernel_size, n_kernels,
n_additional_convs, n_attention_network_layers, n_attention_network_units, n_output_network_units,
consider_seq_counts, sequence_reduction_fraction, reduction_mb_size, n_updates, n_torch_threads,
learning_rate, l1_weight_decay, l2_weight_decay, evaluate_at, sample_n_sequences, training_batch_size, n_workers,
keep_dataset_in_ram, pytorch_device_name):
super(DeepRC, self).__init__()
if not ReflectionHandler.is_installed("deeprc"):
raise RuntimeError(f"{DeepRC.__name__}: deeprc module is not installed. Please check the documentation at "
f"https://docs.immuneml.uio.no/installation/install_with_package_manager.html for instructions how to install it.")
from deeprc.deeprc_binary.training import train
self.training_function = train
self.model = None
self.result_path = None
self.max_seq_len = None
self.label_classes = None
self.label_is_bool = None
self.label = None
self.keep_dataset_in_ram = keep_dataset_in_ram
self.pytorch_device_name = pytorch_device_name
self.pytorch_device = torch.device(self.pytorch_device_name)
# ML model setting (not inherited from DeepRC code)
self.validation_part = validation_part
# DeepRC class settings:
self.add_positional_information = add_positional_information
self.n_input_features = 20 + 3 * self.add_positional_information
self.kernel_size = kernel_size
self.n_kernels = n_kernels
self.n_additional_convs = n_additional_convs
self.n_attention_network_layers = n_attention_network_layers
self.n_attention_network_units = n_attention_network_units
self.n_output_network_units = n_output_network_units
self.consider_seq_counts = consider_seq_counts
self.sequence_reduction_fraction = sequence_reduction_fraction
self.reduction_mb_size = reduction_mb_size
# train function settings:
self.evaluate_at = evaluate_at
self.n_updates = n_updates
self.n_torch_threads = n_torch_threads
self.learning_rate = learning_rate
self.l1_weight_decay = l1_weight_decay
self.l2_weight_decay = l2_weight_decay
# Dataloader related settings:
self.sample_n_sequences = sample_n_sequences
self.training_batch_size = training_batch_size
self.n_workers = n_workers
self.feature_names = None
def _metadata_to_hdf5(self, metadata_filepath: Path, label_name):
from deeprc.deeprc_binary.dataset_converters import DatasetToHDF5
hdf5_filepath = metadata_filepath.parent / f"{metadata_filepath.stem}.hdf5"
converter = DatasetToHDF5(metadata_file=str(metadata_filepath),
id_column=DeepRCEncoder.ID_COLUMN,
single_class_label_columns=tuple([label_name]),
sequence_column=DeepRCEncoder.SEQUENCE_COLUMN,
sequence_counts_column=DeepRCEncoder.COUNTS_COLUMN,
column_sep=DeepRCEncoder.SEP,
filename_extension=f".{DeepRCEncoder.EXTENSION}",
verbose=False)
converter.save_data_to_file(output_file=str(hdf5_filepath), n_workers=self.n_workers)
return hdf5_filepath
def _load_dataset_in_ram(self, hdf5_filepath: Path):
with h5py.File(str(hdf5_filepath), 'r') as hf:
pre_loaded_hdf5_file = dict()
pre_loaded_hdf5_file['seq_lens'] = hf['sampledata']['seq_lens'][:]
pre_loaded_hdf5_file['counts_per_sequence'] = hf['sampledata']['counts_per_sequence'][:]
pre_loaded_hdf5_file['amino_acid_sequences'] = hf['sampledata']['amino_acid_sequences'][:]
return pre_loaded_hdf5_file
def _get_train_val_indices(self, n_examples, classes):
"""splits the data to training and validation and attempts to preserve the class distribution if possible"""
indices = np.arange(0, n_examples)
train_indices, val_indices = train_test_split(indices, test_size=self.validation_part, shuffle=True, stratify=classes)
return train_indices, val_indices
[docs] def make_data_loader(self, hdf5_filepath: Path, pre_loaded_hdf5_file, indices, label, eval_only: bool, is_train: bool, n_workers=1):
"""
Creates a pytorch dataloader using DeepRC's RepertoireDataReaderBinary
:param hdf5_filepath: the path to the HDF5 file
:param pre_loaded_hdf5_file: Optional: It is faster to load the hdf5 file into the RAM as dictionary instead
of keeping it on the disk. `pre_loaded_hdf5_file` is the loaded hdf5 file as dictionary.
If None, the hdf5 file will be read from the disk and consume less RAM.
:param indices: indices of the subset of repertoires in the data that will be used for this dataset.
If 'None', all repertoires will be used.
:param label: the label to be predicted
:param eval_only: whether the dataloader will only be used for evaluation (no training).
if false, sample_n_sequences can be set
:param is_train: whether this is a dataloader for training data. If true, self.training_batch_size is used.
:param n_workers: the number of workers used in torch.utils.data.DataLoader
:return: a Pytorch dataloader
"""
from deeprc.deeprc_binary.dataset_readers import RepertoireDataReaderBinary
from deeprc.deeprc_binary.dataset_readers import no_stack_collate_fn
sample_n_sequences = None if eval_only else self.sample_n_sequences
training_batch_size = self.training_batch_size if is_train else 1
dataset = RepertoireDataReaderBinary(
hdf5_filepath=str(hdf5_filepath), set_inds=indices,
sample_n_sequences=sample_n_sequences, target_label=label,
true_class_label_value=self.label_classes[0],
pre_loaded_hdf5_file=pre_loaded_hdf5_file,
verbose=False)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=training_batch_size,
shuffle=True,
num_workers=n_workers,
collate_fn=no_stack_collate_fn)
return dataloader
def _set_label_classes(self, y):
self.label = list(y.keys())[0]
label_classes_raw = {label: set(classes) for label, classes in y.items()}
self.label_is_bool = {label: label_classes_raw[label] == {True, False} for label in y.keys()}
label_classes = {label: sorted([str(class_name) for class_name in classes]) for label, classes in label_classes_raw.items()}
for label in label_classes.keys():
n_classes = len(label_classes[label])
assert n_classes == 2, f"DeepRC: this method assumes there are 2 possible classes per label, " \
f"for label '{label}' {n_classes} classes were found: {label_classes[label]}"
# If a possible label class is False, make sure it is the second class (so True is the first class)
# to prevent error in DeepRC RepertoireDataReaderBinary.__init__()
if label_classes[label][0] in ("False", "false"):
label_classes[label] = label_classes[label][::-1]
self.label_classes = label_classes[self.label]
def _prepare_caching_params(self, encoded_data: EncodedData, type: str, label_name: str):
return (("metadata_filepath", str(encoded_data.info["metadata_filepath"])),
("y", hashlib.sha256(str(encoded_data.labels[label_name]).encode("utf-8")).hexdigest()),
("label_name", label_name),
("type", type),
("validation_part", self.validation_part),
("add_positional_information", self.add_positional_information),
("n_input_features", self.n_input_features),
("kernel_size", self.kernel_size),
("n_kernels", self.n_kernels),
("n_additional_convs", self.n_additional_convs),
("n_attention_network_layers", self.n_attention_network_layers),
("n_attention_network_units", self.n_attention_network_units),
("n_output_network_units", self.n_output_network_units),
("consider_seq_counts", self.consider_seq_counts),
("sequence_reduction_fraction", self.sequence_reduction_fraction),
("reduction_mb_size", self.reduction_mb_size),
("n_updates", self.n_updates),
("n_torch_threads", self.n_torch_threads),
("learning_rate", self.learning_rate),
("l1_weight_decay", self.l1_weight_decay),
("l2_weight_decay", self.l2_weight_decay),
("sample_n_sequences", self.sample_n_sequences),
("training_batch_size", self.training_batch_size),
("n_workers", self.n_workers),
("evaluate_at", self.evaluate_at),
("pytorch_device_name", self.pytorch_device_name))
[docs] def fit(self, encoded_data: EncodedData, label_name: str, cores_for_training: int = 2):
assert encoded_data.encoding == "DeepRCEncoder", f"DeepRC: ML method DeepRC is only compatible with the DeepRC encoder, found {encoded_data.encoding.replace('Encoder','')} encoder"
self.feature_names = encoded_data.feature_names
self._set_label_classes({label_name: encoded_data.labels[label_name]})
self.model = CacheHandler.memo_by_params(self._prepare_caching_params(encoded_data, "fit", label_name),
lambda: self._fit(encoded_data, label_name, cores_for_training))
def _fit(self, encoded_data: EncodedData, label_name: str, cores_for_training: int = 2):
hdf5_filepath = self._metadata_to_hdf5(encoded_data.info["metadata_filepath"], label_name)
pre_loaded_hdf5_file = self._load_dataset_in_ram(hdf5_filepath) if self.keep_dataset_in_ram else None
train_indices, val_indices = self._get_train_val_indices(len(encoded_data.example_ids), encoded_data.labels[label_name])
self.max_seq_len = encoded_data.info["max_sequence_length"]
self._fit_for_label(hdf5_filepath, pre_loaded_hdf5_file, train_indices, val_indices, label_name, cores_for_training)
self.label = label_name
return self.model
def _fit_for_label(self, hdf5_filepath: Path, pre_loaded_hdf5_file, train_indices, val_indices, label: str, cores_for_training: int):
from deeprc.deeprc_binary.architectures import DeepRC as DeepRCInternal
train_dataloader = self.make_data_loader(hdf5_filepath, pre_loaded_hdf5_file, train_indices, label, eval_only=False, is_train=True,
n_workers=self.n_workers)
train_eval_dataloader = self.make_data_loader(hdf5_filepath, pre_loaded_hdf5_file, train_indices, label, eval_only=True, is_train=True)
val_eval_dataloader = self.make_data_loader(hdf5_filepath, pre_loaded_hdf5_file, val_indices, label, eval_only=True, is_train=False)
self.model = DeepRCInternal(n_input_features=self.n_input_features, n_output_features=1, max_seq_len=self.max_seq_len,
kernel_size=self.kernel_size, consider_seq_counts=self.consider_seq_counts,
n_kernels=self.n_kernels, n_additional_convs=self.n_additional_convs,
n_attention_network_layers=self.n_attention_network_layers,
n_attention_network_units=self.n_attention_network_units,
n_output_network_layers=0,
n_output_network_units=self.n_output_network_units,
add_positional_information=self.add_positional_information,
sequence_reduction_fraction=self.sequence_reduction_fraction,
reduction_mb_size=self.reduction_mb_size, device=self.pytorch_device)
self.training_function(self.model, trainingset_dataloader=train_dataloader, trainingset_eval_dataloader=train_eval_dataloader,
validationset_eval_dataloader=val_eval_dataloader, results_directory=self.result_path / "deeprc_log",
n_updates=self.n_updates, num_torch_threads=self.n_torch_threads, learning_rate=self.learning_rate,
l1_weight_decay=self.l1_weight_decay, l2_weight_decay=self.l2_weight_decay,
show_progress=False, device=self.pytorch_device, evaluate_at=self.evaluate_at)
[docs] def fit_by_cross_validation(self, encoded_data: EncodedData, number_of_splits: int = 5, label_name: str = None, cores_for_training: int = -1,
optimization_metric=None):
warnings.warn("DeepRC: cross-validation on this classifier is not defined: fitting one model instead...")
self.fit(encoded_data, label_name)
[docs] def get_model(self):
return self.model
[docs] def get_params(self):
return {name: param.data.tolist() for name, param in self.model.named_parameters()}
[docs] def check_is_fitted(self, label_name: str):
if label_name != self.label:
raise NotFittedError("This DeepRCs instance is not fitted yet. "
"Call 'fit' with appropriate arguments before using this method.")
[docs] def predict(self, encoded_data: EncodedData, label_name: str):
probabilities = self.predict_proba(encoded_data, label_name)
predictions = dict()
classes = self.get_classes()
pos_class_probs = probabilities[label_name][:, 0]
predictions[label_name] = [classes[0] if probability > 0.5 else classes[1] for probability in pos_class_probs]
if self.label_is_bool[label_name]:
predictions[label_name] = [pred_class == "True" for pred_class in predictions[label_name]]
return predictions
[docs] def predict_proba(self, encoded_data: EncodedData, label_name: str):
self.check_is_fitted(label_name)
probabilities = {}
hdf5_filepath = self._metadata_to_hdf5(encoded_data.info["metadata_filepath"], label_name)
pre_loaded_hdf5_file = self._load_dataset_in_ram(hdf5_filepath) if self.keep_dataset_in_ram else None
test_dataloader = self.make_data_loader(hdf5_filepath, pre_loaded_hdf5_file, indices=None, label=label_name, eval_only=True, is_train=False)
probs_pos_class = self._model_predict(self.model, test_dataloader)
probabilities[label_name] = np.vstack((probs_pos_class, 1 - probs_pos_class)).T
return probabilities
def _model_predict(self, model, dataloader):
"""Based on the DeepRC function evaluate (deeprc.deeprc_binary.training.evaluate)"""
with torch.no_grad():
model.to(device=self.pytorch_device)
scoring_predictions = []
for scoring_data in tqdm(dataloader, total=len(dataloader), desc="Evaluating model",
disable=True, position=1):
# Get samples as lists
labels, inputs, sequence_lengths, counts_per_sequence, sample_ids = scoring_data
# Apply attention-based sequence reduction and create minibatch
labels, inputs, sequence_lengths, n_sequences = model.reduce_and_stack_minibatch(
labels, inputs, sequence_lengths, counts_per_sequence)
# Compute predictions from reduced sequences
logit_outputs = model(inputs, n_sequences)
prediction = torch.sigmoid(logit_outputs)
scoring_predictions.append(prediction)
# predictions
scoring_predictions = torch.cat(scoring_predictions, dim=0).float().cpu().numpy()
return scoring_predictions
[docs] def load(self, path: Path):
name = FilenameHandler.get_filename(self.__class__.__name__, "pt")
file_path = path / name
if file_path.is_file():
self.model = torch.load(str(file_path))
self.model.eval()
else:
raise FileNotFoundError(f"{self.__class__.__name__} model could not be loaded from {file_path}. "
f"Check if the path to the {name} file is properly set.")
[docs] def store(self, path, feature_names=None, details_path: Path = None):
PathBuilder.build(path)
name = FilenameHandler.get_filename(self.__class__.__name__, "pt")
torch.save(self.model, str(path / name))
if details_path is None:
params_path = path / FilenameHandler.get_filename(self.__class__.__name__, "yaml")
else:
params_path = details_path
with params_path.open("w") as file:
desc = {
**(self.get_params()),
"feature_names": feature_names,
"classes": self.get_classes()
}
yaml.dump(desc, file)
[docs] def check_if_exists(self, path):
file_path = path / FilenameHandler.get_filename(self.__class__.__name__, "pt")
return file_path.is_file()
[docs] def get_classes(self) -> list:
return self.label_classes
[docs] def get_package_info(self) -> str:
return 'immuneML ' + Util.get_immuneML_version() + '; deepRC ' + pkg_resources.get_distribution('DeepRC').version
[docs] def get_feature_names(self) -> list:
return self.feature_names
[docs] def can_predict_proba(self) -> bool:
return True
[docs] def get_class_mapping(self) -> dict:
return {}
[docs] def get_label(self) -> str:
return self.label