import logging
import os
import pickle
from pathlib import Path
import numpy as np
from immuneML.data_model.datasets.RepertoireDataset import RepertoireDataset
from immuneML.encodings.abundance_encoding.AbundanceEncoderHelper import AbundanceEncoderHelper
from immuneML.pairwise_repertoire_comparison.ComparisonDataBatch import ComparisonDataBatch
from immuneML.util.Logger import log
from immuneML.util.PathBuilder import PathBuilder
[docs]
class ComparisonData:
@log
def __init__(self, repertoire_ids: list, comparison_attributes, sequence_batch_size: int = 10000, path: Path = None):
self.path = PathBuilder.build(path / "comparison_data")
self.sequence_batch_size = sequence_batch_size
self.item_count = 0
self.comparison_attributes = comparison_attributes
self.repertoire_ids = repertoire_ids
self.iteration_repertoire_ids = None
self.batches = []
self.tmp_batch_paths = []
def __len__(self):
return self.item_count
def __iter__(self):
return self.get_item_vectors(self.iteration_repertoire_ids)
[docs]
def set_iteration_repertoire_ids(self, iteration_repertoire_ids):
"""Defines the subset and order of repertoires to iterate over"""
self.iteration_repertoire_ids = iteration_repertoire_ids
[docs]
def get_item_names(self):
return np.array([item for items in [batch.get_items() for batch in self.batches] for item in items])
[docs]
def get_item_vectors(self, repertoire_ids: list = None):
for batch in self.get_batches(columns=repertoire_ids):
for item_index in range(batch.shape[0]):
yield batch[item_index]
[docs]
def get_repertoire_vectors(self, identifiers: list):
repertoire_vectors = {identifier: np.zeros(self.item_count) for identifier in identifiers}
for batch_index, batch in enumerate(self.get_batches(columns=identifiers, return_dict=True)):
start = batch_index * self.sequence_batch_size
for identifier in identifiers:
end = start + batch[identifier].shape[0]
repertoire_vectors[identifier][start: end] = batch[identifier]
return repertoire_vectors
[docs]
def get_repertoire_vector(self, identifier: str):
repertoire_vector = np.zeros(self.item_count)
for batch_index, batch in enumerate(self.get_batches(columns=[identifier])):
start = batch_index * self.sequence_batch_size
end = start + batch.shape[0]
repertoire_vector[start: end] = batch[:, 0]
return repertoire_vector
[docs]
def get_item_vector(self, index: int):
batch_index = int(index / self.sequence_batch_size)
index_in_batch = index - (batch_index * self.sequence_batch_size)
return self.batches[batch_index].get_matrix()[index_in_batch]
[docs]
def get_batches(self, columns: list = None, return_dict: bool = False):
for index in range(len(self.batches)):
yield self.get_batch(index, columns, return_dict)
[docs]
def get_batch(self, index: int, columns: list = None, return_dict: bool = False):
batch = self.batches[index].load()
if columns is not None and not return_dict:
column_indices = [self.batches[index].repertoire_index_mapping[col] for col in columns]
return batch.get_matrix()[:, column_indices]
elif columns is not None:
column_indices = [self.batches[index].repertoire_index_mapping[col] for col in columns]
matrix = batch.get_matrix()
return {col: matrix[:, column_indices[i]] for i, col in enumerate(columns)}
else:
return batch.get_matrix()
@log
def process_dataset(self, dataset: RepertoireDataset):
extract_fn = AbundanceEncoderHelper.get_matching_func_for_repertoire(self.comparison_attributes)
repertoire_count = dataset.get_example_count()
for index, repertoire in enumerate(dataset.get_data()):
self.process_repertoire(repertoire, str(repertoire.identifier), extract_fn)
logging.info("Repertoire {} ({}/{}) processed.".format(repertoire.identifier, index + 1, repertoire_count))
logging.info(f"Currently, there are {self.item_count} items in the comparison data matrix.")
self.merge_tmp_batches_to_matrix()
[docs]
def merge_tmp_batches_to_matrix(self):
for index in range(len(self.tmp_batch_paths)):
batch = self.load_tmp_batch(index)
matrix = np.zeros((self.sequence_batch_size, len(self.repertoire_ids)), order='F')
items = []
for item_index, item in enumerate(batch):
items.append(item)
for repertoire_index, repertoire_id in enumerate(self.repertoire_ids):
if repertoire_id in batch[item]:
matrix[item_index][repertoire_index] = batch[item][repertoire_id]
df = np.array(matrix[:len(items)], order='F')
repertoire_index_mapping = {rep_id: ind for ind, rep_id in enumerate(self.repertoire_ids)}
comp_data_batch = ComparisonDataBatch(matrix=df, items=items, repertoire_index_mapping=repertoire_index_mapping, path=self.path,
identifier=index)
comp_data_batch.store()
self.batches.append(comp_data_batch)
for path in self.tmp_batch_paths:
os.remove(path)
@log
def process_repertoire(self, repertoire, repertoire_id: str, extract_items_fn):
items = extract_items_fn(repertoire)
new_items = self.filter_existing_items(items, repertoire_id)
self.add_items_for_repertoire(new_items, repertoire_id)
[docs]
def filter_existing_items(self, items: list, repertoire_id: str) -> list:
new_items = items
for batch_index, batch in enumerate(self.get_tmp_batches()):
new_items = self._remove_existing_items(new_items, batch, batch_index, repertoire_id)
return new_items
def _match_items_to_batch(self, items, batch):
keep = set(items).difference(batch)
item_to_update = set(items).intersection(batch)
value = np.ones(len(item_to_update), dtype=np.bool_)
return keep, value, item_to_update
def _remove_existing_items(self, new_items: list, batch: dict, batch_index: int, repertoire_id: str) -> list:
update = {"value": [], "index": []}
new_items_to_keep, update["value"], update["index"] = self._match_items_to_batch(new_items, batch)
for index, item in enumerate(update["index"]):
batch[item][repertoire_id] = update["value"][index]
self.store_tmp_batch(batch, batch_index)
return list(new_items_to_keep)
[docs]
def store_tmp_batch(self, batch: dict, batch_index: int):
if len(self.tmp_batch_paths) > batch_index or len(self.tmp_batch_paths) == batch_index:
batch_path = self.path / f'tmp_batch_{batch_index}.pickle'
with batch_path.open('wb') as file:
pickle.dump(batch, file)
if len(self.tmp_batch_paths) == batch_index:
self.tmp_batch_paths.append(batch_path)
else:
raise KeyError("ComparisonData: batch_index: {} does not exist. tmp_batches length: {}"
.format(batch_index, len(self.tmp_batch_paths)))
[docs]
def get_tmp_batches(self):
for i in range(len(self.tmp_batch_paths)):
yield self.load_tmp_batch(i)
[docs]
def load_tmp_batch(self, batch_index: int) -> dict:
if len(self.tmp_batch_paths) > 0 and self.tmp_batch_paths[batch_index].is_file():
with self.tmp_batch_paths[batch_index].open("rb") as file:
batch = pickle.load(file)
else:
batch = {}
return batch
[docs]
def add_items_for_repertoire(self, items: list, repertoire_id: str):
last_batch_index = len(self.tmp_batch_paths)-1 if len(self.tmp_batch_paths) > 0 else 0
batch = self.load_tmp_batch(last_batch_index)
self.item_count += len(items)
item_index = 0
while len(batch) < self.sequence_batch_size and item_index < len(items):
batch[items[item_index]] = {repertoire_id: 1}
item_index += 1
self.store_tmp_batch(batch, last_batch_index)
items = items[item_index:]
while len(items) > 0:
batch = {item: {repertoire_id: 1} for item in items[:self.sequence_batch_size]}
last_batch_index += 1
self.store_tmp_batch(batch, last_batch_index)
items = items[self.sequence_batch_size:]