Skip to content

Commit b17a1cb

Browse files
cmt0facebook-github-bot
authored andcommitted
Add pybindings for attribute tensors (#13579)
Summary: Add pybindings for grabbing attribute tensor information from method meta Reviewed By: JacobSzwejbka Differential Revision: D80631040
1 parent 3db27cd commit b17a1cb

File tree

3 files changed

+55
-2
lines changed

3 files changed

+55
-2
lines changed

extension/pybindings/pybindings.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,21 @@ struct PyMethodMeta final {
661661
}
662662
}
663663

664+
size_t num_attributes() const {
665+
return meta_.num_attributes();
666+
}
667+
668+
std::unique_ptr<PyTensorInfo> attribute_tensor_meta(size_t index) const {
669+
const auto result = meta_.attribute_tensor_meta(index);
670+
THROW_INDEX_IF_ERROR(
671+
result.error(), "Cannot get attribute tensor meta at %zu", index);
672+
if (module_) {
673+
return std::make_unique<PyTensorInfo>(module_, result.get());
674+
} else {
675+
return std::make_unique<PyTensorInfo>(state_, result.get());
676+
}
677+
}
678+
664679
py::str repr() const {
665680
py::list input_meta_strs;
666681
for (size_t i = 0; i < meta_.num_inputs(); ++i) {
@@ -1641,6 +1656,7 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
16411656
.def("name", &PyMethodMeta::name, call_guard)
16421657
.def("num_inputs", &PyMethodMeta::num_inputs, call_guard)
16431658
.def("num_outputs", &PyMethodMeta::num_outputs, call_guard)
1659+
.def("num_attributes", &PyMethodMeta::num_attributes, call_guard)
16441660
.def(
16451661
"input_tensor_meta",
16461662
&PyMethodMeta::input_tensor_meta,
@@ -1651,6 +1667,11 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
16511667
&PyMethodMeta::output_tensor_meta,
16521668
py::arg("index"),
16531669
call_guard)
1670+
.def(
1671+
"attribute_tensor_meta",
1672+
&PyMethodMeta::attribute_tensor_meta,
1673+
py::arg("index"),
1674+
call_guard)
16541675
.def("__repr__", &PyMethodMeta::repr, call_guard);
16551676

16561677
m.def(

extension/pybindings/pybindings.pyi

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ class MethodMeta:
133133
internal buffers"""
134134
...
135135

136+
def num_attributes(self) -> int:
137+
"""The number of attribute tensors from the method"""
138+
...
139+
136140
def input_tensor_meta(self, index: int) -> TensorInfo:
137141
"""The tensor info for the 'index'th input. Index must be in the interval
138142
[0, num_inputs()). Raises an IndexError if the index is out of bounds"""
@@ -143,6 +147,11 @@ class MethodMeta:
143147
[0, num_outputs()). Raises an IndexError if the index is out of bounds"""
144148
...
145149

150+
def attribute_tensor_meta(self, index: int) -> TensorInfo:
151+
"""The tensor info for the 'index'th attribute. Index must be in the interval
152+
[0, num_attributes()). Raises an IndexError if the index is out of bounds"""
153+
...
154+
146155
def __repr__(self) -> str: ...
147156

148157
@experimental("This API is experimental and subject to change without notice.")

extension/pybindings/test/test_pybindings.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -518,19 +518,32 @@ def test_method_attribute(self):
518518
)
519519

520520
def test_program_method_meta(self) -> None:
521-
exported_program, inputs = create_program(ModuleAdd())
521+
eager_module = ModuleAddWithAttributes()
522+
inputs = eager_module.get_inputs()
523+
524+
exported_program = export(eager_module, inputs, strict=True)
525+
exec_prog = to_edge(exported_program).to_executorch(
526+
config=ExecutorchBackendConfig(
527+
emit_mutable_buffer_names=True,
528+
)
529+
)
530+
531+
exec_prog.dump_executorch_program(verbose=True)
532+
533+
executorch_program = self.load_prog_fn(exec_prog.buffer)
522534

523-
executorch_program = self.load_prog_fn(exported_program.buffer)
524535
meta = executorch_program.method_meta("forward")
525536

526537
del executorch_program
527538
self.assertEqual(meta.name(), "forward")
528539
self.assertEqual(meta.num_inputs(), 2)
529540
self.assertEqual(meta.num_outputs(), 1)
541+
self.assertEqual(meta.num_attributes(), 1)
530542

531543
tensor_info = (
532544
"TensorInfo(sizes=[2, 2], dtype=Float, is_memory_planned=True, nbytes=16)"
533545
)
546+
534547
float_dtype = 6
535548
self.assertEqual(
536549
str(meta),
@@ -541,10 +554,14 @@ def test_program_method_meta(self) -> None:
541554

542555
input_tensors = [meta.input_tensor_meta(i) for i in range(2)]
543556
output_tensor = meta.output_tensor_meta(0)
557+
attribute_tensor = meta.attribute_tensor_meta(0)
544558

545559
with self.assertRaises(IndexError):
546560
meta.input_tensor_meta(2)
547561

562+
with self.assertRaises(IndexError):
563+
meta.attribute_tensor_meta(1)
564+
548565
del meta
549566
self.assertEqual([t.sizes() for t in input_tensors], [(2, 2), (2, 2)])
550567
self.assertEqual([t.dtype() for t in input_tensors], [float_dtype, float_dtype])
@@ -558,6 +575,12 @@ def test_program_method_meta(self) -> None:
558575
self.assertEqual(output_tensor.nbytes(), 16)
559576
self.assertEqual(str(output_tensor), tensor_info)
560577

578+
self.assertEqual(attribute_tensor.sizes(), (2, 2))
579+
self.assertEqual(attribute_tensor.dtype(), float_dtype)
580+
self.assertEqual(attribute_tensor.is_memory_planned(), True)
581+
self.assertEqual(attribute_tensor.nbytes(), 16)
582+
self.assertEqual(str(attribute_tensor), tensor_info)
583+
561584
def test_method_method_meta(self) -> None:
562585
exported_program, inputs = create_program(ModuleAdd())
563586

0 commit comments

Comments
 (0)