From 011673ebb48312309d43c3c41c85667feca85413 Mon Sep 17 00:00:00 2001 From: Juan Date: Fri, 7 Feb 2025 22:56:27 +0100 Subject: [PATCH] Handle post-processing in comparison tests --- validmind/tests/run.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/validmind/tests/run.py b/validmind/tests/run.py index 66dd40e7d..a438abe39 100644 --- a/validmind/tests/run.py +++ b/validmind/tests/run.py @@ -222,6 +222,7 @@ def _run_comparison_test( params: Union[Dict[str, Any], None], param_grid: Union[Dict[str, List[Any]], List[Dict[str, Any]], None], title: Optional[str] = None, + post_process_fn: Optional[Callable[[TestResult], None]] = None, ): """Run a comparison test i.e. a test that compares multiple outputs of a test across different input and/or param combinations""" @@ -232,8 +233,9 @@ def _run_comparison_test( params=params, ) - results = [ - run_test( + results = [] + for config in run_test_configs: + result = run_test( test_id=test_id, name=name, unit_metrics=unit_metrics, @@ -243,8 +245,9 @@ def _run_comparison_test( generate_description=False, title=title, ) - for config in run_test_configs - ] + if post_process_fn: + result = post_process_fn(result) + results.append(result) # composite tests have a test_id thats built from the name if not test_id: @@ -358,7 +361,9 @@ def run_test( # noqa: C901 input_grid=input_grid, params=params, param_grid=param_grid, + post_process_fn=post_process_fn, ) + post_process_fn = None elif unit_metrics: name = "".join(word.capitalize() for word in name.split())