Source code for immuneML.ml_methods.AtchleyKmerMILClassifier

import copy
import logging
import random
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import yaml

from immuneML.data_model.encoded_data.EncodedData import EncodedData
from immuneML.environment.Label import Label
from immuneML.ml_methods.MLMethod import MLMethod
from immuneML.ml_methods.pytorch_implementations.PyTorchLogisticRegression import PyTorchLogisticRegression
from immuneML.ml_methods.util.Util import Util
from immuneML.util.PathBuilder import PathBuilder


[docs] class AtchleyKmerMILClassifier(MLMethod): """ A binary Repertoire classifier which uses the data encoded by :ref:`AtchleyKmer` encoder to predict the repertoire label. The original publication: Ostmeyer J, Christley S, Toby IT, Cowell LG. Biophysicochemical motifs in T cell receptor sequences distinguish repertoires from tumor-infiltrating lymphocytes and adjacent healthy tissue. Cancer Res. Published online January 1, 2019:canres.2292.2018. `doi:10.1158/0008-5472.CAN-18-2292 <https://cancerres.aacrjournals.org/content/79/7/1671>`_ . Arguments: iteration_count (int): max number of training iterations threshold (float): loss threshold at which to stop training if reached evaluate_at (int): log model performance every 'evaluate_at' iterations and store the model every 'evaluate_at' iterations if early stopping is used use_early_stopping (bool): whether to use early stopping learning_rate (float): learning rate for stochastic gradient descent random_seed (int): random seed used zero_abundance_weight_init (bool): whether to use 0 as initial weight for abundance term (if not, a random value is sampled from normal distribution with mean 0 and variance 1 / total_number_of_features number_of_threads: number of threads to be used for training initialization_count (int): how many times to repeat the fitting procedure from the beginning before choosing the optimal model (trains the model with multiple random initializations) pytorch_device_name (str): The name of the pytorch device to use. This name will be passed to torch.device(pytorch_device_name). YAML specification: .. indent with spaces .. code-block:: yaml my_kmer_mil_classifier: AtchleyKmerMILClassifier: iteration_count: 100 evaluate_at: 15 use_early_stopping: False learning_rate: 0.01 random_seed: 100 zero_abundance_weight_init: True number_of_threads: 8 threshold: 0.00001 initialization_count: 4 """ MIN_SEED_VALUE = 1 MAX_SEED_VALUE = 100000 def __init__(self, iteration_count: int = None, threshold: float = None, evaluate_at: int = None, use_early_stopping: bool = None, random_seed: int = None, learning_rate: float = None, zero_abundance_weight_init: bool = None, number_of_threads: int = None, result_path: Path = None, initialization_count: int = None, pytorch_device_name: str = None): super().__init__() self.logistic_regression = None self.random_seed = random_seed self.iteration_count = iteration_count self.threshold = threshold self.evaluate_at = evaluate_at self.use_early_stopping = use_early_stopping self.learning_rate = learning_rate self.zero_abundance_weight_init = zero_abundance_weight_init self.number_of_threads = number_of_threads self.class_mapping = None self.input_size = 0 self.result_path = result_path self.feature_names = None self.initialization_count = initialization_count self.pytorch_device_name = pytorch_device_name def _make_log_reg(self): return PyTorchLogisticRegression(in_features=self.input_size, zero_abundance_weight_init=self.zero_abundance_weight_init)
[docs] def fit(self, encoded_data: EncodedData, label: Label, cores_for_training: int = 2): self.feature_names = encoded_data.feature_names self.label = label self.class_mapping = Util.make_binary_class_mapping(encoded_data.labels[self.label.name], self.label.positive_class) mapped_y = Util.map_to_new_class_values(encoded_data.labels[self.label.name], self.class_mapping) self.logistic_regression = None min_loss = np.inf for initialization in range(self.initialization_count): random.seed(self.random_seed) random_seed = random.randint(AtchleyKmerMILClassifier.MIN_SEED_VALUE, AtchleyKmerMILClassifier.MAX_SEED_VALUE) Util.setup_pytorch(self.number_of_threads, random_seed, self.pytorch_device_name) self.input_size = encoded_data.examples.shape[1] log_reg = self._make_log_reg() loss = np.inf state = {"loss": loss, "model": None} loss_func = torch.nn.BCEWithLogitsLoss(reduction='mean') optimizer = torch.optim.SGD(log_reg.parameters(), lr=self.learning_rate) for iteration in range(self.iteration_count): # reset gradients optimizer.zero_grad() # compute predictions only for k-mers with max score max_logit_indices = self._get_max_logits_indices(encoded_data.examples, log_reg) example_count = encoded_data.examples.shape[0] examples = torch.from_numpy(encoded_data.examples).float()[torch.arange(example_count).long(), :, max_logit_indices] logits = log_reg(examples) # compute the loss loss = loss_func(logits, torch.tensor(mapped_y).float()) # perform update loss.backward() optimizer.step() # log current score and keep model for early stopping if specified if iteration % self.evaluate_at == 0 or iteration == self.iteration_count - 1: logging.info(f"AtchleyKmerMILClassifier: log loss at iteration {iteration + 1}/{self.iteration_count}: {loss}.") if state["loss"] < loss and self.use_early_stopping: state = {"loss": loss.numpy(), "model": copy.deepcopy(log_reg)} if loss < self.threshold: break logging.warning(f"AtchleyKmerMILClassifier: the logistic regression model did not converge.") if loss > state['loss'] and self.use_early_stopping: log_reg.load_state_dict(state["model"]) if min_loss > loss: self.logistic_regression = log_reg min_loss = loss
def _get_max_logits_indices(self, data, log_reg=None): with torch.no_grad(): if log_reg: logits = log_reg(torch.from_numpy(np.swapaxes(data, 1, 2).reshape(data.shape[0] * data.shape[2], -1))) else: logits = self.logistic_regression(torch.from_numpy(np.swapaxes(data, 1, 2).reshape(data.shape[0] * data.shape[2], -1))) logits = torch.reshape(logits, (data.shape[0], data.shape[2])) max_logits_indices = torch.argmax(logits, dim=1) return max_logits_indices.long()
[docs] def predict(self, encoded_data: EncodedData, label: Label): predictions_proba = self.predict_proba(encoded_data, label) return {label.name: [self.class_mapping[val] for val in (predictions_proba[label.name][label.positive_class] > 0.5).tolist()]}
[docs] def fit_by_cross_validation(self, encoded_data: EncodedData, number_of_splits: int = 5, label: Label = None, cores_for_training: int = -1, optimization_metric=None): logging.warning(f"AtchleyKmerMILClassifier: fitting by cross validation is not implemented internally for the model, fitting without " f"cross-validation instead.") self.fit(encoded_data, label)
[docs] def store(self, path: Path, feature_names=None, details_path: Path = None): PathBuilder.build(path) torch.save(copy.deepcopy(self.logistic_regression).state_dict(), str(path / "log_reg.pt")) custom_vars = copy.deepcopy(vars(self)) coefficients_df = pd.DataFrame(custom_vars["logistic_regression"].linear.weight.detach().numpy(), columns=feature_names) coefficients_df["bias"] = custom_vars["logistic_regression"].linear.bias.detach().numpy() coefficients_df.to_csv(path / "coefficients.csv", index=False) del custom_vars["result_path"] del custom_vars["logistic_regression"] del custom_vars["label"] if self.label: custom_vars["label"] = self.label.get_desc_for_storage() params_path = path / "custom_params.yaml" with params_path.open('w') as file: yaml.dump(custom_vars, file)
[docs] def load(self, path): params_path = path / "custom_params.yaml" with params_path.open("r") as file: custom_params = yaml.load(file, Loader=yaml.SafeLoader) for param, value in custom_params.items(): if hasattr(self, param): if param == "label": setattr(self, "label", Label(**value)) else: setattr(self, param, value) self.logistic_regression = self._make_log_reg() self.logistic_regression.load_state_dict(torch.load(str(path / "log_reg.pt")))
[docs] def check_if_exists(self, path) -> bool: return self.logistic_regression is not None
[docs] def get_params(self): params = copy.deepcopy(vars(self)) params["logistic_regression"] = copy.deepcopy(self.logistic_regression).state_dict() return params
[docs] def predict_proba(self, encoded_data: EncodedData, label: Label): self.logistic_regression.eval() example_count = encoded_data.examples.shape[0] max_logit_indices = self._get_max_logits_indices(encoded_data.examples) with torch.no_grad(): data = torch.from_numpy(encoded_data.examples).float()[torch.arange(example_count).long(), :, max_logit_indices] predictions = torch.sigmoid(self.logistic_regression(data)).numpy() return {label.name: {label.positive_class: predictions, label.get_binary_negative_class(): 1 - predictions}}
[docs] def get_label_name(self): return self.label.name
[docs] def get_package_info(self) -> str: return Util.get_immuneML_version()
[docs] def get_feature_names(self) -> list: return self.feature_names
[docs] def can_predict_proba(self) -> bool: return True
[docs] def get_class_mapping(self) -> dict: return self.class_mapping
[docs] def get_compatible_encoders(self): from immuneML.encodings.atchley_kmer_encoding.AtchleyKmerEncoder import AtchleyKmerEncoder return [AtchleyKmerEncoder]