Source code for immuneML.ml_methods.generative_models.KLEvaluator

import plotly.graph_objects as go
from functools import lru_cache

import numpy as np
import pandas as pd


[docs] def KL(sequences, model_1, model_2): """ Computes the KL divergence between two models (model_1 and model_2) for a given set of sequences. Args: sequences: list of sequences model_1: model 1 model_2: model 2 Returns: KL divergence value """ return np.mean(get_kl_weights(model_1, model_2, sequences))
[docs] def get_kl_weights(model_1, model_2, sequences): return model_1.log_prob(sequences) - model_2.log_prob(sequences)
[docs] def evaluate_similarities(true_sequences, simulated_sequences, estimator): true_model = estimator(true_sequences) simulated_model = estimator(simulated_sequences) return KL(true_sequences, true_model, simulated_model), \ KL(simulated_sequences, simulated_model, true_model)
[docs] class KLEvaluator: def __init__(self, true_sequences, simulated_sequences, estimator, n_sequences): self.true_sequences = true_sequences self.simulated_sequences = simulated_sequences self.true_model = estimator(true_sequences) self.simulated_model = estimator(simulated_sequences) self.estimator = estimator self._n_sequences = n_sequences
[docs] @lru_cache() def true_kl_weights(self): return get_kl_weights(self.true_model, self.simulated_model, self.true_sequences)
[docs] @lru_cache() def simulated_kl_weights(self): return get_kl_weights(self.simulated_model, self.true_model, self.simulated_sequences)
[docs] def true_kl(self): return np.mean(self.true_kl_weights())
[docs] def simulated_kl(self): return np.mean(self.simulated_kl_weights())
[docs] def get_worst_true_sequences(self, n=20): indices = np.argsort(self.true_kl_weights())[-n:][::-1] return pd.DataFrame({"sequence": self.true_sequences[indices].tolist(), "kl": self.true_kl_weights()[indices]})
# return self.true_sequences[indices]
[docs] def get_worst_simulated_sequences(self, n=20): indices = self._simulated_indices(n) return pd.DataFrame({"sequence": self.simulated_sequences[indices].tolist(), "kl": self.simulated_kl_weights()[indices]})
def _simulated_indices(self, n): return np.argsort(self.simulated_kl_weights())[-n:][::-1]
[docs] def simulated_plot(self): n_sequences = self._n_sequences indices = self._simulated_indices(n_sequences) kmers = self.simulated_sequences[indices] weights = self.simulated_model.kmer_model.log_prob(kmers) - self.true_model.kmer_model.log_prob(kmers) scores = self.simulated_kl_weights()[indices] fig = self.get_plot(indices, kmers, scores, weights) fig.update_layout(yaxis_title="Sequence", xaxis_title="Position", title="KL weights for the generated sequences that don't fit with the original model") return fig
[docs] def original_plot(self): n_sequences = self._n_sequences indices = np.argsort(self.true_kl_weights())[-n_sequences:][::-1] kmers = self.true_sequences[indices] weights = self.true_model.kmer_model.log_prob(kmers) - self.simulated_model.kmer_model.log_prob(kmers) scores = self.true_kl_weights()[indices] fig = self.get_plot(indices, kmers, scores, weights) fig.update_layout(yaxis_title="Sequence", xaxis_title="Position", title="KL weights for the original sequences that don't fit the generated model") # add label for colorbar #fig.update_layout(colorbar={"title": 'Your title'}) return fig
[docs] def get_plot(self, indices, kmers, scores, weights): n_sequences = len(indices) width = np.max(weights.shape[-1]) z = np.full((n_sequences, width), np.nan) text = [['' for _ in range(width)] for _ in range(n_sequences)] sequences = [] for kmers, weight, row, text_row, score in zip(kmers, weights, z, text, scores): row[:len(kmers)] = weight text_row[:len(kmers)] = [kmer.to_string() for kmer in kmers] seq = ''.join([kmers[0].to_string()[:-1]] + [kmer.to_string()[-1] for kmer in kmers]) sequences.append(f'{seq} : {score:.2f}') fig = go.Figure(data=go.Heatmap( y=sequences, z=z, text=text, texttemplate="%{text}", colorbar=dict(title="KL weights"), colorscale=["green", 'yellow', 'red'] )) fig.update_yaxes(autorange="reversed") return fig