Skip to content

Commit a166001

Browse files
cmt0facebook-github-bot
authored andcommitted
Add pybindings for attribute tensors
Summary: Add pybindings for grabbing attribute tensor information from method meta Differential Revision: D80631040
1 parent c78b0fd commit a166001

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

extension/pybindings/pybindings.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,21 @@ struct PyMethodMeta final {
661661
}
662662
}
663663

664+
size_t num_attributes() const {
665+
return meta_.num_attributes();
666+
}
667+
668+
std::unique_ptr<PyTensorInfo> attribute_tensor_meta(size_t index) const {
669+
const auto result = meta_.attribute_tensor_meta(index);
670+
THROW_INDEX_IF_ERROR(
671+
result.error(), "Cannot get attribute tensor meta at %zu", index);
672+
if (module_) {
673+
return std::make_unique<PyTensorInfo>(module_, result.get());
674+
} else {
675+
return std::make_unique<PyTensorInfo>(state_, result.get());
676+
}
677+
}
678+
664679
py::str repr() const {
665680
py::list input_meta_strs;
666681
for (size_t i = 0; i < meta_.num_inputs(); ++i) {
@@ -1641,6 +1656,7 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
16411656
.def("name", &PyMethodMeta::name, call_guard)
16421657
.def("num_inputs", &PyMethodMeta::num_inputs, call_guard)
16431658
.def("num_outputs", &PyMethodMeta::num_outputs, call_guard)
1659+
.def("num_attributes", &PyMethodMeta::num_attributes, call_guard)
16441660
.def(
16451661
"input_tensor_meta",
16461662
&PyMethodMeta::input_tensor_meta,
@@ -1651,6 +1667,11 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
16511667
&PyMethodMeta::output_tensor_meta,
16521668
py::arg("index"),
16531669
call_guard)
1670+
.def(
1671+
"attribute_tensor_meta",
1672+
&PyMethodMeta::attribute_tensor_meta,
1673+
py::arg("index"),
1674+
call_guard)
16541675
.def("__repr__", &PyMethodMeta::repr, call_guard);
16551676

16561677
m.def(

extension/pybindings/pybindings.pyi

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ class MethodMeta:
133133
internal buffers"""
134134
...
135135

136+
def num_attributes(self) -> int:
137+
"""The number of attribute tensors from the method"""
138+
...
139+
136140
def input_tensor_meta(self, index: int) -> TensorInfo:
137141
"""The tensor info for the 'index'th input. Index must be in the interval
138142
[0, num_inputs()). Raises an IndexError if the index is out of bounds"""
@@ -143,6 +147,12 @@ class MethodMeta:
143147
[0, num_outputs()). Raises an IndexError if the index is out of bounds"""
144148
...
145149

150+
151+
def attribute_tensor_meta(self, index: int) -> TensorInfo:
152+
"""The tensor info for the 'index'th attribute. Index must be in the interval
153+
[0, num_attributes()). Raises an IndexError if the index is out of bounds"""
154+
...
155+
146156
def __repr__(self) -> str: ...
147157

148158
@experimental("This API is experimental and subject to change without notice.")

0 commit comments

Comments
 (0)