from pathlib import Path
from typing import List
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from immuneML.data_model.dataset.Dataset import Dataset
from immuneML.hyperparameter_optimization.HPSetting import HPSetting
from immuneML.ml_methods.MLMethod import MLMethod
from immuneML.reports.ReportOutput import ReportOutput
from immuneML.reports.ReportResult import ReportResult
from immuneML.reports.ml_reports.MLReport import MLReport
from immuneML.util.ParameterValidator import ParameterValidator
from immuneML.util.PathBuilder import PathBuilder
[docs]class ConfounderAnalysis(MLReport):
"""
A report that plots the numbers of false positives and false negatives with respect to each value of
the metadata features specified by the user. This allows checking whether a given machine learning model makes more
misclassifications for some values of a metadata feature than for the others.
Arguments:
metadata_labels (list): A list of the metadata features to use as a basis for the calculations
YAML specification:
.. indent with spaces
.. code-block:: yaml
my_confounder_report:
ConfounderAnalysis:
metadata_labels:
- age
- sex
"""
[docs] @classmethod
def build_object(cls, **kwargs):
ParameterValidator.assert_keys(kwargs.keys(), ['metadata_labels', 'name'], ConfounderAnalysis.__name__, ConfounderAnalysis.__name__)
ParameterValidator.assert_type_and_value(kwargs['metadata_labels'], list, ConfounderAnalysis.__name__, 'metadata_labels')
ParameterValidator.assert_all_type_and_value(kwargs['metadata_labels'], str, ConfounderAnalysis.__name__, 'metadata_labels')
ParameterValidator.assert_type_and_value(kwargs['name'], str, ConfounderAnalysis.__name__, 'name')
return ConfounderAnalysis(metadata_labels=kwargs['metadata_labels'], name=kwargs['name'])
def __init__(self, metadata_labels: List[str], train_dataset: Dataset = None, test_dataset: Dataset = None,
method: MLMethod = None,
result_path: Path = None, name: str = None, hp_setting: HPSetting = None, label=None):
super().__init__(train_dataset, test_dataset, method, result_path, name, hp_setting, label)
self.metadata_labels = metadata_labels
def _generate(self) -> ReportResult:
PathBuilder.build(self.result_path)
paths = []
# make predictions
predictions = self.method.predict(self.test_dataset.encoded_data, self.label)[self.label] # label = disease
true_labels = self.test_dataset.get_metadata(self.metadata_labels + [self.label])
metrics = ["FP", "FN"]
plot = make_subplots(rows=len(self.metadata_labels), cols=2)
listOfPlot = []
for label_index, meta_label in enumerate(self.metadata_labels):
csv_data = {}
for metric_index, metric in enumerate(metrics):
plotting_data = self._metrics(metric=metric, label=self.label, meta_label=meta_label,
predictions=predictions, true_labels=true_labels)
csv_data[f"{metric}"] = plotting_data[f"{metric}"]
plot.add_trace(go.Bar(x=plotting_data[meta_label], y=plotting_data[metric]), row=label_index + 1, col=metric_index + 1)
plot.update_xaxes(title_text=f"{meta_label}", row=label_index + 1, col=metric_index + 1, type='category')
plot.update_yaxes(title_text=f"{metric}", row=label_index + 1, col=metric_index + 1, rangemode="nonnegative", tick0=0, dtick=1)
csv_data[f"{meta_label}"] = plotting_data[f"{meta_label}"]
csv_data = pd.DataFrame(csv_data)
listOfPlot.append(csv_data)
plot.update_traces(marker_color=px.colors.sequential.Teal[3], showlegend=False)
filename = self.result_path / "plots.html"
plot.write_html(str(filename))
report_output_fig = ReportOutput(filename)
paths.append(report_output_fig)
result_table_path = self._write_results_table(listOfPlot, self.metadata_labels)
return ReportResult(name=self.name, output_figures=paths, output_tables=[ReportOutput(result_table_path[0])])
def _write_results_table(self, plotting_data, labels):
filepaths = []
for label_index, label in enumerate(labels):
filepath = self.result_path / f"{label}.csv"
plotting_data[label_index].to_csv(filepath, index=False)
filepaths.append(filepath)
return filepaths
@staticmethod
def _metrics(metric, label, meta_label, predictions, true_labels):
# indices of samples at which misclassification occurred
if metric == "FP":
metric_inds = np.nonzero(np.greater(predictions, true_labels[label]))[0].tolist()
else:
metric_inds = np.nonzero(np.less(predictions, true_labels[label]))[0].tolist()
metadata_values = true_labels[meta_label]
# indices of misclassification with respect to the metadata label
label_inds = np.array(metadata_values)[metric_inds]
metric_vals = []
unique_levels = np.unique(metadata_values)
# number of metric occurrences at each metadata level
for val in unique_levels:
metric_vals.append(np.count_nonzero(label_inds == val))
plotting_data = pd.DataFrame(
{f"{metric}": metric_vals, f"{meta_label}": unique_levels})
return plotting_data