from functools import partial
from pathlib import Path
import bionumpy as bnp
from .TrainGenModelReport import TrainGenModelReport
from ..ReportOutput import ReportOutput
from ..ReportResult import ReportResult
from ...data_model import bnp_util
from ...data_model.SequenceParams import RegionType
from ...data_model.datasets.Dataset import Dataset
from ...environment.SequenceType import SequenceType
from ...ml_methods.generative_models.GenerativeModel import GenerativeModel
from ...ml_methods.generative_models.KLEvaluator import evaluate_similarities, KLEvaluator
from ...ml_methods.generative_models.MultinomialKmerModel import estimate_kmer_model
from ...util.PathBuilder import PathBuilder
[docs]
class KLKmerComparison(TrainGenModelReport):
"""
Estimates the KL divergence between the kmer-distributions of the original and generated dataset, and makes a plots
that shows which sequences (and which kmers) contribute the most to the divergence.
**Specification arguments:**
- k (int): The kmer length to use for the KL divergence estimation. By default, k is set to 3.
- n_sequences (int): The number of sequences to make the plot from (the sequences that contribute the most to the
KL divergence). By default, n_sequences is set to 50.
**YAML specification:**
.. indent with spaces
.. code-block:: yaml
my_kl_report:
KLKmerComparison:
k: 3
n_sequences: 50
"""
[docs]
def __init__(self, original_dataset: Dataset = None, generated_dataset: Dataset = None, result_path: Path = None,
name: str = None, number_of_processes: int = 1, model: GenerativeModel = None, k: int = 3,
n_sequences: int = 50, sequence_type: SequenceType = SequenceType.AMINO_ACID,
region_type: RegionType = RegionType.IMGT_CDR3):
"""
The arguments defined below are set at runtime by the instruction.
Args:
original_dataset (Dataset): a dataset object (can be repertoire, receptor or sequence dataset, depending
on the specific report) provided as input to the TrainGenModel instruction
generated_dataset (Dataset): a dataset object as produced from the generative model after being trained on
the original dataset
result_path (Path): location where the results (plots, tables, etc.) will be stored
name (str): user-defined name of the report used in the HTML overview automatically generated by the
platform from the key used to define the report in the YAML
number_of_processes (int): how many processes should be created at once to speed up the analysis.
For personal machines, 4 or 8 is usually a good choice.
model (GenerativeModel): trained generative model from the instruction
"""
super().__init__(name=name, number_of_processes=number_of_processes)
self.original_dataset = original_dataset
self.generated_dataset = generated_dataset
self.model = model
self.result_path = result_path
self.k = k
self.n_sequences = n_sequences
self.region_type = region_type
self.sequence_type = sequence_type
[docs]
@staticmethod
def get_title():
return "KLTrainGenModel reports"
[docs]
@classmethod
def build_object(cls, **kwargs):
"""
Creates the object of the subclass of the Report class from the parameters so that it can be used in the analysis. Depending on the type of
the report, the parameters provided here will be provided in parsing time, while the other necessary parameters (e.g., subset of the data from
which the report should be created) will be provided at runtime. For more details, see specific direct subclasses of this class, describing
different types of reports.
Args:
**kwargs: keyword arguments that will be provided by users in the specification (if immuneML is used as a command line tool) or in the
dictionary when calling the method from the code, and which should be used to create the report object
Returns:
the object of the appropriate report class
"""
location = cls.__name__
return cls(**kwargs)
def _compute_kl_divergence(self):
"""
Computes the KL divergence between the original and generated dataset
Returns:
KL divergence value
"""
return self._get_kmer_kl_evaluator()
def _get_kmer_kl_evaluator(self):
o_kmers, g_kmers = (bnp.sequence.get_kmers(getattr(dataset.data,
bnp_util.get_sequence_field_name(self.region_type,
self.sequence_type)),
self.k)
for dataset in (self.original_dataset, self.generated_dataset))
estimator = partial(estimate_kmer_model, prior_count=1.0)
return KLEvaluator(o_kmers, g_kmers, estimator, n_sequences=self.n_sequences)
def _generate(self) -> ReportResult:
"""
The function that needs to be implemented by the Report subclasses which actually creates the report (figures, tables, text files), depending
on the specific aim of the report. After checking all prerequisites (e.g., if all parameters were set properly), generate_report() will call
this function and return its result.
Returns:
ReportResult object which encapsulates all outputs (figure, table, and text files) so that they can be conveniently linked to in the
final output of instructions
"""
PathBuilder.build(self.result_path)
evaluator = self._get_kmer_kl_evaluator()
tables = [self._write_output_table(evaluator.get_worst_true_sequences(),
self.result_path / "worst_true_sequences.tsv",
name="Original sequences that don't fit the generated model"),
self._write_output_table(evaluator.get_worst_simulated_sequences(),
self.result_path / "worst_simulated_sequences.tsv",
name="Generated sequences that don't fit with the original model")]
figures = []
figures.append(self._plot_simulated(evaluator=evaluator))
figures.append(self._plot_original(evaluator=evaluator))
info_text = '''Estimated KL divergence between the kmer distributions in the original and generated datasets. Toghether with the sequences that contribute the most to the divergence.
KL(original || generated) = {:.2f}, KL(generated || original) = {:.2f}'''.format(evaluator.true_kl(),
evaluator.simulated_kl())
return ReportResult(name=self.name,
info=info_text,
output_figures=figures,
output_tables=[table for table in tables if table is not None])
def _plot_simulated(self, evaluator: KLEvaluator):
file_path = self.result_path / "bad_simulated_sequences.html"
figure = evaluator.simulated_plot()
figure.write_html(str(file_path))
return ReportOutput(path=file_path, name="Generated Sequences that contributes most to the KL divergence")
def _plot_original(self, evaluator: KLEvaluator):
file_path = self.result_path / "bad_original_sequences.html"
figure = evaluator.original_plot()
figure.write_html(str(file_path))
return ReportOutput(path=file_path, name="Original Sequences that contributes most to the KL divergence")