Source code for immuneML.reports.ml_reports.DeepRCMotifDiscovery

import warnings
from pathlib import Path

import numpy as np
import torch
from matplotlib import pyplot as plt
from tqdm import tqdm

from immuneML.data_model.dataset.Dataset import Dataset
from immuneML.hyperparameter_optimization.HPSetting import HPSetting
from immuneML.ml_methods.DeepRC import DeepRC
from immuneML.ml_methods.MLMethod import MLMethod
from immuneML.reports.ReportOutput import ReportOutput
from immuneML.reports.ReportResult import ReportResult
from immuneML.reports.ml_reports.MLReport import MLReport
from immuneML.util.ParameterValidator import ParameterValidator
from immuneML.util.PathBuilder import PathBuilder


[docs]class DeepRCMotifDiscovery(MLReport): """ This report plots the contributions of (i) input sequences and (ii) kernels to trained :ref:`DeepRC` model with respect to the test dataset. Contributions are computed using integrated gradients (IG). This report produces two figures: - inputs_integrated_gradients: Shows the contributions of the characters within the input sequences (test dataset) that was most important for immune status prediction of the repertoire. IG is only applied to sequences of positive class repertoires. - kernel_integrated_gradients: Shows the 1D CNN kernels with the highest contribution over all positions and amino acids. For both inputs and kernels: Larger characters in the extracted motifs indicate higher contribution, with blue indicating positive contribution and red indicating negative contribution towards the prediction of the immune status. For kernels only: contributions to positional encoding are indicated by < (beginning of sequence), ∧ (center of sequence), and > (end of sequence). See :ref:`DeepRCMotifDiscovery for repertoire classification` for a usage example. Reference: Michael Widrich, Bernhard Schäfl, Milena Pavlović, Geir Kjetil Sandve, Sepp Hochreiter, Victor Greiff, Günter Klambauer ‘DeepRC: Immune repertoire classification with attention-based deep massive multiple instance learning’. bioRxiv preprint doi: `https://doi.org/10.1101/2020.04.12.03815 <https://doi.org/10.1101/2020.04.12.038158>`_ Arguments: n_steps (int): Number of IG steps (more steps -> better path integral -> finer contribution values). 50 is usually good enough. threshold (float): Only applies to the plotting of kernels. Contributions are normalized to range [0, 1], and only kernels with normalized contributions above threshold are plotted. YAML specification: .. indent with spaces .. code-block:: yaml my_deeprc_report: DeepRCMotifDiscovery: threshold: 0.5 n_steps: 50 """ def __init__(self, n_steps, threshold, train_dataset: Dataset = None, test_dataset: Dataset = None, method: MLMethod = None, result_path: Path = None, name: str = None, hp_setting: HPSetting = None, label=None, number_of_processes: int = 1): super().__init__(train_dataset=train_dataset, test_dataset=test_dataset, method=method, result_path=result_path, name=name, hp_setting=hp_setting, label=label, number_of_processes=number_of_processes) self.n_steps = n_steps self.threshold = threshold self.filename_inputs = Path("inputs_integrated_gradients.pdf") self.filename_kernels = Path("kernel_integrated_gradients.pdf")
[docs] @classmethod def build_object(cls, **kwargs): location = "DeepRCMotifDiscovery" name = kwargs["name"] if "name" in kwargs else None ParameterValidator.assert_type_and_value(kwargs["n_steps"], int, location, "n_steps", min_inclusive=1) ParameterValidator.assert_type_and_value(kwargs["threshold"], float, location, "threshold", min_inclusive=0, max_inclusive=1) return DeepRCMotifDiscovery(n_steps=kwargs["n_steps"], threshold=kwargs["threshold"], name=name)
def _generate(self) -> ReportResult: PathBuilder.build(self.result_path) test_metadata_filepath = self.test_dataset.encoded_data.info['metadata_filepath'] hdf5_filepath = self.method._metadata_to_hdf5(metadata_filepath=test_metadata_filepath, label_name=self.label.name) n_examples_test = len(self.test_dataset.encoded_data.example_ids) indices = np.array(range(n_examples_test)) dataloader = self.method.make_data_loader(hdf5_filepath, pre_loaded_hdf5_file=None, indices=indices, label_name=self.label.name, eval_only=True, is_train=False) path_inputs = self.result_path / self.filename_inputs path_kernels = self.result_path / self.filename_kernels self.compute_contributions(intgrds_set_loader=dataloader, deeprc_model=self.method.model, n_steps=self.n_steps, threshold=self.threshold, path_inputs=path_inputs, path_kernels=self.result_path / self.filename_kernels) return ReportResult(self.name, info="Plots the contributions of (i) input sequences and (ii) kernels to trained `DeepRC` model with respect to the test dataset. Contributions are computed using integrated gradients.", output_figures=[ReportOutput(path_inputs, "Integrated Gradients over the inputs to DeepRC"), ReportOutput(path_kernels, "Integrated Gradients over the kernels of DeepRC")])
[docs] def compute_contributions(self, intgrds_set_loader: torch.utils.data.DataLoader, deeprc_model, n_steps: int = 50, threshold: float = 0.5, path_inputs: Path = Path("inputs_integrated_gradients.pdf"), path_kernels: Path = Path("kernel_integrated_gradients.pdf")): """ Compute and plot contributions of sequences and motifs to trained DeepRC model, given a dataset. Contribution is computed using integrated gradients (IG). Author -- Michael Widrich Created on -- 2020-07-20 Contact -- michael.widrich@jku.at Parameters ---------- intgrds_set_loader : torch.utils.data.DataLoader The dataset to compute IG for in form of a PyTorch DataLoader following the DeepRC format. E.g. one of the dataloaders returned by deeprc.deeprc_binary.predefined_datasets.cmv_dataset(). deeprc_model : deeprc.deeprc_binary.architectures.DeepRC DeepRC model to compute IG for. Weights of first CNN layer are accessed via deeprc_model.sequence_embedding_16bit.conv_aas.weight . n_steps : int Number of IG steps (more steps -> better path integral -> finer contribution values). 50 is usually good enough. threshold : float Threshold for plotting of kernels (=motifs). Contributions are normalized to range [0, 1] and then threshold is applied. 0.5 -> only kernels with normalized contributions above 0.5 are plotted. path_inputs : Path path for inputs integrated gradients plot path_kernels : Path path for kernels integrated gradients plot """ intgrds_set = intgrds_set_loader.dataset # # Integrated gradients w.r.t. kernels # active_kernels = deeprc_model.sequence_embedding_16bit.conv_aas.weight original_kernel_values = active_kernels.cpu().clone().data.detach().to(dtype=torch.float32) # Compute IG w.r.t. kernels -> prepare array for IG values int_grd_kernels = np.zeros(original_kernel_values.shape, dtype=np.float) # Compute interpolated kernels interp_factors = torch.linspace(1, 0, n_steps) interpolated_kernels = [original_kernel_values * intf for intf in interp_factors] most_important_inputs = [] most_important_inputs_intgrds = [] most_important_inputs_lens = [] for data_index, data in tqdm(enumerate(intgrds_set_loader), total=len(intgrds_set_loader), desc="calc intgrds", ncols=10, position=0): # Get 1 sample as lists labels, inputs, sequence_lengths, duplicates_per_sequence, sample_ids = data # Only consider positive samples if labels[0][0]: continue # Perform attention pooling with torch.no_grad(): labels, inputs, sequence_lengths, n_sequences = deeprc_model.reduce_and_stack_minibatch( labels, inputs, sequence_lengths, duplicates_per_sequence) # Integrated gradients for kernels for step in tqdm(range(n_steps), total=n_steps, desc="kernels", ncols=10, position=1): deeprc_model.zero_grad() with torch.no_grad(): active_kernels.data[:] = interpolated_kernels[step].data.to(dtype=torch.float16) logit_outputs = deeprc_model(inputs, n_sequences) # prediction = torch.sigmoid(logit_outputs) logit_outputs.backward() int_grd_kernels[:] += ( original_kernel_values * active_kernels.grad.to(device='cpu', dtype=torch.float32) / n_steps / len(intgrds_set_loader)).detach().data.cpu().numpy() # Integrated gradients for inputs with torch.no_grad(): active_kernels.data[:] = original_kernel_values.data.to(dtype=torch.float16) interpolated_inputs = [inputs * intf for intf in interp_factors] int_grd_inputs = np.zeros(inputs.shape, dtype=np.float) inputs_cpu = inputs.to(device='cpu', dtype=torch.float32) for step in tqdm(range(n_steps), total=n_steps, desc="inputs", ncols=10, position=1): deeprc_model.zero_grad() i = torch.tensor(interpolated_inputs[step], requires_grad=True) i.requires_grad_() i.retain_grad() logit_outputs = deeprc_model(i, n_sequences) int_grd_inputs[:] += ( inputs_cpu * torch.autograd.grad(logit_outputs, i, retain_graph=True)[0].to(device='cpu', dtype=torch.float32) / n_steps).detach().data.cpu().numpy() most_important_input_ind = np.argmax(np.sum(np.sum(int_grd_inputs, 1), 1)) most_important_inputs_intgrds.append(int_grd_inputs[most_important_input_ind, :, :].sum(axis=0)) most_important_inputs.append(inputs_cpu[most_important_input_ind].detach().cpu().numpy()) most_important_inputs_lens.append(sequence_lengths[most_important_input_ind].detach().cpu().numpy()) # Plot inputs abs_max = np.max([np.abs(i).max() for i in most_important_inputs_intgrds]) for i in range(len(most_important_inputs_intgrds)): most_important_inputs_intgrds[i][:] /= abs_max most_important_inputs[i] = intgrds_set.inds_to_aa(np.argmax(most_important_inputs[i][:-3], axis=0)) self.plot_inputs_text( chars=most_important_inputs, colorgrad=most_important_inputs_intgrds, seq_lens=most_important_inputs_lens, file_path=path_inputs) # Get/Plot kernels with highest contribution over all positions and AAs kernel_contrib = int_grd_kernels.sum(axis=2).sum(axis=1) normed_kernel_contrib = kernel_contrib / kernel_contrib.max() n_top_kernels = np.sum(normed_kernel_contrib > threshold) top_kernel_inds = normed_kernel_contrib.argsort()[::-1] top_kernel_inds = top_kernel_inds[:n_top_kernels] top_kernels = int_grd_kernels[top_kernel_inds] top_kernels /= kernel_contrib.max() self.plot_kernels_text(kernels=top_kernels, charset=intgrds_set.aas, file_path=path_kernels)
[docs] def plot_inputs_text(self, chars, colorgrad, seq_lens, file_path): """ Author -- Michael Widrich Created on -- 2020-07-20 Contact -- michael.widrich@jku.at """ char_scale = 10 char_offset = 1 max_seq_len = max(seq_lens) + 1 n_seqs = len(chars) fig = plt.figure(figsize=(int(np.round(max_seq_len / 15)), max(int(np.round(n_seqs / 5)), 5))) # fig = plt.figure(figsize=(int(np.round(max_seq_len)), max(int(np.round(n_seqs)), 5))) _ = [([fig.text((char_i + 1) / max_seq_len, 0 + (1 - (seq_i / (n_seqs))) - 0.05, chars[seq_i][char_i].decode("utf-8"), size=char_offset + abs(colorgrad[seq_i][char_i]) * char_scale, ha='center', va='center', color='blue' if colorgrad[seq_i][char_i] > 0 else 'red') for char_i in range(int(seq_lens[seq_i]))], None) for seq_i in range(n_seqs)] fig.savefig(str(file_path)) plt.close(fig) del fig
[docs] def plot_kernels_text(self, kernels, charset, file_path): """ Author -- Michael Widrich Created on -- 2020-07-20 Contact -- michael.widrich@jku.at """ char_scale = 100 char_offset = 1 max_n_kernels = 15 max_n_kernel_features = 15 kernels = kernels[:max_n_kernels, ..., :max_n_kernel_features] n_kernels = kernels.shape[0] if not n_kernels: return None fig, axes = plt.subplots(1, n_kernels, figsize=(int(np.round(kernels.shape[-1] * kernels.shape[0] / 3)), int(np.round(len(charset) / 3)))) if isinstance(axes, np.ndarray): axes = list(axes) if isinstance(axes[0], np.ndarray): axes = [list(a) for a in axes] if not isinstance(axes, list): axes = [axes] for a_i, ax in enumerate(axes): ax.axis('off') for out_i in range(kernels.shape[-1]): sorted_charset = [(char, char_offset + abs(char_contrib) * char_scale, 'blue' if kernels[a_i, char_i, out_i] > 0 else 'red') for char_i, char, char_contrib in zip(range(len(charset)), charset, kernels[a_i, :, out_i])] sorted_charset.sort(key=lambda x: x[1]) _ = [ax.text(out_i / kernels.shape[-1], 0 + (1 - (char_i / len(charset))), char, size=char_size, ha='center', va='center', color=color) for char_i, (char, char_size, color) in enumerate(sorted_charset)] fig.savefig(str(file_path)) plt.close(fig) del fig
[docs] def check_prerequisites(self): run_report = True if not hasattr(self, "result_path") or self.result_path is None: warnings.warn(f"{self.__class__.__name__} requires an output 'path' to be set. {self.__class__.__name__} report will not be created.") run_report = False if not isinstance(self.method, DeepRC): warnings.warn( f"{self.__class__.__name__} can only be used in combination with the DeepRC ML method. {self.__class__.__name__} report will not be created.") run_report = False if self.test_dataset.encoded_data is None: warnings.warn( f"{self.__class__.__name__}: test dataset is not encoded and can not be run. " f"{self.__class__.__name__} report will not be created.") run_report = False return run_report