Source code for immuneML.reports.ml_reports.ConfusionMatrix

import logging
from pathlib import Path

import numpy as np
import pandas as pd
import plotly.graph_objects as go
import yaml
from plotly.subplots import make_subplots
from sklearn.metrics import confusion_matrix

from immuneML.data_model.datasets.Dataset import Dataset
from immuneML.ml_methods.classifiers.MLMethod import MLMethod
from immuneML.reports.PlotlyUtil import PlotlyUtil
from immuneML.reports.ReportOutput import ReportOutput
from immuneML.reports.ReportResult import ReportResult
from immuneML.reports.ml_reports.MLReport import MLReport
from immuneML.util.PathBuilder import PathBuilder


[docs] class ConfusionMatrix(MLReport): """ A report that plots the confusion matrix for a trained ML method. Supports both binary and multiclass classification. **Specification arguments:** - alternative_label (str): optionally, the confusion matrix can be split between different values of an alternative label. This may be useful to compare performance across different data subsets (e.g., batches, sources). If specified, separate confusion matrices will be generated for each value of the alternative label. Default is None. **YAML specification:** .. code-block:: yaml definitions: reports: my_conf_mat_report: ConfusionMatrix """
[docs] @classmethod def build_object(cls, **kwargs): return ConfusionMatrix(**kwargs)
def __init__(self, train_dataset: Dataset = None, test_dataset: Dataset = None, method: MLMethod = None, result_path: Path = None, name: str = None, hp_setting=None, label=None, number_of_processes: int = 1, alternative_label: str = None): super().__init__(train_dataset=train_dataset, test_dataset=test_dataset, method=method, result_path=result_path, name=name, hp_setting=hp_setting, label=label, number_of_processes=number_of_processes) self.alternative_label = alternative_label def _generate(self): PathBuilder.build(self.result_path) y_true = np.array(self.test_dataset.encoded_data.labels[self.label.name]) y_pred = self.method.predict(self.test_dataset.encoded_data, self.label)[self.label.name] labels = self.label.values cm = confusion_matrix(y_true, y_pred, labels=labels) cm_df = pd.DataFrame(cm, index=labels, columns=labels) heatmap_path = self._plot_confusion_matrix(cm_df) table_path = self.result_path / "confusion_matrix.csv" cm_df.rename_axis('true/predicted').reset_index().to_csv(table_path, index=False) output_figures = [ReportOutput(heatmap_path)] self._write_settings() if self.alternative_label is not None: try: alt_label_filename = self._generate_for_alternative_label(y_true, y_pred) alt_label_output = ReportOutput(alt_label_filename, f"Confusion matrix split by {self.alternative_label}") output_figures.append(alt_label_output) except Exception as e: logging.warning(f"Could not generate confusion matrix for alternative label " f"'{self.alternative_label}': {e}") return ReportResult(self.name, info=f"Confusion matrix for {type(self.method).__name__}", output_tables=[ReportOutput(table_path, "Confusion matrix CSV")], output_figures=output_figures) def _generate_for_alternative_label(self, y_true, y_pred): alt_labels = self.test_dataset.get_metadata([self.alternative_label], return_df=True) alt_label_values = sorted(alt_labels[self.alternative_label].unique()) fig = make_subplots(cols=2, rows=(len(alt_label_values) + 1) // 2, subplot_titles=alt_label_values, vertical_spacing=0.1, shared_xaxes=True, shared_yaxes=True, x_title="Predicted Label", y_title='True Label') subplot_index = 0 for alt_label_value in alt_label_values: indices = (alt_labels[self.alternative_label] == alt_label_value).values.astype(bool) y_true_subset = y_true[indices] y_pred_subset = y_pred[indices] labels = self.label.values cm = confusion_matrix(y_true_subset, y_pred_subset, labels=labels) cm_df = pd.DataFrame(cm, index=labels, columns=labels) fig.add_trace(go.Heatmap(z=cm_df.values, texttemplate="%{text}", text=cm_df.values, colorscale='Viridis', hovertemplate="True value: %{y}<br>Predicted value: %{x}" "<br>Count: %{z}<extra></extra>", showscale=False, x=[str(lbl) for lbl in cm_df.index.tolist()], y=[str(lbl) for lbl in cm_df.columns.tolist()]), row=(subplot_index // 2) + 1, col=(subplot_index % 2) + 1) subplot_index += 1 table_path = self.result_path / f"confusion_matrix_{alt_label_value}.csv" cm_df.rename_axis('true/predicted').reset_index().to_csv(table_path, index=False) fig.update_layout(title_text=f"Confusion matrix across {self.alternative_label} values", template="plotly_white") filename = self.result_path / f"confusion_matrix_{self.alternative_label}.html" PlotlyUtil.write_image_to_file(fig, filename, self.test_dataset.get_example_count()) return filename def _plot_confusion_matrix(self, cm_df: pd.DataFrame): fig = go.Figure(go.Heatmap(z=cm_df.values, texttemplate="%{text}", text=cm_df.values, colorscale='Viridis', hovertemplate="True value: %{y}<br>Predicted value: %{x}" "<br>Count: %{z}<extra></extra>", showscale=False, x=[str(lbl) for lbl in cm_df.index.tolist()], y=[str(lbl) for lbl in cm_df.columns.tolist()])) fig.update_layout(title_text="Confusion Matrix", xaxis_title="Predicted class", yaxis_title="True class", template="plotly_white") filename = self.result_path / "confusion_matrix.html" filename = PlotlyUtil.write_image_to_file(fig, filename) return filename def _write_settings(self): if self.hp_setting is not None: file_path = self.result_path / "settings.yaml" with file_path.open("w") as file: yaml.dump({"preprocessing": self.hp_setting.preproc_sequence_name, "encoder": self.hp_setting.encoder_name, "ml_method": self.hp_setting.ml_method_name}, file)
[docs] def check_prerequisites(self): if self.test_dataset is None or self.label is None: logging.warning("ConfusionMatrixReport requires a test dataset and a specified label.") return False return True