Skip to content
Merged
5 changes: 3 additions & 2 deletions scripts/run_e2e_notebooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
"notebooks/how_to/run_unit_metrics.ipynb",
"notebooks/code_samples/custom_tests/integrate_external_test_providers.ipynb",
"notebooks/code_samples/custom_tests/implement_custom_tests.ipynb",
"notebooks/how_to/explore_tests.ipynb",
]

DATA_TEMPLATE_NOTEBOOKS = [
Expand All @@ -66,12 +67,12 @@
{
# [Demo] Hugging Face - Text Sentiment Analysis
"path": "notebooks/code_samples/nlp_and_llm/hugging_face_summarization_demo.ipynb",
"model": "cm4lr52qo00bc0jpbm0vmxxhy"
"model": "cm4lr52qo00bc0jpbm0vmxxhy",
},
{
# [Demo] Customer Churn Model
"path": "notebooks/code_samples/quickstart_customer_churn_full_suite.ipynb",
"model": "cm4lr52lw00a60jpbhmzh8cah"
"model": "cm4lr52lw00a60jpbhmzh8cah",
},
{
# [Demo] Credit Risk Model
Expand Down
106 changes: 99 additions & 7 deletions tests/test_validmind_tests_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,68 @@

import unittest
from unittest import TestCase
from typing import Callable
from typing import Callable, List

from validmind.tests import list_tests, load_test, describe_test, register_test_provider
import pandas as pd

from validmind.tests import (
list_tags,
list_tasks,
list_tasks_and_tags,
list_tests,
load_test,
describe_test,
register_test_provider,
test,
tags,
tasks,
)


class TestTestsModule(TestCase):
def test_list_tags(self):
tags = list_tags()
self.assertIsInstance(tags, list)
self.assertTrue(len(tags) > 0)
self.assertTrue(all(isinstance(tag, str) for tag in tags))

def test_list_tasks(self):
tasks = list_tasks()
self.assertIsInstance(tasks, list)
self.assertTrue(len(tasks) > 0)
self.assertTrue(all(isinstance(task, str) for task in tasks))

def test_list_tasks_and_tags(self):
tasks_and_tags = list_tasks_and_tags()
self.assertIsInstance(tasks_and_tags, pd.io.formats.style.Styler)
df = tasks_and_tags.data
self.assertTrue(len(df) > 0)
self.assertTrue(all(isinstance(task, str) for task in df["Task"]))
self.assertTrue(all(isinstance(tag, str) for tag in df["Tags"]))

def test_list_tests(self):
tests = list_tests(pretty=False)
self.assertIsInstance(tests, list)
self.assertTrue(len(tests) > 0)
self.assertTrue(all(isinstance(test, str) for test in tests))

def test_list_tests_pretty(self):
tests = list_tests(pretty=True)
self.assertIsInstance(tests, pd.io.formats.style.Styler)
df = tests.data
self.assertTrue(len(df) > 0)
# check has the columns: ID, Name, Description, Required Inputs, Params
self.assertTrue("ID" in df.columns)
self.assertTrue("Name" in df.columns)
self.assertTrue("Description" in df.columns)
self.assertTrue("Required Inputs" in df.columns)
self.assertTrue("Params" in df.columns)
# check types of columns
self.assertTrue(all(isinstance(test, str) for test in df["ID"]))
self.assertTrue(all(isinstance(test, str) for test in df["Name"]))
self.assertTrue(all(isinstance(test, str) for test in df["Description"]))
self.assertTrue(all(isinstance(test, list) for test in df["Required Inputs"]))
self.assertTrue(all(isinstance(test, dict) for test in df["Params"]))

def test_list_tests_filter(self):
tests = list_tests(filter="sklearn", pretty=False)
Expand All @@ -23,6 +76,15 @@ def test_list_tests_filter_2(self):
filter="validmind.model_validation.ModelMetadata", pretty=False
)
self.assertTrue(len(tests) == 1)
self.assertTrue(tests[0].startswith("validmind.model_validation.ModelMetadata"))

def test_list_tests_tasks(self):
task = list_tasks()[0]
tests = list_tests(task=task, pretty=False)
self.assertTrue(len(tests) > 0)
for test in tests:
_test = load_test(test)
self.assertTrue(task in _test.__tasks__)

def test_load_test(self):
test = load_test("validmind.model_validation.ModelMetadata")
Expand All @@ -33,7 +95,6 @@ def test_load_test(self):
self.assertTrue(test.params is not None)

def test_describe_test(self):
describe_test("validmind.model_validation.ModelMetadata")
description = describe_test(
"validmind.model_validation.ModelMetadata", raw=True
)
Expand All @@ -46,17 +107,48 @@ def test_describe_test(self):
self.assertTrue("Params" in description)

def test_test_provider_registration(self):
def fake_test():
return None

class TestProvider:
def list_tests(self):
return ["fake.fake_test_id"]
return ["fake_test_id"]

def load_test(self, _):
return lambda: None
return fake_test

register_test_provider("fake", TestProvider())

test = load_test(test_id="fake.fake_test_id")
self.assertEqual(test.test_id, "fake.fake_test_id")
# check that the test provider's test shows up in the list of tests
tests = list_tests(pretty=False)
self.assertTrue("fake.fake_test_id" in tests)

# check that the test provider's test can be loaded
_test = load_test("fake.fake_test_id")
self.assertEqual(_test, fake_test)

# check that the test provider's test can be described
description = describe_test("fake.fake_test_id", raw=True)
self.assertIsInstance(description, dict)
self.assertTrue("ID" in description)
self.assertTrue("Name" in description)
self.assertTrue("Description" in description)
self.assertTrue("Required Inputs" in description)
self.assertTrue("Params" in description)

def test_test_decorators(self):
@tags("fake_tag_1", "fake_tag_2")
@tasks("fake_task_1", "fake_task_2")
@test("fake.fake_test_id_2")
def fake_test_2():
return None

self.assertTrue(fake_test_2.test_id == "fake.fake_test_id_2")
self.assertEqual(fake_test_2.__tags__, ["fake_tag_1", "fake_tag_2"])
self.assertEqual(fake_test_2.__tasks__, ["fake_task_1", "fake_task_2"])

_test = load_test("fake.fake_test_id_2")
self.assertEqual(_test, fake_test_2)


if __name__ == "__main__":
Expand Down
34 changes: 18 additions & 16 deletions validmind/tests/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ def _get_save_func(func, test_id):
test library.
"""

# get og source before its wrapped by the test decorator
source = inspect.getsource(func)
# remove decorator line
source = source.split("\n", 1)[1]

def save(root_folder=".", imports=None):
parts = test_id.split(".")

Expand All @@ -41,35 +46,32 @@ def save(root_folder=".", imports=None):

full_path = os.path.join(path, f"{test_name}.py")

source = inspect.getsource(func)
# remove decorator line
source = source.split("\n", 1)[1]
_source = source.replace(f"def {func.__name__}", f"def {test_name}")

if imports:
imports = "\n".join(imports)
source = f"{imports}\n\n\n{source}"
_source = f"{imports}\n\n\n{_source}"

# add comment to the top of the file
source = f"""
_source = f"""
# Saved from {func.__module__}.{func.__name__}
# Original Test ID: {test_id}
# New Test ID: {new_test_id}

{source}
{_source}
"""

# ensure that the function name matches the test name
source = source.replace(f"def {func.__name__}", f"def {test_name}")

# use black to format the code
try:
import black

source = black.format_str(source, mode=black.FileMode())
_source = black.format_str(_source, mode=black.FileMode())
except ImportError:
# ignore if not available
pass

with open(full_path, "w") as file:
file.writelines(source)
file.writelines(_source)

logger.info(
f"Saved to {os.path.abspath(full_path)}!"
Expand Down Expand Up @@ -119,12 +121,12 @@ def decorator(func):
test_func = load_test(test_id, func, reload=True)
test_store.register_test(test_id, test_func)

@wraps(test_func)
def wrapper(*args, **kwargs):
return test_func(*args, **kwargs)

# special function to allow the function to be saved to a file
wrapper.save = _get_save_func(test_func, test_id)
save_func = _get_save_func(func, test_id)

wrapper = wraps(func)(test_func)
wrapper.test_id = test_id
wrapper.save = save_func

return wrapper

Expand Down
32 changes: 9 additions & 23 deletions validmind/tests/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,6 @@ def run_test( # noqa: C901
input_grid=input_grid,
params=params,
param_grid=param_grid,
generate_description=generate_description,
)

elif unit_metrics:
Expand All @@ -371,18 +370,6 @@ def run_test( # noqa: C901
input_grid=input_grid,
params=params,
param_grid=param_grid,
generate_description=generate_description,
title=title,
)

elif input_grid or param_grid:
result = _run_comparison_test(
test_id=test_id,
inputs=inputs,
input_grid=input_grid,
params=params,
param_grid=param_grid,
generate_description=generate_description,
title=title,
)

Expand All @@ -395,16 +382,15 @@ def run_test( # noqa: C901
if post_process_fn:
result = post_process_fn(result)

if generate_description:
result.description = get_result_description(
test_id=test_id,
test_description=result.doc,
tables=result.tables,
figures=result.figures,
metric=result.metric,
should_generate=generate_description,
title=title,
)
result.description = get_result_description(
test_id=test_id,
test_description=result.doc,
tables=result.tables,
figures=result.figures,
metric=result.metric,
should_generate=generate_description,
title=title,
)

if show:
result.show()
Expand Down
Loading