Skip to content

Commit 06880b9

Browse files
committed
Add custom class pt2 tutorial
1 parent a96b470 commit 06880b9

File tree

1 file changed

+214
-0
lines changed

1 file changed

+214
-0
lines changed

advanced_source/custom_class_pt2.rst

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
Supporting Custom C++ Classes in PyTorch 2
2+
==========================================
3+
4+
This tutorial is a follow-on to the
5+
:doc:`custom C++ classes <torch_script_custom_classes>` tutorial, and
6+
introduces additional steps that are needed to support custom C++ classes in
7+
PyTorch 2.
8+
9+
Concretely, there are a few steps:
10+
11+
1. Implement an ``__obj_flatten__`` method to the C++ custom class
12+
implementation to allow us to inspect its states and guard the changes. The
13+
method should return a tuple of tuple of attribute_name, value
14+
(``tuple[tuple[str, value] * n]``).
15+
2. Register a python fake class using ``@torch._library.register_fake_class``
16+
a. Implement “fake methods” of each of the class’s c++ methods, which should
17+
have the same schema as the C++ implementation.
18+
b. Additionally, implement an ``__obj_unflatten__`` classmethod in the Python
19+
fake class to tell us how to create a fake class from the flattened
20+
states returned by ``__obj_flatten__``.
21+
22+
Here is a breakdown of the diff. Following the guide in
23+
:doc:`Extending TorchScript with Custom C++ Classes <torch_script_custom_classes>`,
24+
we can create a thread-safe tensor queue and build it.
25+
26+
.. code-block:: cpp
27+
28+
// Thread-safe Tensor Queue
29+
struct TensorQueue : torch::CustomClassHolder {
30+
...
31+
private:
32+
std::deque<at::Tensor> queue_;
33+
std::mutex mutex_;
34+
at::Tensor init_tensor_;
35+
};
36+
// The torch binding code
37+
TORCH_LIBRARY(MyCustomClass, m) {
38+
m.class_<TensorQueue>("TensorQueue")
39+
.def(torch::init<at::Tensor>())
40+
.def("push", &TensorQueue::push)
41+
.def("pop", &TensorQueue::pop)
42+
.def("top", &TensorQueue::top)
43+
.def("size", &TensorQueue::size)
44+
.def("clone_queue", &TensorQueue::clone_queue)
45+
.def("get_raw_queue", &TensorQueue::get_raw_queue)
46+
.def_pickle(
47+
// __getstate__
48+
[](const c10::intrusive_ptr<TensorQueue>& self)
49+
-> c10::Dict<std::string, at::Tensor> {
50+
return self->serialize();
51+
},
52+
// __setstate__
53+
[](c10::Dict<std::string, at::Tensor> data)
54+
-> c10::intrusive_ptr<TensorQueue> {
55+
return c10::make_intrusive<TensorQueue>(std::move(data));
56+
});
57+
}
58+
59+
**Step 1**: Add an ``__obj_flatten__`` method to the C++ custom class implementation:
60+
61+
.. code-block:: cpp
62+
63+
// Thread-safe Tensor Queue
64+
struct TensorQueue : torch::CustomClassHolder {
65+
...
66+
std::tuple<std::tuple<std::string, std::vector<at::Tensor>>, std::tuple<std::string, at::Tensor>> __obj_flatten__() {
67+
return std::tuple(std::tuple("queue", this->get_raw_queue()), std::tuple("init_tensor_", this->init_tensor_.clone()));
68+
}
69+
...
70+
}
71+
72+
TORCH_LIBRARY(MyCustomClass, m) {
73+
m.class_<TensorQueue>("TensorQueue")
74+
.def(torch::init<at::Tensor>())
75+
.def("__obj_flatten__", &TensorQueue::__obj_flatten__)
76+
...
77+
}
78+
79+
**Step 2a**: Register a fake class in Python that implements each method.
80+
81+
.. code-block:: python
82+
83+
# namespace::class_name
84+
@torch._library.register_fake_class("MyCustomClass::TensorQueue")
85+
class FakeTensorQueue:
86+
def __init__(
87+
self,
88+
queue: List[torch.Tensor],
89+
init_tensor_: torch.Tensor
90+
) -> None:
91+
self.queue = queue
92+
self.init_tensor_ = init_tensor_
93+
94+
def push(self, tensor: torch.Tensor) -> None:
95+
self.queue.append(tensor)
96+
97+
def pop(self) -> torch.Tensor:
98+
if len(self.queue) > 0:
99+
return self.queue.pop(0)
100+
return self.init_tensor_
101+
102+
def size(self) -> int:
103+
return len(self.queue)
104+
105+
**Step 2b**: Implement an ``__obj_unflatten__`` classmethod in Python.
106+
107+
.. code-block:: python
108+
109+
# namespace::class_name
110+
@torch._library.register_fake_class("MyCustomClass::TensorQueue")
111+
class FakeTensorQueue:
112+
...
113+
@classmethod
114+
def __obj_unflatten__(cls, flattened_tq):
115+
return cls(**dict(flattened_tq))
116+
117+
...
118+
119+
That’s it! Now we can create a module that uses this object and run it with ``torch.compile`` or ``torch.export``:
120+
121+
.. code-block::python
122+
123+
import torch
124+
125+
torch.ops.load_library("//caffe2/test:test_torchbind_cpp_impl")
126+
tq = torch.classes.MyCustomClass.TensorQueue(torch.empty(0).fill_(-1))
127+
128+
class Mod(torch.nn.Module):
129+
def forward(self, tq, x):
130+
tq.push(x.sin())
131+
tq.push(x.cos())
132+
poped_t = tq.pop()
133+
assert torch.allclose(poped_t, x.sin())
134+
return tq, poped_t
135+
136+
tq, poped_t = torch.compile(Mod(), backend="eager", fullgraph=True)(tq, torch.randn(2, 3))
137+
assert tq.size() == 1
138+
139+
exported_program = torch.export.export(Mod(), (tq, torch.randn(2, 3),), strict=False)
140+
exported_program.module()(tq, torch.randn(2, 3))
141+
142+
We can also implement custom ops that take custom classes as inputs. For
143+
example, we could register a custom op ``for_each_add_(tq, tensor)``
144+
145+
.. code-block:: cpp
146+
147+
struct TensorQueue : torch::CustomClassHolder {
148+
...
149+
void for_each_add_(at::Tensor inc) {
150+
for (auto& t : queue_) {
151+
t.add_(inc);
152+
}
153+
}
154+
...
155+
}
156+
157+
158+
TORCH_LIBRARY_FRAGMENT(MyCustomClass, m) {
159+
m.class_<TensorQueue>("TensorQueue")
160+
.def("for_each_add_", &TensorQueue::for_each_add_);
161+
162+
m.def(
163+
"for_each_add_(__torch__.torch.classes.MyCustomClass.TensorQueue foo, Tensor inc) -> ()");
164+
165+
}
166+
167+
void for_each_add_(c10::intrusive_ptr<TensorQueue> tq, at::Tensor inc) {
168+
tq->for_each_add_(inc);
169+
}
170+
171+
TORCH_LIBRARY_IMPL(MyCustomClass, CPU, m) {
172+
m.impl("for_each_add_", for_each_add_);
173+
}
174+
175+
176+
Since the fake class is implemented in python, we require the fake
177+
implementation of custom op must also be registered in python:
178+
179+
.. code-block:: python
180+
181+
@torch.library.register_fake("MyCustomClass::for_each_add_")
182+
def fake_for_each_add_(tq, inc):
183+
tq.for_each_add_(inc)
184+
185+
After re-compilation, we can export the custom op with:
186+
187+
.. code-block:: python
188+
189+
class ForEachAdd(torch.nn.Module):
190+
def forward(self, tq: torch.ScriptObject, a: torch.Tensor) -> torch.ScriptObject:
191+
torch.ops.MyCustomClass.for_each_add_(tq, a)
192+
return tq
193+
194+
mod = ForEachAdd()
195+
tq = empty_tensor_queue()
196+
qlen = 10
197+
for i in range(qlen):
198+
tq.push(torch.zeros(1))
199+
200+
ep = torch.export.export(mod, (tq, torch.ones(1)), strict=False)
201+
202+
Why do we need to make a Fake Class?
203+
------------------------------------
204+
205+
Tracing with real custom object has several major downsides:
206+
1. Operators on real objects can be time consuming e.g. the custom object
207+
might be reading from the network or loading data from the disk.
208+
2. We don’t want to mutate the real custom object or create side-effects to the environment while tracing.
209+
3. It cannot support dynamic shapes.
210+
211+
However, it may be difficult for users to write a fake class: the original class
212+
uses some third-party library that determines the output shape of the methods,
213+
or is complicated and written by others. Besides, users may not care about the
214+
limitations listed above. In this case, please reach out to us!

0 commit comments

Comments
 (0)