Source code for immuneML.api.galaxy.build_train_gen_model_specs

import argparse
import sys
from pathlib import Path

from immuneML.api.galaxy.Util import Util
from immuneML.data_model.bnp_util import write_yaml
from immuneML.util.PathBuilder import PathBuilder


[docs] def parse_command_line_arguments(args): parser = argparse.ArgumentParser(description="Tool for building specification for training generative " "models in Galaxy") parser.add_argument("-e", "--gen_example_count", required=True, help="Number of examples to generate.") parser.add_argument("-c", "--chain_type", required=True, choices=["TRA", "TRB", "IGH", "IGK", "IGL", "TRD", "TRG"], help="Chain type of the sequences to generate.") parser.add_argument("-m", "--generative_method", choices=["SoNNia", "SimpleLSTM", "PWM", "SimpleVAE"], required=True, help="Which generative model should be trained.") parser.add_argument("-s", "--sequence_length_report", choices=["True", "False"], default="False", help="Whether to run the SequenceLengthDistribution report.") parser.add_argument("-q", "--amino_acid_report", choices=["True", "False"], default="False", help="Whether to run the AminoAcidFrequencyDistribution report.") parser.add_argument("-k", "--kl_gen_model_report", choices=["True", "False"], default="False", help="Whether to run the KLKmerComparison report.") parser.add_argument("-t", "--training_percentage", type=float, required=True, help="The percentage of data used for training.") parser.add_argument("-x", "--export_dataset_type", choices=["generated_dataset", "combined_dataset"], default="generated_dataset", help="Whether to export only the generated dataset, or the combined training+(test+)generated dataset.") parser.add_argument("-o", "--output_path", required=True, help="Output location for the generated yaml file (directory).") parser.add_argument("-f", "--file_name", default="specs.yaml", help="Output file name for the yaml file. Default name is 'specs.yaml' if not specified.") return parser.parse_args(args)
[docs] def build_specs(parsed_args): reports = {} if parsed_args.sequence_length_report == "True": reports['sequence_length_distribution'] = "SequenceLengthDistribution" if parsed_args.amino_acid_report == "True": reports['amino_acid_frequency_distribution'] = "AminoAcidFrequencyDistribution" if parsed_args.kl_gen_model_report == "True": reports['kl_kmer_comparison'] = "KLKmerComparison" gen_model_args = {"locus": parsed_args.chain_type} if parsed_args.generative_method == "SimpleVAE": reports['generative_model_overview'] = "VAESummary" gen_model_args['region_type'] = 'imgt_cdr3' elif parsed_args.generative_method == "PWM": reports['generative_model_overview'] = "PWMSummary" elif parsed_args.generative_method == "SimpleLSTM": gen_model_args['region_type'] = 'imgt_cdr3' if parsed_args.generative_method == "SoNNia": gen_model_args["default_model_name"] = f"human{parsed_args.chain_type}" specs = { "definitions": { "datasets": { "dataset": { "format": "AIRR", "params": {"dataset_file": Util.discover_dataset_path()} } }, "reports": reports, "ml_methods": { "generative_model": { parsed_args.generative_method: gen_model_args } } }, "instructions": { f"train_{parsed_args.generative_method}": { "type": "TrainGenModel", "dataset": "dataset", "method": "generative_model", "number_of_processes": 8, "gen_examples_count": int(parsed_args.gen_example_count), "training_percentage": float(parsed_args.training_percentage) / 100, "export_generated_dataset": True if parsed_args.export_dataset_type == "generated_dataset" else False, "export_combined_dataset": True if parsed_args.export_dataset_type == "combined_dataset" else False, "reports": list(reports.keys()), } } } return specs
[docs] def main(args): parsed_args = parse_command_line_arguments(args) specs = build_specs(parsed_args) PathBuilder.build(parsed_args.output_path) output_location = Path(parsed_args.output_path) / parsed_args.file_name write_yaml(output_location, specs) return str(output_location)
if __name__ == "__main__": main(sys.argv[1:])