Skip to content

Commit 3821c0b

Browse files
authored
[ENH] docstring test suite for functions (#1955)
Adds a test suite that runs all function docstrings in the package, as pytest test.
1 parent 53f8cbd commit 3821c0b

File tree

2 files changed

+112
-1
lines changed

2 files changed

+112
-1
lines changed
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# copyright: pytorch-forecasting developers, BSD-3-Clause License (see LICENSE file)
2+
# copy of the sktime utility of the same name (BSD-3)
3+
"""Doctest checks directed through pytest with conditional skipping."""
4+
5+
from functools import lru_cache
6+
import importlib
7+
import inspect
8+
import pkgutil
9+
10+
EXCLUDE_MODULES_STARTING_WITH = ("all", "test")
11+
12+
13+
def _all_functions(module_name):
14+
"""Get all functions from a module, including submodules.
15+
16+
Excludes:
17+
18+
* modules starting with 'all' or 'test'.
19+
* if the flag ``ONLY_CHANGED_MODULES`` is set, modules that have not changed,
20+
compared to the ``main`` branch.
21+
22+
Parameters
23+
----------
24+
module_name : str
25+
Name of the module.
26+
27+
Returns
28+
-------
29+
functions_list : list
30+
List of tuples (function_name, function_object).
31+
"""
32+
res = _all_functions_cached(module_name)
33+
# copy the result to avoid modifying the cached result
34+
return res.copy()
35+
36+
37+
@lru_cache
38+
def _all_functions_cached(module_name, only_changed_modules=False):
39+
"""Get all functions from a module, including submodules.
40+
41+
Excludes:
42+
43+
* modules starting with 'all' or 'test'.
44+
* if ``only_changed_modules`` is ``True``, modules that have not changed,
45+
compared to the ``main`` branch.
46+
47+
Parameters
48+
----------
49+
module_name : str
50+
Name of the module.
51+
only_changed_modules : bool, optional (default=False)
52+
If True, only functions from modules that have changed are returned.
53+
54+
Returns
55+
-------
56+
functions_list : list
57+
List of tuples (function_name, function_object).
58+
"""
59+
# Import the package
60+
package = importlib.import_module(module_name)
61+
62+
# Initialize an empty list to hold all functions
63+
functions_list = []
64+
65+
# Walk through the package's modules
66+
package_path = package.__path__[0]
67+
for _, modname, _ in pkgutil.walk_packages(
68+
path=[package_path], prefix=package.__name__ + "."
69+
):
70+
# Skip modules starting with 'all' or 'test'
71+
if modname.split(".")[-1].startswith(EXCLUDE_MODULES_STARTING_WITH):
72+
continue
73+
74+
# Import the module
75+
module = importlib.import_module(modname)
76+
77+
# Get all functions from the module
78+
for name, obj in inspect.getmembers(module, inspect.isfunction):
79+
# if function is imported from another module, skip it
80+
if obj.__module__ != module.__name__:
81+
continue
82+
# add the function to the list
83+
functions_list.append((name, obj))
84+
85+
return functions_list
86+
87+
88+
def pytest_generate_tests(metafunc):
89+
"""Test parameterization routine for pytest.
90+
91+
Fixtures parameterized
92+
----------------------
93+
func : all functions from sktime, as returned by _all_functions
94+
if ONLY_CHANGED_MODULES is set, only functions from modules that have changed
95+
"""
96+
# we assume all four arguments are present in the test below
97+
funcs_and_names = _all_functions("pytorch_forecasting")
98+
99+
if len(funcs_and_names) > 0:
100+
names, funcs = zip(*funcs_and_names)
101+
102+
metafunc.parametrize("func", funcs, ids=names)
103+
else:
104+
metafunc.parametrize("func", [])
105+
106+
107+
def test_all_functions_doctest(func):
108+
"""Run doctest for all functions in pytorch-forecasting."""
109+
from skbase.utils.doctest_run import run_doctest
110+
111+
run_doctest(func, name=f"function {func.__name__}")

pytorch_forecasting/utils/_dependencies/_safe_import.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _safe_import(import_path, pkg_name=None):
5959
6060
Examples
6161
--------
62-
>>> from pytorch_forecasting.utils.dependencies._safe_import import _safe_import
62+
>>> from pytorch_forecasting.utils._dependencies._safe_import import _safe_import
6363
6464
>>> # Import a top-level module
6565
>>> torch = _safe_import("torch")

0 commit comments

Comments
 (0)