Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions extension/pybindings/pybindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,21 @@ struct PyMethodMeta final {
}
}

size_t num_attributes() const {
return meta_.num_attributes();
}

std::unique_ptr<PyTensorInfo> attribute_tensor_meta(size_t index) const {
const auto result = meta_.attribute_tensor_meta(index);
THROW_INDEX_IF_ERROR(
result.error(), "Cannot get attribute tensor meta at %zu", index);
if (module_) {
return std::make_unique<PyTensorInfo>(module_, result.get());
} else {
return std::make_unique<PyTensorInfo>(state_, result.get());
}
}

py::str repr() const {
py::list input_meta_strs;
for (size_t i = 0; i < meta_.num_inputs(); ++i) {
Expand Down Expand Up @@ -1641,6 +1656,7 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
.def("name", &PyMethodMeta::name, call_guard)
.def("num_inputs", &PyMethodMeta::num_inputs, call_guard)
.def("num_outputs", &PyMethodMeta::num_outputs, call_guard)
.def("num_attributes", &PyMethodMeta::num_attributes, call_guard)
.def(
"input_tensor_meta",
&PyMethodMeta::input_tensor_meta,
Expand All @@ -1651,6 +1667,11 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
&PyMethodMeta::output_tensor_meta,
py::arg("index"),
call_guard)
.def(
"attribute_tensor_meta",
&PyMethodMeta::attribute_tensor_meta,
py::arg("index"),
call_guard)
.def("__repr__", &PyMethodMeta::repr, call_guard);

m.def(
Expand Down
9 changes: 9 additions & 0 deletions extension/pybindings/pybindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ class MethodMeta:
internal buffers"""
...

def num_attributes(self) -> int:
"""The number of attribute tensors from the method"""
...

def input_tensor_meta(self, index: int) -> TensorInfo:
"""The tensor info for the 'index'th input. Index must be in the interval
[0, num_inputs()). Raises an IndexError if the index is out of bounds"""
Expand All @@ -143,6 +147,11 @@ class MethodMeta:
[0, num_outputs()). Raises an IndexError if the index is out of bounds"""
...

def attribute_tensor_meta(self, index: int) -> TensorInfo:
"""The tensor info for the 'index'th attribute. Index must be in the interval
[0, num_attributes()). Raises an IndexError if the index is out of bounds"""
...

def __repr__(self) -> str: ...

@experimental("This API is experimental and subject to change without notice.")
Expand Down
27 changes: 25 additions & 2 deletions extension/pybindings/test/test_pybindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,19 +518,32 @@ def test_method_attribute(self):
)

def test_program_method_meta(self) -> None:
exported_program, inputs = create_program(ModuleAdd())
eager_module = ModuleAddWithAttributes()
inputs = eager_module.get_inputs()

exported_program = export(eager_module, inputs, strict=True)
exec_prog = to_edge(exported_program).to_executorch(
config=ExecutorchBackendConfig(
emit_mutable_buffer_names=True,
)
)

exec_prog.dump_executorch_program(verbose=True)

executorch_program = self.load_prog_fn(exec_prog.buffer)

executorch_program = self.load_prog_fn(exported_program.buffer)
meta = executorch_program.method_meta("forward")

del executorch_program
self.assertEqual(meta.name(), "forward")
self.assertEqual(meta.num_inputs(), 2)
self.assertEqual(meta.num_outputs(), 1)
self.assertEqual(meta.num_attributes(), 1)

tensor_info = (
"TensorInfo(sizes=[2, 2], dtype=Float, is_memory_planned=True, nbytes=16)"
)

float_dtype = 6
self.assertEqual(
str(meta),
Expand All @@ -541,10 +554,14 @@ def test_program_method_meta(self) -> None:

input_tensors = [meta.input_tensor_meta(i) for i in range(2)]
output_tensor = meta.output_tensor_meta(0)
attribute_tensor = meta.attribute_tensor_meta(0)

with self.assertRaises(IndexError):
meta.input_tensor_meta(2)

with self.assertRaises(IndexError):
meta.attribute_tensor_meta(1)

del meta
self.assertEqual([t.sizes() for t in input_tensors], [(2, 2), (2, 2)])
self.assertEqual([t.dtype() for t in input_tensors], [float_dtype, float_dtype])
Expand All @@ -558,6 +575,12 @@ def test_program_method_meta(self) -> None:
self.assertEqual(output_tensor.nbytes(), 16)
self.assertEqual(str(output_tensor), tensor_info)

self.assertEqual(attribute_tensor.sizes(), (2, 2))
self.assertEqual(attribute_tensor.dtype(), float_dtype)
self.assertEqual(attribute_tensor.is_memory_planned(), True)
self.assertEqual(attribute_tensor.nbytes(), 16)
self.assertEqual(str(attribute_tensor), tensor_info)

def test_method_method_meta(self) -> None:
exported_program, inputs = create_program(ModuleAdd())

Expand Down
Loading