Source code for immuneML.reports.ml_reports.BinaryFeaturePrecisionRecall

import logging
from pathlib import Path

import numpy as np
import pandas as pd
import plotly.express as px

from sklearn.metrics import precision_score, recall_score, accuracy_score, balanced_accuracy_score

from immuneML.ml_methods.util.Util import Util
from immuneML.data_model.datasets.Dataset import Dataset
from immuneML.encodings.motif_encoding.MotifEncoder import MotifEncoder
from immuneML.hyperparameter_optimization.HPSetting import HPSetting
from immuneML.ml_methods.classifiers.BinaryFeatureClassifier import BinaryFeatureClassifier
from immuneML.reports.ReportOutput import ReportOutput
from immuneML.reports.ReportResult import ReportResult
from immuneML.reports.ml_reports.MLReport import MLReport
from immuneML.util.PathBuilder import PathBuilder




[docs] class BinaryFeaturePrecisionRecall(MLReport): """ Plots the precision and recall scores for each added feature to the collection of features selected by the BinaryFeatureClassifier. **YAML specification:** .. indent with spaces .. code-block:: yaml definitions: reports: my_report: BinaryFeaturePrecisionRecall """
[docs] @classmethod def build_object(cls, **kwargs): return BinaryFeaturePrecisionRecall(**kwargs)
def __init__(self, train_dataset: Dataset = None, test_dataset: Dataset = None, method: BinaryFeatureClassifier = None, result_path: Path = None, name: str = None, hp_setting: HPSetting = None, label=None, number_of_processes: int = 1): super().__init__(train_dataset=train_dataset, test_dataset=test_dataset, method=method, result_path=result_path, name=name, hp_setting=hp_setting, label=label, number_of_processes=number_of_processes) def _generate(self): PathBuilder.build(self.result_path) encoded_train_data, encoded_val_data = self._split_train_val_data(self.train_dataset.encoded_data) encoded_test_data = self.test_dataset.encoded_data plotting_data_train = self._compute_plotting_data(encoded_train_data) plotting_data_val = self._compute_plotting_data(encoded_val_data) plotting_data_test = self._compute_plotting_data(encoded_test_data) train_table = self._write_plotting_data(plotting_data_train, dataset_type="training") val_table = self._write_plotting_data(plotting_data_val, dataset_type="validation") test_table = self._write_plotting_data(plotting_data_test, dataset_type="test") # train_table = self._write_output_table(plotting_data_train, self.result_path / "training_performance.tsv", name="Training set performance of every subset of binary features") # test_table = self._write_output_table(plotting_data_test, self.result_path / "test_performance.tsv", name="Test set performance of every subset of binary features") train_fig = self._safe_plot(plotting_data=plotting_data_train, dataset_type="training") val_fig = self._safe_plot(plotting_data=plotting_data_val, dataset_type="validation") test_fig = self._safe_plot(plotting_data=plotting_data_test, dataset_type="test") return ReportResult(self.name, info="Precision and recall scores for each subset of learned binary motifs", output_tables=[table for table in [train_table, val_table, test_table] if table is not None], output_figures=[fig for fig in [train_fig, val_fig, test_fig] if fig is not None]) def _split_train_val_data(self, encoded_train_val_data): if self.method.train_indices and self.method.val_indices: encoded_train_data = Util.subset_encoded_data(encoded_train_val_data, self.method.train_indices) encoded_val_data = Util.subset_encoded_data(encoded_train_val_data, self.method.val_indices) else: encoded_train_data = encoded_train_val_data encoded_val_data = None return encoded_train_data, encoded_val_data def _compute_plotting_data(self, encoded_data): if encoded_data is None: return None rule_tree_indices = self.method.rule_tree_indices data = {"n_rules": [], "precision": [], "recall": [], "accuracy": [], "balanced_accuracy": []} y_true_bool = np.array([cls == self.label.positive_class for cls in encoded_data.labels[self.label.name]]) if self.method.keep_all: rules_range = range(len(rule_tree_indices), len(rule_tree_indices) + 1) else: rules_range = range(1, len(rule_tree_indices) + 1) for n_rules in rules_range: rule_subtree = rule_tree_indices[:n_rules] y_pred_bool = self.method._get_rule_tree_predictions_bool(encoded_data, rule_subtree) data["n_rules"].append(n_rules) data["precision"].append(precision_score(y_true_bool, y_pred_bool)) data["recall"].append(recall_score(y_true_bool, y_pred_bool)) data["accuracy"].append(accuracy_score(y_true_bool, y_pred_bool)) data["balanced_accuracy"].append(balanced_accuracy_score(y_true_bool, y_pred_bool)) return pd.DataFrame(data) def _write_plotting_data(self, plotting_data, dataset_type): if plotting_data is not None: return self._write_output_table(plotting_data, self.result_path / f"{dataset_type}_performance.tsv", name=f"{dataset_type.title()} set performance of every subset of binary features") def _plot(self, plotting_data, dataset_type): fig = px.line(plotting_data, x="recall", y="precision", range_x=[0, 1.01], range_y=[0, 1.01], template="plotly_white", hover_data=["n_rules"], color_discrete_sequence=px.colors.diverging.Tealrose, markers=True) fig.update_traces(marker={'size': 4}) file_path = self.result_path / f"{dataset_type}_precision_recall.html" fig.write_html(str(file_path)) return ReportOutput(path=file_path, name=f"Precision and recall scores on the {dataset_type} set for motif subsets")
[docs] def check_prerequisites(self): location = BinaryFeaturePrecisionRecall.__name__ run_report = True if not isinstance(self.method, BinaryFeatureClassifier): logging.warning(f"{location} report can only be created for {BinaryFeatureClassifier.__name__}, but got " f"{type(self.method).__name__} instead. {location} report will not be created.") run_report = False if self.train_dataset.encoded_data is None or self.train_dataset.encoded_data.examples is None or self.train_dataset.encoded_data.feature_names is None or self.train_dataset.encoded_data.encoding != MotifEncoder.__name__: logging.warning( f"{location}: this report can only be created for a dataset encoded with the {MotifEncoder.__name__}. Report {self.name} will not be created.") run_report = False if hasattr(self.method, "keep_all") and self.method.keep_all: logging.warning(f"{location}: keep_all was set to True for ML method {self.method.name}, only one data point will be plotted. ") return run_report