1414# KIND, either express or implied. See the License for the
1515# specific language governing permissions and limitations
1616# under the License.
17- # tvm-ffi-stubgen(skip-file)
1817"""TVM-FFI Stub Generator (``tvm-ffi-stubgen``)."""
1918
2019from __future__ import annotations
2120
2221import argparse
2322import ctypes
23+ import importlib
2424import sys
2525import traceback
2626from pathlib import Path
2727from typing import Callable
2828
2929from . import codegen as G
3030from . import consts as C
31- from .analysis import collect_global_funcs
31+ from .analysis import collect_global_funcs , collect_type_keys
3232from .file_utils import FileInfo , collect_files
33- from .utils import Options
33+ from .utils import FuncInfo , Options
3434
3535
3636def _fn_ty_map (ty_map : dict [str , str ], ty_used : set [str ]) -> Callable [[str ], str ]:
@@ -55,71 +55,43 @@ def __main__() -> int:
5555 overview and examples of the block syntax.
5656 """
5757 opt = _parse_args ()
58+ for imp in opt .imports or []:
59+ importlib .import_module (imp )
60+ if opt .init_path :
61+ opt .files .append (opt .init_path )
5862 dlls = [ctypes .CDLL (lib ) for lib in opt .dlls ]
5963 files : list [FileInfo ] = collect_files ([Path (f ) for f in opt .files ])
64+ global_funcs : dict [str , list [FuncInfo ]] = collect_global_funcs ()
6065
61- # Stage 1: Process `tvm-ffi-stubgen(ty-map)`
66+ # Stage 1: Collect information
67+ # - type maps: `tvm-ffi-stubgen(ty-map)`
68+ # - defined global functions: `tvm-ffi-stubgen(begin): global/...`
69+ # - defined object types: `tvm-ffi-stubgen(begin): object/...`
6270 ty_map : dict [str , str ] = C .TY_MAP_DEFAULTS .copy ()
63-
64- def _stage_1 (file : FileInfo ) -> None :
65- for code in file .code_blocks :
66- if code .kind == "ty-map" :
67- try :
68- lhs , rhs = code .param .split ("->" )
69- except ValueError as e :
70- raise ValueError (
71- f"Invalid ty_map format at line { code .lineno_start } . Example: `A.B -> C.D`"
72- ) from e
73- ty_map [lhs .strip ()] = rhs .strip ()
74-
7571 for file in files :
7672 try :
77- _stage_1 (file )
73+ _stage_1 (file , ty_map )
7874 except Exception :
7975 print (
8076 f'{ C .TERM_RED } [Failed] File "{ file .path } ": { traceback .format_exc ()} { C .TERM_RESET } '
8177 )
8278
83- # Stage 2: Process
79+ # Stage 2. Generate stubs if they are not defined on the file.
80+ if opt .init_path :
81+ _stage_2 (
82+ files ,
83+ init_path = Path (opt .init_path ).resolve (),
84+ global_funcs = global_funcs ,
85+ )
86+
87+ # Stage 3: Process
8488 # - `tvm-ffi-stubgen(begin): global/...`
8589 # - `tvm-ffi-stubgen(begin): object/...`
86- global_funcs = collect_global_funcs ()
87-
88- def _stage_2 (file : FileInfo ) -> None :
89- all_defined = set ()
90+ for file in files :
9091 if opt .verbose :
9192 print (f"{ C .TERM_CYAN } [File] { file .path } { C .TERM_RESET } " )
92- ty_used : set [str ] = set ()
93- ty_on_file : set [str ] = set ()
94- fn_ty_map_fn = _fn_ty_map (ty_map , ty_used )
95- # Stage 2.1. Process `tvm-ffi-stubgen(begin): global/...`
96- for code in file .code_blocks :
97- if code .kind == "global" :
98- funcs = global_funcs .get (code .param , [])
99- for func in funcs :
100- all_defined .add (func .schema .name )
101- G .generate_global_funcs (code , funcs , fn_ty_map_fn , opt )
102- # Stage 2.2. Process `tvm-ffi-stubgen(begin): object/...`
103- for code in file .code_blocks :
104- if code .kind == "object" :
105- type_key = code .param
106- ty_on_file .add (ty_map .get (type_key , type_key ))
107- G .generate_object (code , fn_ty_map_fn , opt )
108- # Stage 2.3. Add imports for used types.
109- for code in file .code_blocks :
110- if code .kind == "import" :
111- G .generate_imports (code , ty_used - ty_on_file , opt )
112- break # Only one import block per file is supported for now.
113- # Stage 2.4. Add `__all__` for defined classes and functions.
114- for code in file .code_blocks :
115- if code .kind == "__all__" :
116- G .generate_all (code , all_defined | ty_on_file , opt )
117- break # Only one __all__ block per file is supported for now.
118- file .update (show_diff = opt .verbose , dry_run = opt .dry_run )
119-
120- for file in files :
12193 try :
122- _stage_2 (file )
94+ _stage_3 (file , opt , ty_map , global_funcs )
12395 except :
12496 print (
12597 f'{ C .TERM_RED } [Failed] File "{ file .path } ": { traceback .format_exc ()} { C .TERM_RESET } '
@@ -128,6 +100,122 @@ def _stage_2(file: FileInfo) -> None:
128100 return 0
129101
130102
103+ def _stage_1 (
104+ file : FileInfo ,
105+ ty_map : dict [str , str ],
106+ ) -> None :
107+ for code in file .code_blocks :
108+ if code .kind == "ty-map" :
109+ try :
110+ assert isinstance (code .param , str )
111+ lhs , rhs = code .param .split ("->" )
112+ except ValueError as e :
113+ raise ValueError (
114+ f"Invalid ty_map format at line { code .lineno_start } . Example: `A.B -> C.D`"
115+ ) from e
116+ ty_map [lhs .strip ()] = rhs .strip ()
117+
118+
119+ def _stage_2 (
120+ files : list [FileInfo ],
121+ init_path : Path ,
122+ global_funcs : dict [str , list [FuncInfo ]],
123+ ) -> None :
124+ def _find_or_insert_file (path : Path ) -> FileInfo :
125+ ret : FileInfo | None
126+ if not path .exists ():
127+ ret = FileInfo (path = path , lines = (), code_blocks = [])
128+ else :
129+ for file in files :
130+ if path .samefile (file .path ):
131+ return file
132+ ret = FileInfo .from_file (file = path )
133+ assert ret is not None , f"Failed to read file: { path } "
134+ files .append (ret )
135+ return ret
136+
137+ # Step 0. Find out functions and classes already defined on files.
138+ defined_func_prefixes : set [str ] = { # type: ignore[union-attr]
139+ code .param [0 ] for file in files for code in file .code_blocks if code .kind == "global"
140+ }
141+ defined_objs : set [str ] = { # type: ignore[assignment]
142+ code .param for file in files for code in file .code_blocks if code .kind == "object"
143+ } | C .BUILTIN_TYPE_KEYS
144+
145+ # Step 1. Generate missing `_ffi_api.py` and `__init__.py` under each prefix.
146+ prefixes : dict [str , list [str ]] = collect_type_keys ()
147+ for prefix in global_funcs :
148+ prefixes .setdefault (prefix , [])
149+
150+ for prefix , obj_names in prefixes .items ():
151+ if prefix .startswith ("testing" ) or prefix .startswith ("ffi" ):
152+ continue
153+ funcs = sorted (
154+ [] if prefix in defined_func_prefixes else global_funcs .get (prefix , []),
155+ key = lambda f : f .schema .name ,
156+ )
157+ objs = sorted (set (obj_names ) - defined_objs )
158+ if not funcs and not objs :
159+ continue
160+ # Step 1.1. Create target directory if not exists
161+ directory = init_path / prefix .replace ("." , "/" )
162+ directory .mkdir (parents = True , exist_ok = True )
163+ # Step 1.2. Generate `_ffi_api.py`
164+ target_path = directory / "_ffi_api.py"
165+ target_file = _find_or_insert_file (target_path )
166+ with target_path .open ("a" , encoding = "utf-8" ) as f :
167+ f .write (G .generate_ffi_api (target_file .code_blocks , prefix , objs ))
168+ target_file .reload ()
169+ # Step 1.3. Generate `__init__.py`
170+ target_path = directory / "__init__.py"
171+ target_file = _find_or_insert_file (target_path )
172+ with target_path .open ("a" , encoding = "utf-8" ) as f :
173+ f .write (G .generate_init (target_file .code_blocks , prefix , submodule = "_ffi_api" ))
174+ target_file .reload ()
175+
176+
177+ def _stage_3 (
178+ file : FileInfo ,
179+ opt : Options ,
180+ ty_map : dict [str , str ],
181+ global_funcs : dict [str , list [FuncInfo ]],
182+ ) -> None :
183+ all_defined = set ()
184+ ty_used : set [str ] = set ()
185+ ty_on_file : set [str ] = set ()
186+ fn_ty_map_fn = _fn_ty_map (ty_map , ty_used )
187+ # Stage 2.1. Process `tvm-ffi-stubgen(begin): global/...`
188+ for code in file .code_blocks :
189+ if code .kind == "global" :
190+ funcs = global_funcs .get (code .param [0 ], [])
191+ for func in funcs :
192+ all_defined .add (func .schema .name )
193+ G .generate_global_funcs (code , funcs , fn_ty_map_fn , opt )
194+ # Stage 2.2. Process `tvm-ffi-stubgen(begin): object/...`
195+ for code in file .code_blocks :
196+ if code .kind == "object" :
197+ type_key = code .param
198+ assert isinstance (type_key , str )
199+ ty_on_file .add (ty_map .get (type_key , type_key ))
200+ G .generate_object (code , fn_ty_map_fn , opt )
201+ # Stage 2.3. Add imports for used types.
202+ for code in file .code_blocks :
203+ if code .kind == "import" :
204+ G .generate_imports (code , ty_used - ty_on_file , opt )
205+ break # Only one import block per file is supported for now.
206+ # Stage 2.4. Add `__all__` for defined classes and functions.
207+ for code in file .code_blocks :
208+ if code .kind == "__all__" :
209+ G .generate_all (code , all_defined | ty_on_file , opt )
210+ break # Only one __all__ block per file is supported for now.
211+ # Stage 2.5. Process `tvm-ffi-stubgen(begin): export/...`
212+ for code in file .code_blocks :
213+ if code .kind == "export" :
214+ G .generate_export (code )
215+ # Finalize: write back to file
216+ file .update (verbose = opt .verbose , dry_run = opt .dry_run )
217+
218+
131219def _parse_args () -> Options :
132220 class HelpFormatter (argparse .ArgumentDefaultsHelpFormatter , argparse .RawTextHelpFormatter ):
133221 pass
@@ -149,16 +237,16 @@ class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelp
149237 " # Preload TVM runtime / extension libraries\n "
150238 " tvm-ffi-stubgen --dlls build/libtvm_runtime.so build/libmy_ext.so my_pkg/_ffi_api.py\n \n "
151239 "Stub block syntax (placed in your source):\n "
152- " # tvm-ffi-stubgen(begin): global/<registry-prefix>\n "
240+ f " { C . STUB_BEGIN } global/<registry-prefix>\n "
153241 " ... generated function stubs ...\n "
154- " # tvm-ffi-stubgen(end) \n \n "
155- " # tvm-ffi-stubgen(begin): object/<type_key>\n "
156- " # tvm-ffi-stubgen(ty_map) : list -> Sequence\n "
157- " # tvm-ffi-stubgen(ty_map) : dict -> Mapping\n "
242+ f " { C . STUB_END } \n \n "
243+ f " { C . STUB_BEGIN } object/<type_key>\n "
244+ f " { C . STUB_TY_MAP } : list -> Sequence\n "
245+ f " { C . STUB_TY_MAP } : dict -> Mapping\n "
158246 " ... generated fields and methods ...\n "
159- " # tvm-ffi-stubgen(end) \n \n "
247+ f " { C . STUB_END } \n \n "
160248 " # Skip a file entirely\n "
161- " # tvm-ffi-stubgen(skip-file) \n \n "
249+ f " { C . STUB_SKIP_FILE } \n \n "
162250 "Tips:\n "
163251 " - Only .py/.pyi files are updated; directories are scanned recursively.\n "
164252 " - Import any aliases you use in ty_map under TYPE_CHECKING, e.g.\n "
@@ -167,6 +255,12 @@ class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelp
167255 " is provided by native extensions.\n "
168256 ),
169257 )
258+ parser .add_argument (
259+ "--imports" ,
260+ nargs = "*" ,
261+ metavar = "IMPORTS" ,
262+ help = ("Additional imports to load before generation." ),
263+ )
170264 parser .add_argument (
171265 "--dlls" ,
172266 nargs = "*" ,
@@ -179,13 +273,19 @@ class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelp
179273 ),
180274 default = [],
181275 )
276+ parser .add_argument (
277+ "--init-path" ,
278+ type = str ,
279+ default = "" ,
280+ help = "If specified, generate stubs under the given package prefix." ,
281+ )
182282 parser .add_argument (
183283 "--indent" ,
184284 type = int ,
185285 default = 4 ,
186286 help = (
187287 "Extra spaces added inside each generated block, relative to the "
188- "indentation of the corresponding '# tvm-ffi-stubgen(begin): ' line."
288+ f "indentation of the corresponding '{ C . STUB_BEGIN } ' line."
189289 ),
190290 )
191291 parser .add_argument (
0 commit comments