Skip to content

Commit f03b959

Browse files
committed
[Stubgen] Introduce --init-path for package-level generation
This PR introduces an additional flag `--init-path` to the tool `tvm-ffi-stubgen`. When this flag is specified, the tool will actively look for global functions and objects that are not yet exported into Python, and generate their type stubs into the path specified by `--init-path`. To give an example, if a C++ TVM-FFI extension defines two items: - A global function: `my_ffi_extension.add_one` - An object: `my_ffi_extension.IntPair` which are not yet exported to Python, the tool will generate the following two files if `--init-path=/path/` is set: ``` /path/__init__.py /path/_ffi_api.py ``` Where `__init__.py` contains: ```python """Package my_ffi_extension.""" /# tvm-ffi-stubgen(begin): export/_ffi_api /# fmt: off /# isort: off from ._ffi_api import * # noqa: F403 from ._ffi_api import __all__ as _ffi_api__all__ if "__all__" not in globals(): __all__ = [] __all__.extend(_ffi_api__all__) /# isort: on /# fmt: on /# tvm-ffi-stubgen(end) ``` and `_ffi_api.py` contains: ```python """FFI API bindings for my_ffi_extension.""" /# tvm-ffi-stubgen(begin): import /# fmt: off /# isort: off from __future__ import annotations from typing import TYPE_CHECKING if TYPE_CHECKING: from tvm_ffi import Object /# isort: on /# fmt: on /# tvm-ffi-stubgen(end) /# tvm-ffi-stubgen(begin): global/my_ffi_extension /# fmt: off /# isort: off from tvm_ffi import init_ffi_api as _INIT _INIT("my_ffi_extension", __name__) /# isort: on if TYPE_CHECKING: def raise_error(_0: str, /) -> None: ... /# fmt: on /# tvm-ffi-stubgen(end) __all__ = [ # tvm-ffi-stubgen(begin): __all__ "IntPair", "raise_error", # tvm-ffi-stubgen(end) ] /# isort: off import tvm_ffi /# isort: on @tvm_ffi.register_object("my_ffi_extension.IntPair") class IntPair(tvm_ffi.Object): """FFI binding for `my_ffi_extension.IntPair`.""" /# tvm-ffi-stubgen(begin): object/my_ffi_extension.IntPair /# fmt: off a: int b: int if TYPE_CHECKING: @staticmethod def __c_ffi_init__(_0: int, _1: int, /) -> Object: ... @staticmethod def static_get_second(_0: IntPair, /) -> int: ... def get_first(self, /) -> int: ... /# fmt: on /# tvm-ffi-stubgen(end) ```
1 parent 6887892 commit f03b959

File tree

13 files changed

+516
-93
lines changed

13 files changed

+516
-93
lines changed

python/tvm_ffi/_ffi_api.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@
2929
# fmt: on
3030
# tvm-ffi-stubgen(end)
3131

32-
from . import registry
33-
34-
# tvm-ffi-stubgen(begin): global/ffi
32+
# tvm-ffi-stubgen(begin): global/ffi@.registry
3533
# fmt: off
34+
# isort: off
35+
from .registry import init_ffi_api as _INIT
36+
_INIT("ffi", __name__)
37+
# isort: on
3638
if TYPE_CHECKING:
3739
def Array(*args: Any) -> Any: ...
3840
def ArrayGetItem(_0: Sequence[Any], _1: int, /) -> Any: ...
@@ -72,8 +74,6 @@ def ToJSONGraphString(_0: Any, _1: Any, /) -> str: ...
7274
# fmt: on
7375
# tvm-ffi-stubgen(end)
7476

75-
registry.init_ffi_api("ffi", __name__)
76-
7777

7878
__all__ = [
7979
# tvm-ffi-stubgen(begin): __all__

python/tvm_ffi/_tensor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ class Shape(tuple, PyNativeObject):
5757

5858
_tvm_ffi_cached_object: Any
5959

60+
# tvm-ffi-stubgen(begin): object/ffi.Shape
61+
# fmt: off
62+
# fmt: on
63+
# tvm-ffi-stubgen(end)
64+
6065
def __new__(cls, content: tuple[int, ...]) -> Shape:
6166
if any(not isinstance(x, Integral) for x in content):
6267
raise ValueError("Shape must be a tuple of integers")

python/tvm_ffi/container.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,11 @@ class Array(core.Object, Sequence[T]):
148148
149149
"""
150150

151+
# tvm-ffi-stubgen(begin): object/ffi.Array
152+
# fmt: off
153+
# fmt: on
154+
# tvm-ffi-stubgen(end)
155+
151156
def __init__(self, input_list: Iterable[T]) -> None:
152157
"""Construct an Array from a Python sequence."""
153158
self.__init_handle_by_constructor__(_ffi_api.Array, *input_list)
@@ -290,6 +295,11 @@ class Map(core.Object, Mapping[K, V]):
290295
291296
"""
292297

298+
# tvm-ffi-stubgen(begin): object/ffi.Map
299+
# fmt: off
300+
# fmt: on
301+
# tvm-ffi-stubgen(end)
302+
293303
def __init__(self, input_dict: Mapping[K, V]) -> None:
294304
"""Construct a Map from a Python mapping."""
295305
list_kvs: list[Any] = []

python/tvm_ffi/stub/analysis.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from __future__ import annotations
2020

21+
from tvm_ffi._ffi_api import GetRegisteredTypeKeys
2122
from tvm_ffi.registry import list_global_func_names
2223

2324
from . import consts as C
@@ -34,8 +35,27 @@ def collect_global_funcs() -> dict[str, list[FuncInfo]]:
3435
except ValueError:
3536
print(f"{C.TERM_YELLOW}[Skipped] Invalid name in global function: {name}{C.TERM_RESET}")
3637
else:
37-
global_funcs.setdefault(prefix, []).append(FuncInfo.from_global_name(name))
38+
try:
39+
global_funcs.setdefault(prefix, []).append(FuncInfo.from_global_name(name))
40+
except:
41+
print(f"{C.TERM_YELLOW}[Skipped] Function has no type schema: {name}{C.TERM_RESET}")
3842
# Ensure stable ordering for deterministic output.
3943
for k in list(global_funcs.keys()):
4044
global_funcs[k].sort(key=lambda x: x.schema.name)
4145
return global_funcs
46+
47+
48+
def collect_type_keys() -> dict[str, list[str]]:
49+
"""Collect registered object type keys from TVM FFI's global registry."""
50+
global_objects: dict[str, list[str]] = {}
51+
for type_key in GetRegisteredTypeKeys():
52+
try:
53+
prefix, _ = type_key.rsplit(".", 1)
54+
except ValueError:
55+
pass
56+
else:
57+
global_objects.setdefault(prefix, []).append(type_key)
58+
# Ensure stable ordering for deterministic output.
59+
for k in list(global_objects.keys()):
60+
global_objects[k].sort()
61+
return global_objects

python/tvm_ffi/stub/cli.py

Lines changed: 160 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,23 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
# tvm-ffi-stubgen(skip-file)
1817
"""TVM-FFI Stub Generator (``tvm-ffi-stubgen``)."""
1918

2019
from __future__ import annotations
2120

2221
import argparse
2322
import ctypes
23+
import importlib
2424
import sys
2525
import traceback
2626
from pathlib import Path
2727
from typing import Callable
2828

2929
from . import codegen as G
3030
from . import consts as C
31-
from .analysis import collect_global_funcs
31+
from .analysis import collect_global_funcs, collect_type_keys
3232
from .file_utils import FileInfo, collect_files
33-
from .utils import Options
33+
from .utils import FuncInfo, Options
3434

3535

3636
def _fn_ty_map(ty_map: dict[str, str], ty_used: set[str]) -> Callable[[str], str]:
@@ -55,71 +55,43 @@ def __main__() -> int:
5555
overview and examples of the block syntax.
5656
"""
5757
opt = _parse_args()
58+
for imp in opt.imports or []:
59+
importlib.import_module(imp)
60+
if opt.init_path:
61+
opt.files.append(opt.init_path)
5862
dlls = [ctypes.CDLL(lib) for lib in opt.dlls]
5963
files: list[FileInfo] = collect_files([Path(f) for f in opt.files])
64+
global_funcs: dict[str, list[FuncInfo]] = collect_global_funcs()
6065

61-
# Stage 1: Process `tvm-ffi-stubgen(ty-map)`
66+
# Stage 1: Collect information
67+
# - type maps: `tvm-ffi-stubgen(ty-map)`
68+
# - defined global functions: `tvm-ffi-stubgen(begin): global/...`
69+
# - defined object types: `tvm-ffi-stubgen(begin): object/...`
6270
ty_map: dict[str, str] = C.TY_MAP_DEFAULTS.copy()
63-
64-
def _stage_1(file: FileInfo) -> None:
65-
for code in file.code_blocks:
66-
if code.kind == "ty-map":
67-
try:
68-
lhs, rhs = code.param.split("->")
69-
except ValueError as e:
70-
raise ValueError(
71-
f"Invalid ty_map format at line {code.lineno_start}. Example: `A.B -> C.D`"
72-
) from e
73-
ty_map[lhs.strip()] = rhs.strip()
74-
7571
for file in files:
7672
try:
77-
_stage_1(file)
73+
_stage_1(file, ty_map)
7874
except Exception:
7975
print(
8076
f'{C.TERM_RED}[Failed] File "{file.path}": {traceback.format_exc()}{C.TERM_RESET}'
8177
)
8278

83-
# Stage 2: Process
79+
# Stage 2. Generate stubs if they are not defined on the file.
80+
if opt.init_path:
81+
_stage_2(
82+
files,
83+
init_path=Path(opt.init_path).resolve(),
84+
global_funcs=global_funcs,
85+
)
86+
87+
# Stage 3: Process
8488
# - `tvm-ffi-stubgen(begin): global/...`
8589
# - `tvm-ffi-stubgen(begin): object/...`
86-
global_funcs = collect_global_funcs()
87-
88-
def _stage_2(file: FileInfo) -> None:
89-
all_defined = set()
90+
for file in files:
9091
if opt.verbose:
9192
print(f"{C.TERM_CYAN}[File] {file.path}{C.TERM_RESET}")
92-
ty_used: set[str] = set()
93-
ty_on_file: set[str] = set()
94-
fn_ty_map_fn = _fn_ty_map(ty_map, ty_used)
95-
# Stage 2.1. Process `tvm-ffi-stubgen(begin): global/...`
96-
for code in file.code_blocks:
97-
if code.kind == "global":
98-
funcs = global_funcs.get(code.param, [])
99-
for func in funcs:
100-
all_defined.add(func.schema.name)
101-
G.generate_global_funcs(code, funcs, fn_ty_map_fn, opt)
102-
# Stage 2.2. Process `tvm-ffi-stubgen(begin): object/...`
103-
for code in file.code_blocks:
104-
if code.kind == "object":
105-
type_key = code.param
106-
ty_on_file.add(ty_map.get(type_key, type_key))
107-
G.generate_object(code, fn_ty_map_fn, opt)
108-
# Stage 2.3. Add imports for used types.
109-
for code in file.code_blocks:
110-
if code.kind == "import":
111-
G.generate_imports(code, ty_used - ty_on_file, opt)
112-
break # Only one import block per file is supported for now.
113-
# Stage 2.4. Add `__all__` for defined classes and functions.
114-
for code in file.code_blocks:
115-
if code.kind == "__all__":
116-
G.generate_all(code, all_defined | ty_on_file, opt)
117-
break # Only one __all__ block per file is supported for now.
118-
file.update(show_diff=opt.verbose, dry_run=opt.dry_run)
119-
120-
for file in files:
12193
try:
122-
_stage_2(file)
94+
_stage_3(file, opt, ty_map, global_funcs)
12395
except:
12496
print(
12597
f'{C.TERM_RED}[Failed] File "{file.path}": {traceback.format_exc()}{C.TERM_RESET}'
@@ -128,6 +100,122 @@ def _stage_2(file: FileInfo) -> None:
128100
return 0
129101

130102

103+
def _stage_1(
104+
file: FileInfo,
105+
ty_map: dict[str, str],
106+
) -> None:
107+
for code in file.code_blocks:
108+
if code.kind == "ty-map":
109+
try:
110+
assert isinstance(code.param, str)
111+
lhs, rhs = code.param.split("->")
112+
except ValueError as e:
113+
raise ValueError(
114+
f"Invalid ty_map format at line {code.lineno_start}. Example: `A.B -> C.D`"
115+
) from e
116+
ty_map[lhs.strip()] = rhs.strip()
117+
118+
119+
def _stage_2(
120+
files: list[FileInfo],
121+
init_path: Path,
122+
global_funcs: dict[str, list[FuncInfo]],
123+
) -> None:
124+
def _find_or_insert_file(path: Path) -> FileInfo:
125+
ret: FileInfo | None
126+
if not path.exists():
127+
ret = FileInfo(path=path, lines=(), code_blocks=[])
128+
else:
129+
for file in files:
130+
if path.samefile(file.path):
131+
return file
132+
ret = FileInfo.from_file(file=path)
133+
assert ret is not None, f"Failed to read file: {path}"
134+
files.append(ret)
135+
return ret
136+
137+
# Step 0. Find out functions and classes already defined on files.
138+
defined_func_prefixes: set[str] = { # type: ignore[union-attr]
139+
code.param[0] for file in files for code in file.code_blocks if code.kind == "global"
140+
}
141+
defined_objs: set[str] = { # type: ignore[assignment]
142+
code.param for file in files for code in file.code_blocks if code.kind == "object"
143+
} | C.BUILTIN_TYPE_KEYS
144+
145+
# Step 1. Generate missing `_ffi_api.py` and `__init__.py` under each prefix.
146+
prefixes: dict[str, list[str]] = collect_type_keys()
147+
for prefix in global_funcs:
148+
prefixes.setdefault(prefix, [])
149+
150+
for prefix, obj_names in prefixes.items():
151+
if prefix.startswith("testing") or prefix.startswith("ffi"):
152+
continue
153+
funcs = sorted(
154+
[] if prefix in defined_func_prefixes else global_funcs.get(prefix, []),
155+
key=lambda f: f.schema.name,
156+
)
157+
objs = sorted(set(obj_names) - defined_objs)
158+
if not funcs and not objs:
159+
continue
160+
# Step 1.1. Create target directory if not exists
161+
directory = init_path / prefix.replace(".", "/")
162+
directory.mkdir(parents=True, exist_ok=True)
163+
# Step 1.2. Generate `_ffi_api.py`
164+
target_path = directory / "_ffi_api.py"
165+
target_file = _find_or_insert_file(target_path)
166+
with target_path.open("a", encoding="utf-8") as f:
167+
f.write(G.generate_ffi_api(target_file.code_blocks, prefix, objs))
168+
target_file.reload()
169+
# Step 1.3. Generate `__init__.py`
170+
target_path = directory / "__init__.py"
171+
target_file = _find_or_insert_file(target_path)
172+
with target_path.open("a", encoding="utf-8") as f:
173+
f.write(G.generate_init(target_file.code_blocks, prefix, submodule="_ffi_api"))
174+
target_file.reload()
175+
176+
177+
def _stage_3(
178+
file: FileInfo,
179+
opt: Options,
180+
ty_map: dict[str, str],
181+
global_funcs: dict[str, list[FuncInfo]],
182+
) -> None:
183+
all_defined = set()
184+
ty_used: set[str] = set()
185+
ty_on_file: set[str] = set()
186+
fn_ty_map_fn = _fn_ty_map(ty_map, ty_used)
187+
# Stage 2.1. Process `tvm-ffi-stubgen(begin): global/...`
188+
for code in file.code_blocks:
189+
if code.kind == "global":
190+
funcs = global_funcs.get(code.param[0], [])
191+
for func in funcs:
192+
all_defined.add(func.schema.name)
193+
G.generate_global_funcs(code, funcs, fn_ty_map_fn, opt)
194+
# Stage 2.2. Process `tvm-ffi-stubgen(begin): object/...`
195+
for code in file.code_blocks:
196+
if code.kind == "object":
197+
type_key = code.param
198+
assert isinstance(type_key, str)
199+
ty_on_file.add(ty_map.get(type_key, type_key))
200+
G.generate_object(code, fn_ty_map_fn, opt)
201+
# Stage 2.3. Add imports for used types.
202+
for code in file.code_blocks:
203+
if code.kind == "import":
204+
G.generate_imports(code, ty_used - ty_on_file, opt)
205+
break # Only one import block per file is supported for now.
206+
# Stage 2.4. Add `__all__` for defined classes and functions.
207+
for code in file.code_blocks:
208+
if code.kind == "__all__":
209+
G.generate_all(code, all_defined | ty_on_file, opt)
210+
break # Only one __all__ block per file is supported for now.
211+
# Stage 2.5. Process `tvm-ffi-stubgen(begin): export/...`
212+
for code in file.code_blocks:
213+
if code.kind == "export":
214+
G.generate_export(code)
215+
# Finalize: write back to file
216+
file.update(verbose=opt.verbose, dry_run=opt.dry_run)
217+
218+
131219
def _parse_args() -> Options:
132220
class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelpFormatter):
133221
pass
@@ -149,16 +237,16 @@ class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelp
149237
" # Preload TVM runtime / extension libraries\n"
150238
" tvm-ffi-stubgen --dlls build/libtvm_runtime.so build/libmy_ext.so my_pkg/_ffi_api.py\n\n"
151239
"Stub block syntax (placed in your source):\n"
152-
" # tvm-ffi-stubgen(begin): global/<registry-prefix>\n"
240+
f" {C.STUB_BEGIN} global/<registry-prefix>\n"
153241
" ... generated function stubs ...\n"
154-
" # tvm-ffi-stubgen(end)\n\n"
155-
" # tvm-ffi-stubgen(begin): object/<type_key>\n"
156-
" # tvm-ffi-stubgen(ty_map): list -> Sequence\n"
157-
" # tvm-ffi-stubgen(ty_map): dict -> Mapping\n"
242+
f" {C.STUB_END}\n\n"
243+
f" {C.STUB_BEGIN} object/<type_key>\n"
244+
f" {C.STUB_TY_MAP}: list -> Sequence\n"
245+
f" {C.STUB_TY_MAP}: dict -> Mapping\n"
158246
" ... generated fields and methods ...\n"
159-
" # tvm-ffi-stubgen(end)\n\n"
247+
f" {C.STUB_END}\n\n"
160248
" # Skip a file entirely\n"
161-
" # tvm-ffi-stubgen(skip-file)\n\n"
249+
f" {C.STUB_SKIP_FILE}\n\n"
162250
"Tips:\n"
163251
" - Only .py/.pyi files are updated; directories are scanned recursively.\n"
164252
" - Import any aliases you use in ty_map under TYPE_CHECKING, e.g.\n"
@@ -167,6 +255,12 @@ class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelp
167255
" is provided by native extensions.\n"
168256
),
169257
)
258+
parser.add_argument(
259+
"--imports",
260+
nargs="*",
261+
metavar="IMPORTS",
262+
help=("Additional imports to load before generation."),
263+
)
170264
parser.add_argument(
171265
"--dlls",
172266
nargs="*",
@@ -179,13 +273,19 @@ class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelp
179273
),
180274
default=[],
181275
)
276+
parser.add_argument(
277+
"--init-path",
278+
type=str,
279+
default="",
280+
help="If specified, generate stubs under the given package prefix.",
281+
)
182282
parser.add_argument(
183283
"--indent",
184284
type=int,
185285
default=4,
186286
help=(
187287
"Extra spaces added inside each generated block, relative to the "
188-
"indentation of the corresponding '# tvm-ffi-stubgen(begin):' line."
288+
f"indentation of the corresponding '{C.STUB_BEGIN}' line."
189289
),
190290
)
191291
parser.add_argument(

0 commit comments

Comments
 (0)