From 3a3fcb180b079e069c5d64c4bb66a010312779fa Mon Sep 17 00:00:00 2001 From: Nathan Morgan Date: Thu, 24 Oct 2024 14:54:11 -0400 Subject: [PATCH] add separate test path option --- chemprop/cli/train.py | 45 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 42 insertions(+), 3 deletions(-) diff --git a/chemprop/cli/train.py b/chemprop/cli/train.py index a27aff11b..9b5f856cb 100644 --- a/chemprop/cli/train.py +++ b/chemprop/cli/train.py @@ -469,6 +469,18 @@ def add_train_args(parser: ArgumentParser) -> ArgumentParser: help="Seed for PyTorch randomness (e.g., random initial weights)", ) + test_args = parser.add_argument_group("separate test args") + test_args.add_argument("--separate-test-path", type=Path) + test_args.add_argument("--separate-test-descriptors-path", type=Path) + test_args.add_argument("--separate-test-atom-features-path", type=Path) + test_args.add_argument("--separate-test-bond-features-path", type=Path) + test_args.add_argument("--separate-test-atom-descriptors-path", type=Path) + test_args.add_argument("--separate-test-no-header-row", action="store_true") + test_args.add_argument("--separate-test-smiles-columns", nargs="+") + test_args.add_argument("--separate-test-reaction-columns", nargs="+") + test_args.add_argument("--separate-test-target-columns", nargs="+") + test_args.add_argument("--separate-test-ignore-columns", nargs="+") + return parser @@ -772,13 +784,40 @@ def build_splits(args, format_kwargs, featurization_kwargs): splitting_mols = [datapoint.rct for datapoint in splitting_data] else: splitting_mols = [datapoint.mol for datapoint in splitting_data] - train_indices, val_indices, test_indices = make_split_indices( - splitting_mols, args.split, args.split_sizes, args.data_seed, args.num_replicates - ) + + if args.split_sizes[2] == 0: + split_sizes = (args.split_sizes[0], 0.0, args.split_sizes[1]) + train_indices, test_indices, val_indices = make_split_indices( + splitting_mols, args.split, split_sizes, args.data_seed, args.num_replicates + ) + else: + train_indices, val_indices, test_indices = make_split_indices( + splitting_mols, args.split, args.split_sizes, args.data_seed, args.num_replicates + ) train_data, val_data, test_data = split_data_by_indices( all_data, train_indices, val_indices, test_indices ) + + if args.separate_test_path: + test_data = build_data_from_files( + p_data=args.separate_test_path, + no_header_row=args.separate_test_no_header_row or args.no_header_row, + smiles_cols=args.separate_test_smiles_columns or args.smiles_columns, + rxn_cols=args.separate_test_reaction_columns or args.reaction_columns, + target_cols=args.separate_test_target_columns or args.target_columns, + ignore_cols=args.separate_test_ignore_columns or args.ignore_columns, + splits_col=None, + weight_col=None, + bounded=args.metrics and any(["bounded" in metrics for metrics in args.metrics]), + p_descriptors=args.separate_test_descriptors_path, + p_atom_feats=args.separate_test_atom_features_path, + p_bond_feats=args.separate_test_bond_features_path, + p_atom_descs=args.separate_test_atom_descriptors_path, + **featurization_kwargs, + ) + test_data = [test_data] * len(train_data) + for i_split in range(len(train_data)): sizes = [len(train_data[i_split][0]), len(val_data[i_split][0]), len(test_data[i_split][0])] logger.info(f"train/val/test split_{i_split} sizes: {sizes}")