Source code for immuneML.reports.encoding_reports.MotifTestSetPerformance
import logging
import shutil
from pathlib import Path
from typing import Type
import numpy as np
from immuneML.IO.dataset_import.DataImport import DataImport
from immuneML.IO.dataset_import.DatasetImportParams import DatasetImportParams
from immuneML.data_model.datasets.Dataset import Dataset
from immuneML.dsl.DefaultParamsLoader import DefaultParamsLoader
from immuneML.dsl.import_parsers.ImportParser import ImportParser
from immuneML.encodings.EncoderParams import EncoderParams
from immuneML.encodings.motif_encoding.MotifEncoder import MotifEncoder
from immuneML.encodings.motif_encoding.PositionalMotifHelper import PositionalMotifHelper
from immuneML.environment.Label import Label
from immuneML.environment.LabelConfiguration import LabelConfiguration
from immuneML.reports.ReportResult import ReportResult
from immuneML.reports.encoding_reports.EncodingReport import EncodingReport
from immuneML.util.MotifPerformancePlotHelper import MotifPerformancePlotHelper
from immuneML.util.ParameterValidator import ParameterValidator
from immuneML.util.ReflectionHandler import ReflectionHandler
[docs]
class MotifTestSetPerformance(EncodingReport):
"""
This report can be used to show the performance of a learned set motifs using the :py:obj:`~immuneML.encodings.motif_encoding.MotifEncoder.MotifEncoder`
on an independent test set of unseen data.
It is recommended to first run the report :py:obj:`~immuneML.reports.data_reports.MotifGeneralizationAnalysis.MotifGeneralizationAnalysis`
in order to calibrate the optimal recall thresholds and plot the performance of motifs on training- and validation sets.
**Specification arguments:**
- test_dataset (dict): parameters for importing a SequenceDataset to use as an independent test set. By default,
the import parameters 'is_repertoire' and 'paired' will be set to False to ensure a SequenceDataset is imported.
**YAML specification:**
.. indent with spaces
.. code-block:: yaml
definitions:
reports:
my_motif_report:
MotifTestSetPerformance:
test_dataset:
format: AIRR # choose any valid import format
params:
path: path/to/files/
is_repertoire: False # is_repertoire must be False to import a SequenceDataset
paired: False # paired must be False to import a SequenceDataset
# optional other parameters...
"""
def __init__(self, dataset: Dataset = None, result_path: Path = None,
test_dataset_import_cls: Type[DataImport] = None,
test_dataset_import_params: DatasetImportParams = None,
training_set_name: str = None, test_set_name: str = None,
split_by_motif_size: bool = None,
highlight_motifs_path: str = None, highlight_motifs_name: str = None,
min_points_in_window: int = None,
smoothing_constant1: float = None, smoothing_constant2: float = None,
keep_test_dataset: bool = None,
number_of_processes: int = 1, name: str = None):
super().__init__(dataset=dataset, result_path=result_path, number_of_processes=number_of_processes, name=name)
self.test_dataset_import_cls = test_dataset_import_cls
self.test_dataset_import_params = test_dataset_import_params
self.keep_test_dataset = keep_test_dataset
self.split_by_motif_size = split_by_motif_size
self.training_set_name = training_set_name
self.test_set_name = test_set_name
self.highlight_motifs_path = highlight_motifs_path
self.highlight_motifs_name = highlight_motifs_name
self.min_points_in_window = min_points_in_window
self.smoothing_constant1 = smoothing_constant1
self.smoothing_constant2 = smoothing_constant2
[docs]
@classmethod
def build_object(cls, **kwargs):
location = MotifTestSetPerformance.__name__
import_cls, test_dataset_import_params = MotifTestSetPerformance._parse_dataset_params(kwargs)
kwargs["test_dataset_import_cls"] = import_cls
kwargs["test_dataset_import_params"] = test_dataset_import_params
del kwargs["test_dataset"]
ParameterValidator.assert_type_and_value(kwargs["split_by_motif_size"], bool, location, "split_by_motif_size")
ParameterValidator.assert_type_and_value(kwargs["training_set_name"], str, location, "training_set_name")
ParameterValidator.assert_type_and_value(kwargs["test_set_name"], str, location, "test_set_name")
ParameterValidator.assert_type_and_value(kwargs["keep_test_dataset"], bool, location, "keep_test_dataset")
ParameterValidator.assert_type_and_value(kwargs["min_points_in_window"], int, location, "min_points_in_window",
min_inclusive=1)
ParameterValidator.assert_type_and_value(kwargs["smoothing_constant1"], (int, float), location,
"smoothing_constant1", min_exclusive=0)
ParameterValidator.assert_type_and_value(kwargs["smoothing_constant2"], (int, float), location,
"smoothing_constant2", min_exclusive=0)
if "highlight_motifs_path" in kwargs and kwargs["highlight_motifs_path"] is not None:
PositionalMotifHelper.check_motif_filepath(kwargs["highlight_motifs_path"], location,
"highlight_motifs_path")
ParameterValidator.assert_type_and_value(kwargs["highlight_motifs_name"], str, location,
"highlight_motifs_name")
return MotifTestSetPerformance(**kwargs)
@staticmethod
def _parse_dataset_params(kwargs):
location = MotifTestSetPerformance.__name__
ParameterValidator.assert_type_and_value(kwargs["test_dataset"], dict, location, "test_dataset")
ParameterValidator.assert_keys_present(kwargs["test_dataset"].keys(), ["format", "params"], location,
"test_dataset")
ParameterValidator.assert_type_and_value(kwargs["test_dataset"]["format"], str, location, "test_dataset/format")
import_cls = ReflectionHandler.get_class_by_name("{}Import".format(kwargs["test_dataset"]["format"]))
default_params = DefaultParamsLoader.load(ImportParser.keyword, kwargs["test_dataset"]["format"])
params_dict = {**default_params, **kwargs["test_dataset"]["params"]}
test_dataset_import_params = DatasetImportParams.build_object(**params_dict)
if test_dataset_import_params.is_repertoire:
logging.warning(f"{location}: This report only allows the reference dataset to be of type SequenceDataset. "
"Setting 'test_dataset/params/is_repertoire' to False...")
test_dataset_import_params.is_repertoire = False
if test_dataset_import_params.paired:
logging.warning(f"{location}: This report only allows the reference dataset to be of type SequenceDataset. "
"Setting 'test_dataset/params/paired' to False...")
test_dataset_import_params.paired = False
return import_cls, test_dataset_import_params
def _generate(self) -> ReportResult:
test_dataset = self._get_test_dataset()
test_encoded_data = self._encode_test_data(test_dataset)
training_plotting_data, test_plotting_data = MotifPerformancePlotHelper.get_plotting_data(
self.dataset.encoded_data,
test_encoded_data.encoded_data,
self.highlight_motifs_path,
self.highlight_motifs_name)
training_plotting_data["motif_size"] = training_plotting_data["feature_names"].apply(
PositionalMotifHelper.get_motif_size)
test_plotting_data["motif_size"] = test_plotting_data["feature_names"].apply(
PositionalMotifHelper.get_motif_size)
output_tables, output_plots = self._get_report_outputs(training_plotting_data, test_plotting_data)
if not self.keep_test_dataset:
shutil.rmtree(self.test_dataset_import_params.result_path)
return ReportResult(name=self.name,
info="Performance of motifs on an independent test set",
output_figures=output_plots,
output_tables=output_tables)
def _get_report_outputs(self, training_plotting_data, test_plotting_data):
if self.split_by_motif_size:
return self._construct_and_plot_data_per_motif_size(training_plotting_data, test_plotting_data)
else:
return self._construct_and_plot_data(training_plotting_data, test_plotting_data)
def _construct_and_plot_data_per_motif_size(self, training_plotting_data, test_plotting_data):
output_tables, output_plots = [], []
for motif_size in sorted(set(training_plotting_data["motif_size"])):
sub_training_plotting_data = training_plotting_data[training_plotting_data["motif_size"] == motif_size]
sub_test_plotting_data = test_plotting_data[test_plotting_data["motif_size"] == motif_size]
sub_output_tables, sub_output_plots = self._construct_and_plot_data(sub_training_plotting_data,
sub_test_plotting_data,
motif_size=motif_size)
output_tables.extend(sub_output_tables)
output_plots.extend(sub_output_plots)
return output_tables, output_plots
def _construct_and_plot_data(self, training_plotting_data, test_plotting_data, motif_size=None):
training_combined_precision = self._get_combined_precision(training_plotting_data)
test_combined_precision = self._get_combined_precision(test_plotting_data)
motif_size_suffix = f"_motif_size={motif_size}" if motif_size is not None else ""
motifs_name = f"motifs of length {motif_size}" if motif_size is not None else "motifs"
output_tables = MotifPerformancePlotHelper.write_output_tables(self, training_plotting_data, test_plotting_data,
training_combined_precision,
test_combined_precision, motifs_name=motifs_name,
file_suffix=motif_size_suffix)
output_plots = MotifPerformancePlotHelper.write_plots(self, training_plotting_data, test_plotting_data,
training_combined_precision, test_combined_precision,
training_tp_cutoff="auto", test_tp_cutoff="auto",
motifs_name=motifs_name, file_suffix=motif_size_suffix)
return output_tables, output_plots
def _get_combined_precision(self, plotting_data):
return MotifPerformancePlotHelper.get_combined_precision(plotting_data,
min_points_in_window=self.min_points_in_window,
smoothing_constant1=self.smoothing_constant1,
smoothing_constant2=self.smoothing_constant2)
def _write_output_tables(self, training_plotting_data, test_plotting_data, training_combined_precision,
test_combined_precision, file_suffix=""):
results_table_name = "Confusion matrix and precision/recall scores for significant motifs on the {}"
combined_precision_table_name = "Combined precision scores of motifs on the {} for each TP value on the " + str(
self.training_set_name)
train_results_table = self._write_output_table(training_plotting_data,
self.result_path / f"training_set_scores{file_suffix}.csv",
results_table_name.format(self.training_set_name))
test_results_table = self._write_output_table(test_plotting_data,
self.result_path / f"test_set_scores{file_suffix}.csv",
results_table_name.format(self.test_set_name))
training_combined_precision_table = self._write_output_table(training_combined_precision,
self.result_path / f"training_combined_precision{file_suffix}.csv",
combined_precision_table_name.format(
self.training_set_name))
test_combined_precision_table = self._write_output_table(test_combined_precision,
self.result_path / f"test_combined_precision{file_suffix}.csv",
combined_precision_table_name.format(
self.test_set_name))
return [table for table in [train_results_table, test_results_table, training_combined_precision_table,
test_combined_precision_table] if table is not None]
def _plot_precision_per_tp(self, file_path, plotting_data, combined_precision, dataset_type, tp_cutoff,
motifs_name="motifs"):
return MotifPerformancePlotHelper.plot_precision_per_tp(file_path, plotting_data, combined_precision,
dataset_type,
training_set_name=self.training_set_name,
motifs_name=motifs_name,
tp_cutoff=tp_cutoff,
highlight_motifs_name=self.highlight_motifs_name)
def _plot_precision_recall(self, file_path, plotting_data, min_recall=None, min_precision=None, dataset_type=None,
motifs_name="motifs"):
return MotifPerformancePlotHelper.plot_precision_recall(file_path, plotting_data, min_recall=min_recall,
min_precision=min_precision,
dataset_type=dataset_type, motifs_name=motifs_name,
highlight_motifs_name=self.highlight_motifs_name)
def _encode_test_data(self, test_dataset):
encoder = self._get_encoder()
params = EncoderParams(result_path=self.result_path / "encoded_test_dataset",
label_config=self._get_label_config(),
pool_size=self.number_of_processes,
learn_model=False)
return encoder.encode(test_dataset, params)
def _get_encoder(self):
encoder = MotifEncoder(label=self._get_label_name(),
name=f"motif_encoder_{self.name}")
encoder.learned_motif_filepath = self.dataset.encoded_data.info["learned_motif_filepath"]
return encoder
def _get_test_dataset_y_true(self, test_dataset):
label_name = self._get_label_name()
positive_class = self._get_positive_class()
y_true = [sequence.get_attribute(label_name) == positive_class for sequence in test_dataset.get_data()]
return np.array(y_true)
def _get_motifs(self):
motif_names = self._get_motif_names()
return [PositionalMotifHelper.string_to_motif(name, "&", "-") for name in motif_names]
def _get_motif_names(self):
return list(self.dataset.encoded_data.feature_annotations.feature_names)
def _get_test_dataset(self):
self._set_result_path()
test_dataset = self._import_test_dataset()
self._check_test_dataset(test_dataset)
return test_dataset
def _check_test_dataset(self, test_dataset):
self._check_sequence_length(test_dataset)
self._check_dataset_label(test_dataset)
def _check_sequence_length(self, test_dataset):
legal_length = self._get_legal_sequence_length()
for sequence in test_dataset.get_data():
assert len(sequence.get_sequence()) == legal_length, \
(f"{MotifTestSetPerformance.__name__}: the length of the sequences in the test dataset is required "
f"to match the length of the original dataset ({legal_length}). Found sequence of length: "
f"{len(sequence.get_sequence())}")
def _get_legal_sequence_length(self):
return self.dataset.data.cdr3_aa.lengths[0]
def _check_dataset_label(self, test_dataset):
label_name = self._get_label_name()
label_values = set(self.dataset.encoded_data.labels[label_name])
assert label_name in list(test_dataset.dynamic_fields.keys()), \
f"{self.__class__.__name__}: Label {label_name} must be set for the test dataset."
test_dataset_label_values = set(test_dataset.get_metadata([label_name])[label_name])
assert label_values == test_dataset_label_values, \
(f"{self.__class__.__name__}: Label {label_name} with classes {', '.join(label_values)} must be set for "
f"the test dataset. Instead, it has the following classes: {', '.join(test_dataset_label_values)}.")
def _get_label_config(self):
return LabelConfiguration([self._get_label()])
def _get_label(self):
label_name = self._get_label_name()
label_values = list(set(self.dataset.encoded_data.labels[label_name]))
positive_class = self._get_positive_class()
return Label(name=label_name, values=label_values, positive_class=positive_class)
def _get_label_name(self):
return list(self.dataset.encoded_data.labels.keys())[0]
def _get_positive_class(self):
return self.dataset.encoded_data.info["positive_class"]
def _set_result_path(self):
self.test_dataset_import_params.result_path = self.result_path / f"test_dataset_{self.name}"
def _import_test_dataset(self):
return self.test_dataset_import_cls(self.test_dataset_import_params,
f"test_dataset_{self.name}").import_dataset()
[docs]
def check_prerequisites(self) -> bool:
location = MotifTestSetPerformance.__name__
if self.dataset.encoded_data is None or self.dataset.encoded_data.info is None:
logging.warning(f"{location}: the dataset is not encoded, skipping this report...")
return False
elif self.dataset.encoded_data.encoding != MotifEncoder.__name__:
logging.warning(
f"{location}: the dataset encoding ({self.dataset.encoded_data.encoding}) "
f"does not match the required encoding ({MotifEncoder.__name__}), skipping this report...")
return False
elif self.dataset.encoded_data.feature_annotations is None:
logging.warning(f"{location}: missing feature annotations for {MotifEncoder.__name__},"
f"skipping this report...")
return False
else:
return True