Source code for immuneML.util.MotifPerformancePlotHelper

from scipy.stats import lognorm
import pandas as pd

import logging
import plotly.express as px
import plotly.graph_objects as go

from immuneML.encodings.motif_encoding.PositionalMotifHelper import PositionalMotifHelper
from immuneML.reports.ReportOutput import ReportOutput


[docs] class MotifPerformancePlotHelper():
[docs] @staticmethod def get_plotting_data(training_encoded_data, test_encoded_data, highlight_motifs_path=None, highlight_motifs_name="highlight"): training_feature_annotations = MotifPerformancePlotHelper._get_annotated_feature_annotations(training_encoded_data, highlight_motifs_path, highlight_motifs_name) test_feature_annotations = MotifPerformancePlotHelper._get_annotated_feature_annotations(test_encoded_data, highlight_motifs_path, highlight_motifs_name) training_feature_annotations["training_TP"] = training_feature_annotations["TP"] test_feature_annotations = MotifPerformancePlotHelper.merge_train_test_feature_annotations(training_feature_annotations, test_feature_annotations) return training_feature_annotations, test_feature_annotations
@staticmethod def _get_annotated_feature_annotations(encoded_data, highlight_motifs_path, highlight_motifs_name): feature_annotations = encoded_data.feature_annotations.copy() MotifPerformancePlotHelper._annotate_confusion_matrix(feature_annotations) MotifPerformancePlotHelper._annotate_highlight(feature_annotations, highlight_motifs_path, highlight_motifs_name) return feature_annotations @staticmethod def _annotate_confusion_matrix(feature_annotations): feature_annotations["precision"] = feature_annotations.apply( lambda row: 0 if row["TP"] == 0 else row["TP"] / (row["TP"] + row["FP"]), axis="columns") feature_annotations["recall"] = feature_annotations.apply( lambda row: 0 if row["TP"] == 0 else row["TP"] / (row["TP"] + row["FN"]), axis="columns") @staticmethod def _annotate_highlight(feature_annotations, highlight_motifs_path, highlight_motifs_name): feature_annotations["highlight"] = MotifPerformancePlotHelper._get_highlight(feature_annotations, highlight_motifs_path, highlight_motifs_name) @staticmethod def _get_highlight(feature_annotations, highlight_motifs_path, highlight_motifs_name): if highlight_motifs_path is not None: # highlight_motifs = [PositionalMotifHelper.motif_to_string(indices, amino_acids, motif_sep="-", newline=False) # for indices, amino_acids in PositionalMotifHelper.read_motifs_from_file(highlight_motifs_path)] highlight_motifs = PositionalMotifHelper.read_motifs_from_file(highlight_motifs_path) motifs = [PositionalMotifHelper.string_to_motif(motif, value_sep="&", motif_sep="-") for motif in feature_annotations["feature_names"]] return [highlight_motifs_name if MotifPerformancePlotHelper._is_highlight_motif(motif, highlight_motifs) else "Motif" for motif in motifs] else: return ["Motif"] * len(feature_annotations) @staticmethod def _is_highlight_motif(motif, highlight_motifs): for highlight_motif in highlight_motifs: if motif == highlight_motif: return True if len(motif[0]) > len(highlight_motif[0]): if MotifPerformancePlotHelper.is_sub_motif(highlight_motif, motif): return True return False
[docs] @staticmethod def is_sub_motif(short_motif, long_motif): assert len(long_motif[0]) > len(short_motif[0]) long_motif_dict = {long_motif[0][i]: long_motif[1][i] for i in range(len(long_motif[0]))} for idx, aa in zip(short_motif[0], short_motif[1]): if idx in long_motif_dict.keys(): if long_motif_dict[idx] != aa: return False else: return False return True
[docs] @staticmethod def merge_train_test_feature_annotations(training_feature_annotations, test_feature_annotations): training_info_to_merge = training_feature_annotations[["feature_names", "training_TP"]].copy() test_info_to_merge = test_feature_annotations.copy() merged_train_test_info = training_info_to_merge.merge(test_info_to_merge) return merged_train_test_info
[docs] @staticmethod def get_combined_precision(plotting_data, min_points_in_window, smoothing_constant1, smoothing_constant2): group_by_tp = plotting_data.groupby("training_TP") combined_precision = group_by_tp["TP"].sum() / (group_by_tp["TP"].sum() + group_by_tp["FP"].sum()) df = pd.DataFrame({"training_TP": list(combined_precision.index), "combined_precision": list(combined_precision)}) df["smooth_combined_precision"] = MotifPerformancePlotHelper._smooth_combined_precision(list(combined_precision.index), list(combined_precision), list(group_by_tp["TP"].count()), min_points_in_window, smoothing_constant1, smoothing_constant2) return df
@staticmethod def _smooth_combined_precision(x, y, weights, min_points_in_window, smoothing_constant1, smoothing_constant2): smoothed_y = [] for i in range(len(x)): scale = MotifPerformancePlotHelper._get_lognorm_scale(x, i, weights, min_points_in_window, smoothing_constant1, smoothing_constant2) lognorm_for_this_x = lognorm.pdf(x, s=0.1, loc=x[i] - scale, scale=scale) smoothed_y.append(sum(lognorm_for_this_x * y) / sum(lognorm_for_this_x)) return smoothed_y @staticmethod def _get_lognorm_scale(x, i, weights, min_points_in_window, smoothing_constant1, smoothing_constant2): window_size = MotifPerformancePlotHelper._determine_window_size(x, i, weights, min_points_in_window) return window_size * smoothing_constant1 + smoothing_constant2 @staticmethod def _determine_window_size(x, i, weights, min_points_in_window): x_rng = 0 n_data_points = weights[i] if sum(weights) < min_points_in_window: logging.warning(f"{MotifPerformancePlotHelper.__name__}: min_points_in_window ({min_points_in_window}) is smaller than the total number of points in the plot ({sum(weights)}). Setting min_points_in_window to {sum(weights)} instead...") min_points_in_window = sum(weights) else: min_points_in_window = min_points_in_window while n_data_points < min_points_in_window: x_rng += 1 to_select = [j for j in range(len(x)) if (x[i] - x_rng) <= x[j] <= (x[i] + x_rng)] lower_index = min(to_select) upper_index = max(to_select) n_data_points = sum(weights[lower_index:upper_index + 1]) return x_rng
[docs] @staticmethod def plot_precision_per_tp(file_path, plotting_data, combined_precision, dataset_type, training_set_name, tp_cutoff, motifs_name="motifs", highlight_motifs_name="highlight"): # fig = px.scatter(plotting_data, # y="precision", x="training_TP", hover_data=["feature_names"], # range_y=[0, 1.01], color_discrete_sequence=["#74C4C4"], # # stripmode="overlay", # log_x=True, # labels={ # "precision": f"Precision ({dataset_type})", # "feature_names": "Motif", # "training_TP": f"True positive predictions ({training_set_name})" # }, template="plotly_white") # make 'base figure' with 1 point fig = px.scatter(plotting_data, y=[0], x=[0], range_y=[-0.01, 1.01], log_x=True, template="plotly_white") # hide 'base figure' point fig.update_traces(marker=dict(size=12, opacity=0), selector=dict(mode='markers')) # add data points (needs to be separate trace to show up in legend) fig.add_trace(go.Scatter(x=plotting_data["training_TP"], y=plotting_data["precision"], mode='markers', name="Motif precision", marker=dict(symbol="circle", color="#74C4C4")), secondary_y=False) # add combined precision fig.add_trace(go.Scatter(x=combined_precision["training_TP"], y=combined_precision["combined_precision"], mode='markers+lines', name="Combined precision", marker=dict(symbol="diamond", color=px.colors.diverging.Tealrose[0])), secondary_y=False) # add highlighted motifs plotting_data_highlight = plotting_data[plotting_data["highlight"] != "Motif"] if len(plotting_data_highlight) > 0: fig.add_trace(go.Scatter(x=plotting_data_highlight["training_TP"], y=plotting_data_highlight["precision"], mode='markers', name=f"{highlight_motifs_name} precision", marker=dict(symbol="circle", color="#F5C144")), secondary_y=False) # add smoothed combined precision if "smooth_combined_precision" in combined_precision: fig.add_trace(go.Scatter(x=combined_precision["training_TP"], y=combined_precision["smooth_combined_precision"], marker=dict(color=px.colors.diverging.Tealrose[-1]), name="Combined precision, smoothed", mode="lines", line_shape='spline', line={'smoothing': 1.3}), secondary_y=False, ) # add vertical TP cutoff line if tp_cutoff is not None: if tp_cutoff == "auto": tp_cutoff = min(plotting_data["training_TP"]) fig.add_vline(x=tp_cutoff, line_dash="dash") tickvals = MotifPerformancePlotHelper._get_log_x_axis_ticks(plotting_data, tp_cutoff) fig.update_layout(xaxis=dict(tickvals=tickvals), xaxis_title=f"True positive predictions ({training_set_name})", yaxis_title=f"Precision ({dataset_type})", showlegend=True) fig.write_html(str(file_path)) return ReportOutput( path=file_path, name=f"Precision scores on the {dataset_type} for {motifs_name} found at each true positive count of the {training_set_name}.", )
@staticmethod def _get_log_x_axis_ticks(plotting_data, tp_cutoff): ticks = [] min_val, max_val = min(plotting_data["training_TP"]), max(plotting_data["training_TP"]) i = 1 while i < max_val: if i > min_val: ticks.append(i) i *= 10 ticks.append(min_val) ticks.append(max_val) if tp_cutoff is not None: ticks.append(tp_cutoff) return sorted(ticks)
[docs] @staticmethod def plot_precision_recall(file_path, plotting_data, min_recall=None, min_precision=None, dataset_type=None, motifs_name="motifs", highlight_motifs_name="highlight"): fig = px.scatter(plotting_data, y="precision", x="recall", hover_data=["feature_names"], range_x=[0, 1.01], range_y=[0, 1.01], color="highlight", color_discrete_map={"Motif": px.colors.qualitative.Pastel[0], highlight_motifs_name: px.colors.qualitative.Pastel[1]}, labels={ "precision": f"Precision ({dataset_type})", "recall": f"Recall ({dataset_type})", "feature_names": "Motif", }, template="plotly_white") if min_precision is not None and min_precision > 0: fig.add_hline(y=min_precision, line_dash="dash") if min_recall is not None and min_recall > 0: fig.add_vline(x=min_recall, line_dash="dash") fig.write_html(str(file_path)) return ReportOutput( path=file_path, name=f"Precision versus recall of significant {motifs_name} on the {dataset_type}", )
[docs] @staticmethod def write_output_tables(report_obj, training_plotting_data, test_plotting_data, training_combined_precision, test_combined_precision, motifs_name="motifs", file_suffix=""): results_table_name = f"Confusion matrix and precision/recall scores for significant {motifs_name}" + " on the {} set" combined_precision_table_name = f"Combined precision scores of {motifs_name}" + " on the {} set for each TP value on the " + str(report_obj.training_set_name) train_results_table = report_obj._write_output_table(training_plotting_data, report_obj.result_path / f"training_set_scores{file_suffix}.csv", results_table_name.format(report_obj.training_set_name)) test_results_table = report_obj._write_output_table(test_plotting_data, report_obj.result_path / f"test_set_scores{file_suffix}.csv", results_table_name.format(report_obj.test_set_name)) training_combined_precision_table = report_obj._write_output_table(training_combined_precision, report_obj.result_path / f"training_combined_precision{file_suffix}.csv", combined_precision_table_name.format(report_obj.training_set_name)) test_combined_precision_table = report_obj._write_output_table(test_combined_precision, report_obj.result_path / f"test_combined_precision{file_suffix}.csv", combined_precision_table_name.format(report_obj.test_set_name)) return [table for table in [train_results_table, test_results_table, training_combined_precision_table, test_combined_precision_table] if table is not None]
[docs] @staticmethod def write_plots(report_obj, training_plotting_data, test_plotting_data, training_combined_precision, test_combined_precision, training_tp_cutoff, test_tp_cutoff, motifs_name="motifs", file_suffix=""): training_tp_plot = report_obj._safe_plot(plot_callable="_plot_precision_per_tp", plotting_data=training_plotting_data, combined_precision=training_combined_precision, dataset_type=report_obj.training_set_name, file_path=report_obj.result_path / f"training_precision_per_tp{file_suffix}.html", motifs_name=motifs_name, tp_cutoff=training_tp_cutoff) test_tp_plot = report_obj._safe_plot(plot_callable="_plot_precision_per_tp", plotting_data=test_plotting_data, combined_precision=test_combined_precision, dataset_type=report_obj.test_set_name, file_path=report_obj.result_path / f"test_precision_per_tp{file_suffix}.html", motifs_name=motifs_name, tp_cutoff=test_tp_cutoff) training_pr_plot = report_obj._safe_plot(plot_callable="_plot_precision_recall", plotting_data=training_plotting_data, dataset_type=report_obj.training_set_name, file_path=report_obj.result_path / f"training_precision_recall{file_suffix}.html", motifs_name=motifs_name) test_pr_plot = report_obj._safe_plot(plot_callable="_plot_precision_recall", plotting_data=test_plotting_data, dataset_type=report_obj.test_set_name, file_path=report_obj.result_path / f"test_precision_recall{file_suffix}.html", motifs_name=motifs_name) return [plot for plot in [training_tp_plot, test_tp_plot, training_pr_plot, test_pr_plot] if plot is not None]