Source code for immuneML.reports.gen_model_reports.VAESummary

import shutil
from pathlib import Path

import pandas as pd
import torch
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px
from sklearn.decomposition import PCA

from immuneML.data_model.datasets.Dataset import Dataset
from immuneML.ml_methods.generative_models.GenerativeModel import GenerativeModel
from immuneML.ml_methods.generative_models.SimpleVAE import SimpleVAE
from immuneML.reports.ReportOutput import ReportOutput
from immuneML.reports.ReportResult import ReportResult
from immuneML.reports.gen_model_reports.GenModelReport import GenModelReport
from immuneML.util.ParameterValidator import ParameterValidator
from immuneML.util.PathBuilder import PathBuilder


[docs] class VAESummary(GenModelReport): """ This report provides the summary of the train SimpleVAE and shows the following: - plots of the latent space after applying PCA to reduce the data to 2 dimensions, highlighted by V and J gene - plots the histogram for each latent dimension - plots loss per epoch **Specification arguments:** - dim_dist_cols (int): how many columns to use to plot the histograms of latent dimensions (either this or dim_dist_rows has to be set, or both) - dim_dist_rows (int): how many rows to use to plot the histogram of latent dimensions (either this or dim_dist_cols has to be set, or both) **YAML specification:** .. indent with spaces .. code-block:: yaml definitions: reports: my_vae_summary: VAESummary: dim_dist_cols: 4 dim_dist_rows: None """
[docs] @classmethod def build_object(cls, **kwargs): name = kwargs["name"] if "name" in kwargs else "VAESummary" ParameterValidator.assert_keys_present(list(kwargs.keys()), ['dim_dist_cols', 'dim_dist_rows'], VAESummary.__name__, 'parameters') assert ((kwargs['dim_dist_cols'] is not None and isinstance(kwargs['dim_dist_cols'], int)) or (kwargs['dim_dist_rows'] is not None and isinstance(kwargs['dim_dist_rows'], int))), \ f"{cls.__name__}: at least one of dim_dist_cols or dim_dist_rows has to be set." return VAESummary(name=name, dim_dist_cols=kwargs['dim_dist_cols'], dim_dist_rows=kwargs['dim_dist_rows'])
def __init__(self, dim_dist_rows: int, dim_dist_cols: int, dataset: Dataset = None, model: GenerativeModel = None, result_path: Path = None, name: str = None): super().__init__(dataset, model, result_path, name) self.dim_dist_rows = dim_dist_rows self.dim_dist_cols = dim_dist_cols def _generate(self) -> ReportResult: PathBuilder.build(self.result_path) latent_space_table_out = self._prepare_latent_space() train_progress_table_out = self._prepare_training_progress() latent_space_fig_out = self._safe_plot(output_written=True, plot_callable='_plot_latent_space', latent_space_table=latent_space_table_out) latent_dim_dist_fig_out = self._safe_plot(output_written=True, plot_callable='_plot_latent_dimension_distributions') train_progress_fig_out = self._safe_plot(output_written=True, plot_callable='_plot_training_progress', training_progress_table=train_progress_table_out) output_figures = [latent_space_fig_out, latent_dim_dist_fig_out, train_progress_fig_out] output_figures = [el for el in output_figures if el is not None] result = ReportResult(name=self.name, info='Summary of the fitted VAE model.', output_figures=output_figures, output_tables=[latent_space_table_out, train_progress_table_out]) return result def _prepare_latent_space(self) -> ReportOutput: data_loader = self.model.encode_dataset(self.dataset, self.dataset.get_example_count(), False) for data in data_loader: with torch.no_grad(): cdr3_input, v_gene_input, j_gene_input = data embeddings = self.model.model.encoding_func(cdr3_input, v_gene_input, j_gene_input) pca = PCA(n_components=2) embeddings = pca.fit_transform(embeddings.numpy()) df = pd.DataFrame(data=embeddings, columns=['PC1', 'PC2']) df['v_gene'] = [v_call.split("*")[0] for v_call in self.dataset.data.v_call.tolist()] df['j_gene'] = [j_call.split("*")[0] for j_call in self.dataset.data.j_call.tolist()] path = self.result_path / 'latent_space_2_component_PCA.csv' df.to_csv(str(path), index=False) return ReportOutput(path, f'principal component analysis on the data embedded into ' f'{self.model.latent_dim} dimensional space') def _prepare_training_progress(self) -> ReportOutput: path = self.result_path / 'training_losses.csv' shutil.copyfile(str(self.model.loss_path), str(path)) return ReportOutput(path, 'Loss per epoch') def _plot_latent_space(self, latent_space_table: ReportOutput) -> ReportOutput: encoded_data = pd.read_csv(latent_space_table.path) fig = make_subplots(rows=2, cols=1, shared_xaxes=True, shared_yaxes=True, subplot_titles=['V genes', 'J genes'], x_title='PC1', y_title='PC2') for gene in self.model.unique_v_genes: tmp_df = encoded_data[encoded_data['v_gene'] == gene] fig.add_trace( go.Scatter(x=tmp_df['PC1'].values, y=tmp_df['PC2'].values, mode='markers', opacity=0.5, name=gene, marker=dict(colorscale='Viridis', line_width=1), legendgroup=1, legendgrouptitle_text='V genes'), row=1, col=1) for gene in self.model.unique_j_genes: tmp_df = encoded_data[encoded_data['j_gene'] == gene] fig.add_trace( go.Scatter(x=tmp_df['PC1'].values, y=tmp_df['PC2'].values, mode='markers', opacity=0.5, name=gene, marker=dict(colorscale='Viridis', line_width=1), legendgroup=2, legendgrouptitle_text='J genes'), row=2, col=1) fig.update_layout(template='plotly_white', legend_tracegroupgap=30) fig.write_html(self.result_path / 'latent_space_PCA.html') return ReportOutput(self.result_path / 'latent_space_PCA.html', f'principal component analysis on the data embedded into {self.model.latent_dim} ' f'dimensional space') def _plot_latent_dimension_distributions(self) -> ReportOutput: data_loader = self.model.encode_dataset(self.dataset, self.dataset.get_example_count(), False) for data in data_loader: with torch.no_grad(): cdr3_input, v_gene_input, j_gene_input = data embeddings = self.model.model.encoding_func(cdr3_input, v_gene_input, j_gene_input) if self.dim_dist_rows is None: self.dim_dist_rows = int(self.model.latent_dim / self.dim_dist_cols) elif self.dim_dist_cols is None: self.dim_dist_cols = int(self.model.latent_dim / self.dim_dist_rows) assert self.dim_dist_rows * self.dim_dist_cols == self.model.latent_dim, \ (f"{VAESummary.__name__}: cannot plot latent dimension distribution since dim_dist_cols and dim_dist_rows " f"parameters do not match the latent dimension of the VAE.") fig = make_subplots(rows=self.dim_dist_rows, cols=self.dim_dist_cols, shared_xaxes=True, shared_yaxes=True) i = 0 for row in range(self.dim_dist_rows): for col in range(self.dim_dist_cols): fig.add_trace(go.Histogram(x=embeddings[:, i].numpy(), name=f'dim {i + 1}'), row=row + 1, col=col + 1) i += 1 fig.update_layout(template='plotly_white') fig.write_html(self.result_path / 'latent_dim_dist.html') return ReportOutput(path=self.result_path / 'latent_dim_dist.html', name=f'latent dimension distribution for examples from dataset {self.dataset.name}') def _plot_training_progress(self, training_progress_table: ReportOutput) -> ReportOutput: df = pd.read_csv(str(training_progress_table.path)) fig = px.line(df, x='epoch', y='loss', markers=True) fig.update_layout(template='plotly_white') fig.write_html(self.result_path / 'loss_per_epoch.html') return ReportOutput(path=self.result_path / 'loss_per_epoch.html', name='loss per epoch')
[docs] def check_prerequisites(self) -> bool: return isinstance(self.model, SimpleVAE)