From 97d0c982bac00a12b9d07569d90a7e8773977ca6 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 30 Nov 2025 10:57:54 -0800 Subject: [PATCH] [Stubgen] Introduce `--init-path` for package-level generation This PR introduces an additional flag `--init-path` to the tool `tvm-ffi-stubgen`. When this flag is specified, the tool will actively look for global functions and objects that are not yet exported into Python, and generate their type stubs into the path specified by `--init-path`. To give an example, if a C++ TVM-FFI extension defines two items: - A global function: `my_ffi_extension.add_one` - An object: `my_ffi_extension.IntPair` which are not yet exported to Python, the tool will generate the following two files if `--init-path=/path/` is set: ``` /path/__init__.py /path/_ffi_api.py ``` Where `__init__.py` contains: ```python """Package my_ffi_extension.""" /# tvm-ffi-stubgen(begin): export/_ffi_api /# fmt: off /# isort: off from ._ffi_api import * # noqa: F403 from ._ffi_api import __all__ as _ffi_api__all__ if "__all__" not in globals(): __all__ = [] __all__.extend(_ffi_api__all__) /# isort: on /# fmt: on /# tvm-ffi-stubgen(end) ``` and `_ffi_api.py` contains: ```python """FFI API bindings for my_ffi_extension.""" /# tvm-ffi-stubgen(begin): import /# fmt: off /# isort: off from __future__ import annotations from typing import TYPE_CHECKING if TYPE_CHECKING: from tvm_ffi import Object /# isort: on /# fmt: on /# tvm-ffi-stubgen(end) /# tvm-ffi-stubgen(begin): global/my_ffi_extension /# fmt: off /# isort: off from tvm_ffi import init_ffi_api as _INIT _INIT("my_ffi_extension", __name__) /# isort: on if TYPE_CHECKING: def raise_error(_0: str, /) -> None: ... /# fmt: on /# tvm-ffi-stubgen(end) __all__ = [ # tvm-ffi-stubgen(begin): __all__ "IntPair", "raise_error", # tvm-ffi-stubgen(end) ] /# isort: off import tvm_ffi /# isort: on @tvm_ffi.register_object("my_ffi_extension.IntPair") class IntPair(tvm_ffi.Object): """FFI binding for `my_ffi_extension.IntPair`.""" /# tvm-ffi-stubgen(begin): object/my_ffi_extension.IntPair /# fmt: off a: int b: int if TYPE_CHECKING: @staticmethod def __c_ffi_init__(_0: int, _1: int, /) -> Object: ... @staticmethod def static_get_second(_0: IntPair, /) -> int: ... def get_first(self, /) -> int: ... /# fmt: on /# tvm-ffi-stubgen(end) ``` --- python/tvm_ffi/_ffi_api.py | 10 +- python/tvm_ffi/_tensor.py | 5 + python/tvm_ffi/container.py | 10 ++ python/tvm_ffi/stub/analysis.py | 22 ++- python/tvm_ffi/stub/cli.py | 222 +++++++++++++++++------- python/tvm_ffi/stub/codegen.py | 112 ++++++++++-- python/tvm_ffi/stub/consts.py | 12 ++ python/tvm_ffi/stub/file_utils.py | 33 ++-- python/tvm_ffi/stub/utils.py | 2 + python/tvm_ffi/testing/__init__.py | 31 ++++ python/tvm_ffi/testing/_ffi_api.py | 133 ++++++++++++++ python/tvm_ffi/{ => testing}/testing.py | 8 +- tests/python/test_stubgen.py | 11 +- 13 files changed, 517 insertions(+), 94 deletions(-) create mode 100644 python/tvm_ffi/testing/__init__.py create mode 100644 python/tvm_ffi/testing/_ffi_api.py rename python/tvm_ffi/{ => testing}/testing.py (97%) diff --git a/python/tvm_ffi/_ffi_api.py b/python/tvm_ffi/_ffi_api.py index ab6704a5..c1afffc4 100644 --- a/python/tvm_ffi/_ffi_api.py +++ b/python/tvm_ffi/_ffi_api.py @@ -29,10 +29,12 @@ # fmt: on # tvm-ffi-stubgen(end) -from . import registry - -# tvm-ffi-stubgen(begin): global/ffi +# tvm-ffi-stubgen(begin): global/ffi@.registry # fmt: off +# isort: off +from .registry import init_ffi_api as _INIT +_INIT("ffi", __name__) +# isort: on if TYPE_CHECKING: def Array(*args: Any) -> Any: ... def ArrayGetItem(_0: Sequence[Any], _1: int, /) -> Any: ... @@ -72,8 +74,6 @@ def ToJSONGraphString(_0: Any, _1: Any, /) -> str: ... # fmt: on # tvm-ffi-stubgen(end) -registry.init_ffi_api("ffi", __name__) - __all__ = [ # tvm-ffi-stubgen(begin): __all__ diff --git a/python/tvm_ffi/_tensor.py b/python/tvm_ffi/_tensor.py index 5005bc65..d4eb22d5 100644 --- a/python/tvm_ffi/_tensor.py +++ b/python/tvm_ffi/_tensor.py @@ -57,6 +57,11 @@ class Shape(tuple, PyNativeObject): _tvm_ffi_cached_object: Any + # tvm-ffi-stubgen(begin): object/ffi.Shape + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __new__(cls, content: tuple[int, ...]) -> Shape: if any(not isinstance(x, Integral) for x in content): raise ValueError("Shape must be a tuple of integers") diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py index 06fb92e7..4bb77206 100644 --- a/python/tvm_ffi/container.py +++ b/python/tvm_ffi/container.py @@ -148,6 +148,11 @@ class Array(core.Object, Sequence[T]): """ + # tvm-ffi-stubgen(begin): object/ffi.Array + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, input_list: Iterable[T]) -> None: """Construct an Array from a Python sequence.""" self.__init_handle_by_constructor__(_ffi_api.Array, *input_list) @@ -290,6 +295,11 @@ class Map(core.Object, Mapping[K, V]): """ + # tvm-ffi-stubgen(begin): object/ffi.Map + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, input_dict: Mapping[K, V]) -> None: """Construct a Map from a Python mapping.""" list_kvs: list[Any] = [] diff --git a/python/tvm_ffi/stub/analysis.py b/python/tvm_ffi/stub/analysis.py index 03dbe364..69b5d88f 100644 --- a/python/tvm_ffi/stub/analysis.py +++ b/python/tvm_ffi/stub/analysis.py @@ -18,6 +18,7 @@ from __future__ import annotations +from tvm_ffi._ffi_api import GetRegisteredTypeKeys from tvm_ffi.registry import list_global_func_names from . import consts as C @@ -34,8 +35,27 @@ def collect_global_funcs() -> dict[str, list[FuncInfo]]: except ValueError: print(f"{C.TERM_YELLOW}[Skipped] Invalid name in global function: {name}{C.TERM_RESET}") else: - global_funcs.setdefault(prefix, []).append(FuncInfo.from_global_name(name)) + try: + global_funcs.setdefault(prefix, []).append(FuncInfo.from_global_name(name)) + except Exception: + print(f"{C.TERM_YELLOW}[Skipped] Function has no type schema: {name}{C.TERM_RESET}") # Ensure stable ordering for deterministic output. for k in list(global_funcs.keys()): global_funcs[k].sort(key=lambda x: x.schema.name) return global_funcs + + +def collect_type_keys() -> dict[str, list[str]]: + """Collect registered object type keys from TVM FFI's global registry.""" + global_objects: dict[str, list[str]] = {} + for type_key in GetRegisteredTypeKeys(): + try: + prefix, _ = type_key.rsplit(".", 1) + except ValueError: + pass + else: + global_objects.setdefault(prefix, []).append(type_key) + # Ensure stable ordering for deterministic output. + for k in list(global_objects.keys()): + global_objects[k].sort() + return global_objects diff --git a/python/tvm_ffi/stub/cli.py b/python/tvm_ffi/stub/cli.py index 85697f52..2504dce1 100644 --- a/python/tvm_ffi/stub/cli.py +++ b/python/tvm_ffi/stub/cli.py @@ -14,13 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# tvm-ffi-stubgen(skip-file) """TVM-FFI Stub Generator (``tvm-ffi-stubgen``).""" from __future__ import annotations import argparse import ctypes +import importlib import sys import traceback from pathlib import Path @@ -28,9 +28,9 @@ from . import codegen as G from . import consts as C -from .analysis import collect_global_funcs +from .analysis import collect_global_funcs, collect_type_keys from .file_utils import FileInfo, collect_files -from .utils import Options +from .utils import FuncInfo, Options def _fn_ty_map(ty_map: dict[str, str], ty_used: set[str]) -> Callable[[str], str]: @@ -55,72 +55,44 @@ def __main__() -> int: overview and examples of the block syntax. """ opt = _parse_args() + for imp in opt.imports or []: + importlib.import_module(imp) + if opt.init_path: + opt.files.append(opt.init_path) dlls = [ctypes.CDLL(lib) for lib in opt.dlls] files: list[FileInfo] = collect_files([Path(f) for f in opt.files]) + global_funcs: dict[str, list[FuncInfo]] = collect_global_funcs() - # Stage 1: Process `tvm-ffi-stubgen(ty-map)` + # Stage 1: Collect information + # - type maps: `tvm-ffi-stubgen(ty-map)` + # - defined global functions: `tvm-ffi-stubgen(begin): global/...` + # - defined object types: `tvm-ffi-stubgen(begin): object/...` ty_map: dict[str, str] = C.TY_MAP_DEFAULTS.copy() - - def _stage_1(file: FileInfo) -> None: - for code in file.code_blocks: - if code.kind == "ty-map": - try: - lhs, rhs = code.param.split("->") - except ValueError as e: - raise ValueError( - f"Invalid ty_map format at line {code.lineno_start}. Example: `A.B -> C.D`" - ) from e - ty_map[lhs.strip()] = rhs.strip() - for file in files: try: - _stage_1(file) + _stage_1(file, ty_map) except Exception: print( f'{C.TERM_RED}[Failed] File "{file.path}": {traceback.format_exc()}{C.TERM_RESET}' ) - # Stage 2: Process + # Stage 2. Generate stubs if they are not defined on the file. + if opt.init_path: + _stage_2( + files, + init_path=Path(opt.init_path).resolve(), + global_funcs=global_funcs, + ) + + # Stage 3: Process # - `tvm-ffi-stubgen(begin): global/...` # - `tvm-ffi-stubgen(begin): object/...` - global_funcs = collect_global_funcs() - - def _stage_2(file: FileInfo) -> None: - all_defined = set() + for file in files: if opt.verbose: print(f"{C.TERM_CYAN}[File] {file.path}{C.TERM_RESET}") - ty_used: set[str] = set() - ty_on_file: set[str] = set() - fn_ty_map_fn = _fn_ty_map(ty_map, ty_used) - # Stage 2.1. Process `tvm-ffi-stubgen(begin): global/...` - for code in file.code_blocks: - if code.kind == "global": - funcs = global_funcs.get(code.param, []) - for func in funcs: - all_defined.add(func.schema.name) - G.generate_global_funcs(code, funcs, fn_ty_map_fn, opt) - # Stage 2.2. Process `tvm-ffi-stubgen(begin): object/...` - for code in file.code_blocks: - if code.kind == "object": - type_key = code.param - ty_on_file.add(ty_map.get(type_key, type_key)) - G.generate_object(code, fn_ty_map_fn, opt) - # Stage 2.3. Add imports for used types. - for code in file.code_blocks: - if code.kind == "import": - G.generate_imports(code, ty_used - ty_on_file, opt) - break # Only one import block per file is supported for now. - # Stage 2.4. Add `__all__` for defined classes and functions. - for code in file.code_blocks: - if code.kind == "__all__": - G.generate_all(code, all_defined | ty_on_file, opt) - break # Only one __all__ block per file is supported for now. - file.update(show_diff=opt.verbose, dry_run=opt.dry_run) - - for file in files: try: - _stage_2(file) - except: + _stage_3(file, opt, ty_map, global_funcs) + except Exception: print( f'{C.TERM_RED}[Failed] File "{file.path}": {traceback.format_exc()}{C.TERM_RESET}' ) @@ -128,6 +100,122 @@ def _stage_2(file: FileInfo) -> None: return 0 +def _stage_1( + file: FileInfo, + ty_map: dict[str, str], +) -> None: + for code in file.code_blocks: + if code.kind == "ty-map": + try: + assert isinstance(code.param, str) + lhs, rhs = code.param.split("->") + except ValueError as e: + raise ValueError( + f"Invalid ty_map format at line {code.lineno_start}. Example: `A.B -> C.D`" + ) from e + ty_map[lhs.strip()] = rhs.strip() + + +def _stage_2( + files: list[FileInfo], + init_path: Path, + global_funcs: dict[str, list[FuncInfo]], +) -> None: + def _find_or_insert_file(path: Path) -> FileInfo: + ret: FileInfo | None + if not path.exists(): + ret = FileInfo(path=path, lines=(), code_blocks=[]) + else: + for file in files: + if path.samefile(file.path): + return file + ret = FileInfo.from_file(file=path) + assert ret is not None, f"Failed to read file: {path}" + files.append(ret) + return ret + + # Step 0. Find out functions and classes already defined on files. + defined_func_prefixes: set[str] = { # type: ignore[union-attr] + code.param[0] for file in files for code in file.code_blocks if code.kind == "global" + } + defined_objs: set[str] = { # type: ignore[assignment] + code.param for file in files for code in file.code_blocks if code.kind == "object" + } | C.BUILTIN_TYPE_KEYS + + # Step 1. Generate missing `_ffi_api.py` and `__init__.py` under each prefix. + prefixes: dict[str, list[str]] = collect_type_keys() + for prefix in global_funcs: + prefixes.setdefault(prefix, []) + + for prefix, obj_names in prefixes.items(): + if prefix.startswith("testing") or prefix.startswith("ffi"): + continue + funcs = sorted( + [] if prefix in defined_func_prefixes else global_funcs.get(prefix, []), + key=lambda f: f.schema.name, + ) + objs = sorted(set(obj_names) - defined_objs) + if not funcs and not objs: + continue + # Step 1.1. Create target directory if not exists + directory = init_path / prefix.replace(".", "/") + directory.mkdir(parents=True, exist_ok=True) + # Step 1.2. Generate `_ffi_api.py` + target_path = directory / "_ffi_api.py" + target_file = _find_or_insert_file(target_path) + with target_path.open("a", encoding="utf-8") as f: + f.write(G.generate_ffi_api(target_file.code_blocks, prefix, objs)) + target_file.reload() + # Step 1.3. Generate `__init__.py` + target_path = directory / "__init__.py" + target_file = _find_or_insert_file(target_path) + with target_path.open("a", encoding="utf-8") as f: + f.write(G.generate_init(target_file.code_blocks, prefix, submodule="_ffi_api")) + target_file.reload() + + +def _stage_3( + file: FileInfo, + opt: Options, + ty_map: dict[str, str], + global_funcs: dict[str, list[FuncInfo]], +) -> None: + all_defined = set() + ty_used: set[str] = set() + ty_on_file: set[str] = set() + fn_ty_map_fn = _fn_ty_map(ty_map, ty_used) + # Stage 2.1. Process `tvm-ffi-stubgen(begin): global/...` + for code in file.code_blocks: + if code.kind == "global": + funcs = global_funcs.get(code.param[0], []) + for func in funcs: + all_defined.add(func.schema.name) + G.generate_global_funcs(code, funcs, fn_ty_map_fn, opt) + # Stage 2.2. Process `tvm-ffi-stubgen(begin): object/...` + for code in file.code_blocks: + if code.kind == "object": + type_key = code.param + assert isinstance(type_key, str) + ty_on_file.add(ty_map.get(type_key, type_key)) + G.generate_object(code, fn_ty_map_fn, opt) + # Stage 2.3. Add imports for used types. + for code in file.code_blocks: + if code.kind == "import": + G.generate_imports(code, ty_used - ty_on_file, opt) + break # Only one import block per file is supported for now. + # Stage 2.4. Add `__all__` for defined classes and functions. + for code in file.code_blocks: + if code.kind == "__all__": + G.generate_all(code, all_defined | ty_on_file, opt) + break # Only one __all__ block per file is supported for now. + # Stage 2.5. Process `tvm-ffi-stubgen(begin): export/...` + for code in file.code_blocks: + if code.kind == "export": + G.generate_export(code) + # Finalize: write back to file + file.update(verbose=opt.verbose, dry_run=opt.dry_run) + + def _parse_args() -> Options: class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelpFormatter): pass @@ -149,16 +237,16 @@ class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelp " # Preload TVM runtime / extension libraries\n" " tvm-ffi-stubgen --dlls build/libtvm_runtime.so build/libmy_ext.so my_pkg/_ffi_api.py\n\n" "Stub block syntax (placed in your source):\n" - " # tvm-ffi-stubgen(begin): global/\n" + f" {C.STUB_BEGIN} global/\n" " ... generated function stubs ...\n" - " # tvm-ffi-stubgen(end)\n\n" - " # tvm-ffi-stubgen(begin): object/\n" - " # tvm-ffi-stubgen(ty_map): list -> Sequence\n" - " # tvm-ffi-stubgen(ty_map): dict -> Mapping\n" + f" {C.STUB_END}\n\n" + f" {C.STUB_BEGIN} object/\n" + f" {C.STUB_TY_MAP}: list -> Sequence\n" + f" {C.STUB_TY_MAP}: dict -> Mapping\n" " ... generated fields and methods ...\n" - " # tvm-ffi-stubgen(end)\n\n" + f" {C.STUB_END}\n\n" " # Skip a file entirely\n" - " # tvm-ffi-stubgen(skip-file)\n\n" + f" {C.STUB_SKIP_FILE}\n\n" "Tips:\n" " - Only .py/.pyi files are updated; directories are scanned recursively.\n" " - Import any aliases you use in ty_map under TYPE_CHECKING, e.g.\n" @@ -167,6 +255,12 @@ class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelp " is provided by native extensions.\n" ), ) + parser.add_argument( + "--imports", + nargs="*", + metavar="IMPORTS", + help=("Additional imports to load before generation."), + ) parser.add_argument( "--dlls", nargs="*", @@ -179,13 +273,19 @@ class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelp ), default=[], ) + parser.add_argument( + "--init-path", + type=str, + default="", + help="If specified, generate stubs under the given package prefix.", + ) parser.add_argument( "--indent", type=int, default=4, help=( "Extra spaces added inside each generated block, relative to the " - "indentation of the corresponding '# tvm-ffi-stubgen(begin):' line." + f"indentation of the corresponding '{C.STUB_BEGIN}' line." ), ) parser.add_argument( diff --git a/python/tvm_ffi/stub/codegen.py b/python/tvm_ffi/stub/codegen.py index c15624a9..08a71933 100644 --- a/python/tvm_ffi/stub/codegen.py +++ b/python/tvm_ffi/stub/codegen.py @@ -26,22 +26,27 @@ def generate_global_funcs( - code: CodeBlock, global_funcs: list[FuncInfo], fn_ty_map: Callable[[str], str], opt: Options + code: CodeBlock, + global_funcs: list[FuncInfo], + fn_ty_map: Callable[[str], str], + opt: Options, ) -> None: """Generate function signatures for global functions.""" assert len(code.lines) >= 2 if not global_funcs: return + assert isinstance(code.param, tuple) + prefix, import_from = code.param + if not import_from: + import_from = "tvm_ffi" results: list[str] = [ "# fmt: off", + "# isort: off", + f"from {import_from} import init_ffi_api as _INIT", + f'_INIT("{prefix}", __name__)', + "# isort: on", "if TYPE_CHECKING:", - *[ - func.gen( - fn_ty_map, - indent=opt.indent, - ) - for func in global_funcs - ], + *[func.gen(fn_ty_map, indent=opt.indent) for func in global_funcs], "# fmt: on", ] indent = " " * code.indent @@ -55,6 +60,7 @@ def generate_global_funcs( def generate_object(code: CodeBlock, fn_ty_map: Callable[[str], str], opt: Options) -> None: """Generate a class definition for an object type.""" assert len(code.lines) >= 2 + assert isinstance(code.param, str) info = ObjectInfo.from_type_key(code.param) if info.methods: results = [ @@ -89,8 +95,6 @@ def generate_imports(code: CodeBlock, ty_used: set[str], opt: Options) -> None: module = module.replace(mod_prefix, mod_replacement, 1) break ty_collected.setdefault(module, []).append(name) - if not ty_collected: - return def _make_line(module: str, names: list[str], indent: int) -> str: names = ", ".join(sorted(set(names))) @@ -135,3 +139,91 @@ def generate_all(code: CodeBlock, names: set[str], opt: Options) -> None: *[f'{indent}"{name}",' for name in sorted(names)], code.lines[-1], ] + + +def generate_export(code: CodeBlock) -> None: + """Generate an `__all__` variable for the given names.""" + assert len(code.lines) >= 2 + + mod = code.param + code.lines = [ + code.lines[0], + "# fmt: off", + "# isort: off", + f"from .{mod} import * # noqa: F403", + f"from .{mod} import __all__ as {mod}__all__", + 'if "__all__" not in globals(): __all__ = []', + f"__all__.extend({mod}__all__)", + "# isort: on", + "# fmt: on", + code.lines[-1], + ] + + +def generate_ffi_api( + code_blocks: list[CodeBlock], + module_name: str, + type_keys: list[str], +) -> str: + """Generate the initial FFI API stub code for a given module.""" + append = "" + if not code_blocks: + append += f"""\"\"\"FFI API bindings for {module_name}.\"\"\" +""" + # Part 1. Imports + if not any(code.kind == "import" for code in code_blocks): + append += f""" +{C.STUB_BEGIN} import +{C.STUB_END} +""" + # Part 2. Global functions + if not any(code.kind == "global" for code in code_blocks): + append += f""" +{C.STUB_BEGIN} global/{module_name} +{C.STUB_END} +""" + # Part 3. __all__ + if not any(code.kind == "all" for code in code_blocks): + append += f""" +__all__ = [ + {C.STUB_BEGIN} __all__ + {C.STUB_END} +] +""" + # Part 4. Object types + if type_keys: + append += """ + +# isort: off +import tvm_ffi +# isort: on + +""" + for type_key in sorted(type_keys): + type_cls_name = type_key.rsplit(".", 1)[-1] + append += f""" +@tvm_ffi.register_object("{type_key}") +class {type_cls_name}(tvm_ffi.Object): + \"\"\"FFI binding for `{type_key}`.\"\"\" + + {C.STUB_BEGIN} object/{type_key} + {C.STUB_END} +""" + return append + + +def generate_init( + code_blocks: list[CodeBlock], + module_name: str, + submodule: str = "_ffi_api", +) -> str: + """Generate the `__init__.py` file for the `tvm_ffi` package.""" + code = f""" +{C.STUB_BEGIN} export/{submodule} +{C.STUB_END} +""" + if not code_blocks: + return f"""\"\"\"Package {module_name}.\"\"\"\n""" + code + if not any(code.kind == "export" for code in code_blocks): + return code + return "" diff --git a/python/tvm_ffi/stub/consts.py b/python/tvm_ffi/stub/consts.py index 6922254d..1caad280 100644 --- a/python/tvm_ffi/stub/consts.py +++ b/python/tvm_ffi/stub/consts.py @@ -58,3 +58,15 @@ FN_NAME_MAP = { "__ffi_init__": "__c_ffi_init__", } + +BUILTIN_TYPE_KEYS = { + "ffi.Bytes", + "ffi.Error", + "ffi.Function", + "ffi.Object", + "ffi.OpaquePyObject", + "ffi.SmallBytes", + "ffi.SmallStr", + "ffi.String", + "ffi.Tensor", +} diff --git a/python/tvm_ffi/stub/file_utils.py b/python/tvm_ffi/stub/file_utils.py index f100c553..a03c97ed 100644 --- a/python/tvm_ffi/stub/file_utils.py +++ b/python/tvm_ffi/stub/file_utils.py @@ -31,15 +31,15 @@ class CodeBlock: """A block of code to be generated in a stub file.""" - kind: Literal["global", "object", "ty-map", "import", "__all__", None] - param: str + kind: Literal["global", "object", "ty-map", "import", "export", "__all__", None] + param: str | tuple[str, ...] lineno_start: int lineno_end: int | None lines: list[str] def __post_init__(self) -> None: """Validate the code block after initialization.""" - assert self.kind in {"global", "object", "ty-map", "import", "__all__", None} + assert self.kind in {"global", "object", "ty-map", "import", "export", "__all__", None} @property def indent(self) -> int: @@ -61,19 +61,27 @@ def from_begin_line(lineo: int, line: str) -> CodeBlock: lines=[], ) assert line.startswith(C.STUB_BEGIN) + param: str | tuple[str, ...] stub = line[len(C.STUB_BEGIN) :].strip() if stub.startswith("global/"): kind = "global" param = stub[len("global/") :].strip() + if "@" in param: + param = tuple(param.split("@")) + else: + param = (param, "") elif stub.startswith("object/"): kind = "object" param = stub[len("object/") :].strip() elif stub.startswith("ty-map/"): kind = "ty-map" param = stub[len("ty-map/") :].strip() - elif stub.startswith("import"): + elif stub == "import": kind = "import" param = "" + elif stub.startswith("export/"): + kind = "export" + param = stub[len("export/") :].strip() elif stub == "__all__": kind = "__all__" param = "" @@ -96,12 +104,14 @@ class FileInfo: lines: tuple[str, ...] code_blocks: list[CodeBlock] - def update(self, show_diff: bool, dry_run: bool) -> bool: + def update(self, verbose: bool, dry_run: bool) -> bool: """Update the file's lines based on the current code blocks and optionally show a diff.""" new_lines = tuple(line for block in self.code_blocks for line in block.lines) if self.lines == new_lines: + if verbose: + print(f"{C.TERM_CYAN}-----> Unchanged{C.TERM_RESET}") return False - if show_diff: + if verbose: for line in difflib.unified_diff(self.lines, new_lines, lineterm=""): # Skip placeholder headers when fromfile/tofile are unspecified if line.startswith("---") or line.startswith("+++"): @@ -126,11 +136,8 @@ def from_file(file: Path) -> FileInfo | None: # noqa: PLR0912 file = file.resolve() has_marker = False lines: list[str] = file.read_text(encoding="utf-8").splitlines() - for line_no, line in enumerate(lines, start=1): + for _, line in enumerate(lines, start=1): if line.strip().startswith(C.STUB_SKIP_FILE): - print( - f"{C.TERM_YELLOW}[Skipped] skip-file marker found on line {line_no}: {file}{C.TERM_RESET}" - ) return None if line.strip().startswith(C.STUB_PREFIX): has_marker = True @@ -175,6 +182,12 @@ def from_file(file: Path) -> FileInfo | None: # noqa: PLR0912 raise ValueError("Unclosed stub block at end of file") return FileInfo(path=file, lines=tuple(lines), code_blocks=codes) + def reload(self) -> None: + """Reload the code blocks from disk while preserving original `lines`.""" + source = FileInfo.from_file(self.path) + assert source is not None, f"File no longer exists or valid: {self.path}" + self.code_blocks = source.code_blocks + def collect_files(paths: list[Path]) -> list[FileInfo]: """Collect all files from the given paths and parse them into FileInfo objects.""" diff --git a/python/tvm_ffi/stub/utils.py b/python/tvm_ffi/stub/utils.py index e8beb416..d092c708 100644 --- a/python/tvm_ffi/stub/utils.py +++ b/python/tvm_ffi/stub/utils.py @@ -31,7 +31,9 @@ class Options: """Command line options for stub generation.""" + imports: list[str] = dataclasses.field(default_factory=list) dlls: list[str] = dataclasses.field(default_factory=list) + init_path: str = "" indent: int = 4 files: list[str] = dataclasses.field(default_factory=list) verbose: bool = False diff --git a/python/tvm_ffi/testing/__init__.py b/python/tvm_ffi/testing/__init__.py new file mode 100644 index 00000000..cd357364 --- /dev/null +++ b/python/tvm_ffi/testing/__init__.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Testing utilities.""" + +from .testing import ( + TestIntPair, + TestObjectBase, + TestObjectDerived, + _SchemaAllTypes, + _TestCxxClassBase, + _TestCxxClassDerived, + _TestCxxClassDerivedDerived, + _TestCxxInitSubset, + add_one, + create_object, + make_unregistered_object, +) diff --git a/python/tvm_ffi/testing/_ffi_api.py b/python/tvm_ffi/testing/_ffi_api.py new file mode 100644 index 00000000..e587fc73 --- /dev/null +++ b/python/tvm_ffi/testing/_ffi_api.py @@ -0,0 +1,133 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI API for namespace `testing`.""" + +# tvm-ffi-stubgen(begin): import +# fmt: off +# isort: off +from __future__ import annotations +from typing import Any, Callable, TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from tvm_ffi import Device, Object, Tensor, dtype + from tvm_ffi.testing import TestIntPair +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) + +# tvm-ffi-stubgen(begin): global/testing@..registry +# fmt: off +# isort: off +from ..registry import init_ffi_api as _INIT +_INIT("testing", __name__) +# isort: on +if TYPE_CHECKING: + def TestIntPairSum(_0: TestIntPair, /) -> int: ... + def add_one(_0: int, /) -> int: ... + def apply(*args: Any) -> Any: ... + def echo(*args: Any) -> Any: ... + def get_add_one_c_symbol() -> int: ... + def get_mlir_add_one_c_symbol() -> int: ... + def make_unregistered_object() -> Object: ... + def nop(*args: Any) -> Any: ... + def object_use_count(_0: Object, /) -> int: ... + def optional_tensor_view_has_value(_0: Tensor | None, /) -> bool: ... + def run_check_signal(_0: int, /) -> None: ... + def schema_arr_map_opt(_0: Sequence[int | None], _1: Mapping[str, Sequence[int]], _2: str | None, /) -> Mapping[str, Sequence[int]]: ... + def schema_id_any(_0: Any, /) -> Any: ... + def schema_id_arr(_0: Sequence[Any], /) -> Sequence[Any]: ... + def schema_id_arr_int(_0: Sequence[int], /) -> Sequence[int]: ... + def schema_id_arr_obj(_0: Sequence[Object], /) -> Sequence[Object]: ... + def schema_id_arr_str(_0: Sequence[str], /) -> Sequence[str]: ... + def schema_id_bool(_0: bool, /) -> bool: ... + def schema_id_bytes(_0: bytes, /) -> bytes: ... + def schema_id_device(_0: Device, /) -> Device: ... + def schema_id_dltensor(_0: Tensor, /) -> Tensor: ... + def schema_id_dtype(_0: dtype, /) -> dtype: ... + def schema_id_float(_0: float, /) -> float: ... + def schema_id_func(_0: Callable[..., Any], /) -> Callable[..., Any]: ... + def schema_id_func_typed(_0: Callable[[int, float, Callable[..., Any]], None], /) -> Callable[[int, float, Callable[..., Any]], None]: ... + def schema_id_int(_0: int, /) -> int: ... + def schema_id_map(_0: Mapping[Any, Any], /) -> Mapping[Any, Any]: ... + def schema_id_map_str_int(_0: Mapping[str, int], /) -> Mapping[str, int]: ... + def schema_id_map_str_obj(_0: Mapping[str, Object], /) -> Mapping[str, Object]: ... + def schema_id_map_str_str(_0: Mapping[str, str], /) -> Mapping[str, str]: ... + def schema_id_object(_0: Object, /) -> Object: ... + def schema_id_opt_int(_0: int | None, /) -> int | None: ... + def schema_id_opt_obj(_0: Object | None, /) -> Object | None: ... + def schema_id_opt_str(_0: str | None, /) -> str | None: ... + def schema_id_string(_0: str, /) -> str: ... + def schema_id_tensor(_0: Tensor, /) -> Tensor: ... + def schema_id_variant_int_str(_0: int | str, /) -> int | str: ... + def schema_no_args() -> int: ... + def schema_no_args_no_return() -> None: ... + def schema_no_return(_0: int, /) -> None: ... + def schema_packed(*args: Any) -> Any: ... + def schema_tensor_view_input(_0: Tensor, /) -> None: ... + def schema_variant_mix(_0: int | str | Sequence[int], /) -> int | str | Sequence[int]: ... + def test_raise_error(_0: str, _1: str, /) -> None: ... +# fmt: on +# tvm-ffi-stubgen(end) + +__all__ = [ + # tvm-ffi-stubgen(begin): __all__ + "TestIntPairSum", + "add_one", + "apply", + "echo", + "get_add_one_c_symbol", + "get_mlir_add_one_c_symbol", + "make_unregistered_object", + "nop", + "object_use_count", + "optional_tensor_view_has_value", + "run_check_signal", + "schema_arr_map_opt", + "schema_id_any", + "schema_id_arr", + "schema_id_arr_int", + "schema_id_arr_obj", + "schema_id_arr_str", + "schema_id_bool", + "schema_id_bytes", + "schema_id_device", + "schema_id_dltensor", + "schema_id_dtype", + "schema_id_float", + "schema_id_func", + "schema_id_func_typed", + "schema_id_int", + "schema_id_map", + "schema_id_map_str_int", + "schema_id_map_str_obj", + "schema_id_map_str_str", + "schema_id_object", + "schema_id_opt_int", + "schema_id_opt_obj", + "schema_id_opt_str", + "schema_id_string", + "schema_id_tensor", + "schema_id_variant_int_str", + "schema_no_args", + "schema_no_args_no_return", + "schema_no_return", + "schema_packed", + "schema_tensor_view_input", + "schema_variant_mix", + "test_raise_error", + # tvm-ffi-stubgen(end) +] diff --git a/python/tvm_ffi/testing.py b/python/tvm_ffi/testing/testing.py similarity index 97% rename from python/tvm_ffi/testing.py rename to python/tvm_ffi/testing/testing.py index 2b157b1e..e9fcd7c4 100644 --- a/python/tvm_ffi/testing.py +++ b/python/tvm_ffi/testing/testing.py @@ -31,10 +31,10 @@ from typing import ClassVar -from . import _ffi_api -from .core import Object -from .dataclasses import c_class, field -from .registry import get_global_func, register_object +from .. import _ffi_api +from ..core import Object +from ..dataclasses import c_class, field +from ..registry import get_global_func, register_object @register_object("testing.TestObjectBase") diff --git a/tests/python/test_stubgen.py b/tests/python/test_stubgen.py index ef531433..ab860b20 100644 --- a/tests/python/test_stubgen.py +++ b/tests/python/test_stubgen.py @@ -35,7 +35,8 @@ def _identity_ty_map(name: str) -> str: def test_codeblock_from_begin_line_variants() -> None: cases = [ - (f"{C.STUB_BEGIN} global/example", "global", "example"), + (f"{C.STUB_BEGIN} global/example", "global", ("example", "")), + (f"{C.STUB_BEGIN} global/example@.registry", "global", ("example", ".registry")), (f"{C.STUB_BEGIN} object/testing.TestObjectBase", "object", "testing.TestObjectBase"), (f"{C.STUB_BEGIN} ty-map/custom", "ty-map", "custom"), (f"{C.STUB_BEGIN} import", "import", ""), @@ -93,7 +94,7 @@ def test_fileinfo_from_file_parses_blocks(tmp_path: Path) -> None: assert first.kind is None and first.lines == ["first = 1"] assert stub.kind == "global" - assert stub.param == "demo.func" + assert stub.param == ("demo.func", "") assert stub.lineno_start == 2 assert stub.lineno_end == 4 assert stub.lines == [ @@ -212,7 +213,7 @@ def ty_map(name: str) -> str: def test_generate_global_funcs_updates_block() -> None: code = CodeBlock( kind="global", - param="testing", + param=("testing", ""), lineno_start=1, lineno_end=2, lines=[f"{C.STUB_BEGIN} global/testing", C.STUB_END], @@ -231,6 +232,10 @@ def test_generate_global_funcs_updates_block() -> None: assert code.lines == [ f"{C.STUB_BEGIN} global/testing", "# fmt: off", + "# isort: off", + "from tvm_ffi import init_ffi_api as _INIT", + '_INIT("testing", __name__)', + "# isort: on", "if TYPE_CHECKING:", " def add_one(_0: int, /) -> int: ...", "# fmt: on",