Source code for immuneML.ml_methods.classifiers.DeepRC

import hashlib
import warnings
from pathlib import Path

import numpy as np
import pkg_resources
import torch
import yaml
from sklearn.exceptions import NotFittedError
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

from immuneML.caching.CacheHandler import CacheHandler
from immuneML.data_model.EncodedData import EncodedData
from immuneML.encodings.deeprc.DeepRCEncoder import DeepRCEncoder
from immuneML.environment.Label import Label
from immuneML.ml_methods.classifiers.MLMethod import MLMethod
from immuneML.ml_methods.util.Util import Util
from immuneML.util.FilenameHandler import FilenameHandler
from immuneML.util.PathBuilder import PathBuilder
from immuneML.util.ReflectionHandler import ReflectionHandler


[docs] class DeepRC(MLMethod): """ This classifier uses the DeepRC method for repertoire classification. The DeepRC ML method should be used in combination with the DeepRC encoder. Also consider using the :ref:`DeepRCMotifDiscovery` report for interpretability. Notes: - DeepRC uses PyTorch functionalities that depend on GPU. Therefore, DeepRC does not work on a CPU. - This wrapper around DeepRC currently only supports binary classification. Reference: Michael Widrich, Bernhard Schäfl, Milena Pavlović, Geir Kjetil Sandve, Sepp Hochreiter, Victor Greiff, Günter Klambauer ‘DeepRC: Immune repertoire classification with attention-based deep massive multiple instance learning’. bioRxiv preprint doi: `https://doi.org/10.1101/2020.04.12.038158 <https://doi.org/10.1101/2020.04.12.038158>`_ **Specification arguments:** - validation_part (float): the part of the data that will be used for validation, the rest will be used for training. - add_positional_information (bool): whether positional information should be included in the input features. - kernel_size (int): the size of the 1D-CNN kernels. - n_kernels (int): the number of 1D-CNN kernels in each layer. - n_additional_convs (int): Number of additional 1D-CNN layers after first layer - n_attention_network_layers (int): Number of attention layers to compute keys - n_attention_network_units (int): Number of units in each attention layer - n_output_network_units (int): Number of units in the output layer - consider_seq_counts (bool): whether the input data should be scaled by the receptor sequence counts. - sequence_reduction_fraction (float): Fraction of number of sequences to which to reduce the number of sequences per bag based on attention weights. Has to be in range [0,1]. - reduction_mb_size (int): Reduction of sequences per bag is performed using minibatches of reduction_mb_size` sequences to compute the attention weights. - n_updates (int): Number of updates to train for - n_torch_threads (int): Number of parallel threads to allow PyTorch - learning_rate (float): Learning rate for adam optimizer - l1_weight_decay (float): l1 weight decay factor. l1 weight penalty will be added to loss, scaled by `l1_weight_decay` - l2_weight_decay (float): l2 weight decay factor. l2 weight penalty will be added to loss, scaled by `l2_weight_decay` - sequence_counts_scaling_fn: it can either be `log` (logarithmic scaling of sequence counts) or None - evaluate_at (int): Evaluate model on training and validation set every `evaluate_at` updates. This will also check for a new best model for early stopping. - sample_n_sequences (int): Optional random sub-sampling of `sample_n_sequences` sequences per repertoire. Number of sequences per repertoire might be smaller than `sample_n_sequences` if repertoire is smaller or random indices have been drawn multiple times. If None, all sequences will be loaded for each repertoire. - training_batch_size (int): Number of repertoires per minibatch during training. - n_workers (int): Number of background processes to use for converting dataset to hdf5 container and training set data loader. - pytorch_device_name (str): The name of the pytorch device to use. This name will be passed to torch.device(self.pytorch_device_name). The default value is cuda:0 **YAML specification:** .. indent with spaces .. code-block:: yaml definitions: ml_methods: my_deeprc_method: DeepRC: validation_part: 0.2 add_positional_information: True kernel_size: 9 """ def __init__(self, validation_part, add_positional_information, kernel_size, n_kernels, n_additional_convs, n_attention_network_layers, n_attention_network_units, n_output_network_units, consider_seq_counts, sequence_reduction_fraction, reduction_mb_size, n_updates, n_torch_threads, learning_rate, l1_weight_decay, l2_weight_decay, evaluate_at, sample_n_sequences, training_batch_size, n_workers, sequence_counts_scaling_fn, keep_dataset_in_ram, pytorch_device_name): super(DeepRC, self).__init__() if not ReflectionHandler.is_installed("deeprc"): raise RuntimeError(f"{DeepRC.__name__}: deeprc module is not installed. Please check the documentation at " f"https://docs.immuneml.uio.no/installation/install_with_package_manager.html for " f"instructions how to install it.") from deeprc.training import train self.training_function = train self.model = None self.max_seq_len = None self.keep_dataset_in_ram = keep_dataset_in_ram self.pytorch_device_name = pytorch_device_name self.pytorch_device = torch.device(self.pytorch_device_name) # ML model setting (not inherited from DeepRC code) self.validation_part = validation_part # DeepRC class settings: self.add_positional_information = add_positional_information self.n_input_features = 20 self.kernel_size = kernel_size self.n_kernels = n_kernels self.n_additional_convs = n_additional_convs self.n_attention_network_layers = n_attention_network_layers self.n_attention_network_units = n_attention_network_units self.n_output_network_units = n_output_network_units self.consider_seq_counts = consider_seq_counts self.sequence_reduction_fraction = sequence_reduction_fraction self.reduction_mb_size = reduction_mb_size from deeprc.dataset_readers import log_sequence_count_scaling, no_sequence_count_scaling self.sequence_counts_scaling_fn = log_sequence_count_scaling if sequence_counts_scaling_fn == "log" \ else no_sequence_count_scaling # train function settings: self.evaluate_at = evaluate_at self.n_updates = n_updates self.n_torch_threads = n_torch_threads self.learning_rate = learning_rate self.l1_weight_decay = l1_weight_decay self.l2_weight_decay = l2_weight_decay # Dataloader related settings: self.sample_n_sequences = sample_n_sequences self.training_batch_size = training_batch_size self.n_workers = n_workers def _metadata_to_hdf5(self, metadata_filepath: Path, label_name: str): from deeprc.dataset_converters import DatasetToHDF5 hdf5_filepath = metadata_filepath.parent / f"{metadata_filepath.stem}.hdf5" converter = DatasetToHDF5(repertoiresdata_directory=str(metadata_filepath.parent), sequence_column=DeepRCEncoder.SEQUENCE_COLUMN, sequence_counts_column=DeepRCEncoder.COUNTS_COLUMN, column_sep=DeepRCEncoder.SEP, filename_extension=f".{DeepRCEncoder.EXTENSION}", verbose=False) converter.save_data_to_file(output_file=str(hdf5_filepath), n_workers=self.n_workers) return hdf5_filepath def _load_dataset_in_ram(self, hdf5_filepath: Path): import h5py with h5py.File(str(hdf5_filepath), 'r') as hf: pre_loaded_hdf5_file = dict() pre_loaded_hdf5_file['seq_lens'] = hf['sampledata']['seq_lens'][:] pre_loaded_hdf5_file['counts_per_sequence'] = hf['sampledata']['counts_per_sequence'][:] pre_loaded_hdf5_file['amino_acid_sequences'] = hf['sampledata']['amino_acid_sequences'][:] return pre_loaded_hdf5_file def _get_train_val_indices(self, n_examples, classes): """splits the data to training and validation and attempts to preserve the class distribution if possible""" indices = np.arange(0, n_examples) train_indices, val_indices = train_test_split(indices, test_size=self.validation_part, shuffle=True, stratify=classes) return train_indices, val_indices
[docs] def make_data_loader(self, full_dataset, indices, label_name, eval_only: bool, is_train: bool, n_workers=1): """ Creates a pytorch dataloader using DeepRC's RepertoireDataReaderBinary :param hdf5_filepath: the path to the HDF5 file :param pre_loaded_hdf5_file: Optional: It is faster to load the hdf5 file into the RAM as dictionary instead of keeping it on the disk. `pre_loaded_hdf5_file` is the loaded hdf5 file as dictionary. If None, the hdf5 file will be read from the disk and consume less RAM. :param indices: indices of the subset of repertoires in the data that will be used for this dataset. If 'None', all repertoires will be used. :param label_name: the name of the label to be predicted :param eval_only: whether the dataloader will only be used for evaluation (no training). if false, sample_n_sequences can be set :param is_train: whether this is a dataloader for training data. If true, self.training_batch_size is used. :param n_workers: the number of workers used in torch.utils.data.DataLoader :return: a Pytorch dataloader """ from deeprc.dataset_readers import no_stack_collate_fn from deeprc.dataset_readers import RepertoireDatasetSubset as DeepRCRepDatasetSubset sample_n_sequences = None if eval_only else self.sample_n_sequences batch_size = self.training_batch_size if is_train else 1 dataset = DeepRCRepDatasetSubset(dataset=full_dataset, indices=indices, sample_n_sequences=sample_n_sequences) data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, collate_fn=no_stack_collate_fn, multiprocessing_context='spawn') return data_loader
def _prepare_caching_params(self, encoded_data: EncodedData, type: str, label_name: str): return (("metadata_filepath", str(encoded_data.info["metadata_filepath"])), ("y", hashlib.sha256(str(encoded_data.labels[label_name]).encode("utf-8")).hexdigest()), ("label_name", label_name), ("type", type), ("validation_part", self.validation_part), ("add_positional_information", self.add_positional_information), ("n_input_features", self.n_input_features), ("kernel_size", self.kernel_size), ("n_kernels", self.n_kernels), ("n_additional_convs", self.n_additional_convs), ("n_attention_network_layers", self.n_attention_network_layers), ("n_attention_network_units", self.n_attention_network_units), ("n_output_network_units", self.n_output_network_units), ("consider_seq_counts", self.consider_seq_counts), ("sequence_reduction_fraction", self.sequence_reduction_fraction), ("reduction_mb_size", self.reduction_mb_size), ("n_updates", self.n_updates), ("n_torch_threads", self.n_torch_threads), ("learning_rate", self.learning_rate), ("l1_weight_decay", self.l1_weight_decay), ("l2_weight_decay", self.l2_weight_decay), ("sample_n_sequences", self.sample_n_sequences), ("training_batch_size", self.training_batch_size), ("n_workers", self.n_workers), ("evaluate_at", self.evaluate_at), ("pytorch_device_name", self.pytorch_device_name)) def _fit(self, encoded_data: EncodedData, cores_for_training: int = 2): self.model = CacheHandler.memo_by_params(self._prepare_caching_params(encoded_data, "fit", self.label.name), lambda: self._fit(encoded_data, self.label, cores_for_training)) def _fit_model(self, encoded_data: EncodedData, label: Label, cores_for_training: int = 2): hdf5_filepath, pre_loaded_hdf5_file = self._convert_dataset_to_hdf5(encoded_data, label) train_indices, val_indices = self._get_train_val_indices(len(encoded_data.example_ids), encoded_data.labels[label.name]) self.max_seq_len = encoded_data.info["max_sequence_length"] self._fit_for_label(encoded_data.info["metadata_filepath"], hdf5_filepath, train_indices, val_indices, label, cores_for_training) return self.model def _convert_dataset_to_hdf5(self, encoded_data, label): hdf5_filepath = self._metadata_to_hdf5(encoded_data.info["metadata_filepath"], label.name) pre_loaded_hdf5_file = None # self._load_dataset_in_ram(hdf5_filepath) if self.keep_dataset_in_ram else None return hdf5_filepath, pre_loaded_hdf5_file def _make_model(self, task_definition): from deeprc.architectures import DeepRC as DeepRCInternal # Create sequence embedding network (for CNN, kernel_size and n_kernels are important hyperparameters) from deeprc.architectures import SequenceEmbeddingCNN sequence_embedding_network = SequenceEmbeddingCNN( n_input_features=self.n_input_features + 3 * self.add_positional_information, kernel_size=self.kernel_size, n_kernels=self.n_kernels, n_layers=1) # Create attention network from deeprc.architectures import AttentionNetwork attention_network = AttentionNetwork(n_input_features=self.n_kernels, n_layers=self.n_attention_network_layers, n_units=self.n_attention_network_units) # Create output network from deeprc.architectures import OutputNetwork output_network = OutputNetwork(n_input_features=self.n_kernels, n_output_features=task_definition.get_n_output_features(), n_layers=1, n_units=self.n_output_network_units) # Combine networks to DeepRC network return DeepRCInternal(max_seq_len=self.max_seq_len, sequence_embedding_network=sequence_embedding_network, attention_network=attention_network, output_network=output_network, consider_seq_counts=False, n_input_features=self.n_input_features, add_positional_information=self.add_positional_information, sequence_reduction_fraction=self.sequence_reduction_fraction, reduction_mb_size=self.reduction_mb_size, device=self.pytorch_device).to(device=self.pytorch_device) def _make_task_definition(self, label): from deeprc.task_definitions import TaskDefinition, BinaryTarget, MulticlassTarget if len(label.values) == 2: target = BinaryTarget(column_name=label.name, true_class_value=label.positive_class) else: target = MulticlassTarget(column_name=label.name, possible_target_values=label.values) return TaskDefinition(targets=[target]) def _fit_for_label(self, metadata_file, hdf5_filepath: Path, train_indices, val_indices, label: Label, cores_for_training: int): from deeprc.dataset_readers import RepertoireDataset as DeepRCRepDataset task_definition = self._make_task_definition(label) full_dataset = DeepRCRepDataset(metadata_filepath=metadata_file, hdf5_filepath=str(hdf5_filepath), sample_id_column=DeepRCEncoder.ID_COLUMN, metadata_file_column_sep=",", task_definition=task_definition, keep_in_ram=self.keep_dataset_in_ram, inputformat='NCL', sequence_counts_scaling_fn=self.sequence_counts_scaling_fn) train_dataloader = self.make_data_loader(full_dataset, train_indices, label.name, eval_only=False, is_train=True, n_workers=self.n_workers) train_eval_dataloader = self.make_data_loader(full_dataset, train_indices, eval_only=True, is_train=True, n_workers=1, label_name=label.name) val_dataloader = self.make_data_loader(full_dataset, val_indices, eval_only=True, is_train=False, label_name=label.name, n_workers=1) self.model = self._make_model(task_definition) self.training_function(self.model, trainingset_dataloader=train_dataloader, trainingset_eval_dataloader=train_eval_dataloader, validationset_eval_dataloader=val_dataloader, results_directory=self.result_path / "deep_rc_log", n_updates=self.n_updates, num_torch_threads=self.n_torch_threads, learning_rate=self.learning_rate, l1_weight_decay=self.l1_weight_decay, l2_weight_decay=self.l2_weight_decay, show_progress=False, device=self.pytorch_device, evaluate_at=self.evaluate_at, task_definition=task_definition, early_stopping_target_id=label.name)
[docs] def get_params(self): return {name: param.data.tolist() for name, param in self.model.named_parameters()}
[docs] def check_is_fitted(self, label_name: str): if label_name != self.label.name: raise NotFittedError("This DeepRCs instance is not fitted yet. " "Call 'fit' with appropriate arguments before using this method.")
def _predict(self, encoded_data: EncodedData): probabilities = self._predict_proba(encoded_data) pos_class_probs = probabilities[self.label.name][self.label.positive_class] negative_class = self.label.get_binary_negative_class() return {self.label.name: [self.label.positive_class if probability > 0.5 else negative_class for probability in pos_class_probs]} def _predict_proba(self, encoded_data: EncodedData): from deeprc.dataset_readers import RepertoireDataset as DeepRCRepDataset self.check_is_fitted(self.label.name) hdf5_filepath, _ = self._convert_dataset_to_hdf5(encoded_data, self.label) task_definition = self._make_task_definition(self.label) test_dataset = DeepRCRepDataset(metadata_filepath=encoded_data.info['metadata_filepath'], hdf5_filepath=str(hdf5_filepath), sample_id_column=DeepRCEncoder.ID_COLUMN, metadata_file_column_sep=DeepRCEncoder.SEP, task_definition=task_definition, keep_in_ram=self.keep_dataset_in_ram, inputformat='NCL', sequence_counts_scaling_fn=self.sequence_counts_scaling_fn) test_dataloader = self.make_data_loader(test_dataset, indices=None, label_name=self.label.name, eval_only=True, is_train=False) probs_pos_class = self._model_predict(self.model, test_dataloader) # TODO: update for multiclass return {label.name: {label.positive_class: probs_pos_class, label.get_binary_negative_class(): 1 - probs_pos_class}} def _model_predict(self, model, dataloader): """Based on the DeepRC function evaluate (deeprc.training.evaluate)""" from tqdm import tqdm with torch.no_grad(): model.to(device=self.pytorch_device) scoring_predictions = [] for scoring_data in tqdm(dataloader, total=len(dataloader), desc="Evaluating model", disable=True, position=1): # Get samples as lists labels, inputs, sequence_lengths, counts_per_sequence, sample_ids = scoring_data # Apply attention-based sequence reduction and create minibatch labels, inputs, sequence_lengths, n_sequences = model.reduce_and_stack_minibatch( labels, inputs, sequence_lengths, counts_per_sequence) # Compute predictions from reduced sequences logit_outputs = model(inputs, n_sequences) prediction = torch.sigmoid(logit_outputs) scoring_predictions.append(prediction) # predictions scoring_predictions = torch.cat(scoring_predictions, dim=0).float().cpu().numpy() return scoring_predictions
[docs] def load(self, path: Path): name = FilenameHandler.get_filename(self.__class__.__name__, "pt") file_path = path / name if file_path.is_file(): self.model = torch.load(str(file_path)) self.model.eval() else: raise FileNotFoundError(f"{self.__class__.__name__} model could not be loaded from {file_path}. " f"Check if the path to the {name} file is properly set.") params_path = path / FilenameHandler.get_filename(self.__class__.__name__, "yaml") if params_path.is_file(): with params_path.open("r") as file: desc = yaml.safe_load(file) if "label" in desc: setattr(self, "label", Label(**desc["label"])) for param in ["feature_names", "classes"]: if param in desc: setattr(self, param, desc[param])
[docs] def store(self, path): PathBuilder.build(path) name = FilenameHandler.get_filename(self.__class__.__name__, "pt") torch.save(self.model, str(path / name)) params_path = path / FilenameHandler.get_filename(self.__class__.__name__, "yaml") with params_path.open("w") as file: desc = { **(self.get_params()), "feature_names": self.get_feature_names(), "classes": self.get_classes() } if self.label is not None: desc["label"] = self.label.get_desc_for_storage() yaml.dump(desc, file)
[docs] def get_package_info(self) -> str: return Util.get_immuneML_version() + '; deepRC ' + pkg_resources.get_distribution('DeepRC').version
[docs] def can_predict_proba(self) -> bool: return True
[docs] def can_fit_with_example_weights(self) -> bool: return False
[docs] def get_compatible_encoders(self): return [DeepRCEncoder]