Skip to content
103 changes: 44 additions & 59 deletions validmind/tests/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,9 @@ def _get_test_kwargs(
def build_test_result(
outputs: Union[Any, Tuple[Any, ...]],
test_id: str,
test_doc: str,
inputs: Dict[str, Union[VMInput, List[VMInput]]],
params: Union[Dict[str, Any], None],
doc: str,
description: str,
generate_description: bool = True,
title: Optional[str] = None,
):
"""Build a TestResult object from a set of raw test function outputs"""
Expand All @@ -150,7 +148,7 @@ def build_test_result(
ref_id=ref_id,
inputs=inputs,
params=params if params else None, # None if empty dict or None
doc=doc,
doc=test_doc,
)

if not isinstance(outputs, tuple):
Expand All @@ -159,16 +157,6 @@ def build_test_result(
for item in outputs:
process_output(item, result)

result.description = get_result_description(
test_id=test_id,
test_description=description,
tables=result.tables,
figures=result.figures,
metric=result.metric,
should_generate=generate_description,
title=title,
)

return result


Expand All @@ -179,7 +167,6 @@ def _run_composite_test(
input_grid: Union[Dict[str, List[Any]], List[Dict[str, Any]], None],
params: Union[Dict[str, Any], None],
param_grid: Union[Dict[str, List[Any]], List[Dict[str, Any]], None],
generate_description: bool,
title: Optional[str] = None,
):
"""Run a composite test i.e. a test made up of multiple metrics"""
Expand All @@ -201,9 +188,12 @@ def _run_composite_test(
if not all(result.metric is not None for result in results):
raise ValueError("All tests must return a metric when used as a composite test")

# Create composite doc from all test results
# Create composite docstring from all test results
composite_doc = "\n\n".join(
[f"{test_id_to_name(result.result_id)}:\n{result.doc}" for result in results]
[
f"{test_id_to_name(result.result_id)}:\n{_test_description(result.doc)}"
for result in results
]
)

return build_test_result(
Expand All @@ -215,13 +205,9 @@ def _run_composite_test(
for result in results
], # pass in a single table with metric values as our 'outputs'
test_id=test_id,
test_doc=composite_doc,
inputs=results[0].inputs,
params=results[0].params,
doc=composite_doc,
description="\n\n".join(
[_test_description(result.description, num_lines=1) for result in results]
), # join truncated (first line only) test descriptions
generate_description=generate_description,
title=title,
)

Expand All @@ -234,7 +220,6 @@ def _run_comparison_test(
input_grid: Union[Dict[str, List[Any]], List[Dict[str, Any]], None],
params: Union[Dict[str, Any], None],
param_grid: Union[Dict[str, List[Any]], List[Dict[str, Any]], None],
generate_description: bool,
title: Optional[str] = None,
):
"""Run a comparison test i.e. a test that compares multiple outputs of a test across
Expand Down Expand Up @@ -263,35 +248,43 @@ def _run_comparison_test(
# composite tests have a test_id thats built from the name
if not test_id:
test_id = results[0].result_id
description = results[0].description
test_doc = results[0].doc
else:
description = describe_test(test_id, raw=True)["Description"]
test_doc = describe_test(test_id, raw=True)["Description"]

combined_outputs, combined_inputs, combined_params = combine_results(results)

if unit_metrics:
doc = "\n\n".join(
[
f"{test_id_to_name(unit_metric)}:\n{getdoc(load_test(unit_metric))}"
for unit_metric in unit_metrics
]
)
else:
doc = getdoc(load_test(test_id))

return build_test_result(
outputs=tuple(combined_outputs),
test_id=test_id,
test_doc=test_doc,
inputs=combined_inputs,
params=combined_params,
doc=doc,
description=description,
generate_description=generate_description,
title=title,
)


def run_test(
def _run_test(test_id: TestID, inputs: Dict[str, Any], params: Dict[str, Any]):
"""Run a standard test and return a TestResult object"""
test_func = load_test(test_id)
input_kwargs, param_kwargs = _get_test_kwargs(
test_func=test_func,
inputs=inputs or {},
params=params or {},
)

raw_result = test_func(**input_kwargs, **param_kwargs)

return build_test_result(
outputs=raw_result,
test_id=test_id,
test_doc=getdoc(test_func),
inputs=input_kwargs,
params=param_kwargs,
)


def run_test( # noqa: C901
test_id: Union[TestID, None] = None,
name: Union[str, None] = None,
unit_metrics: Union[List[TestID], None] = None,
Expand Down Expand Up @@ -394,33 +387,25 @@ def run_test(
)

else:
test_func = load_test(test_id)

input_kwargs, param_kwargs = _get_test_kwargs(
test_func, inputs or {}, params or {}
)

raw_result = test_func(**input_kwargs, **param_kwargs)

doc = getdoc(test_func)

result = build_test_result(
outputs=raw_result,
test_id=test_id,
inputs=input_kwargs,
params=param_kwargs,
doc=doc,
description=doc,
generate_description=generate_description,
title=title,
)
result = _run_test(test_id, inputs, params)

end_time = time.perf_counter()
result.metadata = _get_run_metadata(duration_seconds=end_time - start_time)

if post_process_fn:
result = post_process_fn(result)

if generate_description:
result.description = get_result_description(
test_id=test_id,
test_description=result.doc,
tables=result.tables,
figures=result.figures,
metric=result.metric,
should_generate=generate_description,
title=title,
)

if show:
result.show()

Expand Down
Loading