import copy
from typing import List
import pandas as pd
from immuneML.IO.dataset_import.PickleImport import PickleImport
from immuneML.data_model.dataset.Dataset import Dataset
from immuneML.data_model.dataset.ReceptorDataset import ReceptorDataset
from immuneML.data_model.dataset.RepertoireDataset import RepertoireDataset
from immuneML.data_model.receptor.Receptor import Receptor
from immuneML.data_model.repertoire.Repertoire import Repertoire
from immuneML.simulation.SimulationState import SimulationState
from immuneML.util.FilenameHandler import FilenameHandler
from immuneML.util.PathBuilder import PathBuilder
from immuneML.workflows.steps.Step import Step
[docs]class SignalImplanter(Step):
DATASET_NAME = "simulated_dataset"
[docs] @staticmethod
def run(simulation_state: SimulationState = None):
path = simulation_state.result_path / FilenameHandler.get_dataset_name(SignalImplanter.__name__)
if path.is_file():
dataset = PickleImport.import_dataset({"path": path}, SignalImplanter.DATASET_NAME)
else:
dataset = SignalImplanter._implant_signals_in_dataset(simulation_state)
return dataset
@staticmethod
def _implant_signals_in_dataset(simulation_state: SimulationState = None) -> Dataset:
PathBuilder.build_from_objects(simulation_state.result_path)
if isinstance(simulation_state.dataset, RepertoireDataset):
dataset = SignalImplanter._implant_signals_in_repertoires(simulation_state)
else:
dataset = SignalImplanter._implant_signals_in_receptors(simulation_state)
return dataset
@staticmethod
def _implant_signals_in_receptors(simulation_state: SimulationState) -> Dataset:
processed_receptors = SignalImplanter._implant_signals(simulation_state, SignalImplanter._process_receptor)
processed_dataset = ReceptorDataset.build_from_objects(receptors=processed_receptors, file_size=simulation_state.dataset.file_size,
name=simulation_state.dataset.name, path=simulation_state.result_path)
processed_dataset.labels = {**(simulation_state.dataset.labels if simulation_state.dataset.labels is not None else {}),
**{signal: [True, False] for signal in simulation_state.signals}}
return processed_dataset
@staticmethod
def _implant_signals_in_repertoires(simulation_state: SimulationState = None) -> Dataset:
PathBuilder.build_from_objects(simulation_state.result_path / "repertoires")
processed_repertoires = SignalImplanter._implant_signals(simulation_state, SignalImplanter._process_repertoire)
processed_dataset = RepertoireDataset(repertoires=processed_repertoires, labels={**(simulation_state.dataset.labels if simulation_state.dataset.labels is not None else {}),
**{signal.id: [True, False] for signal in simulation_state.signals}},
name=simulation_state.dataset.name,
metadata_file=SignalImplanter._create_metadata_file(processed_repertoires, simulation_state))
return processed_dataset
@staticmethod
def _implant_signals(simulation_state: SimulationState, process_element_func):
processed_elements = []
simulation_limits = SignalImplanter._prepare_simulation_limits(simulation_state.simulation.implantings,
simulation_state.dataset.get_example_count())
current_implanting_index = 0
current_implanting = simulation_state.simulation.implantings[current_implanting_index]
for index, element in enumerate(simulation_state.dataset.get_data()):
if current_implanting is not None and index >= simulation_limits[current_implanting.name]:
current_implanting_index += 1
if current_implanting_index < len(simulation_limits.keys()):
current_implanting = simulation_state.simulation.implantings[current_implanting_index]
else:
current_implanting = None
processed_element = process_element_func(index, element, current_implanting, simulation_state)
processed_elements.append(processed_element)
return processed_elements
@staticmethod
def _process_receptor(index, receptor, implanting, simulation_state) -> Receptor:
if implanting is not None:
new_receptor = receptor
for signal in implanting.signals:
new_receptor = signal.implant_in_receptor(new_receptor, implanting.is_noise)
else:
new_receptor = receptor.clone()
for signal in simulation_state.signals:
if signal.id not in new_receptor.metadata:
new_receptor.metadata[signal.id] = False
return new_receptor
@staticmethod
def _process_repertoire(index, repertoire, current_implanting, simulation_state) -> Repertoire:
if current_implanting is not None:
return SignalImplanter._implant_in_repertoire(index, repertoire, current_implanting, simulation_state)
else:
new_repertoire = Repertoire.build_from_sequence_objects(repertoire.sequences, simulation_state.result_path / "repertoires",
repertoire.metadata)
for signal in simulation_state.signals:
new_repertoire.metadata[f"{signal.id}"] = False
return new_repertoire
@staticmethod
def _create_metadata_file(processed_repertoires: List[Repertoire], simulation_state) -> str:
path = simulation_state.result_path / "metadata.csv"
new_df = pd.DataFrame([repertoire.metadata for repertoire in processed_repertoires])
new_df.drop('field_list', axis=1, inplace=True)
new_df["filename"] = [repertoire.data_filename for repertoire in processed_repertoires]
new_df.to_csv(path, index=False)
return path
@staticmethod
def _implant_in_repertoire(index, repertoire, implanting, simulation_state) -> Repertoire:
new_repertoire = copy.deepcopy(repertoire)
for signal in implanting.signals:
new_repertoire = signal.implant_to_repertoire(repertoire=new_repertoire,
repertoire_implanting_rate=implanting.repertoire_implanting_rate,
path=simulation_state.result_path / "repertoires/")
for signal in implanting.signals:
if implanting.is_noise:
new_repertoire.metadata[f"{signal.id}"] = False
else:
new_repertoire.metadata[f"{signal.id}"] = True
for signal in simulation_state.signals:
if signal not in implanting.signals:
new_repertoire.metadata[f"{signal.id}"] = False
return new_repertoire
@staticmethod
def _prepare_simulation_limits(simulation: list, element_count: int) -> dict:
"""for each implanting returns the last index of the element in the dataset with that implanting scheme"""
limits = {item.name: int(item.dataset_implanting_rate * element_count) for item in simulation}
limits = {item_name: sum(list(limits.values())[:i+1]) for i, item_name in enumerate(limits.keys())}
return limits