Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions python/tvm_ffi/_ffi_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
# fmt: on
# tvm-ffi-stubgen(end)

from . import registry

# tvm-ffi-stubgen(begin): global/ffi
# tvm-ffi-stubgen(begin): global/ffi@.registry
# fmt: off
# isort: off
from .registry import init_ffi_api as _INIT
_INIT("ffi", __name__)
# isort: on
if TYPE_CHECKING:
def Array(*args: Any) -> Any: ...
def ArrayGetItem(_0: Sequence[Any], _1: int, /) -> Any: ...
Expand Down Expand Up @@ -72,8 +74,6 @@ def ToJSONGraphString(_0: Any, _1: Any, /) -> str: ...
# fmt: on
# tvm-ffi-stubgen(end)

registry.init_ffi_api("ffi", __name__)


__all__ = [
# tvm-ffi-stubgen(begin): __all__
Expand Down
5 changes: 5 additions & 0 deletions python/tvm_ffi/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ class Shape(tuple, PyNativeObject):

_tvm_ffi_cached_object: Any

# tvm-ffi-stubgen(begin): object/ffi.Shape
# fmt: off
# fmt: on
# tvm-ffi-stubgen(end)

def __new__(cls, content: tuple[int, ...]) -> Shape:
if any(not isinstance(x, Integral) for x in content):
raise ValueError("Shape must be a tuple of integers")
Expand Down
10 changes: 10 additions & 0 deletions python/tvm_ffi/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ class Array(core.Object, Sequence[T]):

"""

# tvm-ffi-stubgen(begin): object/ffi.Array
# fmt: off
# fmt: on
# tvm-ffi-stubgen(end)

def __init__(self, input_list: Iterable[T]) -> None:
"""Construct an Array from a Python sequence."""
self.__init_handle_by_constructor__(_ffi_api.Array, *input_list)
Expand Down Expand Up @@ -290,6 +295,11 @@ class Map(core.Object, Mapping[K, V]):

"""

# tvm-ffi-stubgen(begin): object/ffi.Map
# fmt: off
# fmt: on
# tvm-ffi-stubgen(end)

def __init__(self, input_dict: Mapping[K, V]) -> None:
"""Construct a Map from a Python mapping."""
list_kvs: list[Any] = []
Expand Down
22 changes: 21 additions & 1 deletion python/tvm_ffi/stub/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from __future__ import annotations

from tvm_ffi._ffi_api import GetRegisteredTypeKeys
from tvm_ffi.registry import list_global_func_names

from . import consts as C
Expand All @@ -34,8 +35,27 @@ def collect_global_funcs() -> dict[str, list[FuncInfo]]:
except ValueError:
print(f"{C.TERM_YELLOW}[Skipped] Invalid name in global function: {name}{C.TERM_RESET}")
else:
global_funcs.setdefault(prefix, []).append(FuncInfo.from_global_name(name))
try:
global_funcs.setdefault(prefix, []).append(FuncInfo.from_global_name(name))
except Exception:
print(f"{C.TERM_YELLOW}[Skipped] Function has no type schema: {name}{C.TERM_RESET}")
# Ensure stable ordering for deterministic output.
for k in list(global_funcs.keys()):
global_funcs[k].sort(key=lambda x: x.schema.name)
return global_funcs


def collect_type_keys() -> dict[str, list[str]]:
"""Collect registered object type keys from TVM FFI's global registry."""
global_objects: dict[str, list[str]] = {}
for type_key in GetRegisteredTypeKeys():
try:
prefix, _ = type_key.rsplit(".", 1)
except ValueError:
pass
else:
global_objects.setdefault(prefix, []).append(type_key)
# Ensure stable ordering for deterministic output.
for k in list(global_objects.keys()):
global_objects[k].sort()
return global_objects
Loading