From c984e93060d9fd6e675b9117bed1584b45cea931 Mon Sep 17 00:00:00 2001 From: Esther Heid Date: Wed, 16 Feb 2022 16:20:58 +0100 Subject: [PATCH] bugfix keep extra columns argument --- README.md | 6 +++--- correct.py | 4 ++-- templatecorr/correct_templates.py | 7 ++++++- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 9d86b71..b34be0d 100644 --- a/README.md +++ b/README.md @@ -42,10 +42,10 @@ python correct.py --path data/uspto_50k where `--path` specifies the path to the reaction file without file ending. This will create the file data/uspto_50k_corrected.csv, which now contains an additional column `template`, holding the extracted, canonicalized and corrected template. The above command is the same as ``` -python correct.py --path data/uspto_50k --reaction_column rxn_smiles --name template --nproc 20 --drop_extra_cols --data_format csv +python correct.py --path data/uspto_50k --reaction_column rxn_smiles --name template --nproc 20 --data_format csv ``` -where `--reaction_column rxn_smiles` specifies the name of the column containing reaction SMILES, `--name template` sets the name of the column for the extracted templates in the output file (here to "template"), `--nproc 20` parallelizes the program over 20 processes, `--drop_extra_cols` causes additional helper columns during extraction (canonical reactant SMILES, templates at radius 0 and 1) to be dropped before saving the dataframe to file, and `--data_format csv` specifies the input format of the data, as well as the output format. +where `--reaction_column rxn_smiles` specifies the name of the column containing reaction SMILES, `--name template` sets the name of the column for the extracted templates in the output file (here to "template"), `--nproc 20` parallelizes the program over 20 processes, and `--data_format csv` specifies the input format of the data, as well as the output format. Per default, additional helper columns during extraction (canonical reactant SMILES, templates at radius 0 and 1) are dropped before saving the dataframe to file. To prevent this, use the flag `--keep_extra_cols`. ### Use to retrain a template relevance model @@ -71,7 +71,7 @@ AiZynthFinder template and policy model files are available in the folder `aizyn ### Contact -For questions, feedback, concerns or wishes, contact Esther at eheid@mit.edu. +For questions, feedback, concerns or wishes, contact Esther at eheid@mit.edu or raise an Issue on Github. ### Copyright diff --git a/correct.py b/correct.py index 4dd7715..8950d3a 100644 --- a/correct.py +++ b/correct.py @@ -7,7 +7,7 @@ def parse_arguments(): parser.add_argument('--reaction_column', dest='reaction_column', default='rxn_smiles') parser.add_argument('--name', dest='name', default='template') parser.add_argument('--nproc', dest='nproc', type=int, default=20) - parser.add_argument('--drop_extra_cols', dest='drop_extra_cols', action='store_false', default=True) + parser.add_argument('--keep_extra_cols', dest='keep_extra_cols', action='store_true', default=False) parser.add_argument('--data_format', dest='data_format', default='csv') return parser.parse_args() @@ -18,6 +18,6 @@ def parse_arguments(): reaction_column = args.reaction_column, name=args.name, nproc=args.nproc, - drop_extra_cols = args.drop_extra_cols, + drop_extra_cols = not args.keep_extra_cols, data_format=args.data_format, save=True) diff --git a/templatecorr/correct_templates.py b/templatecorr/correct_templates.py index eb992b5..7dce488 100644 --- a/templatecorr/correct_templates.py +++ b/templatecorr/correct_templates.py @@ -205,13 +205,18 @@ def templates_from_file(path, reaction_column = "rxn_smiles", name="template", n data = data.dropna(subset=[name,name+"_r0",name+"_r1"]) + data[name+'_uncorrected']=data[name] + print("Hierarchically correcting templates...") data[name+"_r1"] = correct_all_templates(data,name+"_r0",name+"_r1", nproc) data[name] = correct_all_templates(data,name+"_r1",name, nproc) if drop_extra_cols: - data = data.drop(columns=["canonical_reac_smiles", name+"_r0",name+"_r1"]) + data = data.drop(columns=["canonical_reac_smiles", name+"_r0",name+"_r1",name+"_uncorrected"]) + if "new_t" in data.keys(): + data = data.drop(columns=["new_t"]) + if save: if data_format == 'csv': data.to_csv(path+"_corrected."+data_format, index=False)