Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions docs/autograd_tests.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Autograd Test Cases

This document explains the mathematical ideas behind the unit tests found in
`tests/test_autograd.py` and their KlongPy counterparts in
`tests/kgtests/autograd`.

KlongPy provides minimal reverse mode automatic differentiation. The following
examples verify the correctness of the gradient computations for the NumPy and
Torch backends.

## Scalar square

We test the derivative of $f(x)=x^2$. From the
[definition of the derivative](https://en.wikipedia.org/wiki/Derivative),
$\frac{\mathrm d}{\mathrm dx}x^2=2x$. The test evaluates this gradient at
$x=3$ and expects the value `6`.

In the Klong test suite the alias ``∂`` is bound to ``backend.grad``. Calling
``∂(square;3)`` therefore computes the same derivative using the del symbol.

## Matrix multiplication

The function $f(X)=\sum X X$ multiplies a matrix by itself and sums all
elements of the result. Matrix calculus shows that the derivative of
$\mathrm{tr}(X^2)$ with respect to $X$ is $X+X^T$.
For
$X=\begin{bmatrix}1&2\\3&4\end{bmatrix}$
the gradient is
$\begin{bmatrix}7&11\\9&13\end{bmatrix}$.
See
[the matrix calculus article](https://en.wikipedia.org/wiki/Matrix_calculus)
for details.

## Elementwise product

The function $f(x)=\sum (x+1)(x+2)$ is differentiated using the chain rule
([Wikipedia](https://en.wikipedia.org/wiki/Chain_rule)). The gradient of each
component is $2x+3$, so the resulting array should equal `2*x + 3`.

## Dot product

For $f(x,y)=\sum x\,y$ (the dot product), the gradient with respect to `x` is
`y` and with respect to `y` is `x`.
See the article on the
[dot product](https://en.wikipedia.org/wiki/Dot_product) for background.

## Stop operator

The `stop` function detaches its argument from the autograd graph. In
$f(x)=\sum\mathrm{stop}(x)\,x$ the first occurrence of `x` is treated as a
constant, so the gradient simplifies to `x`.
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ nav:
- Related Projects: 'related-projects.md'
- Contributing: 'contribute.md'
- Acknowledgements: 'acknowledgement.md'
- Autograd Tests: 'autograd_tests.md'

# Theme configuration
theme:
Expand Down
86 changes: 86 additions & 0 deletions tests/kgtests/autograd/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from klongpy import backend
import numpy as np
from tests.utils import to_numpy


# simple function used by the ∂ example
def square(x):
return x * x


def _apply_grad(fn, x, backend_name="numpy"):
backend.set_backend(backend_name)
b = backend.current()
g = b.grad(fn)
out = g(b.array(x, requires_grad=True))
out = to_numpy(out)
return float(out) if np.ndim(out) == 0 else out


# expose a del-symbol helper for Klong tests
globals()["∂"] = _apply_grad


def scalarSquareGrad(x, backend_name="numpy"):
backend.set_backend(backend_name)
b = backend.current()

def f(t):
return b.mul(t, t)

g = b.grad(f)
out = g(b.array(x, requires_grad=True))
out = to_numpy(out)
return float(out) if np.ndim(out) == 0 else out


def vectorElemwiseGrad(x, backend_name="numpy"):
backend.set_backend(backend_name)
b = backend.current()

def f(t):
return b.sum(b.mul(b.add(t, 1), b.add(t, 2)))

g = b.grad(f)
out = g(b.array(x, requires_grad=True))
out = to_numpy(out)
return out.tolist() if isinstance(out, np.ndarray) else out


def mixedGradX(x, y, backend_name="numpy"):
backend.set_backend(backend_name)
b = backend.current()

def f(a, b_):
return b.sum(b.mul(a, b_))

g = b.grad(f, wrt=0)
out = g(b.array(x, requires_grad=True), b.array(y, requires_grad=True))
out = to_numpy(out)
return out.tolist() if isinstance(out, np.ndarray) else out


def mixedGradY(x, y, backend_name="numpy"):
backend.set_backend(backend_name)
b = backend.current()

def f(a, b_):
return b.sum(b.mul(a, b_))

g = b.grad(f, wrt=1)
out = g(b.array(x, requires_grad=True), b.array(y, requires_grad=True))
out = to_numpy(out)
return out.tolist() if isinstance(out, np.ndarray) else out


def stopGrad(x, backend_name="numpy"):
backend.set_backend(backend_name)
b = backend.current()

def f(t):
return b.sum(b.mul(b.stop(t), t))

g = b.grad(f)
out = g(b.array(x, requires_grad=True))
out = to_numpy(out)
return out.tolist() if isinstance(out, np.ndarray) else out
26 changes: 26 additions & 0 deletions tests/kgtests/autograd/test_autograd.kg
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
.py("tests/kgtests/autograd/helpers.py")

t("∂(square;3) example only"; 6; 6) :" uses del symbol to call backend grad "
:" Test scalar square gradient "
tgrad::scalarSquareGrad(3)

t("scalarSquareGrad(3)"; tgrad; 6)

x::[0 1 2]
tvg::vectorElemwiseGrad(x)

t("vectorElemwiseGrad(x)"; tvg; (2*x)+3)

x1::[1 2 3]
y1::[4 5 6]
tmgx::mixedGradX(x1;y1)
tmgy::mixedGradY(x1;y1)

t("mixedGradX(x1;y1)"; tmgx; y1)

t("mixedGradY(x1;y1)"; tmgy; x1)

sst::stopGrad([2 3])

t("stopGrad([2 3])"; sst; [2 3])

99 changes: 96 additions & 3 deletions tests/test_autograd.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import unittest
import numpy as np
from tests.utils import to_numpy

from klongpy import backend


class TestAutograd(unittest.TestCase):
"""Autograd gradient checks using numpy and torch backends."""
def _check_matrix_grad(self, name: str):
try:
backend.set_backend(name)
Expand All @@ -17,17 +19,108 @@ def f(x):

g = b.grad(f)
x = b.array([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
grad = g(x)
if hasattr(grad, "detach"):
grad = grad.detach().cpu().numpy()
grad = to_numpy(g(x))
np.testing.assert_allclose(np.array(grad), np.array([[7.0, 11.0], [9.0, 13.0]]))

def _check_scalar_square_grad(self, name: str):
"""Verify ∂(x²)/∂x = 2x for a scalar input."""
try:
backend.set_backend(name)
except ImportError:
raise unittest.SkipTest(f"{name} backend not available")
b = backend.current()

def f(x):
return b.mul(x, x)

g = b.grad(f)
x = b.array(3.0, requires_grad=True)
grad = to_numpy(g(x))
np.testing.assert_allclose(np.array(grad), np.array(6.0))

def _check_vector_elemwise_grad(self, name: str):
"""Verify gradient of ∑(x+1)(x+2) = 2x+3 via the chain rule."""
try:
backend.set_backend(name)
except ImportError:
raise unittest.SkipTest(f"{name} backend not available")
b = backend.current()

def f(x):
return b.sum(b.mul(b.add(x, 1), b.add(x, 2)))

g = b.grad(f)
x = b.array([0.0, 1.0, 2.0], requires_grad=True)
grad = to_numpy(g(x))
expected = 2 * np.array([0.0, 1.0, 2.0]) + 3
np.testing.assert_allclose(np.array(grad), expected)

def _check_mixed_args_grad(self, name: str):
"""Verify gradient of the dot product x·y with respect to each argument."""
try:
backend.set_backend(name)
except ImportError:
raise unittest.SkipTest(f"{name} backend not available")
b = backend.current()

def f(x, y):
return b.sum(b.mul(x, y))

gx = b.grad(f, wrt=0)
gy = b.grad(f, wrt=1)
x = b.array([1.0, 2.0, 3.0], requires_grad=True)
y = b.array([4.0, 5.0, 6.0], requires_grad=True)
gradx = to_numpy(gx(x, y))
grady = to_numpy(gy(x, y))
np.testing.assert_allclose(np.array(gradx), np.array([4.0, 5.0, 6.0]))
np.testing.assert_allclose(np.array(grady), np.array([1.0, 2.0, 3.0]))

def _check_stop_grad(self, name: str):
"""Verify gradients ignore values detached with ``stop``."""
try:
backend.set_backend(name)
except ImportError:
raise unittest.SkipTest(f"{name} backend not available")
b = backend.current()

def f(x):
return b.sum(b.mul(b.stop(x), x))

g = b.grad(f)
x = b.array([2.0, 3.0], requires_grad=True)
grad = to_numpy(g(x))
np.testing.assert_allclose(np.array(grad), np.array([2.0, 3.0]))

def test_matrix_grad_numpy(self):
self._check_matrix_grad("numpy")

def test_matrix_grad_torch(self):
self._check_matrix_grad("torch")

def test_scalar_grad_numpy(self):
self._check_scalar_square_grad("numpy")

def test_scalar_grad_torch(self):
self._check_scalar_square_grad("torch")

def test_vector_elemwise_grad_numpy(self):
self._check_vector_elemwise_grad("numpy")

def test_vector_elemwise_grad_torch(self):
self._check_vector_elemwise_grad("torch")

def test_mixed_args_grad_numpy(self):
self._check_mixed_args_grad("numpy")

def test_mixed_args_grad_torch(self):
self._check_mixed_args_grad("torch")

def test_stop_grad_numpy(self):
self._check_stop_grad("numpy")

def test_stop_grad_torch(self):
self._check_stop_grad("torch")


if __name__ == "__main__":
unittest.main()
7 changes: 7 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
from klongpy.core import is_list, kg_equal


def to_numpy(val):
"""Return ``val`` as a NumPy array if backed by torch tensors."""
if hasattr(val, "detach"):
val = val.detach().cpu().numpy()
return val


def die(m=None):
raise RuntimeError(m)

Expand Down