[docs]classOneHotReceptorEncoder(OneHotEncoder):""" One-hot encoded repertoire data is represented in a matrix with dimensions: [receptors, chains, sequence_lengths, one_hot_characters] when use_positional_info is true, the last 3 indices in one_hot_characters represents the positional information: - start position (high when close to start) - middle position (high in the middle of the sequence) - end position (high when close to end) """def_encode_new_dataset(self,dataset,params:EncoderParams):encoded_data=self._encode_data(dataset,params)encoded_dataset=dataset.clone()encoded_dataset.encoded_data=encoded_datareturnencoded_datasetdef_encode_data(self,dataset:ReceptorDataset,params:EncoderParams):data=dataset.datareceptor_ids,chains,mask1,mask2=EncoderHelper.get_receptor_chain_masks(dataset)first_chain_seqs=data[mask1]second_chain_seqs=data[mask2]sequence_field=self._get_seq_field_name(params)max_seq_len=max(getattr(data,sequence_field).lengths)labels=self._get_labels(data,mask1,params)ifparams.encode_labelselseNoneexamples_first_chain=self._encode_sequence_list(first_chain_seqs,pad_n_sequences=len(data)//2,pad_sequence_len=max_seq_len,params=params)examples_second_chain=self._encode_sequence_list(second_chain_seqs,pad_n_sequences=len(data)//2,pad_sequence_len=max_seq_len,params=params)examples=np.stack((examples_first_chain,examples_second_chain),axis=1)feature_names=self._get_feature_names(max_seq_len,chains)ifself.flatten:examples=examples.reshape((len(data)//2,2*max_seq_len*len(self.onehot_dimensions)))feature_names=[itemforsublistinfeature_namesforsubsublistinsublistforiteminsubsublist]encoded_data=EncodedData(examples=examples,labels=labels,example_ids=receptor_ids,feature_names=feature_names,encoding=OneHotEncoder.__name__,info={"chain_names":chains})returnencoded_datadef_get_feature_names(self,max_seq_len,chains):return[[[f"{chain}_{pos}_{dim}"fordiminself.onehot_dimensions]forposinrange(max_seq_len)]forchaininchains]def_get_labels(self,data,mask1,params:EncoderParams):label_names=params.label_config.get_labels_by_name()df=data.topandas()return{name:df[name].values[mask1].tolist()fornameinlabel_names}