diff --git a/klongpy/backends/torch_backend.py b/klongpy/backends/torch_backend.py index 78ddeb7..818d325 100644 --- a/klongpy/backends/torch_backend.py +++ b/klongpy/backends/torch_backend.py @@ -93,7 +93,11 @@ def _grad_fn(*args: Any) -> Any: raise RuntimeError("not differentiable") from e if not isinstance(out, torch.Tensor): raise RuntimeError("not differentiable") - g, = torch.autograd.grad(out, targs[wrt]) + if out.ndim == 0: + grad_out = None + else: + grad_out = torch.ones_like(out) + g, = torch.autograd.grad(out, targs[wrt], grad_outputs=grad_out) return g return _grad_fn diff --git a/tests/test_autograd.py b/tests/test_autograd.py new file mode 100644 index 0000000..d95201c --- /dev/null +++ b/tests/test_autograd.py @@ -0,0 +1,33 @@ +import unittest +import numpy as np + +from klongpy import backend + + +class TestAutograd(unittest.TestCase): + def _check_matrix_grad(self, name: str): + try: + backend.set_backend(name) + except ImportError: + raise unittest.SkipTest(f"{name} backend not available") + b = backend.current() + + def f(x): + return b.sum(b.matmul(x, x)) + + g = b.grad(f) + x = b.array([[1.0, 2.0], [3.0, 4.0]], requires_grad=True) + grad = g(x) + if hasattr(grad, "detach"): + grad = grad.detach().cpu().numpy() + np.testing.assert_allclose(np.array(grad), np.array([[7.0, 11.0], [9.0, 13.0]])) + + def test_matrix_grad_numpy(self): + self._check_matrix_grad("numpy") + + def test_matrix_grad_torch(self): + self._check_matrix_grad("torch") + + +if __name__ == "__main__": + unittest.main()