|
| 1 | +Supporting Custom C++ Classes in PyTorch 2 |
| 2 | +========================================== |
| 3 | + |
| 4 | +This tutorial is a follow-on to the |
| 5 | +:doc:`custom C++ classes <torch_script_custom_classes>` tutorial, and |
| 6 | +introduces additional steps that are needed to support custom C++ classes in |
| 7 | +PyTorch 2. |
| 8 | + |
| 9 | +Concretely, there are a few steps: |
| 10 | + |
| 11 | +1. Implement an ``__obj_flatten__`` method to the C++ custom class |
| 12 | + implementation to allow us to inspect its states and guard the changes. The |
| 13 | + method should return a tuple of tuple of attribute_name, value |
| 14 | + (``tuple[tuple[str, value] * n]``). |
| 15 | +2. Register a python fake class using ``@torch._library.register_fake_class`` |
| 16 | + a. Implement “fake methods” of each of the class’s c++ methods, which should |
| 17 | + have the same schema as the C++ implementation. |
| 18 | + b. Additionally, implement an ``__obj_unflatten__`` classmethod in the Python |
| 19 | + fake class to tell us how to create a fake class from the flattened |
| 20 | + states returned by ``__obj_flatten__``. |
| 21 | + |
| 22 | +Here is a breakdown of the diff. Following the guide in |
| 23 | +:doc:`Extending TorchScript with Custom C++ Classes <torch_script_custom_classes>`, |
| 24 | +we can create a thread-safe tensor queue and build it. |
| 25 | + |
| 26 | +.. code-block:: cpp |
| 27 | +
|
| 28 | + // Thread-safe Tensor Queue |
| 29 | + struct TensorQueue : torch::CustomClassHolder { |
| 30 | + ... |
| 31 | + private: |
| 32 | + std::deque<at::Tensor> queue_; |
| 33 | + std::mutex mutex_; |
| 34 | + at::Tensor init_tensor_; |
| 35 | + }; |
| 36 | + // The torch binding code |
| 37 | + TORCH_LIBRARY(MyCustomClass, m) { |
| 38 | + m.class_<TensorQueue>("TensorQueue") |
| 39 | + .def(torch::init<at::Tensor>()) |
| 40 | + .def("push", &TensorQueue::push) |
| 41 | + .def("pop", &TensorQueue::pop) |
| 42 | + .def("top", &TensorQueue::top) |
| 43 | + .def("size", &TensorQueue::size) |
| 44 | + .def("clone_queue", &TensorQueue::clone_queue) |
| 45 | + .def("get_raw_queue", &TensorQueue::get_raw_queue) |
| 46 | + .def_pickle( |
| 47 | + // __getstate__ |
| 48 | + [](const c10::intrusive_ptr<TensorQueue>& self) |
| 49 | + -> c10::Dict<std::string, at::Tensor> { |
| 50 | + return self->serialize(); |
| 51 | + }, |
| 52 | + // __setstate__ |
| 53 | + [](c10::Dict<std::string, at::Tensor> data) |
| 54 | + -> c10::intrusive_ptr<TensorQueue> { |
| 55 | + return c10::make_intrusive<TensorQueue>(std::move(data)); |
| 56 | + }); |
| 57 | + } |
| 58 | +
|
| 59 | +**Step 1**: Add an ``__obj_flatten__`` method to the C++ custom class implementation: |
| 60 | + |
| 61 | +.. code-block:: cpp |
| 62 | +
|
| 63 | + // Thread-safe Tensor Queue |
| 64 | + struct TensorQueue : torch::CustomClassHolder { |
| 65 | + ... |
| 66 | + std::tuple<std::tuple<std::string, std::vector<at::Tensor>>, std::tuple<std::string, at::Tensor>> __obj_flatten__() { |
| 67 | + return std::tuple(std::tuple("queue", this->get_raw_queue()), std::tuple("init_tensor_", this->init_tensor_.clone())); |
| 68 | + } |
| 69 | + ... |
| 70 | + } |
| 71 | +
|
| 72 | + TORCH_LIBRARY(MyCustomClass, m) { |
| 73 | + m.class_<TensorQueue>("TensorQueue") |
| 74 | + .def(torch::init<at::Tensor>()) |
| 75 | + .def("__obj_flatten__", &TensorQueue::__obj_flatten__) |
| 76 | + ... |
| 77 | + } |
| 78 | +
|
| 79 | +**Step 2a**: Register a fake class in Python that implements each method. |
| 80 | + |
| 81 | +.. code-block:: python |
| 82 | +
|
| 83 | + # namespace::class_name |
| 84 | + @torch._library.register_fake_class("MyCustomClass::TensorQueue") |
| 85 | + class FakeTensorQueue: |
| 86 | + def __init__( |
| 87 | + self, |
| 88 | + queue: List[torch.Tensor], |
| 89 | + init_tensor_: torch.Tensor |
| 90 | + ) -> None: |
| 91 | + self.queue = queue |
| 92 | + self.init_tensor_ = init_tensor_ |
| 93 | +
|
| 94 | + def push(self, tensor: torch.Tensor) -> None: |
| 95 | + self.queue.append(tensor) |
| 96 | +
|
| 97 | + def pop(self) -> torch.Tensor: |
| 98 | + if len(self.queue) > 0: |
| 99 | + return self.queue.pop(0) |
| 100 | + return self.init_tensor_ |
| 101 | +
|
| 102 | + def size(self) -> int: |
| 103 | + return len(self.queue) |
| 104 | +
|
| 105 | +**Step 2b**: Implement an ``__obj_unflatten__`` classmethod in Python. |
| 106 | + |
| 107 | +.. code-block:: python |
| 108 | +
|
| 109 | + # namespace::class_name |
| 110 | + @torch._library.register_fake_class("MyCustomClass::TensorQueue") |
| 111 | + class FakeTensorQueue: |
| 112 | + ... |
| 113 | + @classmethod |
| 114 | + def __obj_unflatten__(cls, flattened_tq): |
| 115 | + return cls(**dict(flattened_tq)) |
| 116 | +
|
| 117 | + ... |
| 118 | +
|
| 119 | +That’s it! Now we can create a module that uses this object and run it with ``torch.compile`` or ``torch.export``: |
| 120 | + |
| 121 | +.. code-block::python |
| 122 | +
|
| 123 | + import torch |
| 124 | +
|
| 125 | + torch.ops.load_library("//caffe2/test:test_torchbind_cpp_impl") |
| 126 | + tq = torch.classes.MyCustomClass.TensorQueue(torch.empty(0).fill_(-1)) |
| 127 | +
|
| 128 | + class Mod(torch.nn.Module): |
| 129 | + def forward(self, tq, x): |
| 130 | + tq.push(x.sin()) |
| 131 | + tq.push(x.cos()) |
| 132 | + poped_t = tq.pop() |
| 133 | + assert torch.allclose(poped_t, x.sin()) |
| 134 | + return tq, poped_t |
| 135 | +
|
| 136 | + tq, poped_t = torch.compile(Mod(), backend="eager", fullgraph=True)(tq, torch.randn(2, 3)) |
| 137 | + assert tq.size() == 1 |
| 138 | +
|
| 139 | + exported_program = torch.export.export(Mod(), (tq, torch.randn(2, 3),), strict=False) |
| 140 | + exported_program.module()(tq, torch.randn(2, 3)) |
| 141 | +
|
| 142 | +We can also implement custom ops that take custom classes as inputs. For |
| 143 | +example, we could register a custom op ``for_each_add_(tq, tensor)`` |
| 144 | + |
| 145 | +.. code-block:: cpp |
| 146 | +
|
| 147 | + struct TensorQueue : torch::CustomClassHolder { |
| 148 | + ... |
| 149 | + void for_each_add_(at::Tensor inc) { |
| 150 | + for (auto& t : queue_) { |
| 151 | + t.add_(inc); |
| 152 | + } |
| 153 | + } |
| 154 | + ... |
| 155 | + } |
| 156 | +
|
| 157 | +
|
| 158 | + TORCH_LIBRARY_FRAGMENT(MyCustomClass, m) { |
| 159 | + m.class_<TensorQueue>("TensorQueue") |
| 160 | + .def("for_each_add_", &TensorQueue::for_each_add_); |
| 161 | +
|
| 162 | + m.def( |
| 163 | + "for_each_add_(__torch__.torch.classes.MyCustomClass.TensorQueue foo, Tensor inc) -> ()"); |
| 164 | +
|
| 165 | + } |
| 166 | +
|
| 167 | + void for_each_add_(c10::intrusive_ptr<TensorQueue> tq, at::Tensor inc) { |
| 168 | + tq->for_each_add_(inc); |
| 169 | + } |
| 170 | +
|
| 171 | + TORCH_LIBRARY_IMPL(MyCustomClass, CPU, m) { |
| 172 | + m.impl("for_each_add_", for_each_add_); |
| 173 | + } |
| 174 | +
|
| 175 | +
|
| 176 | +Since the fake class is implemented in python, we require the fake |
| 177 | +implementation of custom op must also be registered in python: |
| 178 | + |
| 179 | +.. code-block:: python |
| 180 | +
|
| 181 | + @torch.library.register_fake("MyCustomClass::for_each_add_") |
| 182 | + def fake_for_each_add_(tq, inc): |
| 183 | + tq.for_each_add_(inc) |
| 184 | +
|
| 185 | +After re-compilation, we can export the custom op with: |
| 186 | + |
| 187 | +.. code-block:: python |
| 188 | +
|
| 189 | + class ForEachAdd(torch.nn.Module): |
| 190 | + def forward(self, tq: torch.ScriptObject, a: torch.Tensor) -> torch.ScriptObject: |
| 191 | + torch.ops.MyCustomClass.for_each_add_(tq, a) |
| 192 | + return tq |
| 193 | +
|
| 194 | + mod = ForEachAdd() |
| 195 | + tq = empty_tensor_queue() |
| 196 | + qlen = 10 |
| 197 | + for i in range(qlen): |
| 198 | + tq.push(torch.zeros(1)) |
| 199 | +
|
| 200 | + ep = torch.export.export(mod, (tq, torch.ones(1)), strict=False) |
| 201 | +
|
| 202 | +Why do we need to make a Fake Class? |
| 203 | +------------------------------------ |
| 204 | + |
| 205 | +Tracing with real custom object has several major downsides: |
| 206 | +1. Operators on real objects can be time consuming e.g. the custom object |
| 207 | + might be reading from the network or loading data from the disk. |
| 208 | +2. We don’t want to mutate the real custom object or create side-effects to the environment while tracing. |
| 209 | +3. It cannot support dynamic shapes. |
| 210 | + |
| 211 | +However, it may be difficult for users to write a fake class: the original class |
| 212 | +uses some third-party library that determines the output shape of the methods, |
| 213 | +or is complicated and written by others. Besides, users may not care about the |
| 214 | +limitations listed above. In this case, please reach out to us! |
0 commit comments