diff --git a/python/tvm_ffi/_ffi_api.py b/python/tvm_ffi/_ffi_api.py index ab6704a5..c1afffc4 100644 --- a/python/tvm_ffi/_ffi_api.py +++ b/python/tvm_ffi/_ffi_api.py @@ -29,10 +29,12 @@ # fmt: on # tvm-ffi-stubgen(end) -from . import registry - -# tvm-ffi-stubgen(begin): global/ffi +# tvm-ffi-stubgen(begin): global/ffi@.registry # fmt: off +# isort: off +from .registry import init_ffi_api as _INIT +_INIT("ffi", __name__) +# isort: on if TYPE_CHECKING: def Array(*args: Any) -> Any: ... def ArrayGetItem(_0: Sequence[Any], _1: int, /) -> Any: ... @@ -72,8 +74,6 @@ def ToJSONGraphString(_0: Any, _1: Any, /) -> str: ... # fmt: on # tvm-ffi-stubgen(end) -registry.init_ffi_api("ffi", __name__) - __all__ = [ # tvm-ffi-stubgen(begin): __all__ diff --git a/python/tvm_ffi/_tensor.py b/python/tvm_ffi/_tensor.py index 5005bc65..d4eb22d5 100644 --- a/python/tvm_ffi/_tensor.py +++ b/python/tvm_ffi/_tensor.py @@ -57,6 +57,11 @@ class Shape(tuple, PyNativeObject): _tvm_ffi_cached_object: Any + # tvm-ffi-stubgen(begin): object/ffi.Shape + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __new__(cls, content: tuple[int, ...]) -> Shape: if any(not isinstance(x, Integral) for x in content): raise ValueError("Shape must be a tuple of integers") diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py index 06fb92e7..4bb77206 100644 --- a/python/tvm_ffi/container.py +++ b/python/tvm_ffi/container.py @@ -148,6 +148,11 @@ class Array(core.Object, Sequence[T]): """ + # tvm-ffi-stubgen(begin): object/ffi.Array + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, input_list: Iterable[T]) -> None: """Construct an Array from a Python sequence.""" self.__init_handle_by_constructor__(_ffi_api.Array, *input_list) @@ -290,6 +295,11 @@ class Map(core.Object, Mapping[K, V]): """ + # tvm-ffi-stubgen(begin): object/ffi.Map + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, input_dict: Mapping[K, V]) -> None: """Construct a Map from a Python mapping.""" list_kvs: list[Any] = [] diff --git a/python/tvm_ffi/stub/analysis.py b/python/tvm_ffi/stub/analysis.py index 03dbe364..69b5d88f 100644 --- a/python/tvm_ffi/stub/analysis.py +++ b/python/tvm_ffi/stub/analysis.py @@ -18,6 +18,7 @@ from __future__ import annotations +from tvm_ffi._ffi_api import GetRegisteredTypeKeys from tvm_ffi.registry import list_global_func_names from . import consts as C @@ -34,8 +35,27 @@ def collect_global_funcs() -> dict[str, list[FuncInfo]]: except ValueError: print(f"{C.TERM_YELLOW}[Skipped] Invalid name in global function: {name}{C.TERM_RESET}") else: - global_funcs.setdefault(prefix, []).append(FuncInfo.from_global_name(name)) + try: + global_funcs.setdefault(prefix, []).append(FuncInfo.from_global_name(name)) + except Exception: + print(f"{C.TERM_YELLOW}[Skipped] Function has no type schema: {name}{C.TERM_RESET}") # Ensure stable ordering for deterministic output. for k in list(global_funcs.keys()): global_funcs[k].sort(key=lambda x: x.schema.name) return global_funcs + + +def collect_type_keys() -> dict[str, list[str]]: + """Collect registered object type keys from TVM FFI's global registry.""" + global_objects: dict[str, list[str]] = {} + for type_key in GetRegisteredTypeKeys(): + try: + prefix, _ = type_key.rsplit(".", 1) + except ValueError: + pass + else: + global_objects.setdefault(prefix, []).append(type_key) + # Ensure stable ordering for deterministic output. + for k in list(global_objects.keys()): + global_objects[k].sort() + return global_objects diff --git a/python/tvm_ffi/stub/cli.py b/python/tvm_ffi/stub/cli.py index 85697f52..2504dce1 100644 --- a/python/tvm_ffi/stub/cli.py +++ b/python/tvm_ffi/stub/cli.py @@ -14,13 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# tvm-ffi-stubgen(skip-file) """TVM-FFI Stub Generator (``tvm-ffi-stubgen``).""" from __future__ import annotations import argparse import ctypes +import importlib import sys import traceback from pathlib import Path @@ -28,9 +28,9 @@ from . import codegen as G from . import consts as C -from .analysis import collect_global_funcs +from .analysis import collect_global_funcs, collect_type_keys from .file_utils import FileInfo, collect_files -from .utils import Options +from .utils import FuncInfo, Options def _fn_ty_map(ty_map: dict[str, str], ty_used: set[str]) -> Callable[[str], str]: @@ -55,72 +55,44 @@ def __main__() -> int: overview and examples of the block syntax. """ opt = _parse_args() + for imp in opt.imports or []: + importlib.import_module(imp) + if opt.init_path: + opt.files.append(opt.init_path) dlls = [ctypes.CDLL(lib) for lib in opt.dlls] files: list[FileInfo] = collect_files([Path(f) for f in opt.files]) + global_funcs: dict[str, list[FuncInfo]] = collect_global_funcs() - # Stage 1: Process `tvm-ffi-stubgen(ty-map)` + # Stage 1: Collect information + # - type maps: `tvm-ffi-stubgen(ty-map)` + # - defined global functions: `tvm-ffi-stubgen(begin): global/...` + # - defined object types: `tvm-ffi-stubgen(begin): object/...` ty_map: dict[str, str] = C.TY_MAP_DEFAULTS.copy() - - def _stage_1(file: FileInfo) -> None: - for code in file.code_blocks: - if code.kind == "ty-map": - try: - lhs, rhs = code.param.split("->") - except ValueError as e: - raise ValueError( - f"Invalid ty_map format at line {code.lineno_start}. Example: `A.B -> C.D`" - ) from e - ty_map[lhs.strip()] = rhs.strip() - for file in files: try: - _stage_1(file) + _stage_1(file, ty_map) except Exception: print( f'{C.TERM_RED}[Failed] File "{file.path}": {traceback.format_exc()}{C.TERM_RESET}' ) - # Stage 2: Process + # Stage 2. Generate stubs if they are not defined on the file. + if opt.init_path: + _stage_2( + files, + init_path=Path(opt.init_path).resolve(), + global_funcs=global_funcs, + ) + + # Stage 3: Process # - `tvm-ffi-stubgen(begin): global/...` # - `tvm-ffi-stubgen(begin): object/...` - global_funcs = collect_global_funcs() - - def _stage_2(file: FileInfo) -> None: - all_defined = set() + for file in files: if opt.verbose: print(f"{C.TERM_CYAN}[File] {file.path}{C.TERM_RESET}") - ty_used: set[str] = set() - ty_on_file: set[str] = set() - fn_ty_map_fn = _fn_ty_map(ty_map, ty_used) - # Stage 2.1. Process `tvm-ffi-stubgen(begin): global/...` - for code in file.code_blocks: - if code.kind == "global": - funcs = global_funcs.get(code.param, []) - for func in funcs: - all_defined.add(func.schema.name) - G.generate_global_funcs(code, funcs, fn_ty_map_fn, opt) - # Stage 2.2. Process `tvm-ffi-stubgen(begin): object/...` - for code in file.code_blocks: - if code.kind == "object": - type_key = code.param - ty_on_file.add(ty_map.get(type_key, type_key)) - G.generate_object(code, fn_ty_map_fn, opt) - # Stage 2.3. Add imports for used types. - for code in file.code_blocks: - if code.kind == "import": - G.generate_imports(code, ty_used - ty_on_file, opt) - break # Only one import block per file is supported for now. - # Stage 2.4. Add `__all__` for defined classes and functions. - for code in file.code_blocks: - if code.kind == "__all__": - G.generate_all(code, all_defined | ty_on_file, opt) - break # Only one __all__ block per file is supported for now. - file.update(show_diff=opt.verbose, dry_run=opt.dry_run) - - for file in files: try: - _stage_2(file) - except: + _stage_3(file, opt, ty_map, global_funcs) + except Exception: print( f'{C.TERM_RED}[Failed] File "{file.path}": {traceback.format_exc()}{C.TERM_RESET}' ) @@ -128,6 +100,122 @@ def _stage_2(file: FileInfo) -> None: return 0 +def _stage_1( + file: FileInfo, + ty_map: dict[str, str], +) -> None: + for code in file.code_blocks: + if code.kind == "ty-map": + try: + assert isinstance(code.param, str) + lhs, rhs = code.param.split("->") + except ValueError as e: + raise ValueError( + f"Invalid ty_map format at line {code.lineno_start}. Example: `A.B -> C.D`" + ) from e + ty_map[lhs.strip()] = rhs.strip() + + +def _stage_2( + files: list[FileInfo], + init_path: Path, + global_funcs: dict[str, list[FuncInfo]], +) -> None: + def _find_or_insert_file(path: Path) -> FileInfo: + ret: FileInfo | None + if not path.exists(): + ret = FileInfo(path=path, lines=(), code_blocks=[]) + else: + for file in files: + if path.samefile(file.path): + return file + ret = FileInfo.from_file(file=path) + assert ret is not None, f"Failed to read file: {path}" + files.append(ret) + return ret + + # Step 0. Find out functions and classes already defined on files. + defined_func_prefixes: set[str] = { # type: ignore[union-attr] + code.param[0] for file in files for code in file.code_blocks if code.kind == "global" + } + defined_objs: set[str] = { # type: ignore[assignment] + code.param for file in files for code in file.code_blocks if code.kind == "object" + } | C.BUILTIN_TYPE_KEYS + + # Step 1. Generate missing `_ffi_api.py` and `__init__.py` under each prefix. + prefixes: dict[str, list[str]] = collect_type_keys() + for prefix in global_funcs: + prefixes.setdefault(prefix, []) + + for prefix, obj_names in prefixes.items(): + if prefix.startswith("testing") or prefix.startswith("ffi"): + continue + funcs = sorted( + [] if prefix in defined_func_prefixes else global_funcs.get(prefix, []), + key=lambda f: f.schema.name, + ) + objs = sorted(set(obj_names) - defined_objs) + if not funcs and not objs: + continue + # Step 1.1. Create target directory if not exists + directory = init_path / prefix.replace(".", "/") + directory.mkdir(parents=True, exist_ok=True) + # Step 1.2. Generate `_ffi_api.py` + target_path = directory / "_ffi_api.py" + target_file = _find_or_insert_file(target_path) + with target_path.open("a", encoding="utf-8") as f: + f.write(G.generate_ffi_api(target_file.code_blocks, prefix, objs)) + target_file.reload() + # Step 1.3. Generate `__init__.py` + target_path = directory / "__init__.py" + target_file = _find_or_insert_file(target_path) + with target_path.open("a", encoding="utf-8") as f: + f.write(G.generate_init(target_file.code_blocks, prefix, submodule="_ffi_api")) + target_file.reload() + + +def _stage_3( + file: FileInfo, + opt: Options, + ty_map: dict[str, str], + global_funcs: dict[str, list[FuncInfo]], +) -> None: + all_defined = set() + ty_used: set[str] = set() + ty_on_file: set[str] = set() + fn_ty_map_fn = _fn_ty_map(ty_map, ty_used) + # Stage 2.1. Process `tvm-ffi-stubgen(begin): global/...` + for code in file.code_blocks: + if code.kind == "global": + funcs = global_funcs.get(code.param[0], []) + for func in funcs: + all_defined.add(func.schema.name) + G.generate_global_funcs(code, funcs, fn_ty_map_fn, opt) + # Stage 2.2. Process `tvm-ffi-stubgen(begin): object/...` + for code in file.code_blocks: + if code.kind == "object": + type_key = code.param + assert isinstance(type_key, str) + ty_on_file.add(ty_map.get(type_key, type_key)) + G.generate_object(code, fn_ty_map_fn, opt) + # Stage 2.3. Add imports for used types. + for code in file.code_blocks: + if code.kind == "import": + G.generate_imports(code, ty_used - ty_on_file, opt) + break # Only one import block per file is supported for now. + # Stage 2.4. Add `__all__` for defined classes and functions. + for code in file.code_blocks: + if code.kind == "__all__": + G.generate_all(code, all_defined | ty_on_file, opt) + break # Only one __all__ block per file is supported for now. + # Stage 2.5. Process `tvm-ffi-stubgen(begin): export/...` + for code in file.code_blocks: + if code.kind == "export": + G.generate_export(code) + # Finalize: write back to file + file.update(verbose=opt.verbose, dry_run=opt.dry_run) + + def _parse_args() -> Options: class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelpFormatter): pass @@ -149,16 +237,16 @@ class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelp " # Preload TVM runtime / extension libraries\n" " tvm-ffi-stubgen --dlls build/libtvm_runtime.so build/libmy_ext.so my_pkg/_ffi_api.py\n\n" "Stub block syntax (placed in your source):\n" - " # tvm-ffi-stubgen(begin): global/\n" + f" {C.STUB_BEGIN} global/\n" " ... generated function stubs ...\n" - " # tvm-ffi-stubgen(end)\n\n" - " # tvm-ffi-stubgen(begin): object/\n" - " # tvm-ffi-stubgen(ty_map): list -> Sequence\n" - " # tvm-ffi-stubgen(ty_map): dict -> Mapping\n" + f" {C.STUB_END}\n\n" + f" {C.STUB_BEGIN} object/\n" + f" {C.STUB_TY_MAP}: list -> Sequence\n" + f" {C.STUB_TY_MAP}: dict -> Mapping\n" " ... generated fields and methods ...\n" - " # tvm-ffi-stubgen(end)\n\n" + f" {C.STUB_END}\n\n" " # Skip a file entirely\n" - " # tvm-ffi-stubgen(skip-file)\n\n" + f" {C.STUB_SKIP_FILE}\n\n" "Tips:\n" " - Only .py/.pyi files are updated; directories are scanned recursively.\n" " - Import any aliases you use in ty_map under TYPE_CHECKING, e.g.\n" @@ -167,6 +255,12 @@ class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelp " is provided by native extensions.\n" ), ) + parser.add_argument( + "--imports", + nargs="*", + metavar="IMPORTS", + help=("Additional imports to load before generation."), + ) parser.add_argument( "--dlls", nargs="*", @@ -179,13 +273,19 @@ class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelp ), default=[], ) + parser.add_argument( + "--init-path", + type=str, + default="", + help="If specified, generate stubs under the given package prefix.", + ) parser.add_argument( "--indent", type=int, default=4, help=( "Extra spaces added inside each generated block, relative to the " - "indentation of the corresponding '# tvm-ffi-stubgen(begin):' line." + f"indentation of the corresponding '{C.STUB_BEGIN}' line." ), ) parser.add_argument( diff --git a/python/tvm_ffi/stub/codegen.py b/python/tvm_ffi/stub/codegen.py index c15624a9..08a71933 100644 --- a/python/tvm_ffi/stub/codegen.py +++ b/python/tvm_ffi/stub/codegen.py @@ -26,22 +26,27 @@ def generate_global_funcs( - code: CodeBlock, global_funcs: list[FuncInfo], fn_ty_map: Callable[[str], str], opt: Options + code: CodeBlock, + global_funcs: list[FuncInfo], + fn_ty_map: Callable[[str], str], + opt: Options, ) -> None: """Generate function signatures for global functions.""" assert len(code.lines) >= 2 if not global_funcs: return + assert isinstance(code.param, tuple) + prefix, import_from = code.param + if not import_from: + import_from = "tvm_ffi" results: list[str] = [ "# fmt: off", + "# isort: off", + f"from {import_from} import init_ffi_api as _INIT", + f'_INIT("{prefix}", __name__)', + "# isort: on", "if TYPE_CHECKING:", - *[ - func.gen( - fn_ty_map, - indent=opt.indent, - ) - for func in global_funcs - ], + *[func.gen(fn_ty_map, indent=opt.indent) for func in global_funcs], "# fmt: on", ] indent = " " * code.indent @@ -55,6 +60,7 @@ def generate_global_funcs( def generate_object(code: CodeBlock, fn_ty_map: Callable[[str], str], opt: Options) -> None: """Generate a class definition for an object type.""" assert len(code.lines) >= 2 + assert isinstance(code.param, str) info = ObjectInfo.from_type_key(code.param) if info.methods: results = [ @@ -89,8 +95,6 @@ def generate_imports(code: CodeBlock, ty_used: set[str], opt: Options) -> None: module = module.replace(mod_prefix, mod_replacement, 1) break ty_collected.setdefault(module, []).append(name) - if not ty_collected: - return def _make_line(module: str, names: list[str], indent: int) -> str: names = ", ".join(sorted(set(names))) @@ -135,3 +139,91 @@ def generate_all(code: CodeBlock, names: set[str], opt: Options) -> None: *[f'{indent}"{name}",' for name in sorted(names)], code.lines[-1], ] + + +def generate_export(code: CodeBlock) -> None: + """Generate an `__all__` variable for the given names.""" + assert len(code.lines) >= 2 + + mod = code.param + code.lines = [ + code.lines[0], + "# fmt: off", + "# isort: off", + f"from .{mod} import * # noqa: F403", + f"from .{mod} import __all__ as {mod}__all__", + 'if "__all__" not in globals(): __all__ = []', + f"__all__.extend({mod}__all__)", + "# isort: on", + "# fmt: on", + code.lines[-1], + ] + + +def generate_ffi_api( + code_blocks: list[CodeBlock], + module_name: str, + type_keys: list[str], +) -> str: + """Generate the initial FFI API stub code for a given module.""" + append = "" + if not code_blocks: + append += f"""\"\"\"FFI API bindings for {module_name}.\"\"\" +""" + # Part 1. Imports + if not any(code.kind == "import" for code in code_blocks): + append += f""" +{C.STUB_BEGIN} import +{C.STUB_END} +""" + # Part 2. Global functions + if not any(code.kind == "global" for code in code_blocks): + append += f""" +{C.STUB_BEGIN} global/{module_name} +{C.STUB_END} +""" + # Part 3. __all__ + if not any(code.kind == "all" for code in code_blocks): + append += f""" +__all__ = [ + {C.STUB_BEGIN} __all__ + {C.STUB_END} +] +""" + # Part 4. Object types + if type_keys: + append += """ + +# isort: off +import tvm_ffi +# isort: on + +""" + for type_key in sorted(type_keys): + type_cls_name = type_key.rsplit(".", 1)[-1] + append += f""" +@tvm_ffi.register_object("{type_key}") +class {type_cls_name}(tvm_ffi.Object): + \"\"\"FFI binding for `{type_key}`.\"\"\" + + {C.STUB_BEGIN} object/{type_key} + {C.STUB_END} +""" + return append + + +def generate_init( + code_blocks: list[CodeBlock], + module_name: str, + submodule: str = "_ffi_api", +) -> str: + """Generate the `__init__.py` file for the `tvm_ffi` package.""" + code = f""" +{C.STUB_BEGIN} export/{submodule} +{C.STUB_END} +""" + if not code_blocks: + return f"""\"\"\"Package {module_name}.\"\"\"\n""" + code + if not any(code.kind == "export" for code in code_blocks): + return code + return "" diff --git a/python/tvm_ffi/stub/consts.py b/python/tvm_ffi/stub/consts.py index 6922254d..1caad280 100644 --- a/python/tvm_ffi/stub/consts.py +++ b/python/tvm_ffi/stub/consts.py @@ -58,3 +58,15 @@ FN_NAME_MAP = { "__ffi_init__": "__c_ffi_init__", } + +BUILTIN_TYPE_KEYS = { + "ffi.Bytes", + "ffi.Error", + "ffi.Function", + "ffi.Object", + "ffi.OpaquePyObject", + "ffi.SmallBytes", + "ffi.SmallStr", + "ffi.String", + "ffi.Tensor", +} diff --git a/python/tvm_ffi/stub/file_utils.py b/python/tvm_ffi/stub/file_utils.py index f100c553..a03c97ed 100644 --- a/python/tvm_ffi/stub/file_utils.py +++ b/python/tvm_ffi/stub/file_utils.py @@ -31,15 +31,15 @@ class CodeBlock: """A block of code to be generated in a stub file.""" - kind: Literal["global", "object", "ty-map", "import", "__all__", None] - param: str + kind: Literal["global", "object", "ty-map", "import", "export", "__all__", None] + param: str | tuple[str, ...] lineno_start: int lineno_end: int | None lines: list[str] def __post_init__(self) -> None: """Validate the code block after initialization.""" - assert self.kind in {"global", "object", "ty-map", "import", "__all__", None} + assert self.kind in {"global", "object", "ty-map", "import", "export", "__all__", None} @property def indent(self) -> int: @@ -61,19 +61,27 @@ def from_begin_line(lineo: int, line: str) -> CodeBlock: lines=[], ) assert line.startswith(C.STUB_BEGIN) + param: str | tuple[str, ...] stub = line[len(C.STUB_BEGIN) :].strip() if stub.startswith("global/"): kind = "global" param = stub[len("global/") :].strip() + if "@" in param: + param = tuple(param.split("@")) + else: + param = (param, "") elif stub.startswith("object/"): kind = "object" param = stub[len("object/") :].strip() elif stub.startswith("ty-map/"): kind = "ty-map" param = stub[len("ty-map/") :].strip() - elif stub.startswith("import"): + elif stub == "import": kind = "import" param = "" + elif stub.startswith("export/"): + kind = "export" + param = stub[len("export/") :].strip() elif stub == "__all__": kind = "__all__" param = "" @@ -96,12 +104,14 @@ class FileInfo: lines: tuple[str, ...] code_blocks: list[CodeBlock] - def update(self, show_diff: bool, dry_run: bool) -> bool: + def update(self, verbose: bool, dry_run: bool) -> bool: """Update the file's lines based on the current code blocks and optionally show a diff.""" new_lines = tuple(line for block in self.code_blocks for line in block.lines) if self.lines == new_lines: + if verbose: + print(f"{C.TERM_CYAN}-----> Unchanged{C.TERM_RESET}") return False - if show_diff: + if verbose: for line in difflib.unified_diff(self.lines, new_lines, lineterm=""): # Skip placeholder headers when fromfile/tofile are unspecified if line.startswith("---") or line.startswith("+++"): @@ -126,11 +136,8 @@ def from_file(file: Path) -> FileInfo | None: # noqa: PLR0912 file = file.resolve() has_marker = False lines: list[str] = file.read_text(encoding="utf-8").splitlines() - for line_no, line in enumerate(lines, start=1): + for _, line in enumerate(lines, start=1): if line.strip().startswith(C.STUB_SKIP_FILE): - print( - f"{C.TERM_YELLOW}[Skipped] skip-file marker found on line {line_no}: {file}{C.TERM_RESET}" - ) return None if line.strip().startswith(C.STUB_PREFIX): has_marker = True @@ -175,6 +182,12 @@ def from_file(file: Path) -> FileInfo | None: # noqa: PLR0912 raise ValueError("Unclosed stub block at end of file") return FileInfo(path=file, lines=tuple(lines), code_blocks=codes) + def reload(self) -> None: + """Reload the code blocks from disk while preserving original `lines`.""" + source = FileInfo.from_file(self.path) + assert source is not None, f"File no longer exists or valid: {self.path}" + self.code_blocks = source.code_blocks + def collect_files(paths: list[Path]) -> list[FileInfo]: """Collect all files from the given paths and parse them into FileInfo objects.""" diff --git a/python/tvm_ffi/stub/utils.py b/python/tvm_ffi/stub/utils.py index e8beb416..d092c708 100644 --- a/python/tvm_ffi/stub/utils.py +++ b/python/tvm_ffi/stub/utils.py @@ -31,7 +31,9 @@ class Options: """Command line options for stub generation.""" + imports: list[str] = dataclasses.field(default_factory=list) dlls: list[str] = dataclasses.field(default_factory=list) + init_path: str = "" indent: int = 4 files: list[str] = dataclasses.field(default_factory=list) verbose: bool = False diff --git a/python/tvm_ffi/testing/__init__.py b/python/tvm_ffi/testing/__init__.py new file mode 100644 index 00000000..cd357364 --- /dev/null +++ b/python/tvm_ffi/testing/__init__.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Testing utilities.""" + +from .testing import ( + TestIntPair, + TestObjectBase, + TestObjectDerived, + _SchemaAllTypes, + _TestCxxClassBase, + _TestCxxClassDerived, + _TestCxxClassDerivedDerived, + _TestCxxInitSubset, + add_one, + create_object, + make_unregistered_object, +) diff --git a/python/tvm_ffi/testing/_ffi_api.py b/python/tvm_ffi/testing/_ffi_api.py new file mode 100644 index 00000000..e587fc73 --- /dev/null +++ b/python/tvm_ffi/testing/_ffi_api.py @@ -0,0 +1,133 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI API for namespace `testing`.""" + +# tvm-ffi-stubgen(begin): import +# fmt: off +# isort: off +from __future__ import annotations +from typing import Any, Callable, TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from tvm_ffi import Device, Object, Tensor, dtype + from tvm_ffi.testing import TestIntPair +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) + +# tvm-ffi-stubgen(begin): global/testing@..registry +# fmt: off +# isort: off +from ..registry import init_ffi_api as _INIT +_INIT("testing", __name__) +# isort: on +if TYPE_CHECKING: + def TestIntPairSum(_0: TestIntPair, /) -> int: ... + def add_one(_0: int, /) -> int: ... + def apply(*args: Any) -> Any: ... + def echo(*args: Any) -> Any: ... + def get_add_one_c_symbol() -> int: ... + def get_mlir_add_one_c_symbol() -> int: ... + def make_unregistered_object() -> Object: ... + def nop(*args: Any) -> Any: ... + def object_use_count(_0: Object, /) -> int: ... + def optional_tensor_view_has_value(_0: Tensor | None, /) -> bool: ... + def run_check_signal(_0: int, /) -> None: ... + def schema_arr_map_opt(_0: Sequence[int | None], _1: Mapping[str, Sequence[int]], _2: str | None, /) -> Mapping[str, Sequence[int]]: ... + def schema_id_any(_0: Any, /) -> Any: ... + def schema_id_arr(_0: Sequence[Any], /) -> Sequence[Any]: ... + def schema_id_arr_int(_0: Sequence[int], /) -> Sequence[int]: ... + def schema_id_arr_obj(_0: Sequence[Object], /) -> Sequence[Object]: ... + def schema_id_arr_str(_0: Sequence[str], /) -> Sequence[str]: ... + def schema_id_bool(_0: bool, /) -> bool: ... + def schema_id_bytes(_0: bytes, /) -> bytes: ... + def schema_id_device(_0: Device, /) -> Device: ... + def schema_id_dltensor(_0: Tensor, /) -> Tensor: ... + def schema_id_dtype(_0: dtype, /) -> dtype: ... + def schema_id_float(_0: float, /) -> float: ... + def schema_id_func(_0: Callable[..., Any], /) -> Callable[..., Any]: ... + def schema_id_func_typed(_0: Callable[[int, float, Callable[..., Any]], None], /) -> Callable[[int, float, Callable[..., Any]], None]: ... + def schema_id_int(_0: int, /) -> int: ... + def schema_id_map(_0: Mapping[Any, Any], /) -> Mapping[Any, Any]: ... + def schema_id_map_str_int(_0: Mapping[str, int], /) -> Mapping[str, int]: ... + def schema_id_map_str_obj(_0: Mapping[str, Object], /) -> Mapping[str, Object]: ... + def schema_id_map_str_str(_0: Mapping[str, str], /) -> Mapping[str, str]: ... + def schema_id_object(_0: Object, /) -> Object: ... + def schema_id_opt_int(_0: int | None, /) -> int | None: ... + def schema_id_opt_obj(_0: Object | None, /) -> Object | None: ... + def schema_id_opt_str(_0: str | None, /) -> str | None: ... + def schema_id_string(_0: str, /) -> str: ... + def schema_id_tensor(_0: Tensor, /) -> Tensor: ... + def schema_id_variant_int_str(_0: int | str, /) -> int | str: ... + def schema_no_args() -> int: ... + def schema_no_args_no_return() -> None: ... + def schema_no_return(_0: int, /) -> None: ... + def schema_packed(*args: Any) -> Any: ... + def schema_tensor_view_input(_0: Tensor, /) -> None: ... + def schema_variant_mix(_0: int | str | Sequence[int], /) -> int | str | Sequence[int]: ... + def test_raise_error(_0: str, _1: str, /) -> None: ... +# fmt: on +# tvm-ffi-stubgen(end) + +__all__ = [ + # tvm-ffi-stubgen(begin): __all__ + "TestIntPairSum", + "add_one", + "apply", + "echo", + "get_add_one_c_symbol", + "get_mlir_add_one_c_symbol", + "make_unregistered_object", + "nop", + "object_use_count", + "optional_tensor_view_has_value", + "run_check_signal", + "schema_arr_map_opt", + "schema_id_any", + "schema_id_arr", + "schema_id_arr_int", + "schema_id_arr_obj", + "schema_id_arr_str", + "schema_id_bool", + "schema_id_bytes", + "schema_id_device", + "schema_id_dltensor", + "schema_id_dtype", + "schema_id_float", + "schema_id_func", + "schema_id_func_typed", + "schema_id_int", + "schema_id_map", + "schema_id_map_str_int", + "schema_id_map_str_obj", + "schema_id_map_str_str", + "schema_id_object", + "schema_id_opt_int", + "schema_id_opt_obj", + "schema_id_opt_str", + "schema_id_string", + "schema_id_tensor", + "schema_id_variant_int_str", + "schema_no_args", + "schema_no_args_no_return", + "schema_no_return", + "schema_packed", + "schema_tensor_view_input", + "schema_variant_mix", + "test_raise_error", + # tvm-ffi-stubgen(end) +] diff --git a/python/tvm_ffi/testing.py b/python/tvm_ffi/testing/testing.py similarity index 97% rename from python/tvm_ffi/testing.py rename to python/tvm_ffi/testing/testing.py index 2b157b1e..e9fcd7c4 100644 --- a/python/tvm_ffi/testing.py +++ b/python/tvm_ffi/testing/testing.py @@ -31,10 +31,10 @@ from typing import ClassVar -from . import _ffi_api -from .core import Object -from .dataclasses import c_class, field -from .registry import get_global_func, register_object +from .. import _ffi_api +from ..core import Object +from ..dataclasses import c_class, field +from ..registry import get_global_func, register_object @register_object("testing.TestObjectBase") diff --git a/tests/python/test_stubgen.py b/tests/python/test_stubgen.py index ef531433..ab860b20 100644 --- a/tests/python/test_stubgen.py +++ b/tests/python/test_stubgen.py @@ -35,7 +35,8 @@ def _identity_ty_map(name: str) -> str: def test_codeblock_from_begin_line_variants() -> None: cases = [ - (f"{C.STUB_BEGIN} global/example", "global", "example"), + (f"{C.STUB_BEGIN} global/example", "global", ("example", "")), + (f"{C.STUB_BEGIN} global/example@.registry", "global", ("example", ".registry")), (f"{C.STUB_BEGIN} object/testing.TestObjectBase", "object", "testing.TestObjectBase"), (f"{C.STUB_BEGIN} ty-map/custom", "ty-map", "custom"), (f"{C.STUB_BEGIN} import", "import", ""), @@ -93,7 +94,7 @@ def test_fileinfo_from_file_parses_blocks(tmp_path: Path) -> None: assert first.kind is None and first.lines == ["first = 1"] assert stub.kind == "global" - assert stub.param == "demo.func" + assert stub.param == ("demo.func", "") assert stub.lineno_start == 2 assert stub.lineno_end == 4 assert stub.lines == [ @@ -212,7 +213,7 @@ def ty_map(name: str) -> str: def test_generate_global_funcs_updates_block() -> None: code = CodeBlock( kind="global", - param="testing", + param=("testing", ""), lineno_start=1, lineno_end=2, lines=[f"{C.STUB_BEGIN} global/testing", C.STUB_END], @@ -231,6 +232,10 @@ def test_generate_global_funcs_updates_block() -> None: assert code.lines == [ f"{C.STUB_BEGIN} global/testing", "# fmt: off", + "# isort: off", + "from tvm_ffi import init_ffi_api as _INIT", + '_INIT("testing", __name__)', + "# isort: on", "if TYPE_CHECKING:", " def add_one(_0: int, /) -> int: ...", "# fmt: on",