import logging
import warnings
from pathlib import Path
import pandas as pd
import plotly.express as px
from immuneML.data_model.datasets.Dataset import Dataset
from immuneML.reports.ReportOutput import ReportOutput
from immuneML.reports.encoding_reports.FeatureReport import FeatureReport
from immuneML.util.ParameterValidator import ParameterValidator
[docs]
class FeatureComparison(FeatureReport):
"""
Encoding a dataset results in a numeric matrix, where the rows are examples (e.g., sequences, receptors, repertoires)
and the columns are features. For example, when :ref:`KmerFrequency` encoder is used, the features are the
k-mers (AAA, AAC, etc..) and the feature values are the frequencies per k-mer.
This report separates the examples based on a binary metadata label, and plots the mean feature value
of each feature in one example group against the other example group (for example: plot the feature
value of 'sick' repertoires on the x axis, and 'healthy' repertoires on the y axis to spot consistent differences).
The plot can be separated into different colors or facets using other metadata labels
(for example: plot the average feature values of 'cohort1', 'cohort2' and 'cohort3' in different colors to spot biases).
Alternatively, when plotting features without comparing them across a binary label, see:
:py:obj:`~immuneML.reports.encoding_reports.FeatureValueBarplot.FeatureValueBarplot` report to plot
a simple bar chart per feature (average across examples).
Or :py:obj:`~immuneML.reports.encoding_reports.FeatureDistribution.FeatureDistribution` report to plot
the distribution of each feature across examples, rather than only showing the mean value in a bar plot.
Example output:
.. image:: _static/images/reports/feature_comparison_zoom.png
:alt: Feature comparison zoomed in plot with VLEQ highlighted
:width: 650
**Specification arguments:**
- comparison_label (str): Mandatory label. This label is used to split the encoded data matrix and define the x and y axes of the plot.
This label is only allowed to have 2 classes (for example: sick and healthy, binding and non-binding).
- color_grouping_label (str): Optional label that is used to color the points in the scatterplot. This can not be the same as comparison_label.
- row_grouping_label (str): Optional label that is used to group scatterplots into different row facets. This can not be the same as comparison_label.
- column_grouping_label (str): Optional label that is used to group scatterplots into different column facets. This can not be the same as comparison_label.
- show_error_bar (bool): Whether to show the error bar (standard deviation) for the points, both in the x and y dimension.
- log_scale (bool): Whether to plot the x and y axes in log10 scale (log_scale = True) or continuous scale (log_scale = False). By default, log_scale is False.
- keep_fraction (float): The total number of features may be very large and only the features differing significantly across
comparison labels may be of interest. When the keep_fraction parameter is set below 1, only the fraction of features that
differs the most across comparison labels is kept for plotting (note that the produced .csv file still contains all data).
By default, keep_fraction is 1, meaning that all features are plotted.
- opacity (float): a value between 0 and 1 setting the opacity for data points making it easier to see if there are overlapping points
**YAML specification:**
.. indent with spaces
.. code-block:: yaml
definitions:
reports:
my_comparison_report:
FeatureComparison: # compare the different classes defined in the label disease
comparison_label: disease
"""
[docs]
@classmethod
def build_object(cls, **kwargs):
comparison_label = kwargs["comparison_label"] if "comparison_label" in kwargs else None
color_grouping_label = kwargs["color_grouping_label"] if "color_grouping_label" in kwargs else None
row_grouping_label = kwargs["row_grouping_label"] if "row_grouping_label" in kwargs else None
column_grouping_label = kwargs["column_grouping_label"] if "column_grouping_label" in kwargs else None
log_scale = kwargs["log_scale"] if "log_scale" in kwargs else None
keep_fraction = float(kwargs["keep_fraction"]) if "keep_fraction" in kwargs else 1.0
ParameterValidator.assert_type_and_value(keep_fraction, float, "FeatureComparison", "keep_fraction", min_inclusive=0, max_inclusive=1)
ParameterValidator.assert_type_and_value(log_scale, bool, "FeatureComparison", "log_scale")
assert comparison_label is not None, "FeatureComparison: the parameter 'comparison_label' must be set in order to be able to compare across this label"
assert comparison_label != color_grouping_label, f"FeatureComparison: comparison label {comparison_label} can not be used as color_grouping_label"
assert comparison_label != row_grouping_label, f"FeatureComparison: comparison label {comparison_label} can not be used as row_grouping_label"
assert comparison_label != column_grouping_label, f"FeatureComparison: comparison label {comparison_label} can not be used as column_grouping_label"
return FeatureComparison(**kwargs)
def __init__(self, dataset: Dataset = None, result_path: Path = None, comparison_label: str = None,
color_grouping_label: str = None, row_grouping_label=None, column_grouping_label=None, opacity: float = 0.7,
show_error_bar=True, log_scale: bool = False, keep_fraction: int = 1, number_of_processes: int = 1, name: str = None):
super().__init__(dataset=dataset, result_path=result_path, color_grouping_label=color_grouping_label,
row_grouping_label=row_grouping_label, column_grouping_label=column_grouping_label,
number_of_processes=number_of_processes, name=name)
self.comparison_label = comparison_label
self.show_error_bar = show_error_bar
self.log_scale = log_scale
self.keep_fraction = keep_fraction
self.opacity = opacity
self.result_name = "feature_comparison"
self.name = name
def _generate(self):
result = self._generate_report_result()
result.info = "Compares the feature values in a given encoded data matrix across two values for a metadata label. Each point in the resulting scatterplot represents one feature, and the values on the x and y axes are the average feature values across examples of two different classes. "
return result
def _plot(self, data_long_format) -> ReportOutput:
groupby_cols = [self.comparison_label, self.x, self.color, self.facet_row, self.facet_column]
groupby_cols = [i for i in groupby_cols if i]
groupby_cols = list(set(groupby_cols))
plotting_data = data_long_format.groupby(groupby_cols, as_index=False).agg(
{"value": ['mean', self.std]})
plotting_data.columns = plotting_data.columns.map(''.join)
unique_label_values = plotting_data[self.comparison_label].unique()
assert len(
unique_label_values) == 2, f"FeatureComparison: comparison label {self.comparison_label} does not have 2 values; {unique_label_values}"
class_x, class_y = unique_label_values
merge_labels = [label for label in ["feature", self.color, self.facet_row, self.facet_column] if label]
plotting_data = pd.merge(plotting_data.loc[plotting_data[self.comparison_label] == class_x],
plotting_data.loc[plotting_data[self.comparison_label] == class_y],
on=merge_labels)
if plotting_data.shape[0] == 0:
logging.warning(f"{FeatureComparison.__name__}: there is no overlap between the (combination of) values of {merge_labels} for "
f"different values the {self.comparison_label}.")
plotting_data = self._filter_keep_fraction(plotting_data) if self.keep_fraction < 1 else plotting_data
error_x = "valuestd_x" if self.show_error_bar else None
error_y = "valuestd_y" if self.show_error_bar else None
figure = px.scatter(plotting_data, x="valuemean_x", y="valuemean_y", error_x=error_x, error_y=error_y,
color=self.color, facet_row=self.facet_row, facet_col=self.facet_column, hover_name="feature",
log_x=self.log_scale, log_y=self.log_scale, opacity=self.opacity,
labels={
"valuemean_x": f"Average feature values for {self.comparison_label} = {class_x}",
"valuemean_y": f"Average feature values for {self.comparison_label} = {class_y}",
}, template='plotly_white',
color_discrete_sequence=px.colors.diverging.Tealrose)
self.add_diagonal(figure)
file_path = self.result_path / f"{self.result_name}.html"
figure.write_html(str(file_path))
return ReportOutput(path=file_path, name=f"Comparison of feature values across {self.comparison_label}")
[docs]
def add_diagonal(self, figure):
figure.update_layout(shapes=[{'type': "line", 'line': dict(color="#B0C2C7", dash="dash"), 'yref': 'paper', 'xref': 'paper', 'y0': 0,
'y1': 1, 'x0': 0, 'x1': 1, 'layer': 'below'}])
def _filter_keep_fraction(self, plotting_data):
plotting_data["diff_xy"] = abs(plotting_data["valuemean_x"] - plotting_data["valuemean_y"])
plotting_data.sort_values(by="diff_xy", inplace=True, ascending=False)
plotting_data.drop(columns="diff_xy", inplace=True)
keep_nrows = round(plotting_data.shape[0] * self.keep_fraction)
return plotting_data.head(keep_nrows)
[docs]
def check_prerequisites(self):
location = self.__class__.__name__
run_report = True
if self.dataset.encoded_data is None or self.dataset.encoded_data.examples is None:
warnings.warn(
f"{location}: this report can only be created for an encoded dataset. {location} report will not be created.")
run_report = False
elif len(self.dataset.encoded_data.examples.shape) != 2:
warnings.warn(
f"{location}: this report can only be created for a 2-dimensional encoded dataset. {location} report will not be created.")
run_report = False
else:
legal_labels = list(self.dataset.get_label_names())
if self.comparison_label not in legal_labels:
warnings.warn(
f"{location}: comparison_label was not defined. {location} report will not be created.")
run_report = False
elif len(set(self.dataset.get_metadata([self.comparison_label])[self.comparison_label])) != 2:
warnings.warn(
f"{location}: comparison label {self.comparison_label} does not have 2 values: {set(self.dataset.get_metadata([self.comparison_label])[self.comparison_label])}. {location} report will not be created.")
run_report = False
else:
legal_labels.remove(self.comparison_label)
for label_param in [self.color, self.facet_row, self.facet_column]:
if label_param is not None:
if label_param == self.comparison_label:
warnings.warn(
f"{location}: comparison label '{self.comparison_label}' can not be used in other fields. {location} report will not be created.")
run_report = False
if label_param not in legal_labels:
warnings.warn(
f"{location}: undefined label '{label_param}'. Legal options are: {legal_labels}. {location} report will not be created.")
run_report = False
return run_report