From 4610d2aca6513aa9bbf5159e2295ecd0dab47663 Mon Sep 17 00:00:00 2001 From: lpardey Date: Tue, 16 Jul 2024 21:58:56 +0000 Subject: [PATCH 01/18] update filenames --- rdock-utils/rdock_utils/{rbhtfinder => rbhtfinder_original} | 0 .../tests/fixtures/rbhtfinder/{rbhtfinder_input.txt => input.txt} | 0 .../fixtures/rbhtfinder/{rbhtfinder_output.txt => output.txt} | 0 .../rbhtfinder/{rbhtfinder_threshold.txt => threshold.txt} | 0 4 files changed, 0 insertions(+), 0 deletions(-) rename rdock-utils/rdock_utils/{rbhtfinder => rbhtfinder_original} (100%) rename rdock-utils/tests/fixtures/rbhtfinder/{rbhtfinder_input.txt => input.txt} (100%) rename rdock-utils/tests/fixtures/rbhtfinder/{rbhtfinder_output.txt => output.txt} (100%) rename rdock-utils/tests/fixtures/rbhtfinder/{rbhtfinder_threshold.txt => threshold.txt} (100%) diff --git a/rdock-utils/rdock_utils/rbhtfinder b/rdock-utils/rdock_utils/rbhtfinder_original similarity index 100% rename from rdock-utils/rdock_utils/rbhtfinder rename to rdock-utils/rdock_utils/rbhtfinder_original diff --git a/rdock-utils/tests/fixtures/rbhtfinder/rbhtfinder_input.txt b/rdock-utils/tests/fixtures/rbhtfinder/input.txt similarity index 100% rename from rdock-utils/tests/fixtures/rbhtfinder/rbhtfinder_input.txt rename to rdock-utils/tests/fixtures/rbhtfinder/input.txt diff --git a/rdock-utils/tests/fixtures/rbhtfinder/rbhtfinder_output.txt b/rdock-utils/tests/fixtures/rbhtfinder/output.txt similarity index 100% rename from rdock-utils/tests/fixtures/rbhtfinder/rbhtfinder_output.txt rename to rdock-utils/tests/fixtures/rbhtfinder/output.txt diff --git a/rdock-utils/tests/fixtures/rbhtfinder/rbhtfinder_threshold.txt b/rdock-utils/tests/fixtures/rbhtfinder/threshold.txt similarity index 100% rename from rdock-utils/tests/fixtures/rbhtfinder/rbhtfinder_threshold.txt rename to rdock-utils/tests/fixtures/rbhtfinder/threshold.txt From d07ac5f5d9156b2371d3260a88b5e1f49b24a79f Mon Sep 17 00:00:00 2001 From: lpardey Date: Wed, 17 Jul 2024 22:26:55 +0000 Subject: [PATCH 02/18] fix original rbhtfinder (wasn't working) add basic test integration (all passing) generate a new input file with tabs as delimiter (original has spaces as delimiter) --- rdock-utils/requirements.txt | 3 +- .../tests/fixtures/rbhtfinder/input_tabs.txt | 101 ++++++++++++++++++ .../fixtures/rbhtfinder/original_output.txt | 7 ++ rdock-utils/tests/rbhtfinder/__init__.py | 0 rdock-utils/tests/rbhtfinder/conftest.py | 5 + .../tests/rbhtfinder/test_integration.py | 16 +++ 6 files changed, 131 insertions(+), 1 deletion(-) create mode 100644 rdock-utils/tests/fixtures/rbhtfinder/input_tabs.txt create mode 100644 rdock-utils/tests/fixtures/rbhtfinder/original_output.txt create mode 100644 rdock-utils/tests/rbhtfinder/__init__.py create mode 100644 rdock-utils/tests/rbhtfinder/conftest.py create mode 100644 rdock-utils/tests/rbhtfinder/test_integration.py diff --git a/rdock-utils/requirements.txt b/rdock-utils/requirements.txt index a113d2c3..1b76bf47 100644 --- a/rdock-utils/requirements.txt +++ b/rdock-utils/requirements.txt @@ -1,2 +1,3 @@ numpy==1.26.2 -openbabel==3.1.1.1 \ No newline at end of file +openbabel==3.1.1.1 +pandas==2.2.2 \ No newline at end of file diff --git a/rdock-utils/tests/fixtures/rbhtfinder/input_tabs.txt b/rdock-utils/tests/fixtures/rbhtfinder/input_tabs.txt new file mode 100644 index 00000000..3b261f8f --- /dev/null +++ b/rdock-utils/tests/fixtures/rbhtfinder/input_tabs.txt @@ -0,0 +1,101 @@ +REC _TITLE1 TOTAL INTER INTRA RESTR VDW +001 mol00 -16.905 -11.204 -6.416 0.715 -18.926 +002 mol00 2.595 -0.601 -1.152 4.347 -11.001 +003 mol00 -13.022 -12.572 -8.953 8.502 -20.443 +004 mol00 -16.128 -12.742 -8.977 5.591 -17.353 +005 mol00 -10.576 -4.606 -6.451 0.481 -16.707 +006 mol00 -18.429 -11.402 -8.179 1.152 -18.191 +007 mol00 -18.316 -12.749 -6.842 1.275 -21.002 +008 mol00 -13.123 -6.272 -9.001 2.150 -16.672 +009 mol00 -6.763 -7.234 -4.006 4.478 -15.995 +010 mol00 -16.302 -11.451 -5.042 0.192 -21.602 +011 mol01 -14.764 -12.244 -3.069 0.550 -16.362 +012 mol01 -8.102 -9.014 -2.509 3.421 -13.535 +013 mol01 -17.136 -13.983 -4.509 1.356 -15.128 +014 mol01 -10.791 -7.401 -4.334 0.944 -12.455 +015 mol01 -15.107 -11.770 -3.681 0.343 -12.760 +016 mol01 -15.348 -12.600 -3.085 0.337 -12.213 +017 mol01 -13.234 -9.356 -4.039 0.161 -13.449 +018 mol01 -12.883 -10.593 -2.692 0.401 -14.155 +019 mol01 -14.937 -12.053 -3.622 0.738 -16.503 +020 mol01 -15.504 -12.806 -3.140 0.442 -12.497 +021 mol02 -12.446 -11.333 -4.405 3.291 -15.701 +022 mol02 -13.334 -11.044 -2.708 0.418 -13.332 +023 mol02 -12.298 -8.953 -4.006 0.662 -13.422 +024 mol02 -10.855 -8.415 -3.033 0.593 -12.782 +025 mol02 -12.506 -9.802 -3.198 0.494 -14.579 +026 mol02 -13.582 -11.559 -2.422 0.399 -15.628 +027 mol02 -14.966 -11.346 -4.361 0.741 -16.671 +028 mol02 -15.302 -12.238 -3.389 0.324 -13.782 +029 mol02 -9.849 -9.111 -4.596 3.858 -14.011 +030 mol02 -13.621 -11.178 -2.870 0.427 -15.527 +031 mol03 -10.492 -8.634 -2.412 0.554 -12.702 +032 mol03 -16.369 -12.611 -3.925 0.166 -15.707 +033 mol03 -16.074 -12.018 -4.147 0.091 -14.921 +034 mol03 -6.623 -8.868 -2.337 4.582 -13.383 +035 mol03 -4.061 -4.354 -4.135 4.428 -11.803 +036 mol03 -16.844 -13.744 -3.429 0.329 -14.531 +037 mol03 -16.759 -14.229 -2.994 0.464 -15.433 +038 mol03 -15.680 -11.976 -3.889 0.185 -15.065 +039 mol03 -11.919 -9.693 -2.623 0.398 -14.239 +040 mol03 -8.137 -7.516 -3.235 2.614 -11.614 +041 mol04 -7.776 -6.296 -2.270 0.790 -16.535 +042 mol04 6.644 5.519 -0.566 1.691 -0.734 +043 mol04 -3.363 -7.773 0.964 3.446 -13.299 +044 mol04 -4.351 -4.121 -1.905 1.675 -11.049 +045 mol04 -2.875 -5.317 0.643 1.799 -13.852 +046 mol04 -7.823 -9.622 -0.031 1.830 -14.752 +047 mol04 -2.534 -1.876 -2.013 1.354 -10.910 +048 mol04 -13.193 -11.516 -2.048 0.371 -17.047 +049 mol04 -8.574 -9.947 1.073 0.301 -18.351 +050 mol04 -9.966 -9.181 -1.811 1.027 -14.498 +051 mol05 -5.717 -12.369 -0.344 6.997 -20.154 +052 mol05 -5.265 -9.689 0.036 4.387 -16.474 +053 mol05 -11.101 -9.229 -2.354 0.483 -17.823 +054 mol05 -3.375 -5.926 -1.281 3.832 -14.547 +055 mol05 -9.546 -12.438 -1.927 4.819 -17.671 +056 mol05 -12.771 -15.095 1.703 0.621 -17.161 +057 mol05 -19.198 -19.152 -0.788 0.743 -17.933 +058 mol05 -12.564 -13.726 -0.425 1.587 -19.786 +059 mol05 -3.387 -7.638 1.574 2.678 -16.308 +060 mol05 -14.882 -17.451 -0.477 3.045 -19.050 +061 mol06 -15.764 -17.717 0.853 1.101 -21.131 +062 mol06 -2.956 -7.275 0.313 4.006 -14.833 +063 mol06 -6.103 -12.909 2.281 4.526 -17.262 +064 mol06 1.370 -1.589 -0.619 3.579 -9.989 +065 mol06 0.980 -14.709 0.605 15.084 -20.358 +066 mol06 3.784 -6.808 8.337 2.255 -14.995 +067 mol06 -5.845 -12.679 2.130 4.704 -17.065 +068 mol06 -5.255 -12.309 4.456 2.598 -17.557 +069 mol06 -5.051 -8.500 -1.065 4.515 -12.298 +070 mol06 -8.737 -13.409 3.272 1.400 -17.974 +071 mol07 -5.945 -6.564 -0.932 1.551 -15.670 +072 mol07 -11.177 -12.429 -1.525 2.777 -15.118 +073 mol07 -3.446 -1.734 -2.958 1.246 -7.623 +074 mol07 -4.229 -5.796 -0.264 1.831 -14.220 +075 mol07 -14.958 -15.847 -0.333 1.222 -18.956 +076 mol07 -8.390 -8.507 -0.927 1.045 -14.022 +077 mol07 -5.093 -5.862 -1.992 2.761 -15.437 +078 mol07 -9.813 -12.418 -0.122 2.726 -17.489 +079 mol07 -10.936 -10.623 -1.940 1.626 -16.272 +080 mol07 -2.593 -7.660 3.906 1.162 -10.076 +081 mol08 -30.625 -10.460 -24.533 4.369 -21.331 +082 mol08 -34.896 -10.897 -28.333 4.334 -24.000 +083 mol08 -37.535 -5.959 -32.574 0.998 -17.627 +084 mol08 -24.337 -1.398 -32.330 9.391 -13.655 +085 mol08 -33.982 -6.759 -29.808 2.584 -20.003 +086 mol08 -22.908 -5.812 -32.172 15.076 -17.519 +087 mol08 -10.119 5.962 -25.259 9.178 -7.399 +088 mol08 -36.286 -7.066 -31.019 1.799 -19.466 +089 mol08 -32.439 -4.421 -28.944 0.926 -16.742 +090 mol08 -33.056 -3.138 -31.632 1.714 -16.795 +091 mol09 -37.922 -11.009 -28.015 1.102 -14.514 +092 mol09 -33.961 -11.278 -28.396 5.713 -18.027 +093 mol09 -30.177 -6.085 -27.327 3.235 -11.667 +094 mol09 -36.755 -10.942 -27.524 1.710 -16.747 +095 mol09 -27.609 -3.028 -27.462 2.881 -5.874 +096 mol09 -29.025 -10.924 -25.192 7.091 -17.062 +097 mol09 -28.521 -6.851 -28.559 6.889 -12.872 +098 mol09 -37.849 -18.828 -26.348 7.327 -18.185 +099 mol09 -33.968 -11.233 -28.349 5.614 -17.982 +100 mol09 -37.434 -10.703 -28.080 1.348 -16.012 diff --git a/rdock-utils/tests/fixtures/rbhtfinder/original_output.txt b/rdock-utils/tests/fixtures/rbhtfinder/original_output.txt new file mode 100644 index 00000000..eba0b51b --- /dev/null +++ b/rdock-utils/tests/fixtures/rbhtfinder/original_output.txt @@ -0,0 +1,7 @@ +FILTER1 NSTEPS1 THR1 PERC1 FILTER2 NSTEPS2 THR2 PERC2 TOP5_INTER ENRICH_INTER TOP5_RESTR ENRICH_RESTR TIME +INTER 3 -10.00 90.00 RESTR 5 1.00 60.00 40.00 0.67 80.00 1.33 0.7800 +INTER 3 -10.00 90.00 RESTR 5 6.00 90.00 100.00 1.11 80.00 0.89 0.9300 +INTER 3 -5.00 100.00 RESTR 5 1.00 70.00 40.00 0.57 100.00 1.43 0.8500 +INTER 3 -5.00 100.00 RESTR 5 6.00 100.00 100.00 1.00 100.00 1.00 1.0000 +INTER 3 0.00 100.00 RESTR 5 1.00 70.00 40.00 0.57 100.00 1.43 0.8500 +INTER 3 0.00 100.00 RESTR 5 6.00 100.00 100.00 1.00 100.00 1.00 1.0000 diff --git a/rdock-utils/tests/rbhtfinder/__init__.py b/rdock-utils/tests/rbhtfinder/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/rdock-utils/tests/rbhtfinder/conftest.py b/rdock-utils/tests/rbhtfinder/conftest.py new file mode 100644 index 00000000..2a442e05 --- /dev/null +++ b/rdock-utils/tests/rbhtfinder/conftest.py @@ -0,0 +1,5 @@ +from ..conftest import FIXTURES_FOLDER + +RBHTFINDER_FIXTURES_FOLDER = FIXTURES_FOLDER / "rbhtfinder" +EXPECTED_OUTPUT_FILE = str(RBHTFINDER_FIXTURES_FOLDER / "output.txt") +RESULT_OUTPUT_FILE = str(RBHTFINDER_FIXTURES_FOLDER / "original_output.txt") diff --git a/rdock-utils/tests/rbhtfinder/test_integration.py b/rdock-utils/tests/rbhtfinder/test_integration.py new file mode 100644 index 00000000..5b5466ac --- /dev/null +++ b/rdock-utils/tests/rbhtfinder/test_integration.py @@ -0,0 +1,16 @@ +import pytest + +from rdock_utils.rbhtfinder.main import main + +from .conftest import EXPECTED_OUTPUT_FILE, RESULT_OUTPUT_FILE + + +def test_do_nothing(): + with pytest.raises(SystemExit): + main() + + +def test_integration(): + # result = main() + with open(EXPECTED_OUTPUT_FILE, "r") as expected_file, open(RESULT_OUTPUT_FILE, "r") as result_file: + assert result_file.readlines() == expected_file.readlines() From 7ad462d81afe527af3518f56b79e0177d90c4ed1 Mon Sep 17 00:00:00 2001 From: lpardey Date: Wed, 17 Jul 2024 23:02:59 +0000 Subject: [PATCH 03/18] add parser (WIP) --- .../rdock_utils/rbhtfinder/__init__.py | 0 rdock-utils/rdock_utils/rbhtfinder/main.py | 445 ++++++++++++++++++ rdock-utils/rdock_utils/rbhtfinder/parser.py | 163 +++++++ .../rdock_utils/rbhtfinder/rbhtfinder.py | 0 4 files changed, 608 insertions(+) create mode 100644 rdock-utils/rdock_utils/rbhtfinder/__init__.py create mode 100644 rdock-utils/rdock_utils/rbhtfinder/main.py create mode 100644 rdock-utils/rdock_utils/rbhtfinder/parser.py create mode 100644 rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py diff --git a/rdock-utils/rdock_utils/rbhtfinder/__init__.py b/rdock-utils/rdock_utils/rbhtfinder/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/rdock-utils/rdock_utils/rbhtfinder/main.py b/rdock-utils/rdock_utils/rbhtfinder/main.py new file mode 100644 index 00000000..62e6638b --- /dev/null +++ b/rdock-utils/rdock_utils/rbhtfinder/main.py @@ -0,0 +1,445 @@ +import numpy as np + +try: + import pandas as pd +except ImportError: + pd = None +import argparse +import itertools +import multiprocessing +import os +from collections import Counter +from functools import partial + +Filter = dict[str, float] + + +def apply_threshold(scored_poses, column, steps, threshold): + """ + Filter out molecules from `scored_poses`, where the minimum score reached (for a specified `column`) after `steps` is more negative than `threshold`. + """ + # minimum score after `steps` per molecule + mins = np.min(scored_poses[:, :steps, column], axis=1) + # return those molecules where the minimum score is less than the threshold + passing_molecules = np.where(mins < threshold)[0] + return passing_molecules + + +def prepare_array(sdreport_array: np.ndarray, name_column: int) -> np.ndarray: + """ + Convert `sdreport_array` (read directly from the tsv) to 3D array (molecules x poses x columns) and filter out molecules with too few/many poses + """ + # print(sdreport_array.shape[1]) + # if name_column >= sdreport_array.shape[1]: + # raise IndexError( + # f"name_column index {name_column} is out of bounds for array with shape {sdreport_array.shape}" + # ) + + # find points in the array where the name_column changes (i.e. we are dealing with a new molecule) and split the array + split_indices = ( + np.where( + sdreport_array[:, name_column] + != np.hstack((sdreport_array[1:, name_column], sdreport_array[0, name_column])) + )[0] + + 1 + ) + split_array = np.split(sdreport_array, split_indices) + + modal_shape = Counter([n.shape for n in split_array]).most_common(1)[0] + number_of_poses = modal_shape[0][0] # find modal number of poses per molecule in the array + + split_array_clean = sum( + [ + np.array_split(n, n.shape[0] / number_of_poses) + for n in split_array + if not n.shape[0] % number_of_poses and n.shape[0] + ], + [], + ) + + if len(split_array_clean) * number_of_poses < sdreport_array.shape[0] * 0.99: + print( + f"WARNING: the number of poses provided per molecule is inconsistent. Only {len(split_array_clean)} of {int(sdreport_array.shape[0] / number_of_poses)} moleules have {number_of_poses} poses." + ) + + molecule_array = np.array(split_array_clean) + # overwrite the name column (should be the only one with dtype=str) so we can force everything to float + molecule_array[:, :, name_column] = 0 + return np.array(molecule_array, dtype=float) + + +def calculate_results_for_filter_combination( + filter_combination, + molecule_array, + filters, + min_score_indices, + number_of_validation_mols, +): + """ + For a particular combination of filters, calculate the percentage of molecules that will be filtered, the percentage of top-scoring molecules that will be filtered, and the time taken relative to exhaustive docking + """ + # mols_passed_threshold is a list of indices of molecules which have passed the applied filters. As more filters are applied, it gets smaller. Before any iteration, we initialise with all molecules passing + mols_passed_threshold = list(range(molecule_array.shape[0])) + filter_percentages = [] + number_of_simulated_poses = 0 # number of poses which we calculate would be generated, we use this to calculate the TIME column in the final output + for n, threshold in enumerate(filter_combination): + if n: + # e.g. if there are 5000 mols left after 15 steps and the last filter was at 5 steps, append 5000 * (15 - 5) to number_of_simulated_poses + number_of_simulated_poses += len(mols_passed_threshold) * (filters[n]["steps"] - filters[n - 1]["steps"]) + else: + number_of_simulated_poses += len(mols_passed_threshold) * filters[n]["steps"] + mols_passed_threshold = [ # all mols which pass the threshold and which were already in mols_passed_threshold, i.e. passed all previous filters + n + for n in apply_threshold(molecule_array, filters[n]["column"], filters[n]["steps"], threshold) + if n in mols_passed_threshold + ] + filter_percentages.append(len(mols_passed_threshold) / molecule_array.shape[0]) + number_of_simulated_poses += len(mols_passed_threshold) * (molecule_array.shape[1] - filters[-1]["steps"]) + perc_val = { + k: len([n for n in v if n in mols_passed_threshold]) / number_of_validation_mols + for k, v in min_score_indices.items() + } + return { + "filter_combination": filter_combination, + "perc_val": perc_val, + "filter_percentages": filter_percentages, + "time": number_of_simulated_poses / np.product(molecule_array.shape[:2]), + } + + +def write_output(results, filters, number_of_validation_mols, output_file, column_names): + """ + Print results as a table. The number of columns varies depending how many columns the user picked. + """ + with open(output_file, "w") as f: + # write header + for n in range(len(results[0]["filter_combination"])): + f.write(f"FILTER{n+1}\tNSTEPS{n+1}\tTHR{n+1}\tPERC{n+1}\t") + for n in results[0]["perc_val"]: + f.write(f"TOP{number_of_validation_mols}_{column_names[n]}\t") + f.write(f"ENRICH_{column_names[n]}\t") + f.write("TIME\n") + + # write results + for result in results: + for n, threshold in enumerate(result["filter_combination"]): + f.write( + f"{column_names[filters[n]['column']]}\t{filters[n]['steps']}\t{threshold:.2f}\t{result['filter_percentages'][n]*100:.2f}\t" + ) + for n in result["perc_val"]: + f.write(f"{result['perc_val'][n]*100:.2f}\t") + if result["filter_percentages"][-1]: + f.write(f"{result['perc_val'][n]/result['filter_percentages'][-1]:.2f}\t") + else: + f.write("NaN\t") + f.write(f"{result['time']:.4f}\n") + return + + +def select_best_filter_combination(results, max_time, min_perc): + """ + Very debatable how to do this... + Here we exclude all combinations with TIME < max_time and calculate an "enrichment factor" + (= percentage of validation compounds / percentage of all compounds); we select the + threshold with the highest enrichment factor + """ + min_max_values = {} + for col in results[0]["perc_val"].keys(): + vals = [result["perc_val"][col] for result in results] + min_max_values[col] = {"min": min(vals), "max": max(vals)} + time_vals = [result["time"] for result in results] + min_max_values["time"] = {"min": min(time_vals), "max": max(time_vals)} + + combination_scores = [ + sum( + [ + ( + (result["perc_val"][col] - min_max_values[col]["min"]) + / (min_max_values[col]["max"] - min_max_values[col]["min"]) + ) + for col in results[0]["perc_val"].keys() + ] + + [ + (min_max_values["time"]["max"] - result["time"]) + / (min_max_values["time"]["max"] - min_max_values["time"]["min"]) + ] + ) + if result["time"] < max_time and result["filter_percentages"][-1] >= min_perc / 100 + else 0 + for result in results + ] + return np.argmax(combination_scores) + + +def write_threshold_file(filters, best_filter_combination, threshold_file, column_names, max_number_of_runs): + with open(threshold_file, "w") as f: + # write number of filters to apply + f.write(f"{len(filters) + 1}\n") + # write each filter to a separate line + for n, filtr in enumerate(filters): + f.write( + f'if - {best_filter_combination[n]:.2f} {column_names[filtr["column"]]} 1.0 if - SCORE.NRUNS {filtr["steps"]} 0.0 -1.0,\n' + ) + # write filter to terminate docking when NRUNS reaches the number of runs used in the input file + f.write(f"if - SCORE.NRUNS {max_number_of_runs - 1} 0.0 -1.0\n") + + # write final filters - find strictest filters for all columns and apply them again + filters_by_column = { + col: [best_filter_combination[n] for n, filtr in enumerate(filters) if filtr["column"] == col] + for col in set([filtr["column"] for filtr in filters]) + } + # write number of filters (same as number of columns filtered on) + f.write(f"{len(filters_by_column)}\n") + # write filter + for col, values in filters_by_column.items(): + f.write(f"- {column_names[col]} {min(values)},\n") + + +def parse_filter(filter_str: str) -> Filter: + parsed_filter = {} + for item in filter_str.split(","): + key, value = item.split("=") + parsed_filter[key] = float(value) if key in ["interval", "min", "max"] else int(value) + parsed_filter["column"] -= 1 + return parsed_filter + + +def main(): + """ + Parse arguments; read in data; calculate filter combinations and apply them; print results + """ + parser = argparse.ArgumentParser( + description="""Estimate the results and computation time of an rDock high +throughput protocol. The following steps should be followed: +1) exhaustive docking of a small representative part of the entire + library. +2) Store the result of sdreport -t over that exhaustive docking run + in a file which will be the input of this script. +3) Run rbhtfinder, specifying -i and an arbitrary + number of filters specified using the -f option, for example + "-f column=6,steps=5,min=0.5,max=1.0,interval=0.1". This example + would simulate the effect of applying thresholds on column 6 after + 5 poses have been generated, for values between 0.5 and 1.0 (i.e. + 0.5, 0.6, 0.7, 0.8, 0.9, 1.0). More than one threshold can be + specified, e.g., "-f column=4,steps=5,min=-12,max=-10,interval=1 + -f column=4,steps=15,min=-16,max=-15,interval=1" will test the + following combinations of thresholds on column 4: + 5 -10 15 -15 + 5 -11 15 -15 + 5 -12 15 -15 + 5 -10 15 -16 + 5 -11 15 -16 + 5 -12 15 -16 + The number of combinations will increase very rapidly, the more + filters are used and the larger the range of values specified for + each. It may be sensible to run rbhtfinder several times to explore + the effects of various filters independently. + + The output of the program consists of the following columns. + FILTER1 NSTEPS1 THR1 PERC1 TOP500_SCORE.INTER ENRICH_SCORE.INTER TIME + SCORE.INTER 5 -13.00 6.04 72.80 12.05 0.0500 + SCORE.INTER 5 -12.00 9.96 82.80 8.31 0.0500 + The four columns are repeated for each filter specified with the -f + option: name of the column on which the filter is applied + (FILTER1), number of steps at which the threshold is applied + (NSTEPS1), value of the threshold (THR1) and the percentage of + poses which pass this filter (PERC1). Additional filters (FILTER2, + FILTER3 etc.) are listed in the order that they are applied (i.e. + by NSTEPS). + + The final columns provide some overall statistics for the + combination of thresholds specified in a row. TOP500_SCORE.INTER + gives the percentage of the top-scoring 500 poses, measured by + SCORE.INTER, from the whole of which are retained + after the thresholds are applied. This can be contrasted with the + final PERC column. The higher the ratio (the 'enrichment factor'), + the better the combination of thresholds. If thresholds are applied + on multiple columns, this column will be duplicated for each, e.g. + TOP500_SCORE.INTER and TOP500_SCORE.RESTR will give the percentage + of the top-scoring poses retained for both of these scoring + methods. The exact number of poses used for this validation can be + changed from the default 500 using the --validation flag. + ENRICH_SCORE.INTER gives the enrichment factor as a quick + rule-of-thumb to assess the best choice of thresholds. The final + column TIME provides an estimate of the time taken to perform + docking, as a proportion of the time taken for exhaustive docking. + This value should be below 0.1. + + After a combination of thresholds has been selected, they need to + be encoded into a threshold file which rDock can use as an input. + rbhtfinder attempts to help with this task by automatically + selecting a combination and writing a threshold file. The + combination chosen is that which provides the highest enrichment + factor, after all options with a TIME value over 0.1 are excluded. + This choice should not be blindly followed, so the threshold file + should be considered a template that the user modifies as needed. + + rbhtfinder requires NumPy. Installation of pandas is recommended, + but optional; if pandas is not available, loading the input file + for calculations will be considerably slower. + + """, + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + "-i", + "--input", + help="Input from sdreport (tabular separated format).", + type=str, + required=True, + ) + parser.add_argument( + "-o", + "--output", + help="Output file for report on threshold combinations.", + type=str, + required=True, + ) + parser.add_argument( + "-t", + "--threshold", + help="Threshold file used by rDock as input.", + type=str, + ) + parser.add_argument( + "-n", + "--name", + type=int, + default=1, # Actually, index of molecule name in input file is 1 by default + help="Index of column containing the molecule name. Default is 2.", + ) + parser.add_argument( + "-f", + "--filter", + nargs="+", + # action="append" removed in favor of simplicity and to avoid redundancy (unnecessary nested structure), + type=str, + help="Filter to apply, e.g. column=4,steps=5,min=-10,max=-15,interval=1 will test applying a filter to column 4 after generation of 5 poses, with threshold values between -10 and -15 tested. The variables column, steps, min and max must all be specified; interval defaults to 1 if not given.", + ) + parser.add_argument( + "-v", + "--validation", + type=int, + default=500, + help="Top-scoring N molecules from input to use for validating threshold combinations. Default is 500.", + ) + parser.add_argument( + "--header", + action="store_true", + help="Specify if the input file from sdreport contains a header line with column names. If not, output files will describe columns using indices, e.g. COL4, COL5.", + ) + parser.add_argument( + "--max-time", + type=float, + default=0.1, + help="Maximum value for time to use when autogenerating a high-throughput protocol - default is 0.1, i.e. 10%% of the time exhaustive docking would take.", + ) + parser.add_argument( + "--min-perc", + type=float, + default=1.0, + help="Minimum value for the estimated final percentage of compounds to use when autogenerating a high-throughput protocol - default is 1.", + ) + # input + # python rdock_utils/rbhtfinder_original_copy -i tests/fixtures/rbhtfinder/input.txt -o tests/fixtures/rbhtfinder/output.txt -t tests/fixtures/rbhtfinder/threshold.txt -f column=4,steps=3,min=-10.0,max=0.0,interval=5.0 column=6,steps=5,min=1.0,max=5.0,interval=5.0 --max-time 1 --min-perc 1.0 -v 5 --header + + args = parser.parse_args() + + print(f"args.filter: {args.filter}") + # create filters dictionary from args.filter passed in + filters = [parse_filter(filter) for filter in args.filter] + print(f"parsed filters: {filters}") + # sort filters by step at which they are applied + filters.sort(key=lambda n: n["steps"]) + + # generates all possible combinations from filters provided + fils = [(filtr["min"], filtr["max"] + filtr.get("interval", 1.0), filtr.get("interval", 1.0)) for filtr in filters] + print(f"fils: {fils}") + filter_combinations = list(itertools.product(*(np.arange(*n) for n in fils))) + print(f"combinations {filter_combinations}") + print(f"{len(filter_combinations)} combinations of filters calculated.") + + # remove redundant combinations, i.e. where filters for later steps are less or equally strict to earlier steps + filter_combinations = np.array(filter_combinations) + cols = [filtr["column"] for filtr in filters] + indices_per_col = {col: [n for n, filter_col in enumerate(cols) if col == filter_col] for col in set(cols)} + filter_combination_indices_to_keep = range(len(filter_combinations)) + for col, indices in indices_per_col.items(): + filter_combination_indices_to_keep = [ + n + for n, comb in enumerate(filter_combinations[:, indices]) + if list(comb) == sorted(comb, reverse=True) + and len(set(comb)) == comb.shape[0] + and n in filter_combination_indices_to_keep + ] + filter_combinations = filter_combinations[filter_combination_indices_to_keep] + + if len(filter_combinations): + print( + f"{len(filter_combinations)} combinations of filters remain after removal of redundant combinations. Starting calculations..." + ) + else: + print("No filter combinations could be calculated - check the thresholds specified.") + exit(1) + + if pd: + # pandas is weird... i.e., skip line 0 if there's a header, else read all lines + header = 0 if args.header else None + sdreport_dataframe = pd.read_csv(args.input, sep="\t", header=header) + if args.header: + column_names = sdreport_dataframe.columns.values + else: + # use index names; add 1 to deal with zero-based numbering + column_names = [f"COL{n+1}" for n in range(len(sdreport_dataframe.columns))] + sdreport_array = sdreport_dataframe.values + print(f"First few rows of the input array:\n{sdreport_array[:5]}") + else: # pd not available + np_array = np.loadtxt(args.input, dtype=str) + if args.header: + column_names = np_array[0] + sdreport_array = np_array[1:] + else: + column_names = [f"COL{n+1}" for n in range(np_array.shape[1])] + sdreport_array = np_array + print("Data read in from input file.") + + # convert to 3D array (molecules x poses x columns) + molecule_array = prepare_array(sdreport_array, args.name) + + # find the top scoring compounds for validation of the filter combinations + min_score_indices = {} + for column in set(filtr["column"] for filtr in filters): + min_scores = np.min(molecule_array[:, :, column], axis=1) + min_score_indices[column] = np.argpartition(min_scores, args.validation)[: args.validation] + + results = [] + + pool = multiprocessing.Pool(os.cpu_count()) + results = pool.map( + partial( + calculate_results_for_filter_combination, + molecule_array=molecule_array, + filters=filters, + min_score_indices=min_score_indices, + number_of_validation_mols=args.validation, + ), + filter_combinations, + ) + + write_output(results, filters, args.validation, args.output, column_names) + + best_filter_combination = select_best_filter_combination(results, args.max_time, args.min_perc) + if args.threshold: + if best_filter_combination: + write_threshold_file( + filters, + filter_combinations[best_filter_combination], + args.threshold, + column_names, + molecule_array.shape[1], + ) + else: + print( + "Filter combinations defined are too strict or would take too long to run; no threshold file was written." + ) + exit(1) diff --git a/rdock-utils/rdock_utils/rbhtfinder/parser.py b/rdock-utils/rdock_utils/rbhtfinder/parser.py new file mode 100644 index 00000000..50e85967 --- /dev/null +++ b/rdock-utils/rdock_utils/rbhtfinder/parser.py @@ -0,0 +1,163 @@ +import argparse +from dataclasses import dataclass + +Filter = dict[str, float] + + +@dataclass +class rbhtfinderConfig: + input: str + output: str + threshold: str + name: int + filters: list[str] + validation: int + header: bool + max_time: float + min_percentage: float + + def __post_init__(self): + self.filters = self.get_parsed_filters() + + def get_parsed_filters(self) -> list[Filter]: + parsed_filters = [self._parse_filter(filter) for filter in self.filters] + return parsed_filters + + @staticmethod + def _parse_filter(filter_str: str) -> Filter: + parsed_filter = {} + + for item in filter_str.split(","): + key, value = item.split("=") + parsed_filter[key] = float(value) if key in ["interval", "min", "max"] else int(value) + # User inputs with 1-based numbering whereas python uses 0-based + parsed_filter["column"] -= 1 + + return parsed_filter + + +def get_parser() -> argparse.ArgumentParser: + description = """ + Estimate the results and computation time of an rDock high-throughput protocol. + + Steps: + 1. Perform exhaustive docking of a small representative part of the entire library. + 2. Store the result of sdreport -t from that exhaustive docking run in a file + , which will be the input of this script. + 3. Run rbhtfinder, specifying -i and an arbitrary number of filters + using the -f option, for example, "-f column=6,steps=5,min=0.5,max=1.0,interval=0.1". + This example would simulate the effect of applying thresholds on column 6 after 5 poses + have been generated, for values between 0.5 and 1.0 (i.e., 0.5, 0.6, 0.7, 0.8, 0.9, 1.0). + More than one threshold can be specified, e.g., "-f column=4,steps=5,min=-12,max=-10, + interval=1 column=4,steps=15,min=-16,max=-15,interval=1" will test the following + combinations of thresholds on column 4: + 5 -10 15 -15 + 5 -11 15 -15 + 5 -12 15 -15 + 5 -10 15 -16 + 5 -11 15 -16 + 5 -12 15 -16 + The number of combinations will increase very rapidly, the more filters are used and the + larger the range of values specified for each. It may be sensible to run rbhtfinder several + times to explore the effects of various filters independently. + + Output: + The output of the program consists of the following columns: + FILTER1 NSTEPS1 THR1 PERC1 TOP500_SCORE.INTER ENRICH_SCORE.INTER TIME + SCORE.INTER 5 -13.00 6.04 72.80 12.05 0.0500 + SCORE.INTER 5 -12.00 9.96 82.80 8.31 0.0500 + The four columns are repeated for each filter specified with the -f option: + name of the column on which the filter is applied (FILTER1), + number of steps at which the threshold is applied (NSTEPS1), + value of the threshold (THR1) + and the percentage of poses which pass this filter (PERC1). + Additional filters (FILTER2, FILTER3 etc.) are listed in the order that they are applied + (i.e., by NSTEPS). + + The final columns provide some overall statistics for the combination of thresholds + specified in a row. TOP500_SCORE.INTER gives the percentage of the top-scoring 500 poses, + measured by SCORE.INTER, from the whole of which are retained after the + thresholds are applied. This can be contrasted with the final PERC column. The higher the + ratio (the 'enrichment factor'), the better the combination of thresholds. If thresholds are + applied on multiple columns, this column will be duplicated for each, e.g. TOP500_SCORE.INTER + and TOP500_SCORE.RESTR will give the percentage of the top-scoring poses retained for both of + these scoring methods. The exact number of poses used for this validation can be changed from + the default 500 using the --validation flag. + ENRICH_SCORE.INTER gives the enrichment factor as a quick rule-of-thumb to assess the best + choice of thresholds. The final column TIME provides an estimate of the time taken to perform + docking, as a proportion of the time taken for exhaustive docking. This value should be below + 0.1. + + After a combination of thresholds has been selected, they need to be encoded into a threshold + file which rDock can use as an input. rbhtfinder attempts to help with this task by + automatically selecting a combination and writing a threshold file. The combination chosen is + that which provides the highest enrichment factor, after all options with a TIME value over + 0.1 are excluded. This choice should not be blindly followed, so the threshold file should be + considered a template that the user modifies as needed. + + Requirements: + rbhtfinder requires NumPy. Installation of pandas is recommended, but optional; if pandas is + not available, loading the input file for calculations will be considerably slower. + """ + parser = argparse.ArgumentParser(description=description, formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument( + "-i", + "--input", + help="Input from sdreport (tabular separated format).", + type=str, + required=True, + ) + parser.add_argument( + "-o", + "--output", + help="Output file for report on threshold combinations.", + type=str, + required=True, + ) + parser.add_argument( + "-t", + "--threshold", + help="Threshold file used by rDock as input.", + type=str, + ) + parser.add_argument( + "-n", + "--name", + type=int, + default=1, # Actually, index of molecule name in input file is 1 by default + help="Index of column containing the molecule name. Default is 2.", + ) + parser.add_argument( + "-f", + "--filter", + nargs="+", + # action="append" removed in favor of simplicity and to avoid redundancy (unnecessary nested structure), + type=str, + help="Filter to apply, e.g. column=4,steps=5,min=-10,max=-15,interval=1 will test applying a filter to column 4 after generation of 5 poses, with threshold values between -10 and -15 tested. The variables column, steps, min and max must all be specified; interval defaults to 1 if not given.", + ) + parser.add_argument( + "-v", + "--validation", + type=int, + default=500, + help="Top-scoring N molecules from input to use for validating threshold combinations. Default is 500.", + ) + parser.add_argument( + "--header", + action="store_true", + help="Specify if the input file from sdreport contains a header line with column names. If not, output files will describe columns using indices, e.g. COL4, COL5.", + ) + parser.add_argument( + "--max-time", + type=float, + default=0.1, + help="Maximum value for time to use when autogenerating a high-throughput protocol - default is 0.1, i.e. 10%% of the time exhaustive docking would take.", + ) + parser.add_argument( + "--min-perc", + type=float, + default=1.0, + help="Minimum value for the estimated final percentage of compounds to use when autogenerating a high-throughput protocol - default is 1.", + ) + + args = parser.parse_args() diff --git a/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py b/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py new file mode 100644 index 00000000..e69de29b From 6661bd8b18a78099f1b8e8730fc0a93af11c0de8 Mon Sep 17 00:00:00 2001 From: lpardey Date: Thu, 18 Jul 2024 23:32:16 +0000 Subject: [PATCH 04/18] finish parser add rbhtfinder_original as python module add test integration --- rdock-utils/rdock_utils/rbhtfinder/main.py | 195 ++------ rdock-utils/rdock_utils/rbhtfinder/parser.py | 101 ++-- rdock-utils/rdock_utils/rbhtfinder_original | 73 +-- .../rdock_utils/rbhtfinder_original_copy.py | 443 ++++++++++++++++++ rdock-utils/tests/rbhtfinder/conftest.py | 39 +- .../tests/rbhtfinder/test_integration.py | 26 +- 6 files changed, 587 insertions(+), 290 deletions(-) create mode 100755 rdock-utils/rdock_utils/rbhtfinder_original_copy.py diff --git a/rdock-utils/rdock_utils/rbhtfinder/main.py b/rdock-utils/rdock_utils/rbhtfinder/main.py index 62e6638b..d9876a73 100644 --- a/rdock-utils/rdock_utils/rbhtfinder/main.py +++ b/rdock-utils/rdock_utils/rbhtfinder/main.py @@ -4,13 +4,14 @@ import pandas as pd except ImportError: pd = None -import argparse import itertools import multiprocessing import os from collections import Counter from functools import partial +from .parser import get_config + Filter = dict[str, float] @@ -204,156 +205,18 @@ def parse_filter(filter_str: str) -> Filter: return parsed_filter -def main(): - """ - Parse arguments; read in data; calculate filter combinations and apply them; print results - """ - parser = argparse.ArgumentParser( - description="""Estimate the results and computation time of an rDock high -throughput protocol. The following steps should be followed: -1) exhaustive docking of a small representative part of the entire - library. -2) Store the result of sdreport -t over that exhaustive docking run - in a file which will be the input of this script. -3) Run rbhtfinder, specifying -i and an arbitrary - number of filters specified using the -f option, for example - "-f column=6,steps=5,min=0.5,max=1.0,interval=0.1". This example - would simulate the effect of applying thresholds on column 6 after - 5 poses have been generated, for values between 0.5 and 1.0 (i.e. - 0.5, 0.6, 0.7, 0.8, 0.9, 1.0). More than one threshold can be - specified, e.g., "-f column=4,steps=5,min=-12,max=-10,interval=1 - -f column=4,steps=15,min=-16,max=-15,interval=1" will test the - following combinations of thresholds on column 4: - 5 -10 15 -15 - 5 -11 15 -15 - 5 -12 15 -15 - 5 -10 15 -16 - 5 -11 15 -16 - 5 -12 15 -16 - The number of combinations will increase very rapidly, the more - filters are used and the larger the range of values specified for - each. It may be sensible to run rbhtfinder several times to explore - the effects of various filters independently. - - The output of the program consists of the following columns. - FILTER1 NSTEPS1 THR1 PERC1 TOP500_SCORE.INTER ENRICH_SCORE.INTER TIME - SCORE.INTER 5 -13.00 6.04 72.80 12.05 0.0500 - SCORE.INTER 5 -12.00 9.96 82.80 8.31 0.0500 - The four columns are repeated for each filter specified with the -f - option: name of the column on which the filter is applied - (FILTER1), number of steps at which the threshold is applied - (NSTEPS1), value of the threshold (THR1) and the percentage of - poses which pass this filter (PERC1). Additional filters (FILTER2, - FILTER3 etc.) are listed in the order that they are applied (i.e. - by NSTEPS). - - The final columns provide some overall statistics for the - combination of thresholds specified in a row. TOP500_SCORE.INTER - gives the percentage of the top-scoring 500 poses, measured by - SCORE.INTER, from the whole of which are retained - after the thresholds are applied. This can be contrasted with the - final PERC column. The higher the ratio (the 'enrichment factor'), - the better the combination of thresholds. If thresholds are applied - on multiple columns, this column will be duplicated for each, e.g. - TOP500_SCORE.INTER and TOP500_SCORE.RESTR will give the percentage - of the top-scoring poses retained for both of these scoring - methods. The exact number of poses used for this validation can be - changed from the default 500 using the --validation flag. - ENRICH_SCORE.INTER gives the enrichment factor as a quick - rule-of-thumb to assess the best choice of thresholds. The final - column TIME provides an estimate of the time taken to perform - docking, as a proportion of the time taken for exhaustive docking. - This value should be below 0.1. - - After a combination of thresholds has been selected, they need to - be encoded into a threshold file which rDock can use as an input. - rbhtfinder attempts to help with this task by automatically - selecting a combination and writing a threshold file. The - combination chosen is that which provides the highest enrichment - factor, after all options with a TIME value over 0.1 are excluded. - This choice should not be blindly followed, so the threshold file - should be considered a template that the user modifies as needed. - - rbhtfinder requires NumPy. Installation of pandas is recommended, - but optional; if pandas is not available, loading the input file - for calculations will be considerably slower. - - """, - formatter_class=argparse.RawTextHelpFormatter, - ) - parser.add_argument( - "-i", - "--input", - help="Input from sdreport (tabular separated format).", - type=str, - required=True, - ) - parser.add_argument( - "-o", - "--output", - help="Output file for report on threshold combinations.", - type=str, - required=True, - ) - parser.add_argument( - "-t", - "--threshold", - help="Threshold file used by rDock as input.", - type=str, - ) - parser.add_argument( - "-n", - "--name", - type=int, - default=1, # Actually, index of molecule name in input file is 1 by default - help="Index of column containing the molecule name. Default is 2.", - ) - parser.add_argument( - "-f", - "--filter", - nargs="+", - # action="append" removed in favor of simplicity and to avoid redundancy (unnecessary nested structure), - type=str, - help="Filter to apply, e.g. column=4,steps=5,min=-10,max=-15,interval=1 will test applying a filter to column 4 after generation of 5 poses, with threshold values between -10 and -15 tested. The variables column, steps, min and max must all be specified; interval defaults to 1 if not given.", - ) - parser.add_argument( - "-v", - "--validation", - type=int, - default=500, - help="Top-scoring N molecules from input to use for validating threshold combinations. Default is 500.", - ) - parser.add_argument( - "--header", - action="store_true", - help="Specify if the input file from sdreport contains a header line with column names. If not, output files will describe columns using indices, e.g. COL4, COL5.", - ) - parser.add_argument( - "--max-time", - type=float, - default=0.1, - help="Maximum value for time to use when autogenerating a high-throughput protocol - default is 0.1, i.e. 10%% of the time exhaustive docking would take.", - ) - parser.add_argument( - "--min-perc", - type=float, - default=1.0, - help="Minimum value for the estimated final percentage of compounds to use when autogenerating a high-throughput protocol - default is 1.", - ) - # input - # python rdock_utils/rbhtfinder_original_copy -i tests/fixtures/rbhtfinder/input.txt -o tests/fixtures/rbhtfinder/output.txt -t tests/fixtures/rbhtfinder/threshold.txt -f column=4,steps=3,min=-10.0,max=0.0,interval=5.0 column=6,steps=5,min=1.0,max=5.0,interval=5.0 --max-time 1 --min-perc 1.0 -v 5 --header - - args = parser.parse_args() - - print(f"args.filter: {args.filter}") +def main(argv: list[str] | None = None) -> None: + config = get_config(argv) + print(f"args.filter: {config.filters}") # create filters dictionary from args.filter passed in - filters = [parse_filter(filter) for filter in args.filter] - print(f"parsed filters: {filters}") - # sort filters by step at which they are applied - filters.sort(key=lambda n: n["steps"]) + # filters = [parse_filter(filter) for filter in args.filter] + # print(f"parsed filters: {filters}") # generates all possible combinations from filters provided - fils = [(filtr["min"], filtr["max"] + filtr.get("interval", 1.0), filtr.get("interval", 1.0)) for filtr in filters] + fils = [ + (filtr["min"], filtr["max"] + filtr.get("interval", 1.0), filtr.get("interval", 1.0)) + for filtr in config.filters + ] print(f"fils: {fils}") filter_combinations = list(itertools.product(*(np.arange(*n) for n in fils))) print(f"combinations {filter_combinations}") @@ -361,7 +224,7 @@ def main(): # remove redundant combinations, i.e. where filters for later steps are less or equally strict to earlier steps filter_combinations = np.array(filter_combinations) - cols = [filtr["column"] for filtr in filters] + cols = [filtr["column"] for filtr in config.filters] indices_per_col = {col: [n for n, filter_col in enumerate(cols) if col == filter_col] for col in set(cols)} filter_combination_indices_to_keep = range(len(filter_combinations)) for col, indices in indices_per_col.items(): @@ -384,9 +247,9 @@ def main(): if pd: # pandas is weird... i.e., skip line 0 if there's a header, else read all lines - header = 0 if args.header else None - sdreport_dataframe = pd.read_csv(args.input, sep="\t", header=header) - if args.header: + header = 0 if config.header else None + sdreport_dataframe = pd.read_csv(config.input, sep="\t", header=header) + if config.header: column_names = sdreport_dataframe.columns.values else: # use index names; add 1 to deal with zero-based numbering @@ -394,8 +257,8 @@ def main(): sdreport_array = sdreport_dataframe.values print(f"First few rows of the input array:\n{sdreport_array[:5]}") else: # pd not available - np_array = np.loadtxt(args.input, dtype=str) - if args.header: + np_array = np.loadtxt(config.input, dtype=str) + if config.header: column_names = np_array[0] sdreport_array = np_array[1:] else: @@ -404,13 +267,13 @@ def main(): print("Data read in from input file.") # convert to 3D array (molecules x poses x columns) - molecule_array = prepare_array(sdreport_array, args.name) + molecule_array = prepare_array(sdreport_array, config.name) # find the top scoring compounds for validation of the filter combinations min_score_indices = {} - for column in set(filtr["column"] for filtr in filters): + for column in set(filtr["column"] for filtr in config.filters): min_scores = np.min(molecule_array[:, :, column], axis=1) - min_score_indices[column] = np.argpartition(min_scores, args.validation)[: args.validation] + min_score_indices[column] = np.argpartition(min_scores, config.validation)[: config.validation] results = [] @@ -419,22 +282,22 @@ def main(): partial( calculate_results_for_filter_combination, molecule_array=molecule_array, - filters=filters, + filters=config.filters, min_score_indices=min_score_indices, - number_of_validation_mols=args.validation, + number_of_validation_mols=config.validation, ), filter_combinations, ) - write_output(results, filters, args.validation, args.output, column_names) + write_output(results, config.filters, config.validation, config.output, column_names) - best_filter_combination = select_best_filter_combination(results, args.max_time, args.min_perc) - if args.threshold: + best_filter_combination = select_best_filter_combination(results, config.max_time, config.min_percentage) + if config.threshold: if best_filter_combination: write_threshold_file( - filters, + config.filters, filter_combinations[best_filter_combination], - args.threshold, + config.threshold, column_names, molecule_array.shape[1], ) @@ -443,3 +306,7 @@ def main(): "Filter combinations defined are too strict or would take too long to run; no threshold file was written." ) exit(1) + + +if __name__ == "__main__": + main() diff --git a/rdock-utils/rdock_utils/rbhtfinder/parser.py b/rdock-utils/rdock_utils/rbhtfinder/parser.py index 50e85967..2d687512 100644 --- a/rdock-utils/rdock_utils/rbhtfinder/parser.py +++ b/rdock-utils/rdock_utils/rbhtfinder/parser.py @@ -10,17 +10,19 @@ class rbhtfinderConfig: output: str threshold: str name: int - filters: list[str] + filters: list[Filter] validation: int header: bool max_time: float min_percentage: float - def __post_init__(self): + def __post_init__(self) -> None: self.filters = self.get_parsed_filters() def get_parsed_filters(self) -> list[Filter]: parsed_filters = [self._parse_filter(filter) for filter in self.filters] + # sort filters by step at which they are applied + parsed_filters.sort(key=lambda n: n["steps"]) return parsed_filters @staticmethod @@ -99,65 +101,40 @@ def get_parser() -> argparse.ArgumentParser: rbhtfinder requires NumPy. Installation of pandas is recommended, but optional; if pandas is not available, loading the input file for calculations will be considerably slower. """ + input_help = "Input from sdreport (tabular separated format)." + output_help = "Output file for report on threshold combinations." + threshold_help = "Threshold file used by rDock as input." + name_help = "Index of column containing the molecule name (0 indexed). Default is 1." + filter_help = "Filter to apply, e.g. column=4,steps=5,min=-10,max=-15,interval=1 will test applying a filter to column 4 after generation of 5 poses, with threshold values between -10 and -15 tested. The variables column, steps, min and max must all be specified; interval defaults to 1 if not given." + validation_help = "Top-scoring N molecules from input to use for validating threshold combinations. Default 500." + header_help = "Specify if the input file from sdreport contains a header line with column names. If not, output files will describe columns using indices, e.g. COL4, COL5." + max_time_help = "Maximum value for time to use when autogenerating a high-throughput protocol - default is 0.1, i.e. 10%% of the time exhaustive docking would take." + min_perc_help = "Minimum value for the estimated final percentage of compounds to use when autogenerating a high-throughput protocol - default is 1." + parser = argparse.ArgumentParser(description=description, formatter_class=argparse.RawTextHelpFormatter) - parser.add_argument( - "-i", - "--input", - help="Input from sdreport (tabular separated format).", - type=str, - required=True, - ) - parser.add_argument( - "-o", - "--output", - help="Output file for report on threshold combinations.", - type=str, - required=True, - ) - parser.add_argument( - "-t", - "--threshold", - help="Threshold file used by rDock as input.", - type=str, - ) - parser.add_argument( - "-n", - "--name", - type=int, - default=1, # Actually, index of molecule name in input file is 1 by default - help="Index of column containing the molecule name. Default is 2.", - ) - parser.add_argument( - "-f", - "--filter", - nargs="+", - # action="append" removed in favor of simplicity and to avoid redundancy (unnecessary nested structure), - type=str, - help="Filter to apply, e.g. column=4,steps=5,min=-10,max=-15,interval=1 will test applying a filter to column 4 after generation of 5 poses, with threshold values between -10 and -15 tested. The variables column, steps, min and max must all be specified; interval defaults to 1 if not given.", + parser.add_argument("-i", "--input", help=input_help, type=str, required=True) + parser.add_argument("-o", "--output", help=output_help, type=str, required=True) + parser.add_argument("-t", "--threshold", help=threshold_help, type=str) + parser.add_argument("-n", "--name", type=int, default=1, help=name_help) + parser.add_argument("-f", "--filters", nargs="+", type=str, help=filter_help) + parser.add_argument("-v", "--validation", type=int, default=500, help=validation_help) + parser.add_argument("--header", action="store_true", help=header_help) + parser.add_argument("--max-time", type=float, default=0.1, help=max_time_help) + parser.add_argument("--min-perc", type=float, default=1.0, help=min_perc_help) + return parser + + +def get_config(argv: list[str] | None = None) -> rbhtfinderConfig: + parser = get_parser() + args = parser.parse_args(argv) + return rbhtfinderConfig( + input=args.input, + output=args.output, + threshold=args.threshold, + name=args.name, + filters=args.filters, + validation=args.validation, + header=args.header, + max_time=args.max_time, + min_percentage=args.min_perc, ) - parser.add_argument( - "-v", - "--validation", - type=int, - default=500, - help="Top-scoring N molecules from input to use for validating threshold combinations. Default is 500.", - ) - parser.add_argument( - "--header", - action="store_true", - help="Specify if the input file from sdreport contains a header line with column names. If not, output files will describe columns using indices, e.g. COL4, COL5.", - ) - parser.add_argument( - "--max-time", - type=float, - default=0.1, - help="Maximum value for time to use when autogenerating a high-throughput protocol - default is 0.1, i.e. 10%% of the time exhaustive docking would take.", - ) - parser.add_argument( - "--min-perc", - type=float, - default=1.0, - help="Minimum value for the estimated final percentage of compounds to use when autogenerating a high-throughput protocol - default is 1.", - ) - - args = parser.parse_args() diff --git a/rdock-utils/rdock_utils/rbhtfinder_original b/rdock-utils/rdock_utils/rbhtfinder_original index 9e2123c5..c5162bb9 100755 --- a/rdock-utils/rdock_utils/rbhtfinder_original +++ b/rdock-utils/rdock_utils/rbhtfinder_original @@ -11,7 +11,6 @@ import argparse import itertools import multiprocessing import os -import sys from collections import Counter from functools import partial from pathlib import Path @@ -37,16 +36,12 @@ def prepare_array(sdreport_array, name_column): sdreport_array, np.where( sdreport_array[:, name_column] - != np.hstack( - (sdreport_array[1:, name_column], sdreport_array[0, name_column]) - ) + != np.hstack((sdreport_array[1:, name_column], sdreport_array[0, name_column])) )[0] + 1, ) modal_shape = Counter([n.shape for n in split_array]).most_common(1)[0] - number_of_poses = modal_shape[0][ - 0 - ] # find modal number of poses per molecule in the array + number_of_poses = modal_shape[0][0] # find modal number of poses per molecule in the array split_array_clean = sum( [ @@ -85,24 +80,16 @@ def calculate_results_for_filter_combination( for n, threshold in enumerate(filter_combination): if n: # e.g. if there are 5000 mols left after 15 steps and the last filter was at 5 steps, append 5000 * (15 - 5) to number_of_simulated_poses - number_of_simulated_poses += len(mols_passed_threshold) * ( - filters[n]["steps"] - filters[n - 1]["steps"] - ) + number_of_simulated_poses += len(mols_passed_threshold) * (filters[n]["steps"] - filters[n - 1]["steps"]) else: - number_of_simulated_poses += ( - len(mols_passed_threshold) * filters[n]["steps"] - ) + number_of_simulated_poses += len(mols_passed_threshold) * filters[n]["steps"] mols_passed_threshold = [ # all mols which pass the threshold and which were already in mols_passed_threshold, i.e. passed all previous filters n - for n in apply_threshold( - molecule_array, filters[n]["column"], filters[n]["steps"], threshold - ) + for n in apply_threshold(molecule_array, filters[n]["column"], filters[n]["steps"], threshold) if n in mols_passed_threshold ] filter_percentages.append(len(mols_passed_threshold) / molecule_array.shape[0]) - number_of_simulated_poses += len(mols_passed_threshold) * ( - molecule_array.shape[1] - filters[-1]["steps"] - ) + number_of_simulated_poses += len(mols_passed_threshold) * (molecule_array.shape[1] - filters[-1]["steps"]) perc_val = { k: len([n for n in v if n in mols_passed_threshold]) / number_of_validation_mols for k, v in min_score_indices.items() @@ -115,9 +102,7 @@ def calculate_results_for_filter_combination( } -def write_output( - results, filters, number_of_validation_mols, output_file, column_names -): +def write_output(results, filters, number_of_validation_mols, output_file, column_names): """ Print results as a table. The number of columns varies depending how many columns the user picked. """ @@ -139,9 +124,7 @@ def write_output( for n in result["perc_val"]: f.write(f"{result['perc_val'][n]*100:.2f}\t") if result["filter_percentages"][-1]: - f.write( - f"{result['perc_val'][n]/result['filter_percentages'][-1]:.2f}\t" - ) + f.write(f"{result['perc_val'][n]/result['filter_percentages'][-1]:.2f}\t") else: f.write("NaN\t") f.write(f"{result['time']:.4f}\n") @@ -176,17 +159,14 @@ def select_best_filter_combination(results, max_time, min_perc): / (min_max_values["time"]["max"] - min_max_values["time"]["min"]) ] ) - if result["time"] < max_time - and result["filter_percentages"][-1] >= min_perc / 100 + if result["time"] < max_time and result["filter_percentages"][-1] >= min_perc / 100 else 0 for result in results ] return np.argmax(combination_scores) -def write_threshold_file( - filters, best_filter_combination, threshold_file, column_names, max_number_of_runs -): +def write_threshold_file(filters, best_filter_combination, threshold_file, column_names, max_number_of_runs): with open(threshold_file, "w") as f: # write number of filters to apply f.write(f"{len(filters) + 1}\n") @@ -200,11 +180,7 @@ def write_threshold_file( # write final filters - find strictest filters for all columns and apply them again filters_by_column = { - col: [ - best_filter_combination[n] - for n, filtr in enumerate(filters) - if filtr["column"] == col - ] + col: [best_filter_combination[n] for n, filtr in enumerate(filters) if filtr["column"] == col] for col in set([filtr["column"] for filtr in filters]) } # write number of filters (same as number of columns filtered on) @@ -355,15 +331,9 @@ throughput protocol. The following steps should be followed: args.name -= 1 # because np arrays need 0-based indices # create filters dictionary from args.filter passed in + filters = [dict([n.split("=") for n in filtr[0].split(",")]) for filtr in args.filter] filters = [ - dict([n.split("=") for n in filtr[0].split(",")]) for filtr in args.filter - ] - filters = [ - { - k: float(v) if k in ["interval", "min", "max"] else int(v) - for k, v in filtr.items() - } - for filtr in filters + {k: float(v) if k in ["interval", "min", "max"] else int(v) for k, v in filtr.items()} for filtr in filters ] for filtr in filters: @@ -394,10 +364,7 @@ throughput protocol. The following steps should be followed: # remove redundant combinations, i.e. where filters for later steps are less or equally strict to earlier steps filter_combinations = np.array(filter_combinations) cols = [filtr["column"] for filtr in filters] - indices_per_col = { - col: [n for n, filter_col in enumerate(cols) if col == filter_col] - for col in set(cols) - } + indices_per_col = {col: [n for n, filter_col in enumerate(cols) if col == filter_col] for col in set(cols)} filter_combination_indices_to_keep = range(len(filter_combinations)) for col, indices in indices_per_col.items(): filter_combination_indices_to_keep = [ @@ -414,9 +381,7 @@ throughput protocol. The following steps should be followed: f"{len(filter_combinations)} combinations of filters remain after removal of redundant combinations. Starting calculations..." ) else: - print( - "No filter combinations could be calculated - check the thresholds specified." - ) + print("No filter combinations could be calculated - check the thresholds specified.") exit(1) if pd: @@ -446,9 +411,7 @@ throughput protocol. The following steps should be followed: min_score_indices = {} for column in set(filtr["column"] for filtr in filters): min_scores = np.min(molecule_array[:, :, column], axis=1) - min_score_indices[column] = np.argpartition(min_scores, args.validation)[ - : args.validation - ] + min_score_indices[column] = np.argpartition(min_scores, args.validation)[: args.validation] results = [] @@ -466,9 +429,7 @@ throughput protocol. The following steps should be followed: write_output(results, filters, args.validation, args.output, column_names) - best_filter_combination = select_best_filter_combination( - results, args.max_time, args.min_perc - ) + best_filter_combination = select_best_filter_combination(results, args.max_time, args.min_perc) if args.threshold: if best_filter_combination: write_threshold_file( diff --git a/rdock-utils/rdock_utils/rbhtfinder_original_copy.py b/rdock-utils/rdock_utils/rbhtfinder_original_copy.py new file mode 100755 index 00000000..07df5f52 --- /dev/null +++ b/rdock-utils/rdock_utils/rbhtfinder_original_copy.py @@ -0,0 +1,443 @@ +import numpy as np + +try: + import pandas as pd +except ImportError: + pd = None +import argparse +import itertools +import multiprocessing +import os +from collections import Counter +from functools import partial + +Filter = dict[str, float] + + +def apply_threshold(scored_poses, column, steps, threshold): + """ + Filter out molecules from `scored_poses`, where the minimum score reached (for a specified `column`) after `steps` is more negative than `threshold`. + """ + # minimum score after `steps` per molecule + mins = np.min(scored_poses[:, :steps, column], axis=1) + # return those molecules where the minimum score is less than the threshold + passing_molecules = np.where(mins < threshold)[0] + return passing_molecules + + +def prepare_array(sdreport_array: np.ndarray, name_column: int) -> np.ndarray: + """ + Convert `sdreport_array` (read directly from the tsv) to 3D array (molecules x poses x columns) and filter out molecules with too few/many poses + """ + # print(sdreport_array.shape[1]) + # if name_column >= sdreport_array.shape[1]: + # raise IndexError( + # f"name_column index {name_column} is out of bounds for array with shape {sdreport_array.shape}" + # ) + + # find points in the array where the name_column changes (i.e. we are dealing with a new molecule) and split the array + split_indices = ( + np.where( + sdreport_array[:, name_column] + != np.hstack((sdreport_array[1:, name_column], sdreport_array[0, name_column])) + )[0] + + 1 + ) + split_array = np.split(sdreport_array, split_indices) + + modal_shape = Counter([n.shape for n in split_array]).most_common(1)[0] + number_of_poses = modal_shape[0][0] # find modal number of poses per molecule in the array + + split_array_clean = sum( + [ + np.array_split(n, n.shape[0] / number_of_poses) + for n in split_array + if not n.shape[0] % number_of_poses and n.shape[0] + ], + [], + ) + + if len(split_array_clean) * number_of_poses < sdreport_array.shape[0] * 0.99: + print( + f"WARNING: the number of poses provided per molecule is inconsistent. Only {len(split_array_clean)} of {int(sdreport_array.shape[0] / number_of_poses)} moleules have {number_of_poses} poses." + ) + + molecule_array = np.array(split_array_clean) + # overwrite the name column (should be the only one with dtype=str) so we can force everything to float + molecule_array[:, :, name_column] = 0 + return np.array(molecule_array, dtype=float) + + +def calculate_results_for_filter_combination( + filter_combination, + molecule_array, + filters, + min_score_indices, + number_of_validation_mols, +): + """ + For a particular combination of filters, calculate the percentage of molecules that will be filtered, the percentage of top-scoring molecules that will be filtered, and the time taken relative to exhaustive docking + """ + # mols_passed_threshold is a list of indices of molecules which have passed the applied filters. As more filters are applied, it gets smaller. Before any iteration, we initialise with all molecules passing + mols_passed_threshold = list(range(molecule_array.shape[0])) + filter_percentages = [] + number_of_simulated_poses = 0 # number of poses which we calculate would be generated, we use this to calculate the TIME column in the final output + for n, threshold in enumerate(filter_combination): + if n: + # e.g. if there are 5000 mols left after 15 steps and the last filter was at 5 steps, append 5000 * (15 - 5) to number_of_simulated_poses + number_of_simulated_poses += len(mols_passed_threshold) * (filters[n]["steps"] - filters[n - 1]["steps"]) + else: + number_of_simulated_poses += len(mols_passed_threshold) * filters[n]["steps"] + mols_passed_threshold = [ # all mols which pass the threshold and which were already in mols_passed_threshold, i.e. passed all previous filters + n + for n in apply_threshold(molecule_array, filters[n]["column"], filters[n]["steps"], threshold) + if n in mols_passed_threshold + ] + filter_percentages.append(len(mols_passed_threshold) / molecule_array.shape[0]) + number_of_simulated_poses += len(mols_passed_threshold) * (molecule_array.shape[1] - filters[-1]["steps"]) + perc_val = { + k: len([n for n in v if n in mols_passed_threshold]) / number_of_validation_mols + for k, v in min_score_indices.items() + } + return { + "filter_combination": filter_combination, + "perc_val": perc_val, + "filter_percentages": filter_percentages, + "time": number_of_simulated_poses / np.product(molecule_array.shape[:2]), + } + + +def write_output(results, filters, number_of_validation_mols, output_file, column_names): + """ + Print results as a table. The number of columns varies depending how many columns the user picked. + """ + with open(output_file, "w") as f: + # write header + for n in range(len(results[0]["filter_combination"])): + f.write(f"FILTER{n+1}\tNSTEPS{n+1}\tTHR{n+1}\tPERC{n+1}\t") + for n in results[0]["perc_val"]: + f.write(f"TOP{number_of_validation_mols}_{column_names[n]}\t") + f.write(f"ENRICH_{column_names[n]}\t") + f.write("TIME\n") + + # write results + for result in results: + for n, threshold in enumerate(result["filter_combination"]): + f.write( + f"{column_names[filters[n]['column']]}\t{filters[n]['steps']}\t{threshold:.2f}\t{result['filter_percentages'][n]*100:.2f}\t" + ) + for n in result["perc_val"]: + f.write(f"{result['perc_val'][n]*100:.2f}\t") + if result["filter_percentages"][-1]: + f.write(f"{result['perc_val'][n]/result['filter_percentages'][-1]:.2f}\t") + else: + f.write("NaN\t") + f.write(f"{result['time']:.4f}\n") + return + + +def select_best_filter_combination(results, max_time, min_perc): + """ + Very debatable how to do this... + Here we exclude all combinations with TIME < max_time and calculate an "enrichment factor" + (= percentage of validation compounds / percentage of all compounds); we select the + threshold with the highest enrichment factor + """ + min_max_values = {} + for col in results[0]["perc_val"].keys(): + vals = [result["perc_val"][col] for result in results] + min_max_values[col] = {"min": min(vals), "max": max(vals)} + time_vals = [result["time"] for result in results] + min_max_values["time"] = {"min": min(time_vals), "max": max(time_vals)} + + combination_scores = [ + sum( + [ + ( + (result["perc_val"][col] - min_max_values[col]["min"]) + / (min_max_values[col]["max"] - min_max_values[col]["min"]) + ) + for col in results[0]["perc_val"].keys() + ] + + [ + (min_max_values["time"]["max"] - result["time"]) + / (min_max_values["time"]["max"] - min_max_values["time"]["min"]) + ] + ) + if result["time"] < max_time and result["filter_percentages"][-1] >= min_perc / 100 + else 0 + for result in results + ] + return np.argmax(combination_scores) + + +def write_threshold_file(filters, best_filter_combination, threshold_file, column_names, max_number_of_runs): + with open(threshold_file, "w") as f: + # write number of filters to apply + f.write(f"{len(filters) + 1}\n") + # write each filter to a separate line + for n, filtr in enumerate(filters): + f.write( + f'if - {best_filter_combination[n]:.2f} {column_names[filtr["column"]]} 1.0 if - SCORE.NRUNS {filtr["steps"]} 0.0 -1.0,\n' + ) + # write filter to terminate docking when NRUNS reaches the number of runs used in the input file + f.write(f"if - SCORE.NRUNS {max_number_of_runs - 1} 0.0 -1.0\n") + + # write final filters - find strictest filters for all columns and apply them again + filters_by_column = { + col: [best_filter_combination[n] for n, filtr in enumerate(filters) if filtr["column"] == col] + for col in set([filtr["column"] for filtr in filters]) + } + # write number of filters (same as number of columns filtered on) + f.write(f"{len(filters_by_column)}\n") + # write filter + for col, values in filters_by_column.items(): + f.write(f"- {column_names[col]} {min(values)},\n") + + +def parse_filter(filter_str: str) -> Filter: + parsed_filter = {} + for item in filter_str.split(","): + key, value = item.split("=") + parsed_filter[key] = float(value) if key in ["interval", "min", "max"] else int(value) + parsed_filter["column"] -= 1 + return parsed_filter + + +def main(argv: list[str] | None = None): + """ + Parse arguments; read in data; calculate filter combinations and apply them; print results + """ + parser = argparse.ArgumentParser( + description="""Estimate the results and computation time of an rDock high +throughput protocol. The following steps should be followed: +1) exhaustive docking of a small representative part of the entire + library. +2) Store the result of sdreport -t over that exhaustive docking run + in a file which will be the input of this script. +3) Run rbhtfinder, specifying -i and an arbitrary + number of filters specified using the -f option, for example + "-f column=6,steps=5,min=0.5,max=1.0,interval=0.1". This example + would simulate the effect of applying thresholds on column 6 after + 5 poses have been generated, for values between 0.5 and 1.0 (i.e. + 0.5, 0.6, 0.7, 0.8, 0.9, 1.0). More than one threshold can be + specified, e.g., "-f column=4,steps=5,min=-12,max=-10,interval=1 + -f column=4,steps=15,min=-16,max=-15,interval=1" will test the + following combinations of thresholds on column 4: + 5 -10 15 -15 + 5 -11 15 -15 + 5 -12 15 -15 + 5 -10 15 -16 + 5 -11 15 -16 + 5 -12 15 -16 + The number of combinations will increase very rapidly, the more + filters are used and the larger the range of values specified for + each. It may be sensible to run rbhtfinder several times to explore + the effects of various filters independently. + + The output of the program consists of the following columns. + FILTER1 NSTEPS1 THR1 PERC1 TOP500_SCORE.INTER ENRICH_SCORE.INTER TIME + SCORE.INTER 5 -13.00 6.04 72.80 12.05 0.0500 + SCORE.INTER 5 -12.00 9.96 82.80 8.31 0.0500 + The four columns are repeated for each filter specified with the -f + option: name of the column on which the filter is applied + (FILTER1), number of steps at which the threshold is applied + (NSTEPS1), value of the threshold (THR1) and the percentage of + poses which pass this filter (PERC1). Additional filters (FILTER2, + FILTER3 etc.) are listed in the order that they are applied (i.e. + by NSTEPS). + + The final columns provide some overall statistics for the + combination of thresholds specified in a row. TOP500_SCORE.INTER + gives the percentage of the top-scoring 500 poses, measured by + SCORE.INTER, from the whole of which are retained + after the thresholds are applied. This can be contrasted with the + final PERC column. The higher the ratio (the 'enrichment factor'), + the better the combination of thresholds. If thresholds are applied + on multiple columns, this column will be duplicated for each, e.g. + TOP500_SCORE.INTER and TOP500_SCORE.RESTR will give the percentage + of the top-scoring poses retained for both of these scoring + methods. The exact number of poses used for this validation can be + changed from the default 500 using the --validation flag. + ENRICH_SCORE.INTER gives the enrichment factor as a quick + rule-of-thumb to assess the best choice of thresholds. The final + column TIME provides an estimate of the time taken to perform + docking, as a proportion of the time taken for exhaustive docking. + This value should be below 0.1. + + After a combination of thresholds has been selected, they need to + be encoded into a threshold file which rDock can use as an input. + rbhtfinder attempts to help with this task by automatically + selecting a combination and writing a threshold file. The + combination chosen is that which provides the highest enrichment + factor, after all options with a TIME value over 0.1 are excluded. + This choice should not be blindly followed, so the threshold file + should be considered a template that the user modifies as needed. + + rbhtfinder requires NumPy. Installation of pandas is recommended, + but optional; if pandas is not available, loading the input file + for calculations will be considerably slower. + + """, + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + "-i", + "--input", + help="Input from sdreport (tabular separated format).", + type=str, + required=True, + ) + parser.add_argument( + "-o", + "--output", + help="Output file for report on threshold combinations.", + type=str, + required=True, + ) + parser.add_argument( + "-t", + "--threshold", + help="Threshold file used by rDock as input.", + type=str, + ) + parser.add_argument( + "-n", + "--name", + type=int, + default=1, # Index of molecule name in input file is 1 by default + help="Index of column containing the molecule name (0 indexed). Default is 1.", + ) + parser.add_argument( + "-f", + "--filter", + nargs="+", + type=str, + help="Filter to apply, e.g. column=4,steps=5,min=-10,max=-15,interval=1 will test applying a filter to column 4 after generation of 5 poses, with threshold values between -10 and -15 tested. The variables column, steps, min and max must all be specified; interval defaults to 1 if not given.", + ) # Removed action 'append' to avoid unnecessary nested structure + parser.add_argument( + "-v", + "--validation", + type=int, + default=500, + help="Top-scoring N molecules from input to use for validating threshold combinations. Default is 500.", + ) + parser.add_argument( + "--header", + action="store_true", + help="Specify if the input file from sdreport contains a header line with column names. If not, output files will describe columns using indices, e.g. COL4, COL5.", + ) + parser.add_argument( + "--max-time", + type=float, + default=0.1, + help="Maximum value for time to use when autogenerating a high-throughput protocol - default is 0.1, i.e. 10%% of the time exhaustive docking would take.", + ) + parser.add_argument( + "--min-perc", + type=float, + default=1.0, + help="Minimum value for the estimated final percentage of compounds to use when autogenerating a high-throughput protocol - default is 1.", + ) + + args = parser.parse_args(argv) + + # create filters dictionary from args.filter passed in + filters = [parse_filter(filter) for filter in args.filter] + + # sort filters by step at which they are applied + filters.sort(key=lambda n: n["steps"]) + + # generates all possible combinations from filters provided + fils = [(filtr["min"], filtr["max"] + filtr.get("interval", 1.0), filtr.get("interval", 1.0)) for filtr in filters] + filter_combinations = list(itertools.product(*(np.arange(*n) for n in fils))) + print(f"{len(filter_combinations)} combinations of filters calculated.") + + # remove redundant combinations, i.e. where filters for later steps are less or equally strict to earlier steps + filter_combinations = np.array(filter_combinations) + cols = [filtr["column"] for filtr in filters] + indices_per_col = {col: [n for n, filter_col in enumerate(cols) if col == filter_col] for col in set(cols)} + filter_combination_indices_to_keep = range(len(filter_combinations)) + for col, indices in indices_per_col.items(): + filter_combination_indices_to_keep = [ + n + for n, comb in enumerate(filter_combinations[:, indices]) + if list(comb) == sorted(comb, reverse=True) + and len(set(comb)) == comb.shape[0] + and n in filter_combination_indices_to_keep + ] + filter_combinations = filter_combinations[filter_combination_indices_to_keep] + + if len(filter_combinations): + print( + f"{len(filter_combinations)} combinations of filters remain after removal of redundant combinations. Starting calculations..." + ) + else: + print("No filter combinations could be calculated - check the thresholds specified.") + exit(1) + + if pd: + # pandas is weird... i.e., skip line 0 if there's a header, else read all lines + header = 0 if args.header else None + sdreport_dataframe = pd.read_csv(args.input, sep="\t", header=header) + if args.header: + column_names = sdreport_dataframe.columns.values + else: + # use index names; add 1 to deal with zero-based numbering + column_names = [f"COL{n+1}" for n in range(len(sdreport_dataframe.columns))] + sdreport_array = sdreport_dataframe.values + print(f"First few rows of the input array:\n{sdreport_array[:5]}") + else: # pd not available + np_array = np.loadtxt(args.input, dtype=str) + if args.header: + column_names = np_array[0] + sdreport_array = np_array[1:] + else: + column_names = [f"COL{n+1}" for n in range(np_array.shape[1])] + sdreport_array = np_array + print("Data read in from input file.") + + # convert to 3D array (molecules x poses x columns) + molecule_array = prepare_array(sdreport_array, args.name) + + # find the top scoring compounds for validation of the filter combinations + min_score_indices = {} + for column in set(filtr["column"] for filtr in filters): + min_scores = np.min(molecule_array[:, :, column], axis=1) + min_score_indices[column] = np.argpartition(min_scores, args.validation)[: args.validation] + + results = [] + + pool = multiprocessing.Pool(os.cpu_count()) + results = pool.map( + partial( + calculate_results_for_filter_combination, + molecule_array=molecule_array, + filters=filters, + min_score_indices=min_score_indices, + number_of_validation_mols=args.validation, + ), + filter_combinations, + ) + + write_output(results, filters, args.validation, args.output, column_names) + + best_filter_combination = select_best_filter_combination(results, args.max_time, args.min_perc) + if args.threshold: + if best_filter_combination: + write_threshold_file( + filters, + filter_combinations[best_filter_combination], + args.threshold, + column_names, + molecule_array.shape[1], + ) + else: + print( + "Filter combinations defined are too strict or would take too long to run; no threshold file was written." + ) + exit(1) + + +if __name__ == "__main__": + main() diff --git a/rdock-utils/tests/rbhtfinder/conftest.py b/rdock-utils/tests/rbhtfinder/conftest.py index 2a442e05..c8f4ea98 100644 --- a/rdock-utils/tests/rbhtfinder/conftest.py +++ b/rdock-utils/tests/rbhtfinder/conftest.py @@ -1,5 +1,42 @@ +import pytest + from ..conftest import FIXTURES_FOLDER RBHTFINDER_FIXTURES_FOLDER = FIXTURES_FOLDER / "rbhtfinder" + +INPUT_FILE = str(RBHTFINDER_FIXTURES_FOLDER / "input_tabs.txt") +THRESHOLD_FILE = str(RBHTFINDER_FIXTURES_FOLDER / "threshold.txt") EXPECTED_OUTPUT_FILE = str(RBHTFINDER_FIXTURES_FOLDER / "output.txt") -RESULT_OUTPUT_FILE = str(RBHTFINDER_FIXTURES_FOLDER / "original_output.txt") + + +@pytest.fixture +def file_path(tmp_path): + output_path = tmp_path / "output.txt" + return output_path + + +@pytest.fixture +def argv(file_path): + return [ + "-i", + INPUT_FILE, + "-o", + str(file_path), + "-t", + THRESHOLD_FILE, + "-f", + "column=4,steps=3,min=-10.0,max=0.0,interval=5.0", + "column=6,steps=5,min=1.0,max=5.0,interval=5.0", + "--max-time", + "1", + "--min-perc", + "1.0", + "-v", + "5", + "--header", + ] + + +def get_file_content(file: str) -> list[str]: + with open(file, "r") as f: + return f.readlines() diff --git a/rdock-utils/tests/rbhtfinder/test_integration.py b/rdock-utils/tests/rbhtfinder/test_integration.py index 5b5466ac..f0f092a7 100644 --- a/rdock-utils/tests/rbhtfinder/test_integration.py +++ b/rdock-utils/tests/rbhtfinder/test_integration.py @@ -1,16 +1,28 @@ import pytest -from rdock_utils.rbhtfinder.main import main +from rdock_utils.rbhtfinder.main import main as rbhtfinder_main +from rdock_utils.rbhtfinder_original_copy import main as rbhtfinder_old_main -from .conftest import EXPECTED_OUTPUT_FILE, RESULT_OUTPUT_FILE +from .conftest import EXPECTED_OUTPUT_FILE, get_file_content +parametrize_main = pytest.mark.parametrize( + "main", + [ + pytest.param(rbhtfinder_old_main, id="Original version Python 3"), + pytest.param(rbhtfinder_main, id="Improved version Python 3.12"), + ], +) -def test_do_nothing(): + +@parametrize_main +def test_do_nothing(main): with pytest.raises(SystemExit): main() -def test_integration(): - # result = main() - with open(EXPECTED_OUTPUT_FILE, "r") as expected_file, open(RESULT_OUTPUT_FILE, "r") as result_file: - assert result_file.readlines() == expected_file.readlines() +@parametrize_main +def test_integration(main, file_path, argv): + main(argv) + result = get_file_content(file_path) + expected_result = get_file_content(EXPECTED_OUTPUT_FILE) + assert result == expected_result From c48bf3d0108f8268515ec089c352c860489c4384 Mon Sep 17 00:00:00 2001 From: lpardey Date: Fri, 19 Jul 2024 00:50:01 +0000 Subject: [PATCH 05/18] refactor main logic (WIP) --- rdock-utils/rdock_utils/rbhtfinder/main.py | 87 ++++++++++++-------- rdock-utils/rdock_utils/rbhtfinder/parser.py | 3 + 2 files changed, 54 insertions(+), 36 deletions(-) diff --git a/rdock-utils/rdock_utils/rbhtfinder/main.py b/rdock-utils/rdock_utils/rbhtfinder/main.py index d9876a73..33bd6084 100644 --- a/rdock-utils/rdock_utils/rbhtfinder/main.py +++ b/rdock-utils/rdock_utils/rbhtfinder/main.py @@ -196,50 +196,65 @@ def write_threshold_file(filters, best_filter_combination, threshold_file, colum f.write(f"- {column_names[col]} {min(values)},\n") -def parse_filter(filter_str: str) -> Filter: - parsed_filter = {} - for item in filter_str.split(","): - key, value = item.split("=") - parsed_filter[key] = float(value) if key in ["interval", "min", "max"] else int(value) - parsed_filter["column"] -= 1 - return parsed_filter +def generate_all_filter_combinations(filters: list[Filter]) -> list[tuple]: + filter_ranges = ((filter["min"], filter["max"] + filter["interval"], filter["interval"]) for filter in filters) + combinations = (np.arange(*fr) for fr in filter_ranges) + all_filter_combinations = list(itertools.product(*combinations)) + return all_filter_combinations + + +def remove_redundant_combinations(all_combinations: list[tuple], filters: list[Filter]) -> list[tuple]: + all_combinations_array = np.array(all_combinations) + columns = [filter["column"] for filter in filters] + indices_per_col = {col: [i for i, c in enumerate(columns) if c == col] for col in set(columns)} + + # Create a mask to keep only valid combinations + mask = np.ones(len(all_combinations_array), dtype=bool) + + for _, indices in indices_per_col.items(): + col_data = all_combinations_array[:, indices] + sorted_data = np.sort(col_data, axis=1)[:, ::-1] # sort descending + is_valid = np.all(col_data == sorted_data, axis=1) # Check if sorted matches original + is_unique = np.apply_along_axis(lambda x: len(set(x)) == len(x), 1, col_data) + mask &= is_valid & is_unique + + filtered_combinations = all_combinations_array[mask] + return filtered_combinations def main(argv: list[str] | None = None) -> None: config = get_config(argv) - print(f"args.filter: {config.filters}") - # create filters dictionary from args.filter passed in - # filters = [parse_filter(filter) for filter in args.filter] - # print(f"parsed filters: {filters}") # generates all possible combinations from filters provided - fils = [ - (filtr["min"], filtr["max"] + filtr.get("interval", 1.0), filtr.get("interval", 1.0)) - for filtr in config.filters - ] - print(f"fils: {fils}") - filter_combinations = list(itertools.product(*(np.arange(*n) for n in fils))) - print(f"combinations {filter_combinations}") - print(f"{len(filter_combinations)} combinations of filters calculated.") + # filter_ranges = ( + # (filter["min"], filter["max"] + filter["interval"], filter["interval"]) for filter in config.filters + # ) + # combinations = (np.arange(*fr) for fr in filter_ranges) + # filter_combinations = list(itertools.product(*combinations)) + all_filter_combinations = generate_all_filter_combinations(config.filters) + print(f"{len(all_filter_combinations)} combinations of filters calculated.") # remove redundant combinations, i.e. where filters for later steps are less or equally strict to earlier steps - filter_combinations = np.array(filter_combinations) - cols = [filtr["column"] for filtr in config.filters] - indices_per_col = {col: [n for n, filter_col in enumerate(cols) if col == filter_col] for col in set(cols)} - filter_combination_indices_to_keep = range(len(filter_combinations)) - for col, indices in indices_per_col.items(): - filter_combination_indices_to_keep = [ - n - for n, comb in enumerate(filter_combinations[:, indices]) - if list(comb) == sorted(comb, reverse=True) - and len(set(comb)) == comb.shape[0] - and n in filter_combination_indices_to_keep - ] - filter_combinations = filter_combinations[filter_combination_indices_to_keep] + # filter_combinations_array = np.array(all_filter_combinations) + # columns = [filter["column"] for filter in config.filters] + # indices_per_col = {col: [i for i, c in enumerate(columns) if c == col] for col in set(columns)} + # # Create a mask to keep only valid combinations + # mask = np.ones(len(filter_combinations_array), dtype=bool) + + # for _, indices in indices_per_col.items(): + # col_data = filter_combinations_array[:, indices] + # sorted_data = np.sort(col_data, axis=1)[:, ::-1] # sort descending + # is_valid = np.all(col_data == sorted_data, axis=1) # Check if sorted matches original + # is_unique = np.apply_along_axis(lambda x: len(set(x)) == len(x), 1, col_data) + # mask &= is_valid & is_unique + + # cleaned_filter_combinations = filter_combinations_array[mask] + + distinct_filter_combinations = remove_redundant_combinations(all_filter_combinations, config.filters) - if len(filter_combinations): + if len(distinct_filter_combinations): print( - f"{len(filter_combinations)} combinations of filters remain after removal of redundant combinations. Starting calculations..." + f"{len(distinct_filter_combinations)} combinations of filters remain after removal of redundant combinations. Starting calculations..." ) else: print("No filter combinations could be calculated - check the thresholds specified.") @@ -286,7 +301,7 @@ def main(argv: list[str] | None = None) -> None: min_score_indices=min_score_indices, number_of_validation_mols=config.validation, ), - filter_combinations, + distinct_filter_combinations, ) write_output(results, config.filters, config.validation, config.output, column_names) @@ -296,7 +311,7 @@ def main(argv: list[str] | None = None) -> None: if best_filter_combination: write_threshold_file( config.filters, - filter_combinations[best_filter_combination], + distinct_filter_combinations[best_filter_combination], config.threshold, column_names, molecule_array.shape[1], diff --git a/rdock-utils/rdock_utils/rbhtfinder/parser.py b/rdock-utils/rdock_utils/rbhtfinder/parser.py index 2d687512..3dba4508 100644 --- a/rdock-utils/rdock_utils/rbhtfinder/parser.py +++ b/rdock-utils/rdock_utils/rbhtfinder/parser.py @@ -35,6 +35,9 @@ def _parse_filter(filter_str: str) -> Filter: # User inputs with 1-based numbering whereas python uses 0-based parsed_filter["column"] -= 1 + if "interval" not in parsed_filter: + parsed_filter["interval"] = 1.0 + return parsed_filter From 07e0eb972dc4af865b6e6a0df44bf4b3851b267b Mon Sep 17 00:00:00 2001 From: lpardey Date: Fri, 19 Jul 2024 15:52:04 +0000 Subject: [PATCH 06/18] refactor fixture files --- .../tests/fixtures/rbhtfinder/input.txt | 202 +++++++++--------- .../tests/fixtures/rbhtfinder/input_tabs.txt | 101 --------- .../fixtures/rbhtfinder/original_output.txt | 7 - rdock-utils/tests/rbhtfinder/conftest.py | 12 +- .../tests/rbhtfinder/test_integration.py | 7 +- 5 files changed, 113 insertions(+), 216 deletions(-) delete mode 100644 rdock-utils/tests/fixtures/rbhtfinder/input_tabs.txt delete mode 100644 rdock-utils/tests/fixtures/rbhtfinder/original_output.txt diff --git a/rdock-utils/tests/fixtures/rbhtfinder/input.txt b/rdock-utils/tests/fixtures/rbhtfinder/input.txt index 0d9277c6..3b261f8f 100644 --- a/rdock-utils/tests/fixtures/rbhtfinder/input.txt +++ b/rdock-utils/tests/fixtures/rbhtfinder/input.txt @@ -1,101 +1,101 @@ -REC _TITLE1 TOTAL INTER INTRA RESTR VDW -001 mol00 -16.905 -11.204 -6.416 0.715 -18.926 -002 mol00 2.595 -0.601 -1.152 4.347 -11.001 -003 mol00 -13.022 -12.572 -8.953 8.502 -20.443 -004 mol00 -16.128 -12.742 -8.977 5.591 -17.353 -005 mol00 -10.576 -4.606 -6.451 0.481 -16.707 -006 mol00 -18.429 -11.402 -8.179 1.152 -18.191 -007 mol00 -18.316 -12.749 -6.842 1.275 -21.002 -008 mol00 -13.123 -6.272 -9.001 2.150 -16.672 -009 mol00 -6.763 -7.234 -4.006 4.478 -15.995 -010 mol00 -16.302 -11.451 -5.042 0.192 -21.602 -011 mol01 -14.764 -12.244 -3.069 0.550 -16.362 -012 mol01 -8.102 -9.014 -2.509 3.421 -13.535 -013 mol01 -17.136 -13.983 -4.509 1.356 -15.128 -014 mol01 -10.791 -7.401 -4.334 0.944 -12.455 -015 mol01 -15.107 -11.770 -3.681 0.343 -12.760 -016 mol01 -15.348 -12.600 -3.085 0.337 -12.213 -017 mol01 -13.234 -9.356 -4.039 0.161 -13.449 -018 mol01 -12.883 -10.593 -2.692 0.401 -14.155 -019 mol01 -14.937 -12.053 -3.622 0.738 -16.503 -020 mol01 -15.504 -12.806 -3.140 0.442 -12.497 -021 mol02 -12.446 -11.333 -4.405 3.291 -15.701 -022 mol02 -13.334 -11.044 -2.708 0.418 -13.332 -023 mol02 -12.298 -8.953 -4.006 0.662 -13.422 -024 mol02 -10.855 -8.415 -3.033 0.593 -12.782 -025 mol02 -12.506 -9.802 -3.198 0.494 -14.579 -026 mol02 -13.582 -11.559 -2.422 0.399 -15.628 -027 mol02 -14.966 -11.346 -4.361 0.741 -16.671 -028 mol02 -15.302 -12.238 -3.389 0.324 -13.782 -029 mol02 -9.849 -9.111 -4.596 3.858 -14.011 -030 mol02 -13.621 -11.178 -2.870 0.427 -15.527 -031 mol03 -10.492 -8.634 -2.412 0.554 -12.702 -032 mol03 -16.369 -12.611 -3.925 0.166 -15.707 -033 mol03 -16.074 -12.018 -4.147 0.091 -14.921 -034 mol03 -6.623 -8.868 -2.337 4.582 -13.383 -035 mol03 -4.061 -4.354 -4.135 4.428 -11.803 -036 mol03 -16.844 -13.744 -3.429 0.329 -14.531 -037 mol03 -16.759 -14.229 -2.994 0.464 -15.433 -038 mol03 -15.680 -11.976 -3.889 0.185 -15.065 -039 mol03 -11.919 -9.693 -2.623 0.398 -14.239 -040 mol03 -8.137 -7.516 -3.235 2.614 -11.614 -041 mol04 -7.776 -6.296 -2.270 0.790 -16.535 -042 mol04 6.644 5.519 -0.566 1.691 -0.734 -043 mol04 -3.363 -7.773 0.964 3.446 -13.299 -044 mol04 -4.351 -4.121 -1.905 1.675 -11.049 -045 mol04 -2.875 -5.317 0.643 1.799 -13.852 -046 mol04 -7.823 -9.622 -0.031 1.830 -14.752 -047 mol04 -2.534 -1.876 -2.013 1.354 -10.910 -048 mol04 -13.193 -11.516 -2.048 0.371 -17.047 -049 mol04 -8.574 -9.947 1.073 0.301 -18.351 -050 mol04 -9.966 -9.181 -1.811 1.027 -14.498 -051 mol05 -5.717 -12.369 -0.344 6.997 -20.154 -052 mol05 -5.265 -9.689 0.036 4.387 -16.474 -053 mol05 -11.101 -9.229 -2.354 0.483 -17.823 -054 mol05 -3.375 -5.926 -1.281 3.832 -14.547 -055 mol05 -9.546 -12.438 -1.927 4.819 -17.671 -056 mol05 -12.771 -15.095 1.703 0.621 -17.161 -057 mol05 -19.198 -19.152 -0.788 0.743 -17.933 -058 mol05 -12.564 -13.726 -0.425 1.587 -19.786 -059 mol05 -3.387 -7.638 1.574 2.678 -16.308 -060 mol05 -14.882 -17.451 -0.477 3.045 -19.050 -061 mol06 -15.764 -17.717 0.853 1.101 -21.131 -062 mol06 -2.956 -7.275 0.313 4.006 -14.833 -063 mol06 -6.103 -12.909 2.281 4.526 -17.262 -064 mol06 1.370 -1.589 -0.619 3.579 -9.989 -065 mol06 0.980 -14.709 0.605 15.084 -20.358 -066 mol06 3.784 -6.808 8.337 2.255 -14.995 -067 mol06 -5.845 -12.679 2.130 4.704 -17.065 -068 mol06 -5.255 -12.309 4.456 2.598 -17.557 -069 mol06 -5.051 -8.500 -1.065 4.515 -12.298 -070 mol06 -8.737 -13.409 3.272 1.400 -17.974 -071 mol07 -5.945 -6.564 -0.932 1.551 -15.670 -072 mol07 -11.177 -12.429 -1.525 2.777 -15.118 -073 mol07 -3.446 -1.734 -2.958 1.246 -7.623 -074 mol07 -4.229 -5.796 -0.264 1.831 -14.220 -075 mol07 -14.958 -15.847 -0.333 1.222 -18.956 -076 mol07 -8.390 -8.507 -0.927 1.045 -14.022 -077 mol07 -5.093 -5.862 -1.992 2.761 -15.437 -078 mol07 -9.813 -12.418 -0.122 2.726 -17.489 -079 mol07 -10.936 -10.623 -1.940 1.626 -16.272 -080 mol07 -2.593 -7.660 3.906 1.162 -10.076 -081 mol08 -30.625 -10.460 -24.533 4.369 -21.331 -082 mol08 -34.896 -10.897 -28.333 4.334 -24.000 -083 mol08 -37.535 -5.959 -32.574 0.998 -17.627 -084 mol08 -24.337 -1.398 -32.330 9.391 -13.655 -085 mol08 -33.982 -6.759 -29.808 2.584 -20.003 -086 mol08 -22.908 -5.812 -32.172 15.076 -17.519 -087 mol08 -10.119 5.962 -25.259 9.178 -7.399 -088 mol08 -36.286 -7.066 -31.019 1.799 -19.466 -089 mol08 -32.439 -4.421 -28.944 0.926 -16.742 -090 mol08 -33.056 -3.138 -31.632 1.714 -16.795 -091 mol09 -37.922 -11.009 -28.015 1.102 -14.514 -092 mol09 -33.961 -11.278 -28.396 5.713 -18.027 -093 mol09 -30.177 -6.085 -27.327 3.235 -11.667 -094 mol09 -36.755 -10.942 -27.524 1.710 -16.747 -095 mol09 -27.609 -3.028 -27.462 2.881 -5.874 -096 mol09 -29.025 -10.924 -25.192 7.091 -17.062 -097 mol09 -28.521 -6.851 -28.559 6.889 -12.872 -098 mol09 -37.849 -18.828 -26.348 7.327 -18.185 -099 mol09 -33.968 -11.233 -28.349 5.614 -17.982 -100 mol09 -37.434 -10.703 -28.080 1.348 -16.012 +REC _TITLE1 TOTAL INTER INTRA RESTR VDW +001 mol00 -16.905 -11.204 -6.416 0.715 -18.926 +002 mol00 2.595 -0.601 -1.152 4.347 -11.001 +003 mol00 -13.022 -12.572 -8.953 8.502 -20.443 +004 mol00 -16.128 -12.742 -8.977 5.591 -17.353 +005 mol00 -10.576 -4.606 -6.451 0.481 -16.707 +006 mol00 -18.429 -11.402 -8.179 1.152 -18.191 +007 mol00 -18.316 -12.749 -6.842 1.275 -21.002 +008 mol00 -13.123 -6.272 -9.001 2.150 -16.672 +009 mol00 -6.763 -7.234 -4.006 4.478 -15.995 +010 mol00 -16.302 -11.451 -5.042 0.192 -21.602 +011 mol01 -14.764 -12.244 -3.069 0.550 -16.362 +012 mol01 -8.102 -9.014 -2.509 3.421 -13.535 +013 mol01 -17.136 -13.983 -4.509 1.356 -15.128 +014 mol01 -10.791 -7.401 -4.334 0.944 -12.455 +015 mol01 -15.107 -11.770 -3.681 0.343 -12.760 +016 mol01 -15.348 -12.600 -3.085 0.337 -12.213 +017 mol01 -13.234 -9.356 -4.039 0.161 -13.449 +018 mol01 -12.883 -10.593 -2.692 0.401 -14.155 +019 mol01 -14.937 -12.053 -3.622 0.738 -16.503 +020 mol01 -15.504 -12.806 -3.140 0.442 -12.497 +021 mol02 -12.446 -11.333 -4.405 3.291 -15.701 +022 mol02 -13.334 -11.044 -2.708 0.418 -13.332 +023 mol02 -12.298 -8.953 -4.006 0.662 -13.422 +024 mol02 -10.855 -8.415 -3.033 0.593 -12.782 +025 mol02 -12.506 -9.802 -3.198 0.494 -14.579 +026 mol02 -13.582 -11.559 -2.422 0.399 -15.628 +027 mol02 -14.966 -11.346 -4.361 0.741 -16.671 +028 mol02 -15.302 -12.238 -3.389 0.324 -13.782 +029 mol02 -9.849 -9.111 -4.596 3.858 -14.011 +030 mol02 -13.621 -11.178 -2.870 0.427 -15.527 +031 mol03 -10.492 -8.634 -2.412 0.554 -12.702 +032 mol03 -16.369 -12.611 -3.925 0.166 -15.707 +033 mol03 -16.074 -12.018 -4.147 0.091 -14.921 +034 mol03 -6.623 -8.868 -2.337 4.582 -13.383 +035 mol03 -4.061 -4.354 -4.135 4.428 -11.803 +036 mol03 -16.844 -13.744 -3.429 0.329 -14.531 +037 mol03 -16.759 -14.229 -2.994 0.464 -15.433 +038 mol03 -15.680 -11.976 -3.889 0.185 -15.065 +039 mol03 -11.919 -9.693 -2.623 0.398 -14.239 +040 mol03 -8.137 -7.516 -3.235 2.614 -11.614 +041 mol04 -7.776 -6.296 -2.270 0.790 -16.535 +042 mol04 6.644 5.519 -0.566 1.691 -0.734 +043 mol04 -3.363 -7.773 0.964 3.446 -13.299 +044 mol04 -4.351 -4.121 -1.905 1.675 -11.049 +045 mol04 -2.875 -5.317 0.643 1.799 -13.852 +046 mol04 -7.823 -9.622 -0.031 1.830 -14.752 +047 mol04 -2.534 -1.876 -2.013 1.354 -10.910 +048 mol04 -13.193 -11.516 -2.048 0.371 -17.047 +049 mol04 -8.574 -9.947 1.073 0.301 -18.351 +050 mol04 -9.966 -9.181 -1.811 1.027 -14.498 +051 mol05 -5.717 -12.369 -0.344 6.997 -20.154 +052 mol05 -5.265 -9.689 0.036 4.387 -16.474 +053 mol05 -11.101 -9.229 -2.354 0.483 -17.823 +054 mol05 -3.375 -5.926 -1.281 3.832 -14.547 +055 mol05 -9.546 -12.438 -1.927 4.819 -17.671 +056 mol05 -12.771 -15.095 1.703 0.621 -17.161 +057 mol05 -19.198 -19.152 -0.788 0.743 -17.933 +058 mol05 -12.564 -13.726 -0.425 1.587 -19.786 +059 mol05 -3.387 -7.638 1.574 2.678 -16.308 +060 mol05 -14.882 -17.451 -0.477 3.045 -19.050 +061 mol06 -15.764 -17.717 0.853 1.101 -21.131 +062 mol06 -2.956 -7.275 0.313 4.006 -14.833 +063 mol06 -6.103 -12.909 2.281 4.526 -17.262 +064 mol06 1.370 -1.589 -0.619 3.579 -9.989 +065 mol06 0.980 -14.709 0.605 15.084 -20.358 +066 mol06 3.784 -6.808 8.337 2.255 -14.995 +067 mol06 -5.845 -12.679 2.130 4.704 -17.065 +068 mol06 -5.255 -12.309 4.456 2.598 -17.557 +069 mol06 -5.051 -8.500 -1.065 4.515 -12.298 +070 mol06 -8.737 -13.409 3.272 1.400 -17.974 +071 mol07 -5.945 -6.564 -0.932 1.551 -15.670 +072 mol07 -11.177 -12.429 -1.525 2.777 -15.118 +073 mol07 -3.446 -1.734 -2.958 1.246 -7.623 +074 mol07 -4.229 -5.796 -0.264 1.831 -14.220 +075 mol07 -14.958 -15.847 -0.333 1.222 -18.956 +076 mol07 -8.390 -8.507 -0.927 1.045 -14.022 +077 mol07 -5.093 -5.862 -1.992 2.761 -15.437 +078 mol07 -9.813 -12.418 -0.122 2.726 -17.489 +079 mol07 -10.936 -10.623 -1.940 1.626 -16.272 +080 mol07 -2.593 -7.660 3.906 1.162 -10.076 +081 mol08 -30.625 -10.460 -24.533 4.369 -21.331 +082 mol08 -34.896 -10.897 -28.333 4.334 -24.000 +083 mol08 -37.535 -5.959 -32.574 0.998 -17.627 +084 mol08 -24.337 -1.398 -32.330 9.391 -13.655 +085 mol08 -33.982 -6.759 -29.808 2.584 -20.003 +086 mol08 -22.908 -5.812 -32.172 15.076 -17.519 +087 mol08 -10.119 5.962 -25.259 9.178 -7.399 +088 mol08 -36.286 -7.066 -31.019 1.799 -19.466 +089 mol08 -32.439 -4.421 -28.944 0.926 -16.742 +090 mol08 -33.056 -3.138 -31.632 1.714 -16.795 +091 mol09 -37.922 -11.009 -28.015 1.102 -14.514 +092 mol09 -33.961 -11.278 -28.396 5.713 -18.027 +093 mol09 -30.177 -6.085 -27.327 3.235 -11.667 +094 mol09 -36.755 -10.942 -27.524 1.710 -16.747 +095 mol09 -27.609 -3.028 -27.462 2.881 -5.874 +096 mol09 -29.025 -10.924 -25.192 7.091 -17.062 +097 mol09 -28.521 -6.851 -28.559 6.889 -12.872 +098 mol09 -37.849 -18.828 -26.348 7.327 -18.185 +099 mol09 -33.968 -11.233 -28.349 5.614 -17.982 +100 mol09 -37.434 -10.703 -28.080 1.348 -16.012 diff --git a/rdock-utils/tests/fixtures/rbhtfinder/input_tabs.txt b/rdock-utils/tests/fixtures/rbhtfinder/input_tabs.txt deleted file mode 100644 index 3b261f8f..00000000 --- a/rdock-utils/tests/fixtures/rbhtfinder/input_tabs.txt +++ /dev/null @@ -1,101 +0,0 @@ -REC _TITLE1 TOTAL INTER INTRA RESTR VDW -001 mol00 -16.905 -11.204 -6.416 0.715 -18.926 -002 mol00 2.595 -0.601 -1.152 4.347 -11.001 -003 mol00 -13.022 -12.572 -8.953 8.502 -20.443 -004 mol00 -16.128 -12.742 -8.977 5.591 -17.353 -005 mol00 -10.576 -4.606 -6.451 0.481 -16.707 -006 mol00 -18.429 -11.402 -8.179 1.152 -18.191 -007 mol00 -18.316 -12.749 -6.842 1.275 -21.002 -008 mol00 -13.123 -6.272 -9.001 2.150 -16.672 -009 mol00 -6.763 -7.234 -4.006 4.478 -15.995 -010 mol00 -16.302 -11.451 -5.042 0.192 -21.602 -011 mol01 -14.764 -12.244 -3.069 0.550 -16.362 -012 mol01 -8.102 -9.014 -2.509 3.421 -13.535 -013 mol01 -17.136 -13.983 -4.509 1.356 -15.128 -014 mol01 -10.791 -7.401 -4.334 0.944 -12.455 -015 mol01 -15.107 -11.770 -3.681 0.343 -12.760 -016 mol01 -15.348 -12.600 -3.085 0.337 -12.213 -017 mol01 -13.234 -9.356 -4.039 0.161 -13.449 -018 mol01 -12.883 -10.593 -2.692 0.401 -14.155 -019 mol01 -14.937 -12.053 -3.622 0.738 -16.503 -020 mol01 -15.504 -12.806 -3.140 0.442 -12.497 -021 mol02 -12.446 -11.333 -4.405 3.291 -15.701 -022 mol02 -13.334 -11.044 -2.708 0.418 -13.332 -023 mol02 -12.298 -8.953 -4.006 0.662 -13.422 -024 mol02 -10.855 -8.415 -3.033 0.593 -12.782 -025 mol02 -12.506 -9.802 -3.198 0.494 -14.579 -026 mol02 -13.582 -11.559 -2.422 0.399 -15.628 -027 mol02 -14.966 -11.346 -4.361 0.741 -16.671 -028 mol02 -15.302 -12.238 -3.389 0.324 -13.782 -029 mol02 -9.849 -9.111 -4.596 3.858 -14.011 -030 mol02 -13.621 -11.178 -2.870 0.427 -15.527 -031 mol03 -10.492 -8.634 -2.412 0.554 -12.702 -032 mol03 -16.369 -12.611 -3.925 0.166 -15.707 -033 mol03 -16.074 -12.018 -4.147 0.091 -14.921 -034 mol03 -6.623 -8.868 -2.337 4.582 -13.383 -035 mol03 -4.061 -4.354 -4.135 4.428 -11.803 -036 mol03 -16.844 -13.744 -3.429 0.329 -14.531 -037 mol03 -16.759 -14.229 -2.994 0.464 -15.433 -038 mol03 -15.680 -11.976 -3.889 0.185 -15.065 -039 mol03 -11.919 -9.693 -2.623 0.398 -14.239 -040 mol03 -8.137 -7.516 -3.235 2.614 -11.614 -041 mol04 -7.776 -6.296 -2.270 0.790 -16.535 -042 mol04 6.644 5.519 -0.566 1.691 -0.734 -043 mol04 -3.363 -7.773 0.964 3.446 -13.299 -044 mol04 -4.351 -4.121 -1.905 1.675 -11.049 -045 mol04 -2.875 -5.317 0.643 1.799 -13.852 -046 mol04 -7.823 -9.622 -0.031 1.830 -14.752 -047 mol04 -2.534 -1.876 -2.013 1.354 -10.910 -048 mol04 -13.193 -11.516 -2.048 0.371 -17.047 -049 mol04 -8.574 -9.947 1.073 0.301 -18.351 -050 mol04 -9.966 -9.181 -1.811 1.027 -14.498 -051 mol05 -5.717 -12.369 -0.344 6.997 -20.154 -052 mol05 -5.265 -9.689 0.036 4.387 -16.474 -053 mol05 -11.101 -9.229 -2.354 0.483 -17.823 -054 mol05 -3.375 -5.926 -1.281 3.832 -14.547 -055 mol05 -9.546 -12.438 -1.927 4.819 -17.671 -056 mol05 -12.771 -15.095 1.703 0.621 -17.161 -057 mol05 -19.198 -19.152 -0.788 0.743 -17.933 -058 mol05 -12.564 -13.726 -0.425 1.587 -19.786 -059 mol05 -3.387 -7.638 1.574 2.678 -16.308 -060 mol05 -14.882 -17.451 -0.477 3.045 -19.050 -061 mol06 -15.764 -17.717 0.853 1.101 -21.131 -062 mol06 -2.956 -7.275 0.313 4.006 -14.833 -063 mol06 -6.103 -12.909 2.281 4.526 -17.262 -064 mol06 1.370 -1.589 -0.619 3.579 -9.989 -065 mol06 0.980 -14.709 0.605 15.084 -20.358 -066 mol06 3.784 -6.808 8.337 2.255 -14.995 -067 mol06 -5.845 -12.679 2.130 4.704 -17.065 -068 mol06 -5.255 -12.309 4.456 2.598 -17.557 -069 mol06 -5.051 -8.500 -1.065 4.515 -12.298 -070 mol06 -8.737 -13.409 3.272 1.400 -17.974 -071 mol07 -5.945 -6.564 -0.932 1.551 -15.670 -072 mol07 -11.177 -12.429 -1.525 2.777 -15.118 -073 mol07 -3.446 -1.734 -2.958 1.246 -7.623 -074 mol07 -4.229 -5.796 -0.264 1.831 -14.220 -075 mol07 -14.958 -15.847 -0.333 1.222 -18.956 -076 mol07 -8.390 -8.507 -0.927 1.045 -14.022 -077 mol07 -5.093 -5.862 -1.992 2.761 -15.437 -078 mol07 -9.813 -12.418 -0.122 2.726 -17.489 -079 mol07 -10.936 -10.623 -1.940 1.626 -16.272 -080 mol07 -2.593 -7.660 3.906 1.162 -10.076 -081 mol08 -30.625 -10.460 -24.533 4.369 -21.331 -082 mol08 -34.896 -10.897 -28.333 4.334 -24.000 -083 mol08 -37.535 -5.959 -32.574 0.998 -17.627 -084 mol08 -24.337 -1.398 -32.330 9.391 -13.655 -085 mol08 -33.982 -6.759 -29.808 2.584 -20.003 -086 mol08 -22.908 -5.812 -32.172 15.076 -17.519 -087 mol08 -10.119 5.962 -25.259 9.178 -7.399 -088 mol08 -36.286 -7.066 -31.019 1.799 -19.466 -089 mol08 -32.439 -4.421 -28.944 0.926 -16.742 -090 mol08 -33.056 -3.138 -31.632 1.714 -16.795 -091 mol09 -37.922 -11.009 -28.015 1.102 -14.514 -092 mol09 -33.961 -11.278 -28.396 5.713 -18.027 -093 mol09 -30.177 -6.085 -27.327 3.235 -11.667 -094 mol09 -36.755 -10.942 -27.524 1.710 -16.747 -095 mol09 -27.609 -3.028 -27.462 2.881 -5.874 -096 mol09 -29.025 -10.924 -25.192 7.091 -17.062 -097 mol09 -28.521 -6.851 -28.559 6.889 -12.872 -098 mol09 -37.849 -18.828 -26.348 7.327 -18.185 -099 mol09 -33.968 -11.233 -28.349 5.614 -17.982 -100 mol09 -37.434 -10.703 -28.080 1.348 -16.012 diff --git a/rdock-utils/tests/fixtures/rbhtfinder/original_output.txt b/rdock-utils/tests/fixtures/rbhtfinder/original_output.txt deleted file mode 100644 index eba0b51b..00000000 --- a/rdock-utils/tests/fixtures/rbhtfinder/original_output.txt +++ /dev/null @@ -1,7 +0,0 @@ -FILTER1 NSTEPS1 THR1 PERC1 FILTER2 NSTEPS2 THR2 PERC2 TOP5_INTER ENRICH_INTER TOP5_RESTR ENRICH_RESTR TIME -INTER 3 -10.00 90.00 RESTR 5 1.00 60.00 40.00 0.67 80.00 1.33 0.7800 -INTER 3 -10.00 90.00 RESTR 5 6.00 90.00 100.00 1.11 80.00 0.89 0.9300 -INTER 3 -5.00 100.00 RESTR 5 1.00 70.00 40.00 0.57 100.00 1.43 0.8500 -INTER 3 -5.00 100.00 RESTR 5 6.00 100.00 100.00 1.00 100.00 1.00 1.0000 -INTER 3 0.00 100.00 RESTR 5 1.00 70.00 40.00 0.57 100.00 1.43 0.8500 -INTER 3 0.00 100.00 RESTR 5 6.00 100.00 100.00 1.00 100.00 1.00 1.0000 diff --git a/rdock-utils/tests/rbhtfinder/conftest.py b/rdock-utils/tests/rbhtfinder/conftest.py index c8f4ea98..381957db 100644 --- a/rdock-utils/tests/rbhtfinder/conftest.py +++ b/rdock-utils/tests/rbhtfinder/conftest.py @@ -1,22 +1,24 @@ +from pathlib import Path + import pytest from ..conftest import FIXTURES_FOLDER RBHTFINDER_FIXTURES_FOLDER = FIXTURES_FOLDER / "rbhtfinder" -INPUT_FILE = str(RBHTFINDER_FIXTURES_FOLDER / "input_tabs.txt") +INPUT_FILE = str(RBHTFINDER_FIXTURES_FOLDER / "input.txt") THRESHOLD_FILE = str(RBHTFINDER_FIXTURES_FOLDER / "threshold.txt") EXPECTED_OUTPUT_FILE = str(RBHTFINDER_FIXTURES_FOLDER / "output.txt") @pytest.fixture -def file_path(tmp_path): +def file_path(tmp_path: Path) -> Path: output_path = tmp_path / "output.txt" return output_path @pytest.fixture -def argv(file_path): +def argv(file_path: Path) -> list[str]: return [ "-i", INPUT_FILE, @@ -37,6 +39,6 @@ def argv(file_path): ] -def get_file_content(file: str) -> list[str]: +def get_file_content(file: str | Path) -> str: with open(file, "r") as f: - return f.readlines() + return f.read() diff --git a/rdock-utils/tests/rbhtfinder/test_integration.py b/rdock-utils/tests/rbhtfinder/test_integration.py index f0f092a7..07045efb 100644 --- a/rdock-utils/tests/rbhtfinder/test_integration.py +++ b/rdock-utils/tests/rbhtfinder/test_integration.py @@ -1,3 +1,6 @@ +from pathlib import Path +from typing import Callable + import pytest from rdock_utils.rbhtfinder.main import main as rbhtfinder_main @@ -15,13 +18,13 @@ @parametrize_main -def test_do_nothing(main): +def test_do_nothing(main: Callable[[list[str]], None]): with pytest.raises(SystemExit): main() @parametrize_main -def test_integration(main, file_path, argv): +def test_integration(main: Callable[[list[str]], None], file_path: Path, argv: list[str]): main(argv) result = get_file_content(file_path) expected_result = get_file_content(EXPECTED_OUTPUT_FILE) From 446461b2815f2fca205341d3a31c2e337e6bfed6 Mon Sep 17 00:00:00 2001 From: lpardey Date: Fri, 19 Jul 2024 17:06:07 +0000 Subject: [PATCH 07/18] add main module add rbhtfinder class --- rdock-utils/rdock_utils/rbhtfinder/main.py | 321 +----------------- rdock-utils/rdock_utils/rbhtfinder/parser.py | 8 +- .../rdock_utils/rbhtfinder/rbhtfinder.py | 313 +++++++++++++++++ 3 files changed, 320 insertions(+), 322 deletions(-) diff --git a/rdock-utils/rdock_utils/rbhtfinder/main.py b/rdock-utils/rdock_utils/rbhtfinder/main.py index 33bd6084..68264f20 100644 --- a/rdock-utils/rdock_utils/rbhtfinder/main.py +++ b/rdock-utils/rdock_utils/rbhtfinder/main.py @@ -1,326 +1,11 @@ -import numpy as np - -try: - import pandas as pd -except ImportError: - pd = None -import itertools -import multiprocessing -import os -from collections import Counter -from functools import partial - from .parser import get_config - -Filter = dict[str, float] - - -def apply_threshold(scored_poses, column, steps, threshold): - """ - Filter out molecules from `scored_poses`, where the minimum score reached (for a specified `column`) after `steps` is more negative than `threshold`. - """ - # minimum score after `steps` per molecule - mins = np.min(scored_poses[:, :steps, column], axis=1) - # return those molecules where the minimum score is less than the threshold - passing_molecules = np.where(mins < threshold)[0] - return passing_molecules - - -def prepare_array(sdreport_array: np.ndarray, name_column: int) -> np.ndarray: - """ - Convert `sdreport_array` (read directly from the tsv) to 3D array (molecules x poses x columns) and filter out molecules with too few/many poses - """ - # print(sdreport_array.shape[1]) - # if name_column >= sdreport_array.shape[1]: - # raise IndexError( - # f"name_column index {name_column} is out of bounds for array with shape {sdreport_array.shape}" - # ) - - # find points in the array where the name_column changes (i.e. we are dealing with a new molecule) and split the array - split_indices = ( - np.where( - sdreport_array[:, name_column] - != np.hstack((sdreport_array[1:, name_column], sdreport_array[0, name_column])) - )[0] - + 1 - ) - split_array = np.split(sdreport_array, split_indices) - - modal_shape = Counter([n.shape for n in split_array]).most_common(1)[0] - number_of_poses = modal_shape[0][0] # find modal number of poses per molecule in the array - - split_array_clean = sum( - [ - np.array_split(n, n.shape[0] / number_of_poses) - for n in split_array - if not n.shape[0] % number_of_poses and n.shape[0] - ], - [], - ) - - if len(split_array_clean) * number_of_poses < sdreport_array.shape[0] * 0.99: - print( - f"WARNING: the number of poses provided per molecule is inconsistent. Only {len(split_array_clean)} of {int(sdreport_array.shape[0] / number_of_poses)} moleules have {number_of_poses} poses." - ) - - molecule_array = np.array(split_array_clean) - # overwrite the name column (should be the only one with dtype=str) so we can force everything to float - molecule_array[:, :, name_column] = 0 - return np.array(molecule_array, dtype=float) - - -def calculate_results_for_filter_combination( - filter_combination, - molecule_array, - filters, - min_score_indices, - number_of_validation_mols, -): - """ - For a particular combination of filters, calculate the percentage of molecules that will be filtered, the percentage of top-scoring molecules that will be filtered, and the time taken relative to exhaustive docking - """ - # mols_passed_threshold is a list of indices of molecules which have passed the applied filters. As more filters are applied, it gets smaller. Before any iteration, we initialise with all molecules passing - mols_passed_threshold = list(range(molecule_array.shape[0])) - filter_percentages = [] - number_of_simulated_poses = 0 # number of poses which we calculate would be generated, we use this to calculate the TIME column in the final output - for n, threshold in enumerate(filter_combination): - if n: - # e.g. if there are 5000 mols left after 15 steps and the last filter was at 5 steps, append 5000 * (15 - 5) to number_of_simulated_poses - number_of_simulated_poses += len(mols_passed_threshold) * (filters[n]["steps"] - filters[n - 1]["steps"]) - else: - number_of_simulated_poses += len(mols_passed_threshold) * filters[n]["steps"] - mols_passed_threshold = [ # all mols which pass the threshold and which were already in mols_passed_threshold, i.e. passed all previous filters - n - for n in apply_threshold(molecule_array, filters[n]["column"], filters[n]["steps"], threshold) - if n in mols_passed_threshold - ] - filter_percentages.append(len(mols_passed_threshold) / molecule_array.shape[0]) - number_of_simulated_poses += len(mols_passed_threshold) * (molecule_array.shape[1] - filters[-1]["steps"]) - perc_val = { - k: len([n for n in v if n in mols_passed_threshold]) / number_of_validation_mols - for k, v in min_score_indices.items() - } - return { - "filter_combination": filter_combination, - "perc_val": perc_val, - "filter_percentages": filter_percentages, - "time": number_of_simulated_poses / np.product(molecule_array.shape[:2]), - } - - -def write_output(results, filters, number_of_validation_mols, output_file, column_names): - """ - Print results as a table. The number of columns varies depending how many columns the user picked. - """ - with open(output_file, "w") as f: - # write header - for n in range(len(results[0]["filter_combination"])): - f.write(f"FILTER{n+1}\tNSTEPS{n+1}\tTHR{n+1}\tPERC{n+1}\t") - for n in results[0]["perc_val"]: - f.write(f"TOP{number_of_validation_mols}_{column_names[n]}\t") - f.write(f"ENRICH_{column_names[n]}\t") - f.write("TIME\n") - - # write results - for result in results: - for n, threshold in enumerate(result["filter_combination"]): - f.write( - f"{column_names[filters[n]['column']]}\t{filters[n]['steps']}\t{threshold:.2f}\t{result['filter_percentages'][n]*100:.2f}\t" - ) - for n in result["perc_val"]: - f.write(f"{result['perc_val'][n]*100:.2f}\t") - if result["filter_percentages"][-1]: - f.write(f"{result['perc_val'][n]/result['filter_percentages'][-1]:.2f}\t") - else: - f.write("NaN\t") - f.write(f"{result['time']:.4f}\n") - return - - -def select_best_filter_combination(results, max_time, min_perc): - """ - Very debatable how to do this... - Here we exclude all combinations with TIME < max_time and calculate an "enrichment factor" - (= percentage of validation compounds / percentage of all compounds); we select the - threshold with the highest enrichment factor - """ - min_max_values = {} - for col in results[0]["perc_val"].keys(): - vals = [result["perc_val"][col] for result in results] - min_max_values[col] = {"min": min(vals), "max": max(vals)} - time_vals = [result["time"] for result in results] - min_max_values["time"] = {"min": min(time_vals), "max": max(time_vals)} - - combination_scores = [ - sum( - [ - ( - (result["perc_val"][col] - min_max_values[col]["min"]) - / (min_max_values[col]["max"] - min_max_values[col]["min"]) - ) - for col in results[0]["perc_val"].keys() - ] - + [ - (min_max_values["time"]["max"] - result["time"]) - / (min_max_values["time"]["max"] - min_max_values["time"]["min"]) - ] - ) - if result["time"] < max_time and result["filter_percentages"][-1] >= min_perc / 100 - else 0 - for result in results - ] - return np.argmax(combination_scores) - - -def write_threshold_file(filters, best_filter_combination, threshold_file, column_names, max_number_of_runs): - with open(threshold_file, "w") as f: - # write number of filters to apply - f.write(f"{len(filters) + 1}\n") - # write each filter to a separate line - for n, filtr in enumerate(filters): - f.write( - f'if - {best_filter_combination[n]:.2f} {column_names[filtr["column"]]} 1.0 if - SCORE.NRUNS {filtr["steps"]} 0.0 -1.0,\n' - ) - # write filter to terminate docking when NRUNS reaches the number of runs used in the input file - f.write(f"if - SCORE.NRUNS {max_number_of_runs - 1} 0.0 -1.0\n") - - # write final filters - find strictest filters for all columns and apply them again - filters_by_column = { - col: [best_filter_combination[n] for n, filtr in enumerate(filters) if filtr["column"] == col] - for col in set([filtr["column"] for filtr in filters]) - } - # write number of filters (same as number of columns filtered on) - f.write(f"{len(filters_by_column)}\n") - # write filter - for col, values in filters_by_column.items(): - f.write(f"- {column_names[col]} {min(values)},\n") - - -def generate_all_filter_combinations(filters: list[Filter]) -> list[tuple]: - filter_ranges = ((filter["min"], filter["max"] + filter["interval"], filter["interval"]) for filter in filters) - combinations = (np.arange(*fr) for fr in filter_ranges) - all_filter_combinations = list(itertools.product(*combinations)) - return all_filter_combinations - - -def remove_redundant_combinations(all_combinations: list[tuple], filters: list[Filter]) -> list[tuple]: - all_combinations_array = np.array(all_combinations) - columns = [filter["column"] for filter in filters] - indices_per_col = {col: [i for i, c in enumerate(columns) if c == col] for col in set(columns)} - - # Create a mask to keep only valid combinations - mask = np.ones(len(all_combinations_array), dtype=bool) - - for _, indices in indices_per_col.items(): - col_data = all_combinations_array[:, indices] - sorted_data = np.sort(col_data, axis=1)[:, ::-1] # sort descending - is_valid = np.all(col_data == sorted_data, axis=1) # Check if sorted matches original - is_unique = np.apply_along_axis(lambda x: len(set(x)) == len(x), 1, col_data) - mask &= is_valid & is_unique - - filtered_combinations = all_combinations_array[mask] - return filtered_combinations +from .rbhtfinder import RBHTFinder def main(argv: list[str] | None = None) -> None: config = get_config(argv) - - # generates all possible combinations from filters provided - # filter_ranges = ( - # (filter["min"], filter["max"] + filter["interval"], filter["interval"]) for filter in config.filters - # ) - # combinations = (np.arange(*fr) for fr in filter_ranges) - # filter_combinations = list(itertools.product(*combinations)) - all_filter_combinations = generate_all_filter_combinations(config.filters) - print(f"{len(all_filter_combinations)} combinations of filters calculated.") - - # remove redundant combinations, i.e. where filters for later steps are less or equally strict to earlier steps - # filter_combinations_array = np.array(all_filter_combinations) - # columns = [filter["column"] for filter in config.filters] - # indices_per_col = {col: [i for i, c in enumerate(columns) if c == col] for col in set(columns)} - # # Create a mask to keep only valid combinations - # mask = np.ones(len(filter_combinations_array), dtype=bool) - - # for _, indices in indices_per_col.items(): - # col_data = filter_combinations_array[:, indices] - # sorted_data = np.sort(col_data, axis=1)[:, ::-1] # sort descending - # is_valid = np.all(col_data == sorted_data, axis=1) # Check if sorted matches original - # is_unique = np.apply_along_axis(lambda x: len(set(x)) == len(x), 1, col_data) - # mask &= is_valid & is_unique - - # cleaned_filter_combinations = filter_combinations_array[mask] - - distinct_filter_combinations = remove_redundant_combinations(all_filter_combinations, config.filters) - - if len(distinct_filter_combinations): - print( - f"{len(distinct_filter_combinations)} combinations of filters remain after removal of redundant combinations. Starting calculations..." - ) - else: - print("No filter combinations could be calculated - check the thresholds specified.") - exit(1) - - if pd: - # pandas is weird... i.e., skip line 0 if there's a header, else read all lines - header = 0 if config.header else None - sdreport_dataframe = pd.read_csv(config.input, sep="\t", header=header) - if config.header: - column_names = sdreport_dataframe.columns.values - else: - # use index names; add 1 to deal with zero-based numbering - column_names = [f"COL{n+1}" for n in range(len(sdreport_dataframe.columns))] - sdreport_array = sdreport_dataframe.values - print(f"First few rows of the input array:\n{sdreport_array[:5]}") - else: # pd not available - np_array = np.loadtxt(config.input, dtype=str) - if config.header: - column_names = np_array[0] - sdreport_array = np_array[1:] - else: - column_names = [f"COL{n+1}" for n in range(np_array.shape[1])] - sdreport_array = np_array - print("Data read in from input file.") - - # convert to 3D array (molecules x poses x columns) - molecule_array = prepare_array(sdreport_array, config.name) - - # find the top scoring compounds for validation of the filter combinations - min_score_indices = {} - for column in set(filtr["column"] for filtr in config.filters): - min_scores = np.min(molecule_array[:, :, column], axis=1) - min_score_indices[column] = np.argpartition(min_scores, config.validation)[: config.validation] - - results = [] - - pool = multiprocessing.Pool(os.cpu_count()) - results = pool.map( - partial( - calculate_results_for_filter_combination, - molecule_array=molecule_array, - filters=config.filters, - min_score_indices=min_score_indices, - number_of_validation_mols=config.validation, - ), - distinct_filter_combinations, - ) - - write_output(results, config.filters, config.validation, config.output, column_names) - - best_filter_combination = select_best_filter_combination(results, config.max_time, config.min_percentage) - if config.threshold: - if best_filter_combination: - write_threshold_file( - config.filters, - distinct_filter_combinations[best_filter_combination], - config.threshold, - column_names, - molecule_array.shape[1], - ) - else: - print( - "Filter combinations defined are too strict or would take too long to run; no threshold file was written." - ) - exit(1) + rbhtfinder = RBHTFinder(config) + rbhtfinder.run() if __name__ == "__main__": diff --git a/rdock-utils/rdock_utils/rbhtfinder/parser.py b/rdock-utils/rdock_utils/rbhtfinder/parser.py index 3dba4508..946fcf95 100644 --- a/rdock-utils/rdock_utils/rbhtfinder/parser.py +++ b/rdock-utils/rdock_utils/rbhtfinder/parser.py @@ -5,7 +5,7 @@ @dataclass -class rbhtfinderConfig: +class RBHTFinderConfig: input: str output: str threshold: str @@ -31,7 +31,7 @@ def _parse_filter(filter_str: str) -> Filter: for item in filter_str.split(","): key, value = item.split("=") - parsed_filter[key] = float(value) if key in ["interval", "min", "max"] else int(value) + parsed_filter[key] = float(value) if key in ("interval", "min", "max") else int(value) # User inputs with 1-based numbering whereas python uses 0-based parsed_filter["column"] -= 1 @@ -127,10 +127,10 @@ def get_parser() -> argparse.ArgumentParser: return parser -def get_config(argv: list[str] | None = None) -> rbhtfinderConfig: +def get_config(argv: list[str] | None = None) -> RBHTFinderConfig: parser = get_parser() args = parser.parse_args(argv) - return rbhtfinderConfig( + return RBHTFinderConfig( input=args.input, output=args.output, threshold=args.threshold, diff --git a/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py b/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py index e69de29b..f9699abf 100644 --- a/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py +++ b/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py @@ -0,0 +1,313 @@ +import itertools +import logging +import multiprocessing +import os +from collections import Counter +from functools import partial + +import numpy as np +import pandas as pd + +from .parser import Filter, RBHTFinderConfig + +logger = logging.getLogger("RBHTFinder") + +InputData = tuple[np.ndarray, list[str]] + + +class RBHTFinder: + def __init__(self, config: RBHTFinderConfig) -> None: + self.config = config + + def run(self) -> None: + filters_combinations = self.generate_filters_combinations(self.config.filters) + print(f"{len(filters_combinations)} combinations of filters calculated.") + distinct_combinations = self.remove_redundant_combinations(filters_combinations, self.config.filters) + + if len(distinct_combinations) == 0: + raise RuntimeError("No filter combinations could be calculated - check the thresholds specified.") + + print( + f"{len(distinct_combinations)} combinations of filters remain after removal of redundant combinations. " + "Starting calculations..." + ) + sdreport_array, column_names = self.read_data() + print(f"First few rows of the input array:\n{sdreport_array[:5]}") + print("Data read in from input file.") + + # convert to 3D array (molecules x poses x columns) + molecule_array = self.prepare_array(sdreport_array, self.config.name) + + # find the top scoring compounds for validation of the filter combinations + min_score_indices = {} + for column in set(filtr["column"] for filtr in self.config.filters): + min_scores = np.min(molecule_array[:, :, column], axis=1) + min_score_indices[column] = np.argpartition(min_scores, self.config.validation)[: self.config.validation] + + results = [] + + pool = multiprocessing.Pool(os.cpu_count()) + results = pool.map( + partial( + self.calculate_results_for_filter_combination, + molecule_array=molecule_array, + filters=self.config.filters, + min_score_indices=min_score_indices, + number_of_validation_mols=self.config.validation, + ), + distinct_combinations, + ) + + self.write_output(results, self.config.filters, self.config.validation, self.config.output, column_names) + + best_filter_combination = self.select_best_filter_combination( + results, self.config.max_time, self.config.min_percentage + ) + if self.config.threshold: + if best_filter_combination: + self.write_threshold_file( + self.config.filters, + distinct_combinations[best_filter_combination], + self.config.threshold, + column_names, + molecule_array.shape[1], + ) + else: + print( + "Filter combinations defined are too strict or would take too long to run; no threshold file was written." + ) + exit(1) + + def read_data(self) -> InputData: + try: + data_array, column_names = self.read_data_using_pandas() + except Exception as e: + logging.error(f"Error reading data with pandas: {e}") + data_array, column_names = self.read_data_using_numpy() + return data_array, column_names + + def read_data_using_pandas(self) -> InputData: + sdreport_dataframe = pd.read_csv(self.config.input, sep="\t", header=0 if self.config.header else None) + + if self.config.header: + column_names = sdreport_dataframe.columns.values + else: + # use index names; add 1 to deal with zero-based numbering + column_names = [f"COL{n+1}" for n in range(len(sdreport_dataframe.columns))] + + sdreport_array = sdreport_dataframe.values + return sdreport_array, column_names + + def read_data_using_numpy(self) -> InputData: + np_array = np.loadtxt(self.config.input, dtype=str) + + if self.config.header: + column_names = np_array[0] + sdreport_array = np_array[1:] + else: + column_names = [f"COL{n+1}" for n in range(np_array.shape[1])] + sdreport_array = np_array + + return sdreport_array, column_names + + def apply_threshold(self, scored_poses, column, steps, threshold): + """ + Filter out molecules from `scored_poses`, where the minimum score reached (for a specified `column`) after `steps` is more negative than `threshold`. + """ + # minimum score after `steps` per molecule + mins = np.min(scored_poses[:, :steps, column], axis=1) + # return those molecules where the minimum score is less than the threshold + passing_molecules = np.where(mins < threshold)[0] + return passing_molecules + + def prepare_array(self, sdreport_array: np.ndarray, name_column: int) -> np.ndarray: + """ + Convert `sdreport_array` (read directly from the tsv) to 3D array (molecules x poses x columns) and filter out molecules with too few/many poses + """ + # print(sdreport_array.shape[1]) + # if name_column >= sdreport_array.shape[1]: + # raise IndexError( + # f"name_column index {name_column} is out of bounds for array with shape {sdreport_array.shape}" + # ) + + # find points in the array where the name_column changes (i.e. we are dealing with a new molecule) and split the array + split_indices = ( + np.where( + sdreport_array[:, name_column] + != np.hstack((sdreport_array[1:, name_column], sdreport_array[0, name_column])) + )[0] + + 1 + ) + split_array = np.split(sdreport_array, split_indices) + + modal_shape = Counter([n.shape for n in split_array]).most_common(1)[0] + number_of_poses = modal_shape[0][0] # find modal number of poses per molecule in the array + + split_array_clean = sum( + [ + np.array_split(n, n.shape[0] / number_of_poses) + for n in split_array + if not n.shape[0] % number_of_poses and n.shape[0] + ], + [], + ) + + if len(split_array_clean) * number_of_poses < sdreport_array.shape[0] * 0.99: + print( + f"WARNING: the number of poses provided per molecule is inconsistent. Only {len(split_array_clean)} of {int(sdreport_array.shape[0] / number_of_poses)} moleules have {number_of_poses} poses." + ) + + molecule_array = np.array(split_array_clean) + # overwrite the name column (should be the only one with dtype=str) so we can force everything to float + molecule_array[:, :, name_column] = 0 + return np.array(molecule_array, dtype=float) + + def calculate_results_for_filter_combination( + self, + filter_combination, + molecule_array, + filters, + min_score_indices, + number_of_validation_mols, + ): + """ + For a particular combination of filters, calculate the percentage of molecules that will be filtered, the percentage of top-scoring molecules that will be filtered, and the time taken relative to exhaustive docking + """ + # mols_passed_threshold is a list of indices of molecules which have passed the applied filters. As more filters are applied, it gets smaller. Before any iteration, we initialise with all molecules passing + mols_passed_threshold = list(range(molecule_array.shape[0])) + filter_percentages = [] + number_of_simulated_poses = 0 # number of poses which we calculate would be generated, we use this to calculate the TIME column in the final output + for n, threshold in enumerate(filter_combination): + if n: + # e.g. if there are 5000 mols left after 15 steps and the last filter was at 5 steps, append 5000 * (15 - 5) to number_of_simulated_poses + number_of_simulated_poses += len(mols_passed_threshold) * ( + filters[n]["steps"] - filters[n - 1]["steps"] + ) + else: + number_of_simulated_poses += len(mols_passed_threshold) * filters[n]["steps"] + mols_passed_threshold = [ # all mols which pass the threshold and which were already in mols_passed_threshold, i.e. passed all previous filters + n + for n in self.apply_threshold(molecule_array, filters[n]["column"], filters[n]["steps"], threshold) + if n in mols_passed_threshold + ] + filter_percentages.append(len(mols_passed_threshold) / molecule_array.shape[0]) + number_of_simulated_poses += len(mols_passed_threshold) * (molecule_array.shape[1] - filters[-1]["steps"]) + perc_val = { + k: len([n for n in v if n in mols_passed_threshold]) / number_of_validation_mols + for k, v in min_score_indices.items() + } + return { + "filter_combination": filter_combination, + "perc_val": perc_val, + "filter_percentages": filter_percentages, + "time": number_of_simulated_poses / np.product(molecule_array.shape[:2]), + } + + def write_output(self, results, filters, number_of_validation_mols, output_file, column_names): + """ + Print results as a table. The number of columns varies depending how many columns the user picked. + """ + with open(output_file, "w") as f: + # write header + for n in range(len(results[0]["filter_combination"])): + f.write(f"FILTER{n+1}\tNSTEPS{n+1}\tTHR{n+1}\tPERC{n+1}\t") + for n in results[0]["perc_val"]: + f.write(f"TOP{number_of_validation_mols}_{column_names[n]}\t") + f.write(f"ENRICH_{column_names[n]}\t") + f.write("TIME\n") + + # write results + for result in results: + for n, threshold in enumerate(result["filter_combination"]): + f.write( + f"{column_names[filters[n]['column']]}\t{filters[n]['steps']}\t{threshold:.2f}\t{result['filter_percentages'][n]*100:.2f}\t" + ) + for n in result["perc_val"]: + f.write(f"{result['perc_val'][n]*100:.2f}\t") + if result["filter_percentages"][-1]: + f.write(f"{result['perc_val'][n]/result['filter_percentages'][-1]:.2f}\t") + else: + f.write("NaN\t") + f.write(f"{result['time']:.4f}\n") + return + + def select_best_filter_combination(self, results, max_time, min_perc): + """ + Very debatable how to do this... + Here we exclude all combinations with TIME < max_time and calculate an "enrichment factor" + (= percentage of validation compounds / percentage of all compounds); we select the + threshold with the highest enrichment factor + """ + min_max_values = {} + for col in results[0]["perc_val"].keys(): + vals = [result["perc_val"][col] for result in results] + min_max_values[col] = {"min": min(vals), "max": max(vals)} + time_vals = [result["time"] for result in results] + min_max_values["time"] = {"min": min(time_vals), "max": max(time_vals)} + + combination_scores = [ + sum( + [ + ( + (result["perc_val"][col] - min_max_values[col]["min"]) + / (min_max_values[col]["max"] - min_max_values[col]["min"]) + ) + for col in results[0]["perc_val"].keys() + ] + + [ + (min_max_values["time"]["max"] - result["time"]) + / (min_max_values["time"]["max"] - min_max_values["time"]["min"]) + ] + ) + if result["time"] < max_time and result["filter_percentages"][-1] >= min_perc / 100 + else 0 + for result in results + ] + return np.argmax(combination_scores) + + def write_threshold_file(self, filters, best_filter_combination, threshold_file, column_names, max_number_of_runs): + with open(threshold_file, "w") as f: + # write number of filters to apply + f.write(f"{len(filters) + 1}\n") + # write each filter to a separate line + for n, filtr in enumerate(filters): + f.write( + f'if - {best_filter_combination[n]:.2f} {column_names[filtr["column"]]} 1.0 if - SCORE.NRUNS {filtr["steps"]} 0.0 -1.0,\n' + ) + # write filter to terminate docking when NRUNS reaches the number of runs used in the input file + f.write(f"if - SCORE.NRUNS {max_number_of_runs - 1} 0.0 -1.0\n") + + # write final filters - find strictest filters for all columns and apply them again + filters_by_column = { + col: [best_filter_combination[n] for n, filtr in enumerate(filters) if filtr["column"] == col] + for col in set([filtr["column"] for filtr in filters]) + } + # write number of filters (same as number of columns filtered on) + f.write(f"{len(filters_by_column)}\n") + # write filter + for col, values in filters_by_column.items(): + f.write(f"- {column_names[col]} {min(values)},\n") + + def generate_filters_combinations(self, filters: list[Filter]) -> list[tuple]: + filter_ranges = ((filter["min"], filter["max"] + filter["interval"], filter["interval"]) for filter in filters) + combinations = (np.arange(*range) for range in filter_ranges) + filters_combinations = list(itertools.product(*combinations)) + return filters_combinations + + def remove_redundant_combinations(self, all_combinations: list[tuple], filters: list[Filter]) -> list[tuple]: + all_combinations_array = np.array(all_combinations) + columns = [filter["column"] for filter in filters] + indices_per_col = {col: [i for i, c in enumerate(columns) if c == col] for col in set(columns)} + + # Create a mask to keep only valid combinations + mask = np.ones(len(all_combinations_array), dtype=bool) + + for _, indices in indices_per_col.items(): + col_data = all_combinations_array[:, indices] + sorted_data = np.sort(col_data, axis=1)[:, ::-1] # sort descending + is_valid = np.all(col_data == sorted_data, axis=1) # Check if sorted matches original + is_unique = np.apply_along_axis(lambda x: len(set(x)) == len(x), 1, col_data) + mask &= is_valid & is_unique + + valid_combinations = all_combinations_array[mask] + return valid_combinations From fa660ae53098b694cf178578692a44f865afcf6e Mon Sep 17 00:00:00 2001 From: lpardey Date: Mon, 22 Jul 2024 17:27:17 +0000 Subject: [PATCH 08/18] refactor (WIP) --- rdock-utils/pyproject.toml | 20 ++- .../rdock_utils/rbhtfinder/rbhtfinder.py | 124 +++++++++--------- rdock-utils/requirements-dev.txt | 2 +- 3 files changed, 76 insertions(+), 70 deletions(-) diff --git a/rdock-utils/pyproject.toml b/rdock-utils/pyproject.toml index a3d96bde..980531a8 100644 --- a/rdock-utils/pyproject.toml +++ b/rdock-utils/pyproject.toml @@ -6,8 +6,8 @@ description = "Utilities for working with RDock and operating on SD files" requires-python = ">=3.10.0" [tool.setuptools.dynamic] -dependencies = {file = ["requirements.txt"]} -optional-dependencies = { dev = {file = ["requirements-dev.txt"]} } +dependencies = { file = ["requirements.txt"] } +optional-dependencies = { dev = { file = ["requirements-dev.txt"] } } [project.scripts] sdfield = "rdock_utils.sdfield:main" @@ -26,11 +26,16 @@ Repository = "https://github.com/CBDD/rDock.git" [tool.ruff] line-length = 119 target-version = "py312" -exclude = [".git", "__pycache__", "rdock_utils/sdrmsd_original.py", "rdock_utils/sdtether_original.py"] +exclude = [ + ".git", + "__pycache__", + "rdock_utils/sdrmsd_original.py", + "rdock_utils/sdtether_original.py", +] [tool.ruff.lint] select = ["E4", "E7", "E9", "F", "I"] -ignore = ["E231","E501","E203"] +ignore = ["E231", "E501", "E203"] [tool.ruff.format] quote-style = "double" @@ -67,4 +72,9 @@ no_implicit_reexport = false strict_equality = true -exclude = ["build/*", "rdock_utils/sdrmsd_original.py", "tests/", "rdock_utils/sdtether_original.py"] +exclude = [ + "build/*", + "rdock_utils/sdrmsd_original.py", + "tests/", + "rdock_utils/sdtether_original.py", +] diff --git a/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py b/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py index f9699abf..e415e116 100644 --- a/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py +++ b/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py @@ -3,6 +3,7 @@ import multiprocessing import os from collections import Counter +from dataclasses import dataclass from functools import partial import numpy as np @@ -15,6 +16,14 @@ InputData = tuple[np.ndarray, list[str]] +@dataclass +class RBHTResult: + filter_combination: np.ndarray + perc_val: dict[str, int] + filter_percentages: list[int] + time: float + + class RBHTFinder: def __init__(self, config: RBHTFinderConfig) -> None: self.config = config @@ -39,21 +48,17 @@ def run(self) -> None: molecule_array = self.prepare_array(sdreport_array, self.config.name) # find the top scoring compounds for validation of the filter combinations - min_score_indices = {} - for column in set(filtr["column"] for filtr in self.config.filters): + min_score_indices: dict[float, np.ndarray] = {} + for column in set(filter["column"] for filter in self.config.filters): min_scores = np.min(molecule_array[:, :, column], axis=1) min_score_indices[column] = np.argpartition(min_scores, self.config.validation)[: self.config.validation] - results = [] - pool = multiprocessing.Pool(os.cpu_count()) results = pool.map( partial( self.calculate_results_for_filter_combination, molecule_array=molecule_array, - filters=self.config.filters, min_score_indices=min_score_indices, - number_of_validation_mols=self.config.validation, ), distinct_combinations, ) @@ -120,29 +125,19 @@ def apply_threshold(self, scored_poses, column, steps, threshold): passing_molecules = np.where(mins < threshold)[0] return passing_molecules - def prepare_array(self, sdreport_array: np.ndarray, name_column: int) -> np.ndarray: + def prepare_array(self, data_array: np.ndarray, name_column: int) -> np.ndarray: """ Convert `sdreport_array` (read directly from the tsv) to 3D array (molecules x poses x columns) and filter out molecules with too few/many poses """ - # print(sdreport_array.shape[1]) - # if name_column >= sdreport_array.shape[1]: - # raise IndexError( - # f"name_column index {name_column} is out of bounds for array with shape {sdreport_array.shape}" - # ) - - # find points in the array where the name_column changes (i.e. we are dealing with a new molecule) and split the array split_indices = ( np.where( - sdreport_array[:, name_column] - != np.hstack((sdreport_array[1:, name_column], sdreport_array[0, name_column])) + data_array[:, name_column] != np.hstack((data_array[1:, name_column], data_array[0, name_column])) )[0] + 1 ) - split_array = np.split(sdreport_array, split_indices) - + split_array = np.split(data_array, split_indices) modal_shape = Counter([n.shape for n in split_array]).most_common(1)[0] - number_of_poses = modal_shape[0][0] # find modal number of poses per molecule in the array - + number_of_poses = modal_shape[0][0] # Find modal number of poses per molecule in the array split_array_clean = sum( [ np.array_split(n, n.shape[0] / number_of_poses) @@ -152,24 +147,22 @@ def prepare_array(self, sdreport_array: np.ndarray, name_column: int) -> np.ndar [], ) - if len(split_array_clean) * number_of_poses < sdreport_array.shape[0] * 0.99: - print( - f"WARNING: the number of poses provided per molecule is inconsistent. Only {len(split_array_clean)} of {int(sdreport_array.shape[0] / number_of_poses)} moleules have {number_of_poses} poses." + if len(split_array_clean) * number_of_poses < data_array.shape[0] * 0.99: + message = ( + "WARNING: The number of poses provided per molecule is inconsistent. " + f"Only {len(split_array_clean)} of {int(data_array.shape[0] / number_of_poses)} moleules have {number_of_poses} poses." ) + logger.warning(message) - molecule_array = np.array(split_array_clean) - # overwrite the name column (should be the only one with dtype=str) so we can force everything to float - molecule_array[:, :, name_column] = 0 - return np.array(molecule_array, dtype=float) + molecule_3d_array = np.array(split_array_clean) + # Overwrite the name column (should be the only one with dtype=str) so we can force everything to float + molecule_3d_array[:, :, name_column] = 0 + final_array = molecule_3d_array.astype(float) + return final_array def calculate_results_for_filter_combination( - self, - filter_combination, - molecule_array, - filters, - min_score_indices, - number_of_validation_mols, - ): + self, filter_combination: np.ndarray, molecule_array: np.ndarray, min_score_indices: dict[float, np.ndarray] + ) -> RBHTResult: """ For a particular combination of filters, calculate the percentage of molecules that will be filtered, the percentage of top-scoring molecules that will be filtered, and the time taken relative to exhaustive docking """ @@ -181,57 +174,60 @@ def calculate_results_for_filter_combination( if n: # e.g. if there are 5000 mols left after 15 steps and the last filter was at 5 steps, append 5000 * (15 - 5) to number_of_simulated_poses number_of_simulated_poses += len(mols_passed_threshold) * ( - filters[n]["steps"] - filters[n - 1]["steps"] + self.config.filters[n]["steps"] - self.config.filters[n - 1]["steps"] ) else: - number_of_simulated_poses += len(mols_passed_threshold) * filters[n]["steps"] + number_of_simulated_poses += len(mols_passed_threshold) * self.config.filters[n]["steps"] mols_passed_threshold = [ # all mols which pass the threshold and which were already in mols_passed_threshold, i.e. passed all previous filters n - for n in self.apply_threshold(molecule_array, filters[n]["column"], filters[n]["steps"], threshold) + for n in self.apply_threshold( + molecule_array, self.config.filters[n]["column"], self.config.filters[n]["steps"], threshold + ) if n in mols_passed_threshold ] filter_percentages.append(len(mols_passed_threshold) / molecule_array.shape[0]) - number_of_simulated_poses += len(mols_passed_threshold) * (molecule_array.shape[1] - filters[-1]["steps"]) + number_of_simulated_poses += len(mols_passed_threshold) * ( + molecule_array.shape[1] - self.config.filters[-1]["steps"] + ) perc_val = { - k: len([n for n in v if n in mols_passed_threshold]) / number_of_validation_mols + k: len([n for n in v if n in mols_passed_threshold]) / self.config.validation for k, v in min_score_indices.items() } - return { - "filter_combination": filter_combination, - "perc_val": perc_val, - "filter_percentages": filter_percentages, - "time": number_of_simulated_poses / np.product(molecule_array.shape[:2]), - } + time = number_of_simulated_poses / np.product(molecule_array.shape[:2]) + rbhtresult = RBHTResult( + filter_combination=filter_combination, perc_val=perc_val, filter_percentages=filter_percentages, time=time + ) + return rbhtresult - def write_output(self, results, filters, number_of_validation_mols, output_file, column_names): + def write_output(self, results: list[RBHTResult], filters, number_of_validation_mols, output_file, column_names): """ Print results as a table. The number of columns varies depending how many columns the user picked. """ with open(output_file, "w") as f: # write header - for n in range(len(results[0]["filter_combination"])): + for n in range(len(results[0].filter_combination)): f.write(f"FILTER{n+1}\tNSTEPS{n+1}\tTHR{n+1}\tPERC{n+1}\t") - for n in results[0]["perc_val"]: + for n in results[0].perc_val: f.write(f"TOP{number_of_validation_mols}_{column_names[n]}\t") f.write(f"ENRICH_{column_names[n]}\t") f.write("TIME\n") # write results for result in results: - for n, threshold in enumerate(result["filter_combination"]): + for n, threshold in enumerate(result.filter_combination): f.write( - f"{column_names[filters[n]['column']]}\t{filters[n]['steps']}\t{threshold:.2f}\t{result['filter_percentages'][n]*100:.2f}\t" + f"{column_names[filters[n]['column']]}\t{filters[n]['steps']}\t{threshold:.2f}\t{result.filter_percentages[n]*100:.2f}\t" ) - for n in result["perc_val"]: - f.write(f"{result['perc_val'][n]*100:.2f}\t") - if result["filter_percentages"][-1]: - f.write(f"{result['perc_val'][n]/result['filter_percentages'][-1]:.2f}\t") + for n in result.perc_val: + f.write(f"{result.perc_val[n]*100:.2f}\t") + if result.filter_percentages[-1]: + f.write(f"{result.perc_val[n]/result.filter_percentages[-1]:.2f}\t") else: f.write("NaN\t") - f.write(f"{result['time']:.4f}\n") + f.write(f"{result.time:.4f}\n") return - def select_best_filter_combination(self, results, max_time, min_perc): + def select_best_filter_combination(self, results: list[RBHTResult], max_time, min_perc): """ Very debatable how to do this... Here we exclude all combinations with TIME < max_time and calculate an "enrichment factor" @@ -239,27 +235,27 @@ def select_best_filter_combination(self, results, max_time, min_perc): threshold with the highest enrichment factor """ min_max_values = {} - for col in results[0]["perc_val"].keys(): - vals = [result["perc_val"][col] for result in results] + for col in results[0].perc_val.keys(): + vals = [result.perc_val[col] for result in results] min_max_values[col] = {"min": min(vals), "max": max(vals)} - time_vals = [result["time"] for result in results] + time_vals = [result.time for result in results] min_max_values["time"] = {"min": min(time_vals), "max": max(time_vals)} combination_scores = [ sum( [ ( - (result["perc_val"][col] - min_max_values[col]["min"]) + (result.perc_val[col] - min_max_values[col]["min"]) / (min_max_values[col]["max"] - min_max_values[col]["min"]) ) - for col in results[0]["perc_val"].keys() + for col in results[0].perc_val.keys() ] + [ - (min_max_values["time"]["max"] - result["time"]) + (min_max_values["time"]["max"] - result.time) / (min_max_values["time"]["max"] - min_max_values["time"]["min"]) ] ) - if result["time"] < max_time and result["filter_percentages"][-1] >= min_perc / 100 + if result.time < max_time and result.filter_percentages[-1] >= min_perc / 100 else 0 for result in results ] @@ -294,7 +290,7 @@ def generate_filters_combinations(self, filters: list[Filter]) -> list[tuple]: filters_combinations = list(itertools.product(*combinations)) return filters_combinations - def remove_redundant_combinations(self, all_combinations: list[tuple], filters: list[Filter]) -> list[tuple]: + def remove_redundant_combinations(self, all_combinations: list[tuple], filters: list[Filter]) -> np.ndarray: all_combinations_array = np.array(all_combinations) columns = [filter["column"] for filter in filters] indices_per_col = {col: [i for i, c in enumerate(columns) if c == col] for col in set(columns)} diff --git a/rdock-utils/requirements-dev.txt b/rdock-utils/requirements-dev.txt index 16263576..531ad766 100644 --- a/rdock-utils/requirements-dev.txt +++ b/rdock-utils/requirements-dev.txt @@ -1,3 +1,3 @@ mypy==1.8.0 pytest==7.4.4 -ruff==0.1.14 +ruff==0.5.4 From 6e9dc572445b38d2a9759f07b3a13c53e7792036 Mon Sep 17 00:00:00 2001 From: lpardey Date: Tue, 23 Jul 2024 00:30:45 +0000 Subject: [PATCH 09/18] refactor (WIP) --- .../rdock_utils/rbhtfinder/rbhtfinder.py | 154 ++++++++++-------- 1 file changed, 88 insertions(+), 66 deletions(-) diff --git a/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py b/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py index e415e116..26f58bcc 100644 --- a/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py +++ b/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py @@ -17,10 +17,10 @@ @dataclass -class RBHTResult: - filter_combination: np.ndarray +class FilterCombinationResult: + combination: np.ndarray perc_val: dict[str, int] - filter_percentages: list[int] + percentages: list[int] time: float @@ -44,10 +44,10 @@ def run(self) -> None: print(f"First few rows of the input array:\n{sdreport_array[:5]}") print("Data read in from input file.") - # convert to 3D array (molecules x poses x columns) + # Convert to 3D array (molecules x poses x columns) molecule_array = self.prepare_array(sdreport_array, self.config.name) - # find the top scoring compounds for validation of the filter combinations + # Find the top scoring compounds for validation of the filter combinations min_score_indices: dict[float, np.ndarray] = {} for column in set(filter["column"] for filter in self.config.filters): min_scores = np.min(molecule_array[:, :, column], axis=1) @@ -63,7 +63,7 @@ def run(self) -> None: distinct_combinations, ) - self.write_output(results, self.config.filters, self.config.validation, self.config.output, column_names) + self.write_output(results, column_names) best_filter_combination = self.select_best_filter_combination( results, self.config.max_time, self.config.min_percentage @@ -97,7 +97,7 @@ def read_data_using_pandas(self) -> InputData: if self.config.header: column_names = sdreport_dataframe.columns.values else: - # use index names; add 1 to deal with zero-based numbering + # Use index names; add 1 to deal with zero-based numbering column_names = [f"COL{n+1}" for n in range(len(sdreport_dataframe.columns))] sdreport_array = sdreport_dataframe.values @@ -115,13 +115,13 @@ def read_data_using_numpy(self) -> InputData: return sdreport_array, column_names - def apply_threshold(self, scored_poses, column, steps, threshold): + def apply_threshold(self, scored_poses: np.ndarray, column: int, steps: int, threshold: float) -> np.ndarray: """ Filter out molecules from `scored_poses`, where the minimum score reached (for a specified `column`) after `steps` is more negative than `threshold`. """ - # minimum score after `steps` per molecule + # Minimum score after `steps` per molecule mins = np.min(scored_poses[:, :steps, column], axis=1) - # return those molecules where the minimum score is less than the threshold + # Return those molecules where the minimum score is less than the threshold passing_molecules = np.where(mins < threshold)[0] return passing_molecules @@ -162,72 +162,94 @@ def prepare_array(self, data_array: np.ndarray, name_column: int) -> np.ndarray: def calculate_results_for_filter_combination( self, filter_combination: np.ndarray, molecule_array: np.ndarray, min_score_indices: dict[float, np.ndarray] - ) -> RBHTResult: + ) -> FilterCombinationResult: """ For a particular combination of filters, calculate the percentage of molecules that will be filtered, the percentage of top-scoring molecules that will be filtered, and the time taken relative to exhaustive docking """ - # mols_passed_threshold is a list of indices of molecules which have passed the applied filters. As more filters are applied, it gets smaller. Before any iteration, we initialise with all molecules passing - mols_passed_threshold = list(range(molecule_array.shape[0])) + num_molecules = molecule_array.shape[0] + num_steps = molecule_array.shape[1] + + # Passing_molecule_indices is a list of indices of molecules which have passed the applied filters. As more filters are applied, it gets smaller. Before any iteration, we initialise with all molecules passing + passing_molecule_indices = np.arange(num_molecules) filter_percentages = [] - number_of_simulated_poses = 0 # number of poses which we calculate would be generated, we use this to calculate the TIME column in the final output - for n, threshold in enumerate(filter_combination): - if n: - # e.g. if there are 5000 mols left after 15 steps and the last filter was at 5 steps, append 5000 * (15 - 5) to number_of_simulated_poses - number_of_simulated_poses += len(mols_passed_threshold) * ( - self.config.filters[n]["steps"] - self.config.filters[n - 1]["steps"] - ) - else: - number_of_simulated_poses += len(mols_passed_threshold) * self.config.filters[n]["steps"] - mols_passed_threshold = [ # all mols which pass the threshold and which were already in mols_passed_threshold, i.e. passed all previous filters - n - for n in self.apply_threshold( - molecule_array, self.config.filters[n]["column"], self.config.filters[n]["steps"], threshold - ) - if n in mols_passed_threshold - ] - filter_percentages.append(len(mols_passed_threshold) / molecule_array.shape[0]) - number_of_simulated_poses += len(mols_passed_threshold) * ( - molecule_array.shape[1] - self.config.filters[-1]["steps"] - ) + number_of_simulated_poses = 0 # Number of poses which we calculate would be generated, we use this to calculate the TIME column in the final output + + for i, threshold in enumerate(filter_combination): + number_of_simulated_poses += self.calculate_simulated_poses_increment(i, passing_molecule_indices) + passing_indices = self.apply_threshold( + molecule_array, self.config.filters[i]["column"], self.config.filters[i]["steps"], threshold + ) + # All mols which pass the threshold and which were already in passing_molecule_indices, i.e. passed all previous filters + passing_molecule_indices = np.intersect1d(passing_molecule_indices, passing_indices, assume_unique=True) + filter_percentages.append(len(passing_molecule_indices) / num_molecules) + + number_of_simulated_poses += len(passing_molecule_indices) * (num_steps - self.config.filters[-1]["steps"]) perc_val = { - k: len([n for n in v if n in mols_passed_threshold]) / self.config.validation + k: len(np.intersect1d(v, passing_molecule_indices, assume_unique=True)) / self.config.validation for k, v in min_score_indices.items() } - time = number_of_simulated_poses / np.product(molecule_array.shape[:2]) - rbhtresult = RBHTResult( - filter_combination=filter_combination, perc_val=perc_val, filter_percentages=filter_percentages, time=time + time = number_of_simulated_poses / np.prod(molecule_array.shape[:2]) + result = FilterCombinationResult( + combination=filter_combination, + perc_val=perc_val, + percentages=filter_percentages, + time=time, ) - return rbhtresult + return result + + def calculate_simulated_poses_increment(self, index: int, passing_molecule_indices: np.ndarray) -> int: + if index: + # e.g. if there are 5000 mols left after 15 steps and the last filter was at 5 steps, append 5000 * (15 - 5) to number_of_simulated_poses + increment = len(passing_molecule_indices) * ( + self.config.filters[index]["steps"] - self.config.filters[index - 1]["steps"] + ) + else: + increment = len(passing_molecule_indices) * self.config.filters[index]["steps"] + return increment - def write_output(self, results: list[RBHTResult], filters, number_of_validation_mols, output_file, column_names): + def write_output(self, results: list[FilterCombinationResult], column_names: list[str]) -> None: """ Print results as a table. The number of columns varies depending how many columns the user picked. """ - with open(output_file, "w") as f: - # write header - for n in range(len(results[0].filter_combination)): - f.write(f"FILTER{n+1}\tNSTEPS{n+1}\tTHR{n+1}\tPERC{n+1}\t") - for n in results[0].perc_val: - f.write(f"TOP{number_of_validation_mols}_{column_names[n]}\t") - f.write(f"ENRICH_{column_names[n]}\t") - f.write("TIME\n") - - # write results - for result in results: - for n, threshold in enumerate(result.filter_combination): - f.write( - f"{column_names[filters[n]['column']]}\t{filters[n]['steps']}\t{threshold:.2f}\t{result.filter_percentages[n]*100:.2f}\t" - ) - for n in result.perc_val: - f.write(f"{result.perc_val[n]*100:.2f}\t") - if result.filter_percentages[-1]: - f.write(f"{result.perc_val[n]/result.filter_percentages[-1]:.2f}\t") - else: - f.write("NaN\t") - f.write(f"{result.time:.4f}\n") - return - - def select_best_filter_combination(self, results: list[RBHTResult], max_time, min_perc): + with open(self.config.output, "w") as f: + header = self.get_output_header(results[0], column_names) + f.write("\t".join(header) + "\n") + content_lines = ["\t".join(self.get_output_content(result, column_names)) + "\n" for result in results] + f.writelines(content_lines) + + def get_output_header(self, result: FilterCombinationResult, column_names: list[str]) -> list[str]: + header = [] + for i in range(len(result.combination)): + header.extend([f"FILTER{i + 1}", f"NSTEPS{i + 1}", f"THR{i + 1}", f"PERC{i + 1}"]) + + for col_index in result.perc_val.keys(): + header.append(f"TOP{self.config.validation}_{column_names[col_index]}") + header.append(f"ENRICH_{column_names[col_index]}") + + header.append("TIME") + + return header + + def get_output_content(self, result: FilterCombinationResult, column_names: list[str]) -> list[str]: + content = [] + + for index, threshold in enumerate(result.combination): + column_name = column_names[self.config.filters[index]["column"]] + steps = self.config.filters[index]["steps"] + filter_percentage = result.percentages[index] * 100 + content.extend([f"{column_name}", f"{steps}", f"{threshold:.2f}", f"{filter_percentage:.2f}"]) + + for _, perc_val in result.perc_val.items(): + perc_val_percent = perc_val * 100 + enrichment = perc_val / result.percentages[-1] if result.percentages[-1] else float("nan") + content.append(f"{perc_val_percent:.2f}") + content.append(f"{enrichment:.2f}") + + content.append(f"{result.time:.4f}") + + return content + + def select_best_filter_combination(self, results: list[FilterCombinationResult], max_time, min_perc): """ Very debatable how to do this... Here we exclude all combinations with TIME < max_time and calculate an "enrichment factor" @@ -255,7 +277,7 @@ def select_best_filter_combination(self, results: list[RBHTResult], max_time, mi / (min_max_values["time"]["max"] - min_max_values["time"]["min"]) ] ) - if result.time < max_time and result.filter_percentages[-1] >= min_perc / 100 + if result.time < max_time and result.percentages[-1] >= min_perc / 100 else 0 for result in results ] @@ -300,7 +322,7 @@ def remove_redundant_combinations(self, all_combinations: list[tuple], filters: for _, indices in indices_per_col.items(): col_data = all_combinations_array[:, indices] - sorted_data = np.sort(col_data, axis=1)[:, ::-1] # sort descending + sorted_data = np.sort(col_data, axis=1)[:, ::-1] # Sort descending is_valid = np.all(col_data == sorted_data, axis=1) # Check if sorted matches original is_unique = np.apply_along_axis(lambda x: len(set(x)) == len(x), 1, col_data) mask &= is_valid & is_unique From 3fb643f96494d6fcb436aee06bae9d7c95cf6a1a Mon Sep 17 00:00:00 2001 From: lpardey Date: Tue, 23 Jul 2024 17:33:22 +0000 Subject: [PATCH 10/18] threshold value for the rbhtfinderconfig could be None refactor (WIP) --- rdock-utils/rdock_utils/rbhtfinder/parser.py | 6 +- .../rdock_utils/rbhtfinder/rbhtfinder.py | 89 +++++++++---------- 2 files changed, 45 insertions(+), 50 deletions(-) diff --git a/rdock-utils/rdock_utils/rbhtfinder/parser.py b/rdock-utils/rdock_utils/rbhtfinder/parser.py index 946fcf95..21b9390a 100644 --- a/rdock-utils/rdock_utils/rbhtfinder/parser.py +++ b/rdock-utils/rdock_utils/rbhtfinder/parser.py @@ -8,7 +8,7 @@ class RBHTFinderConfig: input: str output: str - threshold: str + threshold: str | None name: int filters: list[Filter] validation: int @@ -101,7 +101,7 @@ def get_parser() -> argparse.ArgumentParser: considered a template that the user modifies as needed. Requirements: - rbhtfinder requires NumPy. Installation of pandas is recommended, but optional; if pandas is + rbhtfinder requires NumPy. Installation of Pandas is recommended, but optional; if Pandas is not available, loading the input file for calculations will be considerably slower. """ input_help = "Input from sdreport (tabular separated format)." @@ -119,7 +119,7 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument("-o", "--output", help=output_help, type=str, required=True) parser.add_argument("-t", "--threshold", help=threshold_help, type=str) parser.add_argument("-n", "--name", type=int, default=1, help=name_help) - parser.add_argument("-f", "--filters", nargs="+", type=str, help=filter_help) + parser.add_argument("-f", "--filters", nargs="+", type=str, help=filter_help, required=True) # Review 'required' parser.add_argument("-v", "--validation", type=int, default=500, help=validation_help) parser.add_argument("--header", action="store_true", help=header_help) parser.add_argument("--max-time", type=float, default=0.1, help=max_time_help) diff --git a/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py b/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py index 26f58bcc..32caab25 100644 --- a/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py +++ b/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py @@ -65,9 +65,8 @@ def run(self) -> None: self.write_output(results, column_names) - best_filter_combination = self.select_best_filter_combination( - results, self.config.max_time, self.config.min_percentage - ) + best_filter_combination = self.select_best_filter_combination(results) + if self.config.threshold: if best_filter_combination: self.write_threshold_file( @@ -78,10 +77,8 @@ def run(self) -> None: molecule_array.shape[1], ) else: - print( - "Filter combinations defined are too strict or would take too long to run; no threshold file was written." - ) - exit(1) + message = "Filter combinations defined are too strict or would take too long to run; no threshold file was written." + logger.warning(message) def read_data(self) -> InputData: try: @@ -136,13 +133,13 @@ def prepare_array(self, data_array: np.ndarray, name_column: int) -> np.ndarray: + 1 ) split_array = np.split(data_array, split_indices) - modal_shape = Counter([n.shape for n in split_array]).most_common(1)[0] + modal_shape = Counter([array.shape for array in split_array]).most_common(1)[0] number_of_poses = modal_shape[0][0] # Find modal number of poses per molecule in the array split_array_clean = sum( [ - np.array_split(n, n.shape[0] / number_of_poses) - for n in split_array - if not n.shape[0] % number_of_poses and n.shape[0] + np.array_split(array, array.shape[0] / number_of_poses) + for array in split_array + if not array.shape[0] % number_of_poses and array.shape[0] ], [], ) @@ -168,7 +165,6 @@ def calculate_results_for_filter_combination( """ num_molecules = molecule_array.shape[0] num_steps = molecule_array.shape[1] - # Passing_molecule_indices is a list of indices of molecules which have passed the applied filters. As more filters are applied, it gets smaller. Before any iteration, we initialise with all molecules passing passing_molecule_indices = np.arange(num_molecules) filter_percentages = [] @@ -207,14 +203,14 @@ def calculate_simulated_poses_increment(self, index: int, passing_molecule_indic increment = len(passing_molecule_indices) * self.config.filters[index]["steps"] return increment - def write_output(self, results: list[FilterCombinationResult], column_names: list[str]) -> None: + def write_output(self, results: list[FilterCombinationResult], column_names: list[str], sep: str = "\t") -> None: """ Print results as a table. The number of columns varies depending how many columns the user picked. """ with open(self.config.output, "w") as f: header = self.get_output_header(results[0], column_names) - f.write("\t".join(header) + "\n") - content_lines = ["\t".join(self.get_output_content(result, column_names)) + "\n" for result in results] + f.write(sep.join(header) + "\n") + content_lines = [sep.join(self.get_output_content(result, column_names)) + "\n" for result in results] f.writelines(content_lines) def get_output_header(self, result: FilterCombinationResult, column_names: list[str]) -> list[str]: @@ -233,15 +229,15 @@ def get_output_header(self, result: FilterCombinationResult, column_names: list[ def get_output_content(self, result: FilterCombinationResult, column_names: list[str]) -> list[str]: content = [] - for index, threshold in enumerate(result.combination): - column_name = column_names[self.config.filters[index]["column"]] - steps = self.config.filters[index]["steps"] - filter_percentage = result.percentages[index] * 100 + for i, threshold in enumerate(result.combination): + column_name = column_names[self.config.filters[i]["column"]] + steps = self.config.filters[i]["steps"] + filter_percentage = result.percentages[i] * 100 content.extend([f"{column_name}", f"{steps}", f"{threshold:.2f}", f"{filter_percentage:.2f}"]) - for _, perc_val in result.perc_val.items(): - perc_val_percent = perc_val * 100 - enrichment = perc_val / result.percentages[-1] if result.percentages[-1] else float("nan") + for value in result.perc_val.values(): + perc_val_percent = value * 100 + enrichment = value / result.percentages[-1] if result.percentages[-1] else float("nan") content.append(f"{perc_val_percent:.2f}") content.append(f"{enrichment:.2f}") @@ -249,7 +245,7 @@ def get_output_content(self, result: FilterCombinationResult, column_names: list return content - def select_best_filter_combination(self, results: list[FilterCombinationResult], max_time, min_perc): + def select_best_filter_combination(self, results: list[FilterCombinationResult]) -> float: """ Very debatable how to do this... Here we exclude all combinations with TIME < max_time and calculate an "enrichment factor" @@ -257,31 +253,31 @@ def select_best_filter_combination(self, results: list[FilterCombinationResult], threshold with the highest enrichment factor """ min_max_values = {} - for col in results[0].perc_val.keys(): - vals = [result.perc_val[col] for result in results] - min_max_values[col] = {"min": min(vals), "max": max(vals)} + # Transpose the `perc_val` data to get columns + perc_vals = {col: [result.perc_val[col] for result in results] for col in results[0].perc_val} + min_max_values.update({col: {"min": min(vals), "max": max(vals)} for col, vals in perc_vals.items()}) time_vals = [result.time for result in results] min_max_values["time"] = {"min": min(time_vals), "max": max(time_vals)} - - combination_scores = [ - sum( - [ - ( - (result.perc_val[col] - min_max_values[col]["min"]) - / (min_max_values[col]["max"] - min_max_values[col]["min"]) - ) - for col in results[0].perc_val.keys() - ] - + [ - (min_max_values["time"]["max"] - result.time) - / (min_max_values["time"]["max"] - min_max_values["time"]["min"]) - ] + combination_scores = [self.calculate_combination_score(result, min_max_values) for result in results] + best_combination = np.argmax(combination_scores) + return best_combination + + def calculate_combination_score(self, result: FilterCombinationResult, min_max_values: dict[str, int]) -> float: + if result.time < self.config.max_time and result.percentages[-1] >= self.config.min_percentage / 100: + col_scores = [ + (result.perc_val[col] - min_max_values[col]["min"]) + / (min_max_values[col]["max"] - min_max_values[col]["min"]) + for col in min_max_values + if col != "time" + ] + + time_score = (min_max_values["time"]["max"] - result.time) / ( + min_max_values["time"]["max"] - min_max_values["time"]["min"] ) - if result.time < max_time and result.percentages[-1] >= min_perc / 100 - else 0 - for result in results - ] - return np.argmax(combination_scores) + + return sum(col_scores) + time_score + + return 0 def write_threshold_file(self, filters, best_filter_combination, threshold_file, column_names, max_number_of_runs): with open(threshold_file, "w") as f: @@ -316,11 +312,10 @@ def remove_redundant_combinations(self, all_combinations: list[tuple], filters: all_combinations_array = np.array(all_combinations) columns = [filter["column"] for filter in filters] indices_per_col = {col: [i for i, c in enumerate(columns) if c == col] for col in set(columns)} - # Create a mask to keep only valid combinations mask = np.ones(len(all_combinations_array), dtype=bool) - for _, indices in indices_per_col.items(): + for indices in indices_per_col.values(): col_data = all_combinations_array[:, indices] sorted_data = np.sort(col_data, axis=1)[:, ::-1] # Sort descending is_valid = np.all(col_data == sorted_data, axis=1) # Check if sorted matches original From 47bdc4528e88fc95f1b4f37e5ce2ed9dc10e3148 Mon Sep 17 00:00:00 2001 From: lpardey Date: Tue, 23 Jul 2024 21:33:20 +0000 Subject: [PATCH 11/18] update tests to assert expected threshold file --- rdock-utils/tests/rbhtfinder/conftest.py | 16 +++++++++++----- rdock-utils/tests/rbhtfinder/test_integration.py | 13 ++++++++----- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/rdock-utils/tests/rbhtfinder/conftest.py b/rdock-utils/tests/rbhtfinder/conftest.py index 381957db..9d688bb3 100644 --- a/rdock-utils/tests/rbhtfinder/conftest.py +++ b/rdock-utils/tests/rbhtfinder/conftest.py @@ -7,25 +7,31 @@ RBHTFINDER_FIXTURES_FOLDER = FIXTURES_FOLDER / "rbhtfinder" INPUT_FILE = str(RBHTFINDER_FIXTURES_FOLDER / "input.txt") -THRESHOLD_FILE = str(RBHTFINDER_FIXTURES_FOLDER / "threshold.txt") +EXPECTED_THRESHOLD_FILE = str(RBHTFINDER_FIXTURES_FOLDER / "threshold.txt") EXPECTED_OUTPUT_FILE = str(RBHTFINDER_FIXTURES_FOLDER / "output.txt") @pytest.fixture -def file_path(tmp_path: Path) -> Path: +def output_temp(tmp_path: Path) -> Path: output_path = tmp_path / "output.txt" return output_path @pytest.fixture -def argv(file_path: Path) -> list[str]: +def threshold_temp(tmp_path: Path) -> Path: + threshold_path = tmp_path / "threshold.txt" + return threshold_path + + +@pytest.fixture +def argv(output_temp: Path, threshold_temp: Path) -> list[str]: return [ "-i", INPUT_FILE, "-o", - str(file_path), + str(output_temp), "-t", - THRESHOLD_FILE, + str(threshold_temp), "-f", "column=4,steps=3,min=-10.0,max=0.0,interval=5.0", "column=6,steps=5,min=1.0,max=5.0,interval=5.0", diff --git a/rdock-utils/tests/rbhtfinder/test_integration.py b/rdock-utils/tests/rbhtfinder/test_integration.py index 07045efb..9e4d637a 100644 --- a/rdock-utils/tests/rbhtfinder/test_integration.py +++ b/rdock-utils/tests/rbhtfinder/test_integration.py @@ -6,7 +6,7 @@ from rdock_utils.rbhtfinder.main import main as rbhtfinder_main from rdock_utils.rbhtfinder_original_copy import main as rbhtfinder_old_main -from .conftest import EXPECTED_OUTPUT_FILE, get_file_content +from .conftest import EXPECTED_OUTPUT_FILE, EXPECTED_THRESHOLD_FILE, get_file_content parametrize_main = pytest.mark.parametrize( "main", @@ -24,8 +24,11 @@ def test_do_nothing(main: Callable[[list[str]], None]): @parametrize_main -def test_integration(main: Callable[[list[str]], None], file_path: Path, argv: list[str]): +def test_integration(main: Callable[[list[str]], None], output_temp: Path, threshold_temp: Path, argv: list[str]): main(argv) - result = get_file_content(file_path) - expected_result = get_file_content(EXPECTED_OUTPUT_FILE) - assert result == expected_result + output = get_file_content(output_temp) + threshold = get_file_content(threshold_temp) + expected_output = get_file_content(EXPECTED_OUTPUT_FILE) + expected_threshold = get_file_content(EXPECTED_THRESHOLD_FILE) + assert output == expected_output + assert threshold == expected_threshold From 24050552c1a9d6fdebe4a72c8a954f609f3bfe2f Mon Sep 17 00:00:00 2001 From: lpardey Date: Wed, 24 Jul 2024 00:49:22 +0000 Subject: [PATCH 12/18] refactor (WIP) --- rdock-utils/rdock_utils/rbhtfinder/parser.py | 4 +- .../rdock_utils/rbhtfinder/rbhtfinder.py | 196 ++++++++++-------- 2 files changed, 115 insertions(+), 85 deletions(-) diff --git a/rdock-utils/rdock_utils/rbhtfinder/parser.py b/rdock-utils/rdock_utils/rbhtfinder/parser.py index 21b9390a..f48049a0 100644 --- a/rdock-utils/rdock_utils/rbhtfinder/parser.py +++ b/rdock-utils/rdock_utils/rbhtfinder/parser.py @@ -1,7 +1,7 @@ import argparse from dataclasses import dataclass -Filter = dict[str, float] +Filter = dict[str, float | int] @dataclass @@ -20,7 +20,7 @@ def __post_init__(self) -> None: self.filters = self.get_parsed_filters() def get_parsed_filters(self) -> list[Filter]: - parsed_filters = [self._parse_filter(filter) for filter in self.filters] + parsed_filters = [self._parse_filter(filter) for filter in self.filters] # type: ignore # sort filters by step at which they are applied parsed_filters.sort(key=lambda n: n["steps"]) return parsed_filters diff --git a/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py b/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py index 32caab25..4d4fcdb9 100644 --- a/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py +++ b/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py @@ -2,9 +2,10 @@ import logging import multiprocessing import os -from collections import Counter +from collections import Counter, defaultdict from dataclasses import dataclass from functools import partial +from typing import Any import numpy as np import pandas as pd @@ -19,7 +20,7 @@ @dataclass class FilterCombinationResult: combination: np.ndarray - perc_val: dict[str, int] + perc_val: dict[float, float] percentages: list[int] time: float @@ -43,42 +44,46 @@ def run(self) -> None: sdreport_array, column_names = self.read_data() print(f"First few rows of the input array:\n{sdreport_array[:5]}") print("Data read in from input file.") - # Convert to 3D array (molecules x poses x columns) molecule_array = self.prepare_array(sdreport_array, self.config.name) - # Find the top scoring compounds for validation of the filter combinations - min_score_indices: dict[float, np.ndarray] = {} - for column in set(filter["column"] for filter in self.config.filters): - min_scores = np.min(molecule_array[:, :, column], axis=1) - min_score_indices[column] = np.argpartition(min_scores, self.config.validation)[: self.config.validation] - - pool = multiprocessing.Pool(os.cpu_count()) - results = pool.map( - partial( + columns = set(filter["column"] for filter in self.config.filters) + min_score_indices = { + column: np.argpartition(np.min(molecule_array[:, :, column], axis=1), self.config.validation)[ + : self.config.validation + ] + for column in columns + } + results = self.process_filter_combinations(molecule_array, min_score_indices, distinct_combinations) + self.write_output(results, column_names) + best_filter_combination_index = self.get_best_filter_combination_index(results) + if self.config.threshold: + self.handle_threshold( + best_filter_combination_index, distinct_combinations, column_names, molecule_array.shape[1] + ) + + def process_filter_combinations( + self, molecule_array: np.ndarray, min_score_indices: dict[float, np.ndarray], distinct_combinations: np.ndarray + ) -> list[FilterCombinationResult]: + num_cpus = os.cpu_count() or 1 + with multiprocessing.Pool(num_cpus) as pool: + function_to_apply = partial( self.calculate_results_for_filter_combination, molecule_array=molecule_array, min_score_indices=min_score_indices, - ), - distinct_combinations, - ) - - self.write_output(results, column_names) - - best_filter_combination = self.select_best_filter_combination(results) - - if self.config.threshold: - if best_filter_combination: - self.write_threshold_file( - self.config.filters, - distinct_combinations[best_filter_combination], - self.config.threshold, - column_names, - molecule_array.shape[1], - ) - else: - message = "Filter combinations defined are too strict or would take too long to run; no threshold file was written." - logger.warning(message) + ) + results = pool.map(function_to_apply, distinct_combinations) + return results + + def handle_threshold( + self, combination_index: int, distinct_combinations: np.ndarray, column_names: list[str], num_poses: int + ) -> None: + if combination_index: + best_filter_combination = distinct_combinations[combination_index] + self.write_threshold(best_filter_combination, column_names, num_poses) + else: + message = "Filter combinations defined are too strict or would take too long to run; no threshold file was written." + logger.warning(message) def read_data(self) -> InputData: try: @@ -112,7 +117,9 @@ def read_data_using_numpy(self) -> InputData: return sdreport_array, column_names - def apply_threshold(self, scored_poses: np.ndarray, column: int, steps: int, threshold: float) -> np.ndarray: + def apply_threshold( + self, scored_poses: np.ndarray, column: float | int, steps: float | int, threshold: float + ) -> np.ndarray: """ Filter out molecules from `scored_poses`, where the minimum score reached (for a specified `column`) after `steps` is more negative than `threshold`. """ @@ -135,23 +142,21 @@ def prepare_array(self, data_array: np.ndarray, name_column: int) -> np.ndarray: split_array = np.split(data_array, split_indices) modal_shape = Counter([array.shape for array in split_array]).most_common(1)[0] number_of_poses = modal_shape[0][0] # Find modal number of poses per molecule in the array - split_array_clean = sum( - [ - np.array_split(array, array.shape[0] / number_of_poses) - for array in split_array - if not array.shape[0] % number_of_poses and array.shape[0] - ], - [], - ) - - if len(split_array_clean) * number_of_poses < data_array.shape[0] * 0.99: + valid_split_arrays = [ + np.array_split(array, array.shape[0] / number_of_poses) # type: ignore + for array in split_array + if not array.shape[0] % number_of_poses and array.shape[0] + ] + flattened_split_array = np.concatenate(valid_split_arrays) + + if len(flattened_split_array) * number_of_poses < data_array.shape[0] * 0.99: message = ( "WARNING: The number of poses provided per molecule is inconsistent. " - f"Only {len(split_array_clean)} of {int(data_array.shape[0] / number_of_poses)} moleules have {number_of_poses} poses." + f"Only {len(flattened_split_array)} of {int(data_array.shape[0] / number_of_poses)} molecules have {number_of_poses} poses." ) logger.warning(message) - molecule_3d_array = np.array(split_array_clean) + molecule_3d_array = np.array(flattened_split_array) # Overwrite the name column (should be the only one with dtype=str) so we can force everything to float molecule_3d_array[:, :, name_column] = 0 final_array = molecule_3d_array.astype(float) @@ -168,7 +173,9 @@ def calculate_results_for_filter_combination( # Passing_molecule_indices is a list of indices of molecules which have passed the applied filters. As more filters are applied, it gets smaller. Before any iteration, we initialise with all molecules passing passing_molecule_indices = np.arange(num_molecules) filter_percentages = [] - number_of_simulated_poses = 0 # Number of poses which we calculate would be generated, we use this to calculate the TIME column in the final output + number_of_simulated_poses: float | int = ( + 0 # Number of poses which we calculate would be generated, we use this to calculate the TIME column in the final output + ) for i, threshold in enumerate(filter_combination): number_of_simulated_poses += self.calculate_simulated_poses_increment(i, passing_molecule_indices) @@ -203,17 +210,23 @@ def calculate_simulated_poses_increment(self, index: int, passing_molecule_indic increment = len(passing_molecule_indices) * self.config.filters[index]["steps"] return increment - def write_output(self, results: list[FilterCombinationResult], column_names: list[str], sep: str = "\t") -> None: + def write_output( + self, + results: list[FilterCombinationResult], + column_names: list[str], + sep: str = "\t", + end: str = "\n", + ) -> None: """ Print results as a table. The number of columns varies depending how many columns the user picked. """ with open(self.config.output, "w") as f: - header = self.get_output_header(results[0], column_names) - f.write(sep.join(header) + "\n") - content_lines = [sep.join(self.get_output_content(result, column_names)) + "\n" for result in results] + header = self._get_output_header(results[0], column_names) + f.write(sep.join(header) + end) + content_lines = [sep.join(self._get_output_content(result, column_names)) + end for result in results] f.writelines(content_lines) - def get_output_header(self, result: FilterCombinationResult, column_names: list[str]) -> list[str]: + def _get_output_header(self, result: FilterCombinationResult, column_names: list[str]) -> list[str]: header = [] for i in range(len(result.combination)): header.extend([f"FILTER{i + 1}", f"NSTEPS{i + 1}", f"THR{i + 1}", f"PERC{i + 1}"]) @@ -226,7 +239,7 @@ def get_output_header(self, result: FilterCombinationResult, column_names: list[ return header - def get_output_content(self, result: FilterCombinationResult, column_names: list[str]) -> list[str]: + def _get_output_content(self, result: FilterCombinationResult, column_names: list[str]) -> list[str]: content = [] for i, threshold in enumerate(result.combination): @@ -245,7 +258,7 @@ def get_output_content(self, result: FilterCombinationResult, column_names: list return content - def select_best_filter_combination(self, results: list[FilterCombinationResult]) -> float: + def get_best_filter_combination_index(self, results: list[FilterCombinationResult]) -> int: """ Very debatable how to do this... Here we exclude all combinations with TIME < max_time and calculate an "enrichment factor" @@ -259,8 +272,8 @@ def select_best_filter_combination(self, results: list[FilterCombinationResult]) time_vals = [result.time for result in results] min_max_values["time"] = {"min": min(time_vals), "max": max(time_vals)} combination_scores = [self.calculate_combination_score(result, min_max_values) for result in results] - best_combination = np.argmax(combination_scores) - return best_combination + index = np.argmax(combination_scores) + return index def calculate_combination_score(self, result: FilterCombinationResult, min_max_values: dict[str, int]) -> float: if result.time < self.config.max_time and result.percentages[-1] >= self.config.min_percentage / 100: @@ -270,37 +283,54 @@ def calculate_combination_score(self, result: FilterCombinationResult, min_max_v for col in min_max_values if col != "time" ] - time_score = (min_max_values["time"]["max"] - result.time) / ( min_max_values["time"]["max"] - min_max_values["time"]["min"] ) - - return sum(col_scores) + time_score - - return 0 - - def write_threshold_file(self, filters, best_filter_combination, threshold_file, column_names, max_number_of_runs): - with open(threshold_file, "w") as f: - # write number of filters to apply - f.write(f"{len(filters) + 1}\n") - # write each filter to a separate line - for n, filtr in enumerate(filters): - f.write( - f'if - {best_filter_combination[n]:.2f} {column_names[filtr["column"]]} 1.0 if - SCORE.NRUNS {filtr["steps"]} 0.0 -1.0,\n' - ) - # write filter to terminate docking when NRUNS reaches the number of runs used in the input file - f.write(f"if - SCORE.NRUNS {max_number_of_runs - 1} 0.0 -1.0\n") - - # write final filters - find strictest filters for all columns and apply them again - filters_by_column = { - col: [best_filter_combination[n] for n, filtr in enumerate(filters) if filtr["column"] == col] - for col in set([filtr["column"] for filtr in filters]) - } - # write number of filters (same as number of columns filtered on) - f.write(f"{len(filters_by_column)}\n") - # write filter - for col, values in filters_by_column.items(): - f.write(f"- {column_names[col]} {min(values)},\n") + score = sum(col_scores) + time_score + else: + score = 0 + + return score + + def write_threshold( + self, + best_filter_combination: np.ndarray, + column_names: list[str], + max_number_of_runs: int, + sep: str = "\n", + end: str = "\n", + ) -> None: + path: str = self.config.threshold + with open(path, "w") as f: + content = self._get_threshold_content(best_filter_combination, column_names, max_number_of_runs) + f.write(sep.join(content) + end) + + def _get_threshold_content( + self, best_filter_combination: np.ndarray, column_names: list[str], max_number_of_runs: int + ) -> list[str]: + content = [] + # Number of filters to apply + content.append(f"{len(self.config.filters) + 1}") + # Get each filter to a separate line + filter_lines = [ + f'if - {best_filter_combination[i]:.2f} {column_names[filter["column"]]} 1.0 ' + f'if - SCORE.NRUNS {filter["steps"]} 0.0 -1.0,' + for i, filter in enumerate(self.config.filters) + ] + content.extend(filter_lines) + # Filter to terminate docking when NRUNS reaches the number of runs used in the input file + content.append(f"if - SCORE.NRUNS {max_number_of_runs - 1} 0.0 -1.0") + # Find strictest filters for all columns and apply them again + filters_by_column = defaultdict(list) + for i, filter in enumerate(self.config.filters): + col = filter["column"] + filters_by_column[col].append(best_filter_combination[i]) + # Number of filters (same as number of columns filtered on) + content.append(f"{len(filters_by_column)}") + # Filter + filter_min_values = [f"- {column_names[col]} {min(values)}," for col, values in filters_by_column.items()] + content.extend(filter_min_values) + return content def generate_filters_combinations(self, filters: list[Filter]) -> list[tuple]: filter_ranges = ((filter["min"], filter["max"] + filter["interval"], filter["interval"]) for filter in filters) From 30ccbde506dec918e699f3e5966d2e65e412990b Mon Sep 17 00:00:00 2001 From: lpardey Date: Wed, 24 Jul 2024 22:50:17 +0000 Subject: [PATCH 13/18] add mypy plugin for numpy to pyprojectoml mypy fix add rbhtfinder types to common module add models module --- rdock-utils/pyproject.toml | 2 + rdock-utils/rdock_utils/common/__init__.py | 26 ++++- rdock-utils/rdock_utils/common/types.py | 41 +++++--- rdock-utils/rdock_utils/rbhtfinder/models.py | 52 ++++++++++ rdock-utils/rdock_utils/rbhtfinder/parser.py | 40 +------- .../rdock_utils/rbhtfinder/rbhtfinder.py | 96 ++++++++++--------- 6 files changed, 158 insertions(+), 99 deletions(-) create mode 100644 rdock-utils/rdock_utils/rbhtfinder/models.py diff --git a/rdock-utils/pyproject.toml b/rdock-utils/pyproject.toml index 980531a8..e18138f7 100644 --- a/rdock-utils/pyproject.toml +++ b/rdock-utils/pyproject.toml @@ -78,3 +78,5 @@ exclude = [ "tests/", "rdock_utils/sdtether_original.py", ] + +plugins = "numpy.typing.mypy_plugin" diff --git a/rdock-utils/rdock_utils/common/__init__.py b/rdock-utils/rdock_utils/common/__init__.py index 8f9411a0..898780af 100644 --- a/rdock-utils/rdock_utils/common/__init__.py +++ b/rdock-utils/rdock_utils/common/__init__.py @@ -2,11 +2,23 @@ from .SDFParser import FastSDMol, molecules_with_progress_log, read_molecules, read_molecules_from_all_inputs from .superpose3d import MolAlignmentData, Superpose3D, update_coordinates from .types import ( + Array1DFloat, + Array1DInt, + Array1DStr, + Array2DFloat, + Array3DFloat, AtomsMapping, AutomorphismRMSD, + ColumnNamesArray, CoordsArray, + FilterCombination, FloatArray, + InputData, Matrix3x3, + MinMaxDict, + MinMaxValues, + MinScoreIndices, + SDReportArray, SingularValueDecomposition, Superpose3DResult, Vector3D, @@ -25,11 +37,23 @@ "MolAlignmentData", "Superpose3D", # -- types -- + "Array1DFloat", + "Array1DInt", + "Array1DStr", + "Array2DFloat", + "Array3DFloat", + "AtomsMapping", "AutomorphismRMSD", + "ColumnNamesArray", "CoordsArray", + "FilterCombination", "FloatArray", - "AtomsMapping", + "InputData", "Matrix3x3", + "MinMaxDict", + "MinMaxValues", + "MinScoreIndices", + "SDReportArray", "SingularValueDecomposition", "Superpose3DResult", "Vector3D", diff --git a/rdock-utils/rdock_utils/common/types.py b/rdock-utils/rdock_utils/common/types.py index f972fd6e..a1d0c688 100644 --- a/rdock-utils/rdock_utils/common/types.py +++ b/rdock-utils/rdock_utils/common/types.py @@ -1,23 +1,40 @@ from typing import Any -import numpy +import numpy as np +import numpy.typing as npt -FloatArray = numpy.ndarray[Any, numpy.dtype[numpy.float64]] -CoordsArray = numpy.ndarray[Any, numpy.dtype[numpy.float64]] +# TODO: Review common types for all rdock_utils scripts +FloatArray = np.ndarray[Any, np.dtype[np.float64]] +CoordsArray = np.ndarray[Any, np.dtype[np.float64]] AutomorphismRMSD = tuple[float, CoordsArray | None] -Vector3D = numpy.ndarray[Any, numpy.dtype[numpy.float64]] -Matrix3x3 = numpy.ndarray[Any, numpy.dtype[numpy.float64]] +Vector3D = np.ndarray[Any, np.dtype[np.float64]] +Matrix3x3 = np.ndarray[Any, np.dtype[np.float64]] SingularValueDecomposition = tuple[Matrix3x3, Vector3D, Matrix3x3] Superpose3DResult = tuple[CoordsArray, float, Matrix3x3] AtomsMapping = tuple[tuple[int, int], ...] -## Shape support for type hinting is not yet avaialable in numpy -## let's keep this as a guide for numpy 2.0 release -# FloatArray = numpy.ndarray[Literal["N"], numpy.dtype[float]] -# BoolArray = numpy.ndarray[Literal["N"], numpy.dtype[bool]] -# CoordsArray = numpy.ndarray[Literal["N", 3], numpy.dtype[float]] + +# RBHTFinder types +SDReportArray = np.ndarray[list[int | str | float], np.dtype[np.object_]] +Array1DFloat = npt.NDArray[np.float_] +Array3DFloat = npt.NDArray[np.float_] +Array2DFloat = npt.NDArray[np.float_] +Array1DStr = npt.NDArray[np.str_] +Array1DInt = npt.NDArray[np.int_] +ColumnNamesArray = Array1DStr | list[str] +InputData = tuple[SDReportArray, ColumnNamesArray] +MinMaxDict = dict[str, float] +MinMaxValues = dict[Any, MinMaxDict] +MinScoreIndices = dict[int, Array1DInt] +FilterCombination = tuple[float, float] + +## Shape support for type hinting is not yet avaialable in np +## let's keep this as a guide for np 2.0 release +# FloatArray = np.ndarray[Literal["N"], np.dtype[float]] +# BoolArray = np.ndarray[Literal["N"], np.dtype[bool]] +# CoordsArray = np.ndarray[Literal["N", 3], np.dtype[float]] # AutomorphismRMSD = tuple[float, CoordsArray | None] -# Vector3D = numpy.ndarray[Literal[3], numpy.dtype[float]] -# Matrix3x3 = numpy.ndarray[Literal[3, 3], numpy.dtype[float]] +# Vector3D = np.ndarray[Literal[3], np.dtype[float]] +# Matrix3x3 = np.ndarray[Literal[3, 3], np.dtype[float]] # SingularValueDecomposition = tuple[Matrix3x3, Vector3D, Matrix3x3] # Superpose3DResult = tuple[CoordsArray, float, Matrix3x3] diff --git a/rdock-utils/rdock_utils/rbhtfinder/models.py b/rdock-utils/rdock_utils/rbhtfinder/models.py new file mode 100644 index 00000000..e3b8a67b --- /dev/null +++ b/rdock-utils/rdock_utils/rbhtfinder/models.py @@ -0,0 +1,52 @@ +from dataclasses import dataclass +from typing import Any + +from rdock_utils.common import Array1DFloat + +Filter = dict[str, Any] # The type for the values is either 'float' or 'int'; 'Any' is used to comply with mypy + + +@dataclass +class RBHTFinderConfig: + input: str + output: str + threshold: str | None + name: int + filters: list[Filter] + validation: int + header: bool + max_time: float + min_percentage: float + + def __post_init__(self) -> None: + self.filters = self.get_parsed_filters() + + def get_parsed_filters(self) -> list[Filter]: + filter_args: list[str] = self.filters # type: ignore + parsed_filters = [self._parse_filter(filter) for filter in filter_args] + # sort filters by step at which they are applied + parsed_filters.sort(key=lambda n: n["steps"]) + return parsed_filters + + @staticmethod + def _parse_filter(filter_str: str) -> Filter: + parsed_filter = {} + + for item in filter_str.split(","): + key, value = item.split("=") + parsed_filter[key] = float(value) if key in ("interval", "min", "max") else int(value) + # User inputs with 1-based numbering whereas python uses 0-based + parsed_filter["column"] -= 1 + + if "interval" not in parsed_filter: + parsed_filter["interval"] = 1.0 + + return parsed_filter + + +@dataclass +class FilterCombinationResult: + combination: Array1DFloat + perc_val: dict[int, float] + percentages: list[float] + time: float diff --git a/rdock-utils/rdock_utils/rbhtfinder/parser.py b/rdock-utils/rdock_utils/rbhtfinder/parser.py index f48049a0..41088314 100644 --- a/rdock-utils/rdock_utils/rbhtfinder/parser.py +++ b/rdock-utils/rdock_utils/rbhtfinder/parser.py @@ -1,44 +1,6 @@ import argparse -from dataclasses import dataclass -Filter = dict[str, float | int] - - -@dataclass -class RBHTFinderConfig: - input: str - output: str - threshold: str | None - name: int - filters: list[Filter] - validation: int - header: bool - max_time: float - min_percentage: float - - def __post_init__(self) -> None: - self.filters = self.get_parsed_filters() - - def get_parsed_filters(self) -> list[Filter]: - parsed_filters = [self._parse_filter(filter) for filter in self.filters] # type: ignore - # sort filters by step at which they are applied - parsed_filters.sort(key=lambda n: n["steps"]) - return parsed_filters - - @staticmethod - def _parse_filter(filter_str: str) -> Filter: - parsed_filter = {} - - for item in filter_str.split(","): - key, value = item.split("=") - parsed_filter[key] = float(value) if key in ("interval", "min", "max") else int(value) - # User inputs with 1-based numbering whereas python uses 0-based - parsed_filter["column"] -= 1 - - if "interval" not in parsed_filter: - parsed_filter["interval"] = 1.0 - - return parsed_filter +from .models import RBHTFinderConfig def get_parser() -> argparse.ArgumentParser: diff --git a/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py b/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py index 4d4fcdb9..1be2a252 100644 --- a/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py +++ b/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py @@ -3,27 +3,28 @@ import multiprocessing import os from collections import Counter, defaultdict -from dataclasses import dataclass from functools import partial -from typing import Any import numpy as np import pandas as pd -from .parser import Filter, RBHTFinderConfig +from rdock_utils.common import ( + Array1DFloat, + Array1DInt, + Array2DFloat, + Array3DFloat, + ColumnNamesArray, + FilterCombination, + InputData, + MinMaxValues, + MinScoreIndices, + SDReportArray, +) + +from .models import Filter, FilterCombinationResult, RBHTFinderConfig logger = logging.getLogger("RBHTFinder") -InputData = tuple[np.ndarray, list[str]] - - -@dataclass -class FilterCombinationResult: - combination: np.ndarray - perc_val: dict[float, float] - percentages: list[int] - time: float - class RBHTFinder: def __init__(self, config: RBHTFinderConfig) -> None: @@ -57,13 +58,12 @@ def run(self) -> None: results = self.process_filter_combinations(molecule_array, min_score_indices, distinct_combinations) self.write_output(results, column_names) best_filter_combination_index = self.get_best_filter_combination_index(results) - if self.config.threshold: - self.handle_threshold( - best_filter_combination_index, distinct_combinations, column_names, molecule_array.shape[1] - ) + if self.config.threshold is not None: + num_poses = molecule_array.shape[1] + self.handle_threshold(best_filter_combination_index, distinct_combinations, column_names, num_poses) def process_filter_combinations( - self, molecule_array: np.ndarray, min_score_indices: dict[float, np.ndarray], distinct_combinations: np.ndarray + self, molecule_array: Array3DFloat, min_score_indices: MinScoreIndices, distinct_combinations: Array2DFloat ) -> list[FilterCombinationResult]: num_cpus = os.cpu_count() or 1 with multiprocessing.Pool(num_cpus) as pool: @@ -76,10 +76,14 @@ def process_filter_combinations( return results def handle_threshold( - self, combination_index: int, distinct_combinations: np.ndarray, column_names: list[str], num_poses: int + self, + combination_index: int, + distinct_combinations: Array2DFloat, + column_names: ColumnNamesArray, + num_poses: int, ) -> None: if combination_index: - best_filter_combination = distinct_combinations[combination_index] + best_filter_combination: Array1DFloat = distinct_combinations[combination_index] self.write_threshold(best_filter_combination, column_names, num_poses) else: message = "Filter combinations defined are too strict or would take too long to run; no threshold file was written." @@ -117,9 +121,7 @@ def read_data_using_numpy(self) -> InputData: return sdreport_array, column_names - def apply_threshold( - self, scored_poses: np.ndarray, column: float | int, steps: float | int, threshold: float - ) -> np.ndarray: + def apply_threshold(self, scored_poses: Array3DFloat, column: int, steps: int, threshold: float) -> Array1DInt: """ Filter out molecules from `scored_poses`, where the minimum score reached (for a specified `column`) after `steps` is more negative than `threshold`. """ @@ -129,7 +131,7 @@ def apply_threshold( passing_molecules = np.where(mins < threshold)[0] return passing_molecules - def prepare_array(self, data_array: np.ndarray, name_column: int) -> np.ndarray: + def prepare_array(self, data_array: SDReportArray, name_column: int) -> Array3DFloat: """ Convert `sdreport_array` (read directly from the tsv) to 3D array (molecules x poses x columns) and filter out molecules with too few/many poses """ @@ -163,7 +165,7 @@ def prepare_array(self, data_array: np.ndarray, name_column: int) -> np.ndarray: return final_array def calculate_results_for_filter_combination( - self, filter_combination: np.ndarray, molecule_array: np.ndarray, min_score_indices: dict[float, np.ndarray] + self, filter_combination: Array2DFloat, molecule_array: Array3DFloat, min_score_indices: MinScoreIndices ) -> FilterCombinationResult: """ For a particular combination of filters, calculate the percentage of molecules that will be filtered, the percentage of top-scoring molecules that will be filtered, and the time taken relative to exhaustive docking @@ -173,15 +175,13 @@ def calculate_results_for_filter_combination( # Passing_molecule_indices is a list of indices of molecules which have passed the applied filters. As more filters are applied, it gets smaller. Before any iteration, we initialise with all molecules passing passing_molecule_indices = np.arange(num_molecules) filter_percentages = [] - number_of_simulated_poses: float | int = ( - 0 # Number of poses which we calculate would be generated, we use this to calculate the TIME column in the final output - ) + number_of_simulated_poses = 0 # Number of poses which we calculate would be generated, we use this to calculate the TIME column in the final output for i, threshold in enumerate(filter_combination): number_of_simulated_poses += self.calculate_simulated_poses_increment(i, passing_molecule_indices) - passing_indices = self.apply_threshold( - molecule_array, self.config.filters[i]["column"], self.config.filters[i]["steps"], threshold - ) + column: int = self.config.filters[i]["column"] + step: int = self.config.filters[i]["steps"] + passing_indices = self.apply_threshold(molecule_array, column, step, threshold) # All mols which pass the threshold and which were already in passing_molecule_indices, i.e. passed all previous filters passing_molecule_indices = np.intersect1d(passing_molecule_indices, passing_indices, assume_unique=True) filter_percentages.append(len(passing_molecule_indices) / num_molecules) @@ -191,7 +191,7 @@ def calculate_results_for_filter_combination( k: len(np.intersect1d(v, passing_molecule_indices, assume_unique=True)) / self.config.validation for k, v in min_score_indices.items() } - time = number_of_simulated_poses / np.prod(molecule_array.shape[:2]) + time = float(number_of_simulated_poses / np.prod(molecule_array.shape[:2])) result = FilterCombinationResult( combination=filter_combination, perc_val=perc_val, @@ -200,10 +200,10 @@ def calculate_results_for_filter_combination( ) return result - def calculate_simulated_poses_increment(self, index: int, passing_molecule_indices: np.ndarray) -> int: + def calculate_simulated_poses_increment(self, index: int, passing_molecule_indices: Array1DInt) -> int: if index: # e.g. if there are 5000 mols left after 15 steps and the last filter was at 5 steps, append 5000 * (15 - 5) to number_of_simulated_poses - increment = len(passing_molecule_indices) * ( + increment: int = len(passing_molecule_indices) * ( self.config.filters[index]["steps"] - self.config.filters[index - 1]["steps"] ) else: @@ -213,7 +213,7 @@ def calculate_simulated_poses_increment(self, index: int, passing_molecule_indic def write_output( self, results: list[FilterCombinationResult], - column_names: list[str], + column_names: ColumnNamesArray, sep: str = "\t", end: str = "\n", ) -> None: @@ -226,7 +226,7 @@ def write_output( content_lines = [sep.join(self._get_output_content(result, column_names)) + end for result in results] f.writelines(content_lines) - def _get_output_header(self, result: FilterCombinationResult, column_names: list[str]) -> list[str]: + def _get_output_header(self, result: FilterCombinationResult, column_names: ColumnNamesArray) -> list[str]: header = [] for i in range(len(result.combination)): header.extend([f"FILTER{i + 1}", f"NSTEPS{i + 1}", f"THR{i + 1}", f"PERC{i + 1}"]) @@ -239,7 +239,7 @@ def _get_output_header(self, result: FilterCombinationResult, column_names: list return header - def _get_output_content(self, result: FilterCombinationResult, column_names: list[str]) -> list[str]: + def _get_output_content(self, result: FilterCombinationResult, column_names: ColumnNamesArray) -> list[str]: content = [] for i, threshold in enumerate(result.combination): @@ -265,7 +265,7 @@ def get_best_filter_combination_index(self, results: list[FilterCombinationResul (= percentage of validation compounds / percentage of all compounds); we select the threshold with the highest enrichment factor """ - min_max_values = {} + min_max_values: MinMaxValues = {} # Transpose the `perc_val` data to get columns perc_vals = {col: [result.perc_val[col] for result in results] for col in results[0].perc_val} min_max_values.update({col: {"min": min(vals), "max": max(vals)} for col, vals in perc_vals.items()}) @@ -273,9 +273,9 @@ def get_best_filter_combination_index(self, results: list[FilterCombinationResul min_max_values["time"] = {"min": min(time_vals), "max": max(time_vals)} combination_scores = [self.calculate_combination_score(result, min_max_values) for result in results] index = np.argmax(combination_scores) - return index + return int(index) - def calculate_combination_score(self, result: FilterCombinationResult, min_max_values: dict[str, int]) -> float: + def calculate_combination_score(self, result: FilterCombinationResult, min_max_values: MinMaxValues) -> float: if result.time < self.config.max_time and result.percentages[-1] >= self.config.min_percentage / 100: col_scores = [ (result.perc_val[col] - min_max_values[col]["min"]) @@ -294,19 +294,19 @@ def calculate_combination_score(self, result: FilterCombinationResult, min_max_v def write_threshold( self, - best_filter_combination: np.ndarray, - column_names: list[str], + best_filter_combination: Array1DFloat, + column_names: ColumnNamesArray, max_number_of_runs: int, sep: str = "\n", end: str = "\n", ) -> None: - path: str = self.config.threshold + path = self.config.threshold or "default_threshold.txt" with open(path, "w") as f: content = self._get_threshold_content(best_filter_combination, column_names, max_number_of_runs) f.write(sep.join(content) + end) def _get_threshold_content( - self, best_filter_combination: np.ndarray, column_names: list[str], max_number_of_runs: int + self, best_filter_combination: Array1DFloat, column_names: ColumnNamesArray, max_number_of_runs: int ) -> list[str]: content = [] # Number of filters to apply @@ -332,13 +332,15 @@ def _get_threshold_content( content.extend(filter_min_values) return content - def generate_filters_combinations(self, filters: list[Filter]) -> list[tuple]: + def generate_filters_combinations(self, filters: list[Filter]) -> list[FilterCombination]: filter_ranges = ((filter["min"], filter["max"] + filter["interval"], filter["interval"]) for filter in filters) combinations = (np.arange(*range) for range in filter_ranges) filters_combinations = list(itertools.product(*combinations)) return filters_combinations - def remove_redundant_combinations(self, all_combinations: list[tuple], filters: list[Filter]) -> np.ndarray: + def remove_redundant_combinations( + self, all_combinations: list[FilterCombination], filters: list[Filter] + ) -> Array2DFloat: all_combinations_array = np.array(all_combinations) columns = [filter["column"] for filter in filters] indices_per_col = {col: [i for i, c in enumerate(columns) if c == col] for col in set(columns)} @@ -352,5 +354,5 @@ def remove_redundant_combinations(self, all_combinations: list[tuple], filters: is_unique = np.apply_along_axis(lambda x: len(set(x)) == len(x), 1, col_data) mask &= is_valid & is_unique - valid_combinations = all_combinations_array[mask] + valid_combinations: Array2DFloat = all_combinations_array[mask] return valid_combinations From a00732d61eeb1b1a776c4383b3561672bf932021 Mon Sep 17 00:00:00 2001 From: lpardey Date: Thu, 25 Jul 2024 17:06:18 +0000 Subject: [PATCH 14/18] refactor types add filter, minmax and minmaxvalues schemas mypy fix final refactor --- rdock-utils/rdock_utils/common/__init__.py | 4 - rdock-utils/rdock_utils/common/types.py | 6 +- rdock-utils/rdock_utils/rbhtfinder/models.py | 52 ---------- rdock-utils/rdock_utils/rbhtfinder/parser.py | 44 ++++++++- .../rdock_utils/rbhtfinder/rbhtfinder.py | 96 +++++++++---------- rdock-utils/rdock_utils/rbhtfinder/schemas.py | 37 +++++++ 6 files changed, 127 insertions(+), 112 deletions(-) delete mode 100644 rdock-utils/rdock_utils/rbhtfinder/models.py create mode 100644 rdock-utils/rdock_utils/rbhtfinder/schemas.py diff --git a/rdock-utils/rdock_utils/common/__init__.py b/rdock-utils/rdock_utils/common/__init__.py index 898780af..ca2b713b 100644 --- a/rdock-utils/rdock_utils/common/__init__.py +++ b/rdock-utils/rdock_utils/common/__init__.py @@ -15,8 +15,6 @@ FloatArray, InputData, Matrix3x3, - MinMaxDict, - MinMaxValues, MinScoreIndices, SDReportArray, SingularValueDecomposition, @@ -50,8 +48,6 @@ "FloatArray", "InputData", "Matrix3x3", - "MinMaxDict", - "MinMaxValues", "MinScoreIndices", "SDReportArray", "SingularValueDecomposition", diff --git a/rdock-utils/rdock_utils/common/types.py b/rdock-utils/rdock_utils/common/types.py index a1d0c688..2f00a294 100644 --- a/rdock-utils/rdock_utils/common/types.py +++ b/rdock-utils/rdock_utils/common/types.py @@ -4,6 +4,7 @@ import numpy.typing as npt # TODO: Review common types for all rdock_utils scripts +# SDRMSD types FloatArray = np.ndarray[Any, np.dtype[np.float64]] CoordsArray = np.ndarray[Any, np.dtype[np.float64]] AutomorphismRMSD = tuple[float, CoordsArray | None] @@ -13,18 +14,15 @@ Superpose3DResult = tuple[CoordsArray, float, Matrix3x3] AtomsMapping = tuple[tuple[int, int], ...] - # RBHTFinder types SDReportArray = np.ndarray[list[int | str | float], np.dtype[np.object_]] Array1DFloat = npt.NDArray[np.float_] -Array3DFloat = npt.NDArray[np.float_] Array2DFloat = npt.NDArray[np.float_] +Array3DFloat = npt.NDArray[np.float_] Array1DStr = npt.NDArray[np.str_] Array1DInt = npt.NDArray[np.int_] ColumnNamesArray = Array1DStr | list[str] InputData = tuple[SDReportArray, ColumnNamesArray] -MinMaxDict = dict[str, float] -MinMaxValues = dict[Any, MinMaxDict] MinScoreIndices = dict[int, Array1DInt] FilterCombination = tuple[float, float] diff --git a/rdock-utils/rdock_utils/rbhtfinder/models.py b/rdock-utils/rdock_utils/rbhtfinder/models.py deleted file mode 100644 index e3b8a67b..00000000 --- a/rdock-utils/rdock_utils/rbhtfinder/models.py +++ /dev/null @@ -1,52 +0,0 @@ -from dataclasses import dataclass -from typing import Any - -from rdock_utils.common import Array1DFloat - -Filter = dict[str, Any] # The type for the values is either 'float' or 'int'; 'Any' is used to comply with mypy - - -@dataclass -class RBHTFinderConfig: - input: str - output: str - threshold: str | None - name: int - filters: list[Filter] - validation: int - header: bool - max_time: float - min_percentage: float - - def __post_init__(self) -> None: - self.filters = self.get_parsed_filters() - - def get_parsed_filters(self) -> list[Filter]: - filter_args: list[str] = self.filters # type: ignore - parsed_filters = [self._parse_filter(filter) for filter in filter_args] - # sort filters by step at which they are applied - parsed_filters.sort(key=lambda n: n["steps"]) - return parsed_filters - - @staticmethod - def _parse_filter(filter_str: str) -> Filter: - parsed_filter = {} - - for item in filter_str.split(","): - key, value = item.split("=") - parsed_filter[key] = float(value) if key in ("interval", "min", "max") else int(value) - # User inputs with 1-based numbering whereas python uses 0-based - parsed_filter["column"] -= 1 - - if "interval" not in parsed_filter: - parsed_filter["interval"] = 1.0 - - return parsed_filter - - -@dataclass -class FilterCombinationResult: - combination: Array1DFloat - perc_val: dict[int, float] - percentages: list[float] - time: float diff --git a/rdock-utils/rdock_utils/rbhtfinder/parser.py b/rdock-utils/rdock_utils/rbhtfinder/parser.py index 41088314..23f914c8 100644 --- a/rdock-utils/rdock_utils/rbhtfinder/parser.py +++ b/rdock-utils/rdock_utils/rbhtfinder/parser.py @@ -1,6 +1,42 @@ import argparse +from dataclasses import dataclass -from .models import RBHTFinderConfig +from .schemas import Filter + + +@dataclass +class RBHTFinderConfig: + input: str + output: str + threshold: str | None + name: int + filters: list[Filter] + validation: int + header: bool + max_time: float + min_percentage: float + + def __post_init__(self) -> None: + self.filters = self.get_parsed_filters() + + def get_parsed_filters(self) -> list[Filter]: + filter_args: list[str] = self.filters # type: ignore + parsed_filters = [self._parse_filter(arg) for arg in filter_args] + # sort filters by step at which they are applied + parsed_filters.sort(key=lambda filter: filter.steps) + return parsed_filters + + @staticmethod + def _parse_filter(argument: str) -> Filter: + parsed_filter = Filter() + + for item in argument.split(","): + key, value = item.split("=") + setattr(parsed_filter, key, float(value) if key in ("interval", "min", "max") else int(value)) + # User inputs with 1-based numbering whereas python uses 0-based + parsed_filter.column -= 1 + parsed_filter.interval = parsed_filter.interval or 1.0 + return parsed_filter def get_parser() -> argparse.ArgumentParser: @@ -15,9 +51,9 @@ def get_parser() -> argparse.ArgumentParser: using the -f option, for example, "-f column=6,steps=5,min=0.5,max=1.0,interval=0.1". This example would simulate the effect of applying thresholds on column 6 after 5 poses have been generated, for values between 0.5 and 1.0 (i.e., 0.5, 0.6, 0.7, 0.8, 0.9, 1.0). - More than one threshold can be specified, e.g., "-f column=4,steps=5,min=-12,max=-10, - interval=1 column=4,steps=15,min=-16,max=-15,interval=1" will test the following - combinations of thresholds on column 4: + More than one threshold can be specified, e.g., + "-f column=4,steps=5,min=-12,max=-10,interval=1 column=4,steps=15,min=-16,max=-15,interval=1" + will test the following combinations of thresholds on column 4: 5 -10 15 -15 5 -11 15 -15 5 -12 15 -15 diff --git a/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py b/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py index 1be2a252..cc459352 100644 --- a/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py +++ b/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py @@ -16,12 +16,12 @@ ColumnNamesArray, FilterCombination, InputData, - MinMaxValues, MinScoreIndices, SDReportArray, ) -from .models import Filter, FilterCombinationResult, RBHTFinderConfig +from .parser import RBHTFinderConfig +from .schemas import Filter, FilterCombinationResult, MinMaxValues logger = logging.getLogger("RBHTFinder") @@ -47,24 +47,21 @@ def run(self) -> None: print("Data read in from input file.") # Convert to 3D array (molecules x poses x columns) molecule_array = self.prepare_array(sdreport_array, self.config.name) + results = self.process_filter_combinations(molecule_array, distinct_combinations) + self.write_output(results, column_names) + self.handle_threshold(results, distinct_combinations, column_names, molecule_array.shape[1]) + + def process_filter_combinations( + self, molecule_array: Array3DFloat, distinct_combinations: Array2DFloat + ) -> list[FilterCombinationResult]: # Find the top scoring compounds for validation of the filter combinations - columns = set(filter["column"] for filter in self.config.filters) + columns = set(filter.column for filter in self.config.filters) min_score_indices = { column: np.argpartition(np.min(molecule_array[:, :, column], axis=1), self.config.validation)[ : self.config.validation ] for column in columns } - results = self.process_filter_combinations(molecule_array, min_score_indices, distinct_combinations) - self.write_output(results, column_names) - best_filter_combination_index = self.get_best_filter_combination_index(results) - if self.config.threshold is not None: - num_poses = molecule_array.shape[1] - self.handle_threshold(best_filter_combination_index, distinct_combinations, column_names, num_poses) - - def process_filter_combinations( - self, molecule_array: Array3DFloat, min_score_indices: MinScoreIndices, distinct_combinations: Array2DFloat - ) -> list[FilterCombinationResult]: num_cpus = os.cpu_count() or 1 with multiprocessing.Pool(num_cpus) as pool: function_to_apply = partial( @@ -77,14 +74,19 @@ def process_filter_combinations( def handle_threshold( self, - combination_index: int, + filter_combinations: list[FilterCombinationResult], distinct_combinations: Array2DFloat, column_names: ColumnNamesArray, num_poses: int, ) -> None: - if combination_index: - best_filter_combination: Array1DFloat = distinct_combinations[combination_index] - self.write_threshold(best_filter_combination, column_names, num_poses) + threshold_file = self.config.threshold or "" + if not threshold_file: + return + + best_combination_index = self.get_best_filter_combination_index(filter_combinations) + if best_combination_index: + best_combination = distinct_combinations[best_combination_index] + self.write_threshold(best_combination, column_names, num_poses) else: message = "Filter combinations defined are too strict or would take too long to run; no threshold file was written." logger.warning(message) @@ -93,7 +95,7 @@ def read_data(self) -> InputData: try: data_array, column_names = self.read_data_using_pandas() except Exception as e: - logging.error(f"Error reading data with pandas: {e}") + logging.warning(f"Error reading data with pandas: {e}") data_array, column_names = self.read_data_using_numpy() return data_array, column_names @@ -168,25 +170,27 @@ def calculate_results_for_filter_combination( self, filter_combination: Array2DFloat, molecule_array: Array3DFloat, min_score_indices: MinScoreIndices ) -> FilterCombinationResult: """ - For a particular combination of filters, calculate the percentage of molecules that will be filtered, the percentage of top-scoring molecules that will be filtered, and the time taken relative to exhaustive docking + For a particular combination of filters, calculate the percentage of molecules that will be filtered, + the percentage of top-scoring molecules that will be filtered, and the time taken relative to exhaustive docking """ num_molecules = molecule_array.shape[0] num_steps = molecule_array.shape[1] - # Passing_molecule_indices is a list of indices of molecules which have passed the applied filters. As more filters are applied, it gets smaller. Before any iteration, we initialise with all molecules passing + # Passing_molecule_indices is a list of indices of molecules which have passed the applied filters. + # As more filters are applied, it gets smaller. Before any iteration, we initialise with all molecules passing passing_molecule_indices = np.arange(num_molecules) filter_percentages = [] number_of_simulated_poses = 0 # Number of poses which we calculate would be generated, we use this to calculate the TIME column in the final output for i, threshold in enumerate(filter_combination): number_of_simulated_poses += self.calculate_simulated_poses_increment(i, passing_molecule_indices) - column: int = self.config.filters[i]["column"] - step: int = self.config.filters[i]["steps"] + column = self.config.filters[i].column + step = self.config.filters[i].steps passing_indices = self.apply_threshold(molecule_array, column, step, threshold) # All mols which pass the threshold and which were already in passing_molecule_indices, i.e. passed all previous filters passing_molecule_indices = np.intersect1d(passing_molecule_indices, passing_indices, assume_unique=True) filter_percentages.append(len(passing_molecule_indices) / num_molecules) - number_of_simulated_poses += len(passing_molecule_indices) * (num_steps - self.config.filters[-1]["steps"]) + number_of_simulated_poses += len(passing_molecule_indices) * (num_steps - self.config.filters[-1].steps) perc_val = { k: len(np.intersect1d(v, passing_molecule_indices, assume_unique=True)) / self.config.validation for k, v in min_score_indices.items() @@ -203,19 +207,15 @@ def calculate_results_for_filter_combination( def calculate_simulated_poses_increment(self, index: int, passing_molecule_indices: Array1DInt) -> int: if index: # e.g. if there are 5000 mols left after 15 steps and the last filter was at 5 steps, append 5000 * (15 - 5) to number_of_simulated_poses - increment: int = len(passing_molecule_indices) * ( - self.config.filters[index]["steps"] - self.config.filters[index - 1]["steps"] + increment = len(passing_molecule_indices) * ( + self.config.filters[index].steps - self.config.filters[index - 1].steps ) else: - increment = len(passing_molecule_indices) * self.config.filters[index]["steps"] + increment = len(passing_molecule_indices) * self.config.filters[index].steps return increment def write_output( - self, - results: list[FilterCombinationResult], - column_names: ColumnNamesArray, - sep: str = "\t", - end: str = "\n", + self, results: list[FilterCombinationResult], column_names: ColumnNamesArray, sep: str = "\t", end: str = "\n" ) -> None: """ Print results as a table. The number of columns varies depending how many columns the user picked. @@ -243,8 +243,8 @@ def _get_output_content(self, result: FilterCombinationResult, column_names: Col content = [] for i, threshold in enumerate(result.combination): - column_name = column_names[self.config.filters[i]["column"]] - steps = self.config.filters[i]["steps"] + column_name = column_names[self.config.filters[i].column] + steps = self.config.filters[i].steps filter_percentage = result.percentages[i] * 100 content.extend([f"{column_name}", f"{steps}", f"{threshold:.2f}", f"{filter_percentage:.2f}"]) @@ -265,12 +265,13 @@ def get_best_filter_combination_index(self, results: list[FilterCombinationResul (= percentage of validation compounds / percentage of all compounds); we select the threshold with the highest enrichment factor """ - min_max_values: MinMaxValues = {} + min_max_values = MinMaxValues() # Transpose the `perc_val` data to get columns perc_vals = {col: [result.perc_val[col] for result in results] for col in results[0].perc_val} - min_max_values.update({col: {"min": min(vals), "max": max(vals)} for col, vals in perc_vals.items()}) + for col, vals in perc_vals.items(): + min_max_values.update(col, min(vals), max(vals)) time_vals = [result.time for result in results] - min_max_values["time"] = {"min": min(time_vals), "max": max(time_vals)} + min_max_values.update("time", min(time_vals), max(time_vals)) combination_scores = [self.calculate_combination_score(result, min_max_values) for result in results] index = np.argmax(combination_scores) return int(index) @@ -278,13 +279,13 @@ def get_best_filter_combination_index(self, results: list[FilterCombinationResul def calculate_combination_score(self, result: FilterCombinationResult, min_max_values: MinMaxValues) -> float: if result.time < self.config.max_time and result.percentages[-1] >= self.config.min_percentage / 100: col_scores = [ - (result.perc_val[col] - min_max_values[col]["min"]) - / (min_max_values[col]["max"] - min_max_values[col]["min"]) - for col in min_max_values - if col != "time" + (result.perc_val[col] - min_max_values.get(col).min) + / (min_max_values.get(col).max - min_max_values.get(col).min) + for col in min_max_values.values + if isinstance(col, int) ] - time_score = (min_max_values["time"]["max"] - result.time) / ( - min_max_values["time"]["max"] - min_max_values["time"]["min"] + time_score = (min_max_values.get("time").max - result.time) / ( + min_max_values.get("time").max - min_max_values.get("time").min ) score = sum(col_scores) + time_score else: @@ -313,8 +314,8 @@ def _get_threshold_content( content.append(f"{len(self.config.filters) + 1}") # Get each filter to a separate line filter_lines = [ - f'if - {best_filter_combination[i]:.2f} {column_names[filter["column"]]} 1.0 ' - f'if - SCORE.NRUNS {filter["steps"]} 0.0 -1.0,' + f"if - {best_filter_combination[i]:.2f} {column_names[filter.column]} 1.0 " + f"if - SCORE.NRUNS {filter.steps} 0.0 -1.0," for i, filter in enumerate(self.config.filters) ] content.extend(filter_lines) @@ -323,8 +324,7 @@ def _get_threshold_content( # Find strictest filters for all columns and apply them again filters_by_column = defaultdict(list) for i, filter in enumerate(self.config.filters): - col = filter["column"] - filters_by_column[col].append(best_filter_combination[i]) + filters_by_column[filter.column].append(best_filter_combination[i]) # Number of filters (same as number of columns filtered on) content.append(f"{len(filters_by_column)}") # Filter @@ -333,7 +333,7 @@ def _get_threshold_content( return content def generate_filters_combinations(self, filters: list[Filter]) -> list[FilterCombination]: - filter_ranges = ((filter["min"], filter["max"] + filter["interval"], filter["interval"]) for filter in filters) + filter_ranges = ((filter.min, filter.max + filter.interval, filter.interval) for filter in filters) combinations = (np.arange(*range) for range in filter_ranges) filters_combinations = list(itertools.product(*combinations)) return filters_combinations @@ -342,7 +342,7 @@ def remove_redundant_combinations( self, all_combinations: list[FilterCombination], filters: list[Filter] ) -> Array2DFloat: all_combinations_array = np.array(all_combinations) - columns = [filter["column"] for filter in filters] + columns = [filter.column for filter in filters] indices_per_col = {col: [i for i, c in enumerate(columns) if c == col] for col in set(columns)} # Create a mask to keep only valid combinations mask = np.ones(len(all_combinations_array), dtype=bool) diff --git a/rdock-utils/rdock_utils/rbhtfinder/schemas.py b/rdock-utils/rdock_utils/rbhtfinder/schemas.py new file mode 100644 index 00000000..43ce073d --- /dev/null +++ b/rdock-utils/rdock_utils/rbhtfinder/schemas.py @@ -0,0 +1,37 @@ +from dataclasses import dataclass, field + +from rdock_utils.common import Array1DFloat + + +@dataclass +class Filter: + column: int = 0 + steps: int = 0 + min: float = 0.0 + max: float = 0.0 + interval: float = 0.0 + + +@dataclass +class FilterCombinationResult: + combination: Array1DFloat + perc_val: dict[int, float] + percentages: list[float] + time: float + + +@dataclass +class MinMax: + min: float + max: float + + +@dataclass +class MinMaxValues: + values: dict[int | str, MinMax] = field(default_factory=dict) + + def update(self, column: int | str, min_val: float, max_val: float) -> None: + self.values[column] = MinMax(min=min_val, max=max_val) + + def get(self, column: int | str) -> MinMax: + return self.values[column] From 10bfb504279f24efb144eec3d97e8641bfa5ed78 Mon Sep 17 00:00:00 2001 From: lpardey Date: Thu, 25 Jul 2024 20:23:06 +0000 Subject: [PATCH 15/18] final refactor --- .../rdock_utils/rbhtfinder/rbhtfinder.py | 163 +++++++++--------- 1 file changed, 81 insertions(+), 82 deletions(-) diff --git a/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py b/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py index cc459352..b9f99df3 100644 --- a/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py +++ b/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py @@ -51,45 +51,30 @@ def run(self) -> None: self.write_output(results, column_names) self.handle_threshold(results, distinct_combinations, column_names, molecule_array.shape[1]) - def process_filter_combinations( - self, molecule_array: Array3DFloat, distinct_combinations: Array2DFloat - ) -> list[FilterCombinationResult]: - # Find the top scoring compounds for validation of the filter combinations - columns = set(filter.column for filter in self.config.filters) - min_score_indices = { - column: np.argpartition(np.min(molecule_array[:, :, column], axis=1), self.config.validation)[ - : self.config.validation - ] - for column in columns - } - num_cpus = os.cpu_count() or 1 - with multiprocessing.Pool(num_cpus) as pool: - function_to_apply = partial( - self.calculate_results_for_filter_combination, - molecule_array=molecule_array, - min_score_indices=min_score_indices, - ) - results = pool.map(function_to_apply, distinct_combinations) - return results + def generate_filters_combinations(self, filters: list[Filter]) -> list[FilterCombination]: + filter_ranges = ((filter.min, filter.max + filter.interval, filter.interval) for filter in filters) + combinations = (np.arange(*range) for range in filter_ranges) + filters_combinations = list(itertools.product(*combinations)) + return filters_combinations - def handle_threshold( - self, - filter_combinations: list[FilterCombinationResult], - distinct_combinations: Array2DFloat, - column_names: ColumnNamesArray, - num_poses: int, - ) -> None: - threshold_file = self.config.threshold or "" - if not threshold_file: - return + def remove_redundant_combinations( + self, all_combinations: list[FilterCombination], filters: list[Filter] + ) -> Array2DFloat: + all_combinations_array = np.array(all_combinations) + columns = [filter.column for filter in filters] + indices_per_col = {col: [i for i, c in enumerate(columns) if c == col] for col in set(columns)} + # Create a mask to keep only valid combinations + mask = np.ones(len(all_combinations_array), dtype=bool) - best_combination_index = self.get_best_filter_combination_index(filter_combinations) - if best_combination_index: - best_combination = distinct_combinations[best_combination_index] - self.write_threshold(best_combination, column_names, num_poses) - else: - message = "Filter combinations defined are too strict or would take too long to run; no threshold file was written." - logger.warning(message) + for indices in indices_per_col.values(): + col_data = all_combinations_array[:, indices] + sorted_data = np.sort(col_data, axis=1)[:, ::-1] # Sort descending + is_valid = np.all(col_data == sorted_data, axis=1) # Check if sorted matches original + is_unique = np.apply_along_axis(lambda x: len(set(x)) == len(x), 1, col_data) + mask &= is_valid & is_unique + + valid_combinations: Array2DFloat = all_combinations_array[mask] + return valid_combinations def read_data(self) -> InputData: try: @@ -123,16 +108,6 @@ def read_data_using_numpy(self) -> InputData: return sdreport_array, column_names - def apply_threshold(self, scored_poses: Array3DFloat, column: int, steps: int, threshold: float) -> Array1DInt: - """ - Filter out molecules from `scored_poses`, where the minimum score reached (for a specified `column`) after `steps` is more negative than `threshold`. - """ - # Minimum score after `steps` per molecule - mins = np.min(scored_poses[:, :steps, column], axis=1) - # Return those molecules where the minimum score is less than the threshold - passing_molecules = np.where(mins < threshold)[0] - return passing_molecules - def prepare_array(self, data_array: SDReportArray, name_column: int) -> Array3DFloat: """ Convert `sdreport_array` (read directly from the tsv) to 3D array (molecules x poses x columns) and filter out molecules with too few/many poses @@ -166,6 +141,27 @@ def prepare_array(self, data_array: SDReportArray, name_column: int) -> Array3DF final_array = molecule_3d_array.astype(float) return final_array + def process_filter_combinations( + self, molecule_array: Array3DFloat, distinct_combinations: Array2DFloat + ) -> list[FilterCombinationResult]: + # Find the top scoring compounds for validation of the filter combinations + columns = set(filter.column for filter in self.config.filters) + min_score_indices = { + column: np.argpartition(np.min(molecule_array[:, :, column], axis=1), self.config.validation)[ + : self.config.validation + ] + for column in columns + } + num_cpus = os.cpu_count() or 1 + with multiprocessing.Pool(num_cpus) as pool: + process_combination = partial( + self.calculate_results_for_filter_combination, + molecule_array=molecule_array, + min_score_indices=min_score_indices, + ) + results = pool.map(process_combination, distinct_combinations) + return results + def calculate_results_for_filter_combination( self, filter_combination: Array2DFloat, molecule_array: Array3DFloat, min_score_indices: MinScoreIndices ) -> FilterCombinationResult: @@ -214,6 +210,16 @@ def calculate_simulated_poses_increment(self, index: int, passing_molecule_indic increment = len(passing_molecule_indices) * self.config.filters[index].steps return increment + def apply_threshold(self, scored_poses: Array3DFloat, column: int, steps: int, threshold: float) -> Array1DInt: + """ + Filter out molecules from `scored_poses`, where the minimum score reached (for a specified `column`) after `steps` is more negative than `threshold`. + """ + # Minimum score after `steps` per molecule + mins = np.min(scored_poses[:, :steps, column], axis=1) + # Return those molecules where the minimum score is less than the threshold + passing_molecules = np.where(mins < threshold)[0] + return passing_molecules + def write_output( self, results: list[FilterCombinationResult], column_names: ColumnNamesArray, sep: str = "\t", end: str = "\n" ) -> None: @@ -221,25 +227,25 @@ def write_output( Print results as a table. The number of columns varies depending how many columns the user picked. """ with open(self.config.output, "w") as f: - header = self._get_output_header(results[0], column_names) + header = self.get_output_header(results[0], column_names) f.write(sep.join(header) + end) - content_lines = [sep.join(self._get_output_content(result, column_names)) + end for result in results] + content_lines = [sep.join(self.get_output_content(result, column_names)) + end for result in results] f.writelines(content_lines) - def _get_output_header(self, result: FilterCombinationResult, column_names: ColumnNamesArray) -> list[str]: + def get_output_header(self, result: FilterCombinationResult, column_names: ColumnNamesArray) -> list[str]: header = [] for i in range(len(result.combination)): header.extend([f"FILTER{i + 1}", f"NSTEPS{i + 1}", f"THR{i + 1}", f"PERC{i + 1}"]) for col_index in result.perc_val.keys(): - header.append(f"TOP{self.config.validation}_{column_names[col_index]}") - header.append(f"ENRICH_{column_names[col_index]}") + column_name = column_names[col_index] + header.extend([f"TOP{self.config.validation}_{column_name}", f"ENRICH_{column_name}"]) header.append("TIME") return header - def _get_output_content(self, result: FilterCombinationResult, column_names: ColumnNamesArray) -> list[str]: + def get_output_content(self, result: FilterCombinationResult, column_names: ColumnNamesArray) -> list[str]: content = [] for i, threshold in enumerate(result.combination): @@ -251,13 +257,31 @@ def _get_output_content(self, result: FilterCombinationResult, column_names: Col for value in result.perc_val.values(): perc_val_percent = value * 100 enrichment = value / result.percentages[-1] if result.percentages[-1] else float("nan") - content.append(f"{perc_val_percent:.2f}") - content.append(f"{enrichment:.2f}") + content.extend([f"{perc_val_percent:.2f}", f"{enrichment:.2f}"]) content.append(f"{result.time:.4f}") return content + def handle_threshold( + self, + filter_combinations: list[FilterCombinationResult], + distinct_combinations: Array2DFloat, + column_names: ColumnNamesArray, + num_poses: int, + ) -> None: + threshold_file = self.config.threshold or "" + if not threshold_file: + return + + best_combination_index = self.get_best_filter_combination_index(filter_combinations) + if best_combination_index: + best_combination = distinct_combinations[best_combination_index] + self.write_threshold(best_combination, column_names, num_poses) + else: + message = "Filter combinations defined are too strict or would take too long to run; no threshold file was written." + logger.warning(message) + def get_best_filter_combination_index(self, results: list[FilterCombinationResult]) -> int: """ Very debatable how to do this... @@ -303,10 +327,10 @@ def write_threshold( ) -> None: path = self.config.threshold or "default_threshold.txt" with open(path, "w") as f: - content = self._get_threshold_content(best_filter_combination, column_names, max_number_of_runs) + content = self.get_threshold_content(best_filter_combination, column_names, max_number_of_runs) f.write(sep.join(content) + end) - def _get_threshold_content( + def get_threshold_content( self, best_filter_combination: Array1DFloat, column_names: ColumnNamesArray, max_number_of_runs: int ) -> list[str]: content = [] @@ -331,28 +355,3 @@ def _get_threshold_content( filter_min_values = [f"- {column_names[col]} {min(values)}," for col, values in filters_by_column.items()] content.extend(filter_min_values) return content - - def generate_filters_combinations(self, filters: list[Filter]) -> list[FilterCombination]: - filter_ranges = ((filter.min, filter.max + filter.interval, filter.interval) for filter in filters) - combinations = (np.arange(*range) for range in filter_ranges) - filters_combinations = list(itertools.product(*combinations)) - return filters_combinations - - def remove_redundant_combinations( - self, all_combinations: list[FilterCombination], filters: list[Filter] - ) -> Array2DFloat: - all_combinations_array = np.array(all_combinations) - columns = [filter.column for filter in filters] - indices_per_col = {col: [i for i, c in enumerate(columns) if c == col] for col in set(columns)} - # Create a mask to keep only valid combinations - mask = np.ones(len(all_combinations_array), dtype=bool) - - for indices in indices_per_col.values(): - col_data = all_combinations_array[:, indices] - sorted_data = np.sort(col_data, axis=1)[:, ::-1] # Sort descending - is_valid = np.all(col_data == sorted_data, axis=1) # Check if sorted matches original - is_unique = np.apply_along_axis(lambda x: len(set(x)) == len(x), 1, col_data) - mask &= is_valid & is_unique - - valid_combinations: Array2DFloat = all_combinations_array[mask] - return valid_combinations From 1d3295c62d9a46f65019cffdd587cdb45f017f5e Mon Sep 17 00:00:00 2001 From: lpardey Date: Thu, 25 Jul 2024 20:30:18 +0000 Subject: [PATCH 16/18] add project script to toml --- rdock-utils/pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/rdock-utils/pyproject.toml b/rdock-utils/pyproject.toml index e18138f7..01a6cb42 100644 --- a/rdock-utils/pyproject.toml +++ b/rdock-utils/pyproject.toml @@ -10,6 +10,8 @@ dependencies = { file = ["requirements.txt"] } optional-dependencies = { dev = { file = ["requirements-dev.txt"] } } [project.scripts] +rbhtfinder = "rdock_utils.rbhtfinder:main" +rbhtfinder_old = "rdock_utils.rbhtfinder_original_copy:main" sdfield = "rdock_utils.sdfield:main" sdrmsd_old = "rdock_utils.sdrmsd_original:main" sdrmsd = "rdock_utils.sdrmsd.main:main" From 63be342e0c9672ad351e92a87459af3483aef287 Mon Sep 17 00:00:00 2001 From: lpardey Date: Thu, 25 Jul 2024 21:30:45 +0000 Subject: [PATCH 17/18] exclude original rbhtfinder code from mypy check --- rdock-utils/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/rdock-utils/pyproject.toml b/rdock-utils/pyproject.toml index 01a6cb42..fe3b7f11 100644 --- a/rdock-utils/pyproject.toml +++ b/rdock-utils/pyproject.toml @@ -79,6 +79,7 @@ exclude = [ "rdock_utils/sdrmsd_original.py", "tests/", "rdock_utils/sdtether_original.py", + "rdock_utils/rbhtfinder_original_copy.py", ] plugins = "numpy.typing.mypy_plugin" From 5703cbe7d16354a6c9eb171273d982ecbfe7cb11 Mon Sep 17 00:00:00 2001 From: lpardey Date: Wed, 31 Jul 2024 22:56:41 +0000 Subject: [PATCH 18/18] fix pandas import exception add cpu_count argument to parser update pandas and numpy imports --- rdock-utils/rdock_utils/common/types.py | 34 +++++----- rdock-utils/rdock_utils/rbhtfinder/parser.py | 4 ++ .../rdock_utils/rbhtfinder/rbhtfinder.py | 64 ++++++++++--------- 3 files changed, 55 insertions(+), 47 deletions(-) diff --git a/rdock-utils/rdock_utils/common/types.py b/rdock-utils/rdock_utils/common/types.py index 2f00a294..392df4b7 100644 --- a/rdock-utils/rdock_utils/common/types.py +++ b/rdock-utils/rdock_utils/common/types.py @@ -1,26 +1,26 @@ from typing import Any -import numpy as np -import numpy.typing as npt +import numpy +import numpy.typing # TODO: Review common types for all rdock_utils scripts # SDRMSD types -FloatArray = np.ndarray[Any, np.dtype[np.float64]] -CoordsArray = np.ndarray[Any, np.dtype[np.float64]] +FloatArray = numpy.ndarray[Any, numpy.dtype[numpy.float64]] +CoordsArray = numpy.ndarray[Any, numpy.dtype[numpy.float64]] AutomorphismRMSD = tuple[float, CoordsArray | None] -Vector3D = np.ndarray[Any, np.dtype[np.float64]] -Matrix3x3 = np.ndarray[Any, np.dtype[np.float64]] +Vector3D = numpy.ndarray[Any, numpy.dtype[numpy.float64]] +Matrix3x3 = numpy.ndarray[Any, numpy.dtype[numpy.float64]] SingularValueDecomposition = tuple[Matrix3x3, Vector3D, Matrix3x3] Superpose3DResult = tuple[CoordsArray, float, Matrix3x3] AtomsMapping = tuple[tuple[int, int], ...] # RBHTFinder types -SDReportArray = np.ndarray[list[int | str | float], np.dtype[np.object_]] -Array1DFloat = npt.NDArray[np.float_] -Array2DFloat = npt.NDArray[np.float_] -Array3DFloat = npt.NDArray[np.float_] -Array1DStr = npt.NDArray[np.str_] -Array1DInt = npt.NDArray[np.int_] +SDReportArray = numpy.ndarray[list[int | str | float], numpy.dtype[numpy.object_]] +Array1DFloat = numpy.typing.NDArray[numpy.float_] +Array2DFloat = numpy.typing.NDArray[numpy.float_] +Array3DFloat = numpy.typing.NDArray[numpy.float_] +Array1DStr = numpy.typing.NDArray[numpy.str_] +Array1DInt = numpy.typing.NDArray[numpy.int_] ColumnNamesArray = Array1DStr | list[str] InputData = tuple[SDReportArray, ColumnNamesArray] MinScoreIndices = dict[int, Array1DInt] @@ -28,11 +28,11 @@ ## Shape support for type hinting is not yet avaialable in np ## let's keep this as a guide for np 2.0 release -# FloatArray = np.ndarray[Literal["N"], np.dtype[float]] -# BoolArray = np.ndarray[Literal["N"], np.dtype[bool]] -# CoordsArray = np.ndarray[Literal["N", 3], np.dtype[float]] +# FloatArray = numpy.ndarray[Literal["N"], numpy.dtype[float]] +# BoolArray = numpy.ndarray[Literal["N"], numpy.dtype[bool]] +# CoordsArray = numpy.ndarray[Literal["N", 3], numpy.dtype[float]] # AutomorphismRMSD = tuple[float, CoordsArray | None] -# Vector3D = np.ndarray[Literal[3], np.dtype[float]] -# Matrix3x3 = np.ndarray[Literal[3, 3], np.dtype[float]] +# Vector3D = numpy.ndarray[Literal[3], numpy.dtype[float]] +# Matrix3x3 = numpy.ndarray[Literal[3, 3], numpy.dtype[float]] # SingularValueDecomposition = tuple[Matrix3x3, Vector3D, Matrix3x3] # Superpose3DResult = tuple[CoordsArray, float, Matrix3x3] diff --git a/rdock-utils/rdock_utils/rbhtfinder/parser.py b/rdock-utils/rdock_utils/rbhtfinder/parser.py index 23f914c8..b0a9c5a9 100644 --- a/rdock-utils/rdock_utils/rbhtfinder/parser.py +++ b/rdock-utils/rdock_utils/rbhtfinder/parser.py @@ -15,6 +15,7 @@ class RBHTFinderConfig: header: bool max_time: float min_percentage: float + cpu_count: int def __post_init__(self) -> None: self.filters = self.get_parsed_filters() @@ -111,6 +112,7 @@ def get_parser() -> argparse.ArgumentParser: header_help = "Specify if the input file from sdreport contains a header line with column names. If not, output files will describe columns using indices, e.g. COL4, COL5." max_time_help = "Maximum value for time to use when autogenerating a high-throughput protocol - default is 0.1, i.e. 10%% of the time exhaustive docking would take." min_perc_help = "Minimum value for the estimated final percentage of compounds to use when autogenerating a high-throughput protocol - default is 1." + cpu_count_help = "Specify the number of CPU cores to use for multiprocessing. Defaults to '1' if not provided." parser = argparse.ArgumentParser(description=description, formatter_class=argparse.RawTextHelpFormatter) parser.add_argument("-i", "--input", help=input_help, type=str, required=True) @@ -119,6 +121,7 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument("-n", "--name", type=int, default=1, help=name_help) parser.add_argument("-f", "--filters", nargs="+", type=str, help=filter_help, required=True) # Review 'required' parser.add_argument("-v", "--validation", type=int, default=500, help=validation_help) + parser.add_argument("-c", "--cpu-count", type=int, default=1, help=cpu_count_help) parser.add_argument("--header", action="store_true", help=header_help) parser.add_argument("--max-time", type=float, default=0.1, help=max_time_help) parser.add_argument("--min-perc", type=float, default=1.0, help=min_perc_help) @@ -138,4 +141,5 @@ def get_config(argv: list[str] | None = None) -> RBHTFinderConfig: header=args.header, max_time=args.max_time, min_percentage=args.min_perc, + cpu_count=args.cpu_count, ) diff --git a/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py b/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py index b9f99df3..a3b37951 100644 --- a/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py +++ b/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py @@ -1,12 +1,10 @@ import itertools import logging import multiprocessing -import os from collections import Counter, defaultdict from functools import partial -import numpy as np -import pandas as pd +import numpy from rdock_utils.common import ( Array1DFloat, @@ -23,6 +21,13 @@ from .parser import RBHTFinderConfig from .schemas import Filter, FilterCombinationResult, MinMaxValues +try: + import pandas +except ImportError: + PANDAS_IS_AVAILABLE = False +else: + PANDAS_IS_AVAILABLE = True + logger = logging.getLogger("RBHTFinder") @@ -53,39 +58,39 @@ def run(self) -> None: def generate_filters_combinations(self, filters: list[Filter]) -> list[FilterCombination]: filter_ranges = ((filter.min, filter.max + filter.interval, filter.interval) for filter in filters) - combinations = (np.arange(*range) for range in filter_ranges) + combinations = (numpy.arange(*range) for range in filter_ranges) filters_combinations = list(itertools.product(*combinations)) return filters_combinations def remove_redundant_combinations( self, all_combinations: list[FilterCombination], filters: list[Filter] ) -> Array2DFloat: - all_combinations_array = np.array(all_combinations) + all_combinations_array = numpy.array(all_combinations) columns = [filter.column for filter in filters] indices_per_col = {col: [i for i, c in enumerate(columns) if c == col] for col in set(columns)} # Create a mask to keep only valid combinations - mask = np.ones(len(all_combinations_array), dtype=bool) + mask = numpy.ones(len(all_combinations_array), dtype=bool) for indices in indices_per_col.values(): col_data = all_combinations_array[:, indices] - sorted_data = np.sort(col_data, axis=1)[:, ::-1] # Sort descending - is_valid = np.all(col_data == sorted_data, axis=1) # Check if sorted matches original - is_unique = np.apply_along_axis(lambda x: len(set(x)) == len(x), 1, col_data) + sorted_data = numpy.sort(col_data, axis=1)[:, ::-1] # Sort descending + is_valid = numpy.all(col_data == sorted_data, axis=1) # Check if sorted matches original + is_unique = numpy.apply_along_axis(lambda x: len(set(x)) == len(x), 1, col_data) mask &= is_valid & is_unique valid_combinations: Array2DFloat = all_combinations_array[mask] return valid_combinations def read_data(self) -> InputData: - try: + if PANDAS_IS_AVAILABLE: data_array, column_names = self.read_data_using_pandas() - except Exception as e: - logging.warning(f"Error reading data with pandas: {e}") + else: + logging.warning("Pandas is not available to read the data") data_array, column_names = self.read_data_using_numpy() return data_array, column_names def read_data_using_pandas(self) -> InputData: - sdreport_dataframe = pd.read_csv(self.config.input, sep="\t", header=0 if self.config.header else None) + sdreport_dataframe = pandas.read_csv(self.config.input, sep="\t", header=0 if self.config.header else None) if self.config.header: column_names = sdreport_dataframe.columns.values @@ -97,7 +102,7 @@ def read_data_using_pandas(self) -> InputData: return sdreport_array, column_names def read_data_using_numpy(self) -> InputData: - np_array = np.loadtxt(self.config.input, dtype=str) + np_array = numpy.loadtxt(self.config.input, dtype=str) if self.config.header: column_names = np_array[0] @@ -113,20 +118,20 @@ def prepare_array(self, data_array: SDReportArray, name_column: int) -> Array3DF Convert `sdreport_array` (read directly from the tsv) to 3D array (molecules x poses x columns) and filter out molecules with too few/many poses """ split_indices = ( - np.where( - data_array[:, name_column] != np.hstack((data_array[1:, name_column], data_array[0, name_column])) + numpy.where( + data_array[:, name_column] != numpy.hstack((data_array[1:, name_column], data_array[0, name_column])) )[0] + 1 ) - split_array = np.split(data_array, split_indices) + split_array = numpy.split(data_array, split_indices) modal_shape = Counter([array.shape for array in split_array]).most_common(1)[0] number_of_poses = modal_shape[0][0] # Find modal number of poses per molecule in the array valid_split_arrays = [ - np.array_split(array, array.shape[0] / number_of_poses) # type: ignore + numpy.array_split(array, array.shape[0] / number_of_poses) # type: ignore for array in split_array if not array.shape[0] % number_of_poses and array.shape[0] ] - flattened_split_array = np.concatenate(valid_split_arrays) + flattened_split_array = numpy.concatenate(valid_split_arrays) if len(flattened_split_array) * number_of_poses < data_array.shape[0] * 0.99: message = ( @@ -135,7 +140,7 @@ def prepare_array(self, data_array: SDReportArray, name_column: int) -> Array3DF ) logger.warning(message) - molecule_3d_array = np.array(flattened_split_array) + molecule_3d_array = numpy.array(flattened_split_array) # Overwrite the name column (should be the only one with dtype=str) so we can force everything to float molecule_3d_array[:, :, name_column] = 0 final_array = molecule_3d_array.astype(float) @@ -147,13 +152,12 @@ def process_filter_combinations( # Find the top scoring compounds for validation of the filter combinations columns = set(filter.column for filter in self.config.filters) min_score_indices = { - column: np.argpartition(np.min(molecule_array[:, :, column], axis=1), self.config.validation)[ + column: numpy.argpartition(numpy.min(molecule_array[:, :, column], axis=1), self.config.validation)[ : self.config.validation ] for column in columns } - num_cpus = os.cpu_count() or 1 - with multiprocessing.Pool(num_cpus) as pool: + with multiprocessing.Pool(self.config.cpu_count) as pool: process_combination = partial( self.calculate_results_for_filter_combination, molecule_array=molecule_array, @@ -173,7 +177,7 @@ def calculate_results_for_filter_combination( num_steps = molecule_array.shape[1] # Passing_molecule_indices is a list of indices of molecules which have passed the applied filters. # As more filters are applied, it gets smaller. Before any iteration, we initialise with all molecules passing - passing_molecule_indices = np.arange(num_molecules) + passing_molecule_indices = numpy.arange(num_molecules) filter_percentages = [] number_of_simulated_poses = 0 # Number of poses which we calculate would be generated, we use this to calculate the TIME column in the final output @@ -183,15 +187,15 @@ def calculate_results_for_filter_combination( step = self.config.filters[i].steps passing_indices = self.apply_threshold(molecule_array, column, step, threshold) # All mols which pass the threshold and which were already in passing_molecule_indices, i.e. passed all previous filters - passing_molecule_indices = np.intersect1d(passing_molecule_indices, passing_indices, assume_unique=True) + passing_molecule_indices = numpy.intersect1d(passing_molecule_indices, passing_indices, assume_unique=True) filter_percentages.append(len(passing_molecule_indices) / num_molecules) number_of_simulated_poses += len(passing_molecule_indices) * (num_steps - self.config.filters[-1].steps) perc_val = { - k: len(np.intersect1d(v, passing_molecule_indices, assume_unique=True)) / self.config.validation + k: len(numpy.intersect1d(v, passing_molecule_indices, assume_unique=True)) / self.config.validation for k, v in min_score_indices.items() } - time = float(number_of_simulated_poses / np.prod(molecule_array.shape[:2])) + time = float(number_of_simulated_poses / numpy.prod(molecule_array.shape[:2])) result = FilterCombinationResult( combination=filter_combination, perc_val=perc_val, @@ -215,9 +219,9 @@ def apply_threshold(self, scored_poses: Array3DFloat, column: int, steps: int, t Filter out molecules from `scored_poses`, where the minimum score reached (for a specified `column`) after `steps` is more negative than `threshold`. """ # Minimum score after `steps` per molecule - mins = np.min(scored_poses[:, :steps, column], axis=1) + mins = numpy.min(scored_poses[:, :steps, column], axis=1) # Return those molecules where the minimum score is less than the threshold - passing_molecules = np.where(mins < threshold)[0] + passing_molecules = numpy.where(mins < threshold)[0] return passing_molecules def write_output( @@ -297,7 +301,7 @@ def get_best_filter_combination_index(self, results: list[FilterCombinationResul time_vals = [result.time for result in results] min_max_values.update("time", min(time_vals), max(time_vals)) combination_scores = [self.calculate_combination_score(result, min_max_values) for result in results] - index = np.argmax(combination_scores) + index = numpy.argmax(combination_scores) return int(index) def calculate_combination_score(self, result: FilterCombinationResult, min_max_values: MinMaxValues) -> float: