Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ python correct.py --path data/uspto_50k
where `--path` specifies the path to the reaction file without file ending. This will create the file data/uspto_50k_corrected.csv, which now contains an additional column `template`, holding the extracted, canonicalized and corrected template. The above command is the same as

```
python correct.py --path data/uspto_50k --reaction_column rxn_smiles --name template --nproc 20 --drop_extra_cols --data_format csv
python correct.py --path data/uspto_50k --reaction_column rxn_smiles --name template --nproc 20 --data_format csv
```

where `--reaction_column rxn_smiles` specifies the name of the column containing reaction SMILES, `--name template` sets the name of the column for the extracted templates in the output file (here to "template"), `--nproc 20` parallelizes the program over 20 processes, `--drop_extra_cols` causes additional helper columns during extraction (canonical reactant SMILES, templates at radius 0 and 1) to be dropped before saving the dataframe to file, and `--data_format csv` specifies the input format of the data, as well as the output format.
where `--reaction_column rxn_smiles` specifies the name of the column containing reaction SMILES, `--name template` sets the name of the column for the extracted templates in the output file (here to "template"), `--nproc 20` parallelizes the program over 20 processes, and `--data_format csv` specifies the input format of the data, as well as the output format. Per default, additional helper columns during extraction (canonical reactant SMILES, templates at radius 0 and 1) are dropped before saving the dataframe to file. To prevent this, use the flag `--keep_extra_cols`.

### Use to retrain a template relevance model

Expand All @@ -71,7 +71,7 @@ AiZynthFinder template and policy model files are available in the folder `aizyn

### Contact

For questions, feedback, concerns or wishes, contact Esther at eheid@mit.edu.
For questions, feedback, concerns or wishes, contact Esther at eheid@mit.edu or raise an Issue on Github.

### Copyright

Expand Down
4 changes: 2 additions & 2 deletions correct.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def parse_arguments():
parser.add_argument('--reaction_column', dest='reaction_column', default='rxn_smiles')
parser.add_argument('--name', dest='name', default='template')
parser.add_argument('--nproc', dest='nproc', type=int, default=20)
parser.add_argument('--drop_extra_cols', dest='drop_extra_cols', action='store_false', default=True)
parser.add_argument('--keep_extra_cols', dest='keep_extra_cols', action='store_true', default=False)
parser.add_argument('--data_format', dest='data_format', default='csv')

return parser.parse_args()
Expand All @@ -18,6 +18,6 @@ def parse_arguments():
reaction_column = args.reaction_column,
name=args.name,
nproc=args.nproc,
drop_extra_cols = args.drop_extra_cols,
drop_extra_cols = not args.keep_extra_cols,
data_format=args.data_format,
save=True)
7 changes: 6 additions & 1 deletion templatecorr/correct_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,18 @@ def templates_from_file(path, reaction_column = "rxn_smiles", name="template", n

data = data.dropna(subset=[name,name+"_r0",name+"_r1"])

data[name+'_uncorrected']=data[name]

print("Hierarchically correcting templates...")
data[name+"_r1"] = correct_all_templates(data,name+"_r0",name+"_r1", nproc)
data[name] = correct_all_templates(data,name+"_r1",name, nproc)

if drop_extra_cols:
data = data.drop(columns=["canonical_reac_smiles", name+"_r0",name+"_r1"])
data = data.drop(columns=["canonical_reac_smiles", name+"_r0",name+"_r1",name+"_uncorrected"])

if "new_t" in data.keys():
data = data.drop(columns=["new_t"])

if save:
if data_format == 'csv':
data.to_csv(path+"_corrected."+data_format, index=False)
Expand Down