diff --git a/CHANGELOG.md b/CHANGELOG.md index 88955e6..1ad91b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **Lazy module registration**: `D3Session.execute()` and `D3AsyncSession.execute()` now automatically register a `@d3function` module on first use, eliminating the need to declare all modules in `context_modules` upfront. - `registered_modules` tracking on session instances prevents duplicate registration calls. - **Jupyter notebook support**: `@d3function` now automatically replaces a previously registered function when the same name is re-registered in the same module, with a warning log. This enables iterative workflows in Jupyter notebooks where cells are re-executed. +- **Automatic import detection**: `@d3function` now automatically discovers file-level imports used by the decorated function and includes them in the registered module. In Jupyter notebooks, place imports inside the function body instead. + +### Removed +- `add_packages_in_current_file()`: Removed. Imports are now detected automatically by `@d3function`. +- `find_packages_in_current_file()`: Removed. Replaced by `find_imports_for_function()`. ### Changed - `d3_api_plugin` has been renamed to `d3_api_execute`. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9f5e34d..fe4a450 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -31,14 +31,19 @@ Thank you for your interest in contributing to designer-plugin! This document pr ### Running Tests -Run the full test suite: +Run unit tests (default): ```bash uv run pytest ``` -Run tests with verbose output: +Run integration tests (requires a running d3 instance): ```bash -uv run pytest -v +uv run pytest -m integration +``` + +Run all tests: +```bash +uv run pytest -m "" ``` Run specific test file: diff --git a/README.md b/README.md index c1684ea..1d39b3c 100644 --- a/README.md +++ b/README.md @@ -169,6 +169,14 @@ The Functional API offers two decorators: `@d3pythonscript` and `@d3function`: - Functions decorated with the same `module_name` are grouped together and can call each other, enabling function chaining and code reuse. - Registration happens automatically on the first call to `execute()` or `rpc()` that references the module — no need to declare modules upfront. You can also pre-register specific modules by passing them to the session context manager (e.g., `D3AsyncSession('localhost', 80, {"mymodule"})`). +> **Jupyter Notebook:** File-level imports (e.g., `import numpy as np` in a separate cell) cannot be automatically detected. In Jupyter, place any required imports inside the function body itself: +> ```python +> @d3function("mymodule") +> def my_fn(): +> import numpy as np +> return np.array([1, 2]) +> ``` + ### Session API Methods Both `D3AsyncSession` and `D3Session` provide two methods for executing functions: diff --git a/pyproject.toml b/pyproject.toml index 675f8e5..5c2eb1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,6 +109,11 @@ python_classes = ["Test*"] python_functions = ["test_*"] addopts = [ "-v", + "-m", "not integration", "--strict-markers", "--strict-config", ] +markers = [ + "integration: tests that require a running d3 instance", +] + diff --git a/src/designer_plugin/d3sdk/__init__.py b/src/designer_plugin/d3sdk/__init__.py index 7f5989e..905c283 100644 --- a/src/designer_plugin/d3sdk/__init__.py +++ b/src/designer_plugin/d3sdk/__init__.py @@ -5,7 +5,7 @@ from .client import D3PluginClient from .function import ( - add_packages_in_current_file, + PackageInfo, d3function, d3pythonscript, get_all_d3functions, @@ -18,9 +18,9 @@ "D3AsyncSession", "D3PluginClient", "D3Session", + "PackageInfo", "d3pythonscript", "d3function", - "add_packages_in_current_file", "get_register_payload", "get_all_d3functions", "get_all_modules", diff --git a/src/designer_plugin/d3sdk/ast_utils.py b/src/designer_plugin/d3sdk/ast_utils.py index 696cad1..72485c2 100644 --- a/src/designer_plugin/d3sdk/ast_utils.py +++ b/src/designer_plugin/d3sdk/ast_utils.py @@ -4,11 +4,71 @@ """ import ast +import functools import inspect +import logging import textwrap import types +from collections.abc import Callable from typing import Any +from pydantic import BaseModel, Field + +from designer_plugin.d3sdk.builtin_modules import SUPPORTED_MODULES + +logger = logging.getLogger(__name__) + + +############################################################################### +# Package info models +class ImportAlias(BaseModel): + """Represents a single imported name with an optional alias. + + Mirrors the structure of ast.alias for Pydantic compatibility. + """ + + name: str = Field( + description="The imported name (e.g., 'Path' in 'from pathlib import Path')" + ) + asname: str | None = Field( + default=None, + description="The alias (e.g., 'np' in 'import numpy as np')", + ) + + +class PackageInfo(BaseModel): + """Structured representation of a Python import statement. + + Rendering rules (via to_import_statement using ast.unparse): + - package only → import package + - package + alias → import package as alias + - package + methods → from package import method1, method2 + - package + methods w/alias → from package import method1 as alias1 + """ + + package: str = Field(description="The module/package name to import") + alias: str | None = Field( + default=None, + description="Alias for the package (e.g., 'np' in 'import numpy as np')", + ) + methods: list[ImportAlias] = Field( + default_factory=list, + description="Imported names for 'from X import ...' style imports", + ) + + def to_import_statement(self) -> str: + """Render back to a Python import statement using ast.unparse.""" + node: ast.stmt + if self.methods: + node = ast.ImportFrom( + module=self.package, + names=[ast.alias(name=m.name, asname=m.asname) for m in self.methods], + level=0, + ) + else: + node = ast.Import(names=[ast.alias(name=self.package, asname=self.alias)]) + return ast.unparse(node) + ############################################################################### # Source code extraction utilities @@ -369,94 +429,144 @@ def validate_and_extract_args( ############################################################################### -# Python package finder utility -def find_packages_in_current_file(caller_stack: int = 1) -> list[str]: - """Find all import statements in the caller's file by inspecting the call stack. +# Function-scoped import extraction utility +def _collect_used_names(func_node: ast.FunctionDef | ast.AsyncFunctionDef) -> set[str]: + """Collect all identifier names used inside a function body. + + Walks the function's AST body and extracts: + - Simple names (ast.Name nodes, e.g., ``foo`` in ``foo()``) + - Root names of attribute chains (e.g., ``np`` in ``np.array()``) + + Args: + func_node: The function AST node to analyse. + + Returns: + Set of identifier strings used in the function body. + """ + names: set[str] = set() + for node in ast.walk(func_node): + if isinstance(node, ast.Name): + names.add(node.id) + elif isinstance(node, ast.Attribute): + # Walk down the attribute chain to find the root name + root: ast.expr = node + while isinstance(root, ast.Attribute): + root = root.value + if isinstance(root, ast.Name): + names.add(root.id) + return names + + +def _is_type_checking_block(node: ast.If) -> bool: + """Check if an if statement is ``if TYPE_CHECKING:``.""" + return isinstance(node.test, ast.Name) and node.test.id == "TYPE_CHECKING" + + +def _is_supported_module(module_name: str) -> bool: + """Check if a module (or its top-level parent) is Designer-supported.""" + top_level = module_name.split(".")[0] + return top_level in SUPPORTED_MODULES - This function walks up the call stack to find the module where it was called from, - then parses that module's source code to extract all import statements that are - compatible with Python 2.7 and safe to send to Designer. + +@functools.cache +def _get_module_ast(module: types.ModuleType) -> ast.Module | None: + """Return the parsed AST for *module*, cached by module identity.""" + try: + return ast.parse(inspect.getsource(module)) + except (OSError, TypeError): + return None + + +def find_imports_for_function(func: Callable[..., Any]) -> list[PackageInfo]: + """Extract import statements used by a function from its source file. + + Inspects the module containing *func*, parses all top-level imports, then + filters them down to only those whose imported names are actually referenced + inside the function body. Args: - caller_stack: Number of frames to go up the call stack. Default is 1 (immediate caller). - Use higher values to inspect files further up the call chain. + func: The callable to analyse. Returns: - Sorted list of unique import statement strings (e.g., "import ast", "from pathlib import Path"). + Sorted list of :class:`PackageInfo` objects representing the imports + used by *func*. Filters applied: - - Excludes imports inside `if TYPE_CHECKING:` blocks (type checking only) - - Excludes imports from the 'd3blobgen' package (client-side only) - - Excludes imports from the 'typing' module (not supported in Python 2.7) - - Excludes imports of this function itself to avoid circular references + - Excludes imports inside ``if TYPE_CHECKING:`` blocks + - Only includes imports from Designer-supported builtin modules + (see ``SUPPORTED_MODULES`` in ``builtin_modules.py``) + - Only includes imports whose names are actually used in the function body """ - # Get the this file frame - current_frame: types.FrameType | None = inspect.currentframe() - if not current_frame: + # --- 1. Get the function's module source --- + module = inspect.getmodule(func) + if not module: return [] - # Get the caller's frame (file where this function is called) - caller_frame: types.FrameType | None = current_frame - for _ in range(caller_stack): - if not caller_frame or not caller_frame.f_back: - return [] - caller_frame = caller_frame.f_back - - if not caller_frame: + module_tree = _get_module_ast(module) + if module_tree is None: + logger.warning( + "Cannot detect file-level imports for '%s': module source unavailable " + "(e.g. Jupyter notebook). Place imports inside the function body instead.", + func.__qualname__, + ) return [] - modules: types.ModuleType | None = inspect.getmodule(caller_frame) - if not modules: + # --- 2. Collect names used inside the function body --- + func_source = textwrap.dedent(inspect.getsource(func)) + func_tree = ast.parse(func_source) + if not func_tree.body: return [] - source: str = inspect.getsource(modules) - - # Parse the source code - tree = ast.parse(source) - - # Get the name of this function to filter it out - # For example, we don't want `from core import find_packages_in_current_file` - function_name: str = current_frame.f_code.co_name - # Skip any package from d3blobgen - d3blobgen_package_name: str = "d3blobgen" - # typing not supported in python2.7 - typing_package_name: str = "typing" + func_node = func_tree.body[0] + if not isinstance(func_node, (ast.FunctionDef, ast.AsyncFunctionDef)): + return [] - def is_type_checking_block(node: ast.If) -> bool: - """Check if an if statement is 'if TYPE_CHECKING:'""" - return isinstance(node.test, ast.Name) and node.test.id == "TYPE_CHECKING" + used_names = _collect_used_names(func_node) - imports: list[str] = [] - for node in tree.body: - # Skip TYPE_CHECKING blocks entirely - if isinstance(node, ast.If) and is_type_checking_block(node): + # --- 3. Parse file-level imports and filter to used ones --- + packages: list[PackageInfo] = [] + for node in module_tree.body: + # Skip TYPE_CHECKING blocks + if isinstance(node, ast.If) and _is_type_checking_block(node): continue if isinstance(node, ast.Import): - imported_modules: list[str] = [alias.name for alias in node.names] - # Skip imports that include d3blobgen - if any(d3blobgen_package_name in module for module in imported_modules): - continue - if any(typing_package_name in module for module in imported_modules): - continue - import_text: str = f"import {', '.join(imported_modules)}" - imports.append(import_text) + for alias in node.names: + if not _is_supported_module(alias.name): + continue + + # The name used in code is the alias if present, otherwise the module name + effective_name = alias.asname if alias.asname else alias.name + if effective_name in used_names: + packages.append( + PackageInfo( + package=alias.name, + alias=alias.asname, + ) + ) elif isinstance(node, ast.ImportFrom): - imported_module: str | None = node.module - imported_names: list[str] = [alias.name for alias in node.names] - if not imported_module: - continue - # Skip imports that include d3blobgen - if d3blobgen_package_name in imported_module: + if not node.module: continue - elif typing_package_name in imported_module: - continue - # Skip imports that include this function itself - if function_name in imported_names: + if not _is_supported_module(node.module): continue - line_text = f"from {imported_module} import {', '.join(imported_names)}" - imports.append(line_text) + # Filter to only methods actually used by the function + matched_methods: list[ImportAlias] = [] + for alias in node.names: + effective_name = alias.asname if alias.asname else alias.name + if effective_name in used_names: + matched_methods.append( + ImportAlias(name=alias.name, asname=alias.asname) + ) + + if matched_methods: + packages.append( + PackageInfo( + package=node.module, + methods=matched_methods, + ) + ) - return sorted(set(imports)) + # Sort by import statement string for deterministic output + return sorted(packages, key=lambda p: p.to_import_statement()) diff --git a/src/designer_plugin/d3sdk/builtin_modules.py b/src/designer_plugin/d3sdk/builtin_modules.py new file mode 100644 index 0000000..8b933a8 --- /dev/null +++ b/src/designer_plugin/d3sdk/builtin_modules.py @@ -0,0 +1,220 @@ +SUPPORTED_MODULES: frozenset[str] = frozenset( + [ + "Bastion", + "ConfigParser", + "Cookie", + "HTMLParser", + "SocketServer", + "StringIO", + "UserDict", + "UserList", + "_winreg", + "abc", + "aifc", + "anydbm", + "array", + "ast", + "atexit", + "audioop", + "base64", + "binascii", + "bisect", + "bz2", + "cPickle", + "cStringIO", + "chunk", + "cmath", + "cmd", + "codecs", + "codeop", + "collections", + "copy", + "copy_reg", + "csv", + "ctypes", + "datetime", + "difflib", + "dircache", + "dis", + "dumbdbm", + "dummy_thread", + "errno", + "filecmp", + "fnmatch", + "functools", + "future_builtins", + "gc", + "getopt", + "hashlib", + "heapq", + "hmac", + "htmlentitydefs", + "imghdr", + "imp", + "importlib", + "inspect", + "io", + "itertools", + "json", + "keyword", + "linecache", + "locale", + "logging", + "mailcap", + "marshal", + "math", + "mmap", + "msvcrt", + "mutex", + "netrc", + "new", + "nntplib", + "numbers", + "operator", + "os", + "parser", + "pkgutil", + "plistlib", + "poplib", + "pprint", + "quopri", + "random", + "re", + "repr", + "rfc822", + "rlcompleter", + "sched", + "select", + "sets", + "sgmllib", + "sha", + "shelve", + "shlex", + "shutil", + "signal", + "site", + "sndhdr", + "socket", + "sqlite3", + "stat", + "statvfs", + "string", + "stringprep", + "struct", + "subprocess", + "sunau", + "symbol", + "symtable", + "sysconfig", + "tempfile", + "textwrap", + "thread", + "time", + "token", + "tokenize", + "traceback", + "types", + "unicodedata", + "unittest", + "urlparse", + "uuid", + "warnings", + "weakref", + "webbrowser", + "winsound", + "wsgiref", + "xdrlib", + "xml", + "zipfile", + "zipimport", + "zlib", + ] +) + +NOT_SUPPORTED_MODULES: frozenset[str] = frozenset( + [ + "BaseHTTPServer", + "CGIHTTPServer", + "DocXMLRPCServer", + "Queue", + "ScrolledText", + "SimpleHTTPServer", + "SimpleXMLRPCServer", + "Tix", + "Tkinter", + "UserString", + "argparse", + "asynchat", + "asyncore", + "bdb", + "binhex", + "bsddb", + "calendar", + "cgi", + "cgitb", + "code", + "colorsys", + "compileall", + "compiler", + "contextlib", + "cookielib", + "dbhash", + "dbm", + "decimal", + "distutils", + "doctest", + "dummy_threading", + "email", + "ensurepip", + "fileinput", + "formatter", + "fractions", + "ftplib", + "getpass", + "gettext", + "glob", + "gzip", + "htmllib", + "httplib", + "imaplib", + "mailbox", + "mhlib", + "mimetools", + "mimetypes", + "mimify", + "modulefinder", + "msilib", + "multiprocessing", + "optparse", + "pdb", + "pickle", + "pickletools", + "platform", + "popen2", + "profile", + "py_compile", + "pyclbr", + "pydoc", + "robotparser", + "runpy", + "smtpd", + "smtplib", + "ssl", + "sys", + "tabnanny", + "tarfile", + "telnetlib", + "test", + "threading", + "timeit", + "trace", + "ttk", + "turtle", + "urllib", + "urllib2", + "uu", + "wave", + "whichdb", + "xmlrpclib", + ] +) diff --git a/src/designer_plugin/d3sdk/function.py b/src/designer_plugin/d3sdk/function.py index c3d1c50..d64b008 100644 --- a/src/designer_plugin/d3sdk/function.py +++ b/src/designer_plugin/d3sdk/function.py @@ -15,8 +15,9 @@ from pydantic import BaseModel, Field from designer_plugin.d3sdk.ast_utils import ( + PackageInfo, convert_function_to_py27, - find_packages_in_current_file, + find_imports_for_function, validate_and_bind_signature, validate_and_extract_args, ) @@ -51,6 +52,9 @@ class FunctionInfo(BaseModel): args: list[str] = Field( default=[], description="list of arguments from extracted function" ) + packages: list[PackageInfo] = Field( + default=[], description="list of packages/imports used by the function" + ) def extract_function_info(func: Callable[..., Any]) -> FunctionInfo: @@ -114,6 +118,8 @@ def extract_function_info(func: Callable[..., Any]) -> FunctionInfo: for stmt in body_nodes_py27: body_py27 += ast.unparse(stmt) + "\n" + packages = find_imports_for_function(func) + return FunctionInfo( source_code=source_code_py3, source_code_py27=source_code_py27, @@ -121,6 +127,7 @@ def extract_function_info(func: Callable[..., Any]) -> FunctionInfo: body=body.strip(), body_py27=body_py27.strip(), args=args, + packages=packages, ) @@ -256,6 +263,11 @@ def __init__(self, module_name: str, func: Callable[P, T]): super().__init__(func) + # Auto-register packages used by this function + D3Function._available_packages[module_name].update( + pkg.to_import_statement() for pkg in self._function_info.packages + ) + # Update the function in case the function was updated in the same session. # For example, jupyter notebook server can be running, but function signature can # change constantly. @@ -311,7 +323,7 @@ def get_module_register_payload(module_name: str) -> RegisterPayload | None: return None contents_packages: str = "\n".join( - list(D3Function._available_packages[module_name]) + sorted(D3Function._available_packages[module_name]) ) contents_functions: str = "\n\n".join( [ @@ -460,35 +472,6 @@ def decorator(func: Callable[P, T]) -> D3Function[P, T]: return decorator -def add_packages_in_current_file(module_name: str) -> None: - """Add all import statements from the caller's file to a d3function module's package list. - - This function scans the calling file's import statements and registers them with - the specified module name, making those imports available when the module is - registered with Designer. This is useful for ensuring all dependencies are included - when deploying Python functions to Designer. - - Args: - module_name: The name of the d3function module to associate the packages with. - Must match the module_name used in @d3function decorator. - - Example: - ```python - import numpy as np - - @d3function("my_module") - def my_function(): - return np.array([1, 2, 3]) - - # Register all imports in the file (numpy) - add_packages_in_current_file("my_module") - ``` - """ - # caller_stack is 2, 1 for this, 1 for caller of this function. - packages: list[str] = find_packages_in_current_file(2) - D3Function._available_packages[module_name].update(packages) - - def get_register_payload(module_name: str) -> RegisterPayload | None: """Get the registration payload for a specific module. diff --git a/tests/test_ast_utils.py b/tests/test_ast_utils.py index 599edf9..4ebe361 100644 --- a/tests/test_ast_utils.py +++ b/tests/test_ast_utils.py @@ -7,16 +7,19 @@ import inspect import textwrap import types +from os.path import join as path_join import pytest from designer_plugin.d3sdk.ast_utils import ( ConvertToPython27, + ImportAlias, + PackageInfo, convert_class_to_py27, convert_function_to_py27, filter_base_classes, filter_init_args, - find_packages_in_current_file, + find_imports_for_function, get_class_node, get_source, ) @@ -890,64 +893,6 @@ def __init__(self): assert param_names == [] -class TestFindPackagesInCurrentFile: - """Tests for find_packages_in_current_file function.""" - - def test_finds_imports_from_current_file(self): - """Test that the function finds import statements from the calling file.""" - # This test file has imports at the top - they should be found - imports = find_packages_in_current_file() - - # Should find at least some of our imports - assert isinstance(imports, list) - assert len(imports) > 0 - - # Should be sorted - assert imports == sorted(imports) - - # Check for specific imports we know exist in this file - assert "import ast" in imports - assert "import pytest" in imports - assert "import textwrap" in imports - - def test_excludes_typing_imports(self): - """Test that typing module imports are excluded.""" - # Since this file doesn't import typing, we can't directly test exclusion here - # But we can verify the function doesn't crash and returns valid results - imports = find_packages_in_current_file() - - # Verify no typing imports are present - typing_imports = [imp for imp in imports if "typing" in imp] - assert len(typing_imports) == 0 - - def test_excludes_d3blobgen_imports(self): - """Test that d3blobgen package imports are excluded.""" - imports = find_packages_in_current_file() - - # Verify no d3blobgen imports are present - d3blobgen_imports = [imp for imp in imports if "d3blobgen" in imp] - assert len(d3blobgen_imports) == 0 - - def test_excludes_find_packages_function_itself(self): - """Test that the function itself is excluded from imports.""" - imports = find_packages_in_current_file() - - # Should not include import of find_packages_in_current_file itself - # even though we import it at the top of this file - function_imports = [imp for imp in imports if "find_packages_in_current_file" in imp] - assert len(function_imports) == 0 - - def test_returns_unique_sorted_imports(self): - """Test that returned imports are unique and sorted.""" - imports = find_packages_in_current_file() - - # Check uniqueness - assert len(imports) == len(set(imports)) - - # Check sorting - assert imports == sorted(imports) - - class TestDecoratorHandling: """Tests for handling decorators in AST transformations.""" @@ -1127,5 +1072,142 @@ def my_function(x, y): assert len(func.body) == 3 # Two assignments and one return +class TestPackageInfo: + """Tests for PackageInfo and ImportAlias models.""" + + def test_import_package_only(self): + """import numpy""" + pkg = PackageInfo(package="numpy") + assert pkg.to_import_statement() == "import numpy" + + def test_import_package_with_alias(self): + """import numpy as np""" + pkg = PackageInfo(package="numpy", alias="np") + assert pkg.to_import_statement() == "import numpy as np" + + def test_from_import_single_method(self): + """from pathlib import Path""" + pkg = PackageInfo( + package="pathlib", + methods=[ImportAlias(name="Path")], + ) + assert pkg.to_import_statement() == "from pathlib import Path" + + def test_from_import_multiple_methods(self): + """from os.path import join, exists""" + pkg = PackageInfo( + package="os.path", + methods=[ + ImportAlias(name="join"), + ImportAlias(name="exists"), + ], + ) + assert pkg.to_import_statement() == "from os.path import join, exists" + + def test_from_import_method_with_alias(self): + """from collections import defaultdict as dd""" + pkg = PackageInfo( + package="collections", + methods=[ImportAlias(name="defaultdict", asname="dd")], + ) + assert pkg.to_import_statement() == "from collections import defaultdict as dd" + + def test_from_import_mixed_aliases(self): + """from collections import OrderedDict, defaultdict as dd""" + pkg = PackageInfo( + package="collections", + methods=[ + ImportAlias(name="OrderedDict"), + ImportAlias(name="defaultdict", asname="dd"), + ], + ) + result = pkg.to_import_statement() + assert result == "from collections import OrderedDict, defaultdict as dd" + + +class TestFindImportsForFunction: + """Tests for find_imports_for_function.""" + + def test_finds_used_import(self): + """Function using ast should get 'import ast' extracted.""" + # This function uses ast.parse which is from 'import ast' at file top + def uses_ast(): + return ast.parse("x = 1") + + packages = find_imports_for_function(uses_ast) + statements = [p.to_import_statement() for p in packages] + assert "import ast" in statements + + def test_excludes_unused_import(self): + """Function not using a module should not include it.""" + def uses_nothing(): + return 42 + + packages = find_imports_for_function(uses_nothing) + statements = [p.to_import_statement() for p in packages] + # Should not include ast, textwrap, etc. since they're not used + assert "import types" not in statements + + def test_finds_from_import(self): + """Function using a 'from X import Y' name should include it.""" + def uses_textwrap(): + return textwrap.dedent(" hello") + + packages = find_imports_for_function(uses_textwrap) + statements = [p.to_import_statement() for p in packages] + assert "import textwrap" in statements + + def test_returns_package_info_objects(self): + """Return type should be list of PackageInfo.""" + def simple_func(): + return ast.dump(ast.parse("1")) + + packages = find_imports_for_function(simple_func) + assert all(isinstance(p, PackageInfo) for p in packages) + + def test_sorted_output(self): + """Output should be sorted by import statement.""" + def uses_multiple(): + _ = textwrap.dedent("x") + _ = ast.parse("y") + return inspect.getsource(uses_multiple) + + packages = find_imports_for_function(uses_multiple) + statements = [p.to_import_statement() for p in packages] + assert statements == sorted(statements) + + def test_excludes_typing_imports(self): + """Typing imports should be excluded.""" + # The 'Any' import from typing at the file top should never appear + def uses_nothing(): + return 1 + + packages = find_imports_for_function(uses_nothing) + statements = [p.to_import_statement() for p in packages] + typing_imports = [s for s in statements if "typing" in s] + assert len(typing_imports) == 0 + + def test_finds_submodule_import(self): + """from os.path import join (sub-module) should be detected.""" + + def uses_path_join(): + return path_join("a", "b") + + packages = find_imports_for_function(uses_path_join) + statements = [p.to_import_statement() for p in packages] + assert "from os.path import join as path_join" in statements + + def test_no_source_module_returns_empty(self): + """Function whose module source is unavailable should return empty list.""" + # Simulate a function from an unsourceable module (like Jupyter __main__) + def dummy(): + return 1 + + # Patch __module__ to a non-existent module + dummy.__module__ = "_nonexistent_module_for_test" + packages = find_imports_for_function(dummy) + assert packages == [] + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/test_core.py b/tests/test_core.py index 2df0478..f064aa3 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -465,3 +465,53 @@ def test_func(a: int, b: int) -> int: with pytest.raises(TypeError, match="multiple values for argument"): test_func.payload(1, a=2) + + +class TestAutoPackageRegistration: + """Test that @d3function auto-registers imports used by the function.""" + + def test_extract_function_info_populates_packages(self): + """extract_function_info should populate the packages field.""" + def func_using_logging(): + return logging.getLogger("test") + + info = extract_function_info(func_using_logging) + statements = [p.to_import_statement() for p in info.packages] + assert "import logging" in statements + + def test_extract_function_info_packages_default_empty_for_no_imports(self): + """Function using no imports should have empty packages.""" + def func_no_imports(): + return 42 + + info = extract_function_info(func_no_imports) + assert info.packages == [] + + def test_d3function_auto_registers_packages(self): + """D3Function should auto-register packages.""" + module = "test_auto_pkg_module" + D3Function._available_d3functions[module].clear() + D3Function._available_packages[module].clear() + + @d3function(module) + def func_using_logging(): + return logging.getLogger("test") + + # Packages should be auto-registered + assert "import logging" in D3Function._available_packages[module] + + def test_d3function_register_payload_includes_auto_packages(self): + """get_register_payload should include auto-extracted imports.""" + module = "test_auto_payload_module" + D3Function._available_d3functions[module].clear() + D3Function._available_packages[module].clear() + + @d3function(module) + def func_using_logging(): + return logging.getLogger("test") + + payload = get_register_payload(module) + assert payload is not None + assert "import logging" in payload.contents + + diff --git a/tests/test_supported_modules.py b/tests/test_supported_modules.py new file mode 100644 index 0000000..aa73dab --- /dev/null +++ b/tests/test_supported_modules.py @@ -0,0 +1,56 @@ +import asyncio + +import pytest + +from designer_plugin.d3sdk import D3AsyncSession, d3function +from designer_plugin.d3sdk.builtin_modules import NOT_SUPPORTED_MODULES, SUPPORTED_MODULES + + +@d3function('test_supported_modules') +def check_import(module_str) -> bool: + try: + module = __import__(module_str) + return True + except ImportError as e: + return False + + +class TestSupportedModules: + """ + Test if supported and not supported modules are handled properly on Designer side. + This is integration test so Designer must be running to pass the test. + """ + + @pytest.mark.integration + def test_supported_modules(self): + """Test if all supported modules are able to be imported on Designer side.""" + + async def run(): + failed = [] + async with D3AsyncSession("localhost", 80) as session: + for module_str in SUPPORTED_MODULES: + import_success: bool = await session.rpc( + check_import.payload(module_str) + ) + if not import_success: + failed.append(module_str) + assert not failed, f"Failed to import: {failed}" + + asyncio.run(run()) + + @pytest.mark.integration + def test_not_supported_modules(self): + """Test if all not supported modules are not importable on Designer side.""" + + async def run(): + failed = [] + async with D3AsyncSession("localhost", 80) as session: + for module_str in NOT_SUPPORTED_MODULES: + import_success: bool = await session.rpc( + check_import.payload(module_str) + ) + if import_success: + failed.append(module_str) + assert not failed, f"Unexpectedly imported: {failed}" + + asyncio.run(run())