import logging
from pathlib import Path
from typing import List
import numpy as np
import pandas as pd
import plotly.express as px
from immuneML.data_model.EncodedData import EncodedData
from immuneML.data_model.datasets.Dataset import Dataset
from immuneML.dsl.definition_parsers.MLParser import MLParser
from immuneML.ml_methods.dim_reduction.DimRedMethod import DimRedMethod
from immuneML.reports.PlotlyUtil import PlotlyUtil
from immuneML.reports.ReportOutput import ReportOutput
from immuneML.reports.ReportResult import ReportResult
from immuneML.reports.encoding_reports.EncodingReport import EncodingReport
from immuneML.util.ParameterValidator import ParameterValidator
from immuneML.util.PathBuilder import PathBuilder
[docs]
class DimensionalityReduction(EncodingReport):
"""
This report visualizes the data obtained by dimensionality reduction. The data points can be highlighted by label of
interest. It is also possible to specify labels that contain lists of values (e.g., HLA), in which case the data points
will be duplicated (so that each point refers to one HLA allele) and jittered slightly to improve visibility
before being highlighted by the concrete HLA allele values.
**Specification arguments:**
- labels (list): names of the label to use for highlighting data points; or None
- dim_red_method (str): dimensionality reduction method to be used for plotting; if set, in a workflow, this
dimensionality reduction will be used for plotting instead of any other set in the workflow; if None, it will
visualize the encoded data of reduced dimensionality if set
**YAML specification:**
.. indent with spaces
.. code-block:: yaml
definitions:
reports:
rep1:
DimensionalityReduction:
labels: [epitope, source]
dim_red_method:
PCA:
n_components: 2
"""
[docs]
@classmethod
def build_object(cls, **kwargs):
if "dim_red_method" in kwargs and kwargs['dim_red_method'] and kwargs['dim_red_method'] != 'None':
cls_name = list(kwargs['dim_red_method'].keys())[0]
method = MLParser.parse_any_model("dim_red_method", kwargs['dim_red_method'], cls_name)[0]
else:
method = None
location = f"DimensionalityReduction ({kwargs['name'] if 'name' in kwargs else ''})"
# backwards compatibility: to be removed from next major version
if "label" in kwargs:
ParameterValidator.warn_deprecated_parameter("label", "labels", location)
ParameterValidator.assert_type_and_value(kwargs["label"], str, location, "label")
labels = [kwargs["label"]]
del kwargs["label"]
else:
ParameterValidator.assert_type_and_value(kwargs["labels"], list, location, "labels")
labels = kwargs["labels"]
ParameterValidator.assert_all_type_and_value(labels, str, location, "labels")
return DimensionalityReduction(**{**kwargs, "dim_red_method": method, 'labels': labels})
def __init__(self, dataset: Dataset = None, batch_size: int = 1, result_path: Path = None,
name: str = None, labels: list = None, dim_red_method: DimRedMethod = None):
super().__init__(dataset=dataset, result_path=result_path, name=name)
self._labels = labels
self._dim_red_method = dim_red_method
self._dimension_names = ['dimension_1', 'dimension_2'] if self._dim_red_method is None \
else self._dim_red_method.get_dimension_names()
self.info = ("This report visualizes the encoded data after applying dimensionality reduction dim_red,"
" optionally colored by labels of interest.")
[docs]
def check_prerequisites(self):
return (isinstance(self.dataset.encoded_data, EncodedData) and
(self.dataset.encoded_data.dimensionality_reduced_data is not None or self._dim_red_method is not None))
def _generate(self) -> ReportResult:
if self._dim_red_method:
assert self.dataset.encoded_data.examples is not None, \
f"{DimensionalityReduction.__name__}: data not encoded, report will not be made."
dim_reduced_data = self._dim_red_method.fit_transform(self.dataset)
else:
assert self.dataset.encoded_data.dimensionality_reduced_data is not None
dim_reduced_data = self.dataset.encoded_data.dimensionality_reduced_data
assert dim_reduced_data.shape[1] == 2, \
(f"{DimensionalityReduction.__name__}: {self.name}: dimensionality reduced data is not 2d (got: "
f"{dim_reduced_data.shape}, so it cannot be plotted.")
data_labels = None
try:
data_labels = self.dataset.get_metadata(self._labels, return_df=True)[self._labels]
except (AttributeError, TypeError) as e:
logging.warning(f"Labels {self._labels} not found in the dataset. Skipping label coloring in the plot.")
PathBuilder.build(self.result_path)
df = pd.DataFrame({'example_id': self.dataset.get_example_ids(),
self._dimension_names[0]: dim_reduced_data[:, 0],
self._dimension_names[1]: dim_reduced_data[:, 1]})
if hasattr(self.dataset, 'get_metadata_fields') and 'subject_id' in self.dataset.get_metadata_fields():
df['subject_id'] = self.dataset.get_metadata(['subject_id'], return_df=True)['subject_id']
if self._labels:
df[self._labels] = data_labels
df.to_csv(self.result_path / 'dimensionality_reduced_data.csv', index=False)
report_output_figures = self._safe_plot(df=df, output_written=True)
dim_red_text = f" ({self._dim_red_method.__class__.__name__})" if self._dim_red_method else ""
return ReportResult(name=self.name, info=self.info.replace(" dim_red", dim_red_text),
output_figures=report_output_figures,
output_tables=[ReportOutput(self.result_path / 'dimensionality_reduced_data.csv',
'data after dimensionality reduction')])
def _plot(self, df: pd.DataFrame) -> List[ReportOutput]:
PathBuilder.build(self.result_path)
outputs = []
if self._labels:
for label in self._labels:
df_copy = self._parse_labels_with_lists(df, label)
unique_values = df_copy[label].unique()
hover_data = self._dimension_names + self._labels
if 'subject_id' in df_copy.columns:
hover_data += ['subject_id']
elif 'example_id' in df_copy.columns:
hover_data += ['example_id']
if len(unique_values) <= 15:
df_copy[label] = df_copy[label].astype('category')
figure = px.scatter(df_copy, x=self._dimension_names[0], y=self._dimension_names[1], color=label,
color_discrete_sequence=px.colors.qualitative.Vivid,
hover_data=hover_data,
category_orders={label: sorted(unique_values)})
else:
figure = px.scatter(df_copy, x=self._dimension_names[0], y=self._dimension_names[1], color=label,
hover_data=hover_data)
figure.update_layout(template="plotly_white", showlegend=True)
figure.update_traces(opacity=.6)
file_path = self.result_path / f"dimensionality_reduction_{label}.html"
file_path = PlotlyUtil.write_image_to_file(figure, file_path)
outputs.append(ReportOutput(path=file_path,
name="Data visualization after dimensionality reduction "
"(highlighted by {})".format(label)))
else:
# No label case - just plot points
figure = px.scatter(df, x=self._dimension_names[0], y=self._dimension_names[1])
figure.update_layout(template="plotly_white")
figure.update_traces(opacity=.6)
file_path = self.result_path / "dimensionality_reduction.html"
file_path = PlotlyUtil.write_image_to_file(figure, file_path)
outputs.append(ReportOutput(path=file_path, name="Data visualization after dimensionality reduction"))
return outputs
def _parse_labels_with_lists(self, df: pd.DataFrame, label: str) -> pd.DataFrame:
df_long = df.copy()
df_long[label] = df_long[label].apply(parse_list_column)
if isinstance(df_long[label].iloc[0], (list, tuple)):
df_long = df_long.explode(label)
# Compute jitter based on the range of each axis
x_range = df_long[self._dimension_names[0]].max() - df_long[self._dimension_names[0]].min()
y_range = df_long[self._dimension_names[1]].max() - df_long[self._dimension_names[1]].min()
jitter_strength = 0.005 * min(x_range, y_range)
# Apply jitter
df_long[self._dimension_names[0]] += np.random.uniform(-jitter_strength, jitter_strength, size=len(df_long))
df_long[self._dimension_names[1]] += np.random.uniform(-jitter_strength, jitter_strength, size=len(df_long))
return df_long
[docs]
def parse_list_column(value):
"""Parses a string representation of a list or tuple into an actual list."""
if not value or pd.isna(value):
return 'unknown'
if isinstance(value, str):
value = value.strip()
if (value.startswith('"') and value.endswith('"')) or (value.startswith('\'') and value.endswith('\'')):
value = value[1:-1]
items = [item.strip().replace('\'', '') for item in value.split(',') if item.strip()]
return items
return value