Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 189 additions & 2 deletions klongpy/autograd.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,177 @@
import numbers
import numpy as np
from .core import KGLambda, KGCall, KGSym, KGFn


class AutodiffNotSupported(Exception):
"""Raised when forward-mode autodiff cannot handle an operation."""


_UNARY_DERIVATIVES = {
np.negative: lambda x: -1.0,
np.positive: lambda x: 1.0,
np.sin: np.cos,
np.cos: lambda x: -np.sin(x),
np.tan: lambda x: 1.0 / (np.cos(x) ** 2),
np.exp: np.exp,
np.expm1: np.exp,
np.log: lambda x: 1.0 / x,
np.log1p: lambda x: 1.0 / (1.0 + x),
np.sqrt: lambda x: 0.5 / np.sqrt(x),
np.square: lambda x: 2.0 * x,
np.reciprocal: lambda x: -1.0 / (x ** 2),
np.sinh: np.cosh,
np.cosh: np.sinh,
np.tanh: lambda x: 1.0 / (np.cosh(x) ** 2),
np.abs: np.sign,
np.log10: lambda x: 1.0 / (x * np.log(10.0)),
np.log2: lambda x: 1.0 / (x * np.log(2.0)),
}

_BINARY_DERIVATIVES = {
np.add: lambda x, y: (1.0, 1.0),
np.subtract: lambda x, y: (1.0, -1.0),
np.multiply: lambda x, y: (y, x),
np.divide: lambda x, y: (1.0 / y, -x / (y * y)),
np.true_divide: lambda x, y: (1.0 / y, -x / (y * y)),
}

_POWER_UFUNCS = {np.power, np.float_power}


class Dual:
"""Simple dual number for forward-mode autodiff on scalar inputs."""

__slots__ = ("value", "grad")
__array_priority__ = 1000

def __init__(self, value, grad=0.0):
self.value = float(value)
self.grad = float(grad)

@staticmethod
def _coerce(other):
if isinstance(other, Dual):
return other
if np.isarray(other):
raise AutodiffNotSupported("array operands are not supported by Dual")
if isinstance(other, numbers.Real):
return Dual(other, 0.0)
try:
return Dual(float(other), 0.0)
except (TypeError, ValueError) as exc:
raise AutodiffNotSupported(
f"unsupported operand of type {type(other)!r}"
) from exc

def __array__(self, dtype=None):
return np.array(self.value, dtype=dtype)

def __float__(self):
return float(self.value)

def __neg__(self):
return Dual(-self.value, -self.grad)

def __pos__(self):
return Dual(+self.value, +self.grad)

def __add__(self, other):
other = self._coerce(other)
return Dual(self.value + other.value, self.grad + other.grad)

def __radd__(self, other):
return self.__add__(other)

def __sub__(self, other):
other = self._coerce(other)
return Dual(self.value - other.value, self.grad - other.grad)

def __rsub__(self, other):
other = self._coerce(other)
return Dual(other.value - self.value, other.grad - self.grad)

def __mul__(self, other):
other = self._coerce(other)
grad = self.grad * other.value + other.grad * self.value
return Dual(self.value * other.value, grad)

def __rmul__(self, other):
return self.__mul__(other)

def __truediv__(self, other):
other = self._coerce(other)
if other.value == 0.0:
raise AutodiffNotSupported("division by zero is not supported")
grad = (self.grad * other.value - other.grad * self.value) / (other.value ** 2)
return Dual(self.value / other.value, grad)

def __rtruediv__(self, other):
other = self._coerce(other)
return other.__truediv__(self)

def __pow__(self, other):
other = self._coerce(other)
result = self.value ** other.value
grad = self.grad * other.value * (self.value ** (other.value - 1))
if other.grad != 0.0:
if self.value <= 0.0:
raise AutodiffNotSupported(
"differentiating w.r.t. the exponent requires a positive base"
)
grad += other.grad * result * np.log(self.value)
return Dual(result, grad)

def __rpow__(self, other):
other = self._coerce(other)
return other.__pow__(self)

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
if method != "__call__":
raise AutodiffNotSupported(f"ufunc method {method!r} is not supported")
if kwargs.get("out") is not None:
raise AutodiffNotSupported("ufunc out parameter is not supported")
if "where" in kwargs and not np.all(kwargs["where"]):
raise AutodiffNotSupported("where argument is not supported")

values = [x.value if isinstance(x, Dual) else x for x in inputs]
grads = [x.grad if isinstance(x, Dual) else 0.0 for x in inputs]

try:
result = getattr(ufunc, method)(*values, **kwargs)
except TypeError as exc:
raise AutodiffNotSupported from exc

try:
if ufunc in _UNARY_DERIVATIVES:
derivative = _UNARY_DERIVATIVES[ufunc](values[0])
grad = grads[0] * derivative
elif ufunc in _BINARY_DERIVATIVES:
dfdx, dfdy = _BINARY_DERIVATIVES[ufunc](values[0], values[1])
grad = grads[0] * dfdx + grads[1] * dfdy
elif ufunc in _POWER_UFUNCS:
base, exponent = values
dfdx = exponent * (base ** (exponent - 1))
grad = grads[0] * dfdx
if grads[1] != 0.0:
if base <= 0.0:
raise AutodiffNotSupported(
"differentiating power w.r.t. exponent requires positive base"
)
grad += grads[1] * (base ** exponent) * np.log(base)
else:
raise AutodiffNotSupported(f"ufunc {ufunc.__name__} is not supported")
except ZeroDivisionError as exc:
raise AutodiffNotSupported from exc

return Dual(result, grad)


def numeric_grad(func, x, eps=1e-6):
"""Compute numeric gradient of scalar-valued function."""
x = np.asarray(x, dtype=float)
grad = np.zeros_like(x, dtype=float)
it = np.nditer(x, flags=['multi_index'], op_flags=['readwrite'])
it = np.nditer(x, flags=["multi_index"], op_flags=["readwrite"])
while not it.finished:
idx = it.multi_index
orig = float(x[idx])
Expand All @@ -20,8 +185,26 @@ def numeric_grad(func, x, eps=1e-6):
return grad


def autodiff_grad(func, x):
"""Compute gradient using forward-mode autodiff for scalar ``x``."""
if np.ndim(x) != 0:
raise AutodiffNotSupported("autodiff only supports scalar inputs")
value = float(np.asarray(x))
dual = Dual(value, 1.0)
try:
result = func(dual)
except AutodiffNotSupported:
raise
except Exception as exc:
raise AutodiffNotSupported from exc
if isinstance(result, Dual):
return np.asarray(result.grad, dtype=float)
raise AutodiffNotSupported("function did not return a Dual value")


def grad_of_fn(klong, fn, x):
"""Return gradient of Klong or Python function ``fn`` at ``x``."""

def call_fn(v):
if isinstance(fn, (KGSym, KGLambda)):
return klong.call(KGCall(fn, [v], 1))
Expand All @@ -31,4 +214,8 @@ def call_fn(v):
return klong.call(KGCall(fn.a, [v], fn.arity))
else:
return fn(v)
return numeric_grad(call_fn, x)

try:
return autodiff_grad(call_fn, x)
except AutodiffNotSupported:
return numeric_grad(call_fn, x)
4 changes: 2 additions & 2 deletions klongpy/dyads.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .core import *
from .autograd import grad_of_fn, numeric_grad
from .autograd import grad_of_fn
import sys


Expand Down Expand Up @@ -993,7 +993,7 @@ def func(v):
finally:
klong[a] = orig

return numeric_grad(func, orig)
return grad_of_fn(klong, func, orig)
else:
return grad_of_fn(klong, b, a)

Expand Down
66 changes: 34 additions & 32 deletions klongpy/monads.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,39 +273,41 @@ def eval_monad_range(a):

"""
if isinstance(a, str):
return ''.join(np.unique(str_to_chr_arr(a)))
elif np.isarray(a):
if a.dtype != 'O' and a.ndim > 1:
_,ids = np.unique(a,axis=0,return_index=True)
seen = set()
ordered = []
for ch in a:
if ch not in seen:
seen.add(ch)
ordered.append(ch)
return ''.join(ordered)

if np.isarray(a):
arr = a
elif is_list(a):
arr = kg_asarray(a)
else:
return a

if arr.ndim == 0:
return arr

if arr.dtype != 'O':
if arr.ndim == 1:
_, idx = np.unique(arr, return_index=True)
else:
# handle the jagged / mixed array case
# from functools import total_ordering
# @total_ordering
# class Wrapper:
# def __init__(self, x):
# self.x = x
# def __eq__(self,o):
# print("eq")
# return array_equal(self.x, o.x)
# def __ne__(self,o):
# return not array_equal(self.x, o.x)
# def __lt__(self, o):
# u = np.sort(np.asarray([self.x, o.x]))
# return u[0] == self.x
# # return u[0] if isinstance(u,np.ndarray) else u
# _,ids = np.unique([Wrapper(x) for x in a], return_index=True)
# TODO: Make UNIQUE work. this feels so dirty.
s = set()
arr = []
for x in a:
sx = str(x)
if sx not in s:
s.add(sx)
arr.append(x)
return np.asarray(arr, dtype=object)
ids.sort()
a = a[ids]
return a
_, idx = np.unique(arr, axis=0, return_index=True)
idx.sort()
return arr[idx]

unique_values = []
for item in arr:
if not any(kg_equal(item, existing) for existing in unique_values):
unique_values.append(item)

if len(unique_values) == len(arr):
return arr

return np.asarray(unique_values, dtype=object)


def eval_monad_reciprocal(a):
Expand Down
36 changes: 36 additions & 0 deletions tests/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,42 @@ def test_scalar_grad(self):
r = klong('g(3.14)')
self.assertTrue(np.isclose(r, 2*3.14 + np.cos(3.14), atol=1e-3))

def test_scalar_grad_high_precision(self):
"""autodiff should provide very accurate gradients for scalars."""
klong = KlongInterpreter()
klong['exp'] = lambda x: np.exp(x)
klong('g::∇{exp(x)}')
r = klong('g(50)')
self.assertTrue(np.isclose(r, np.exp(50.0), atol=1e-12))

def test_scalar_grad_extended_ufuncs(self):
klong = KlongInterpreter()
klong['log'] = lambda x: np.log(x)
klong['sqrt'] = lambda x: np.sqrt(x)
klong['tanh'] = lambda x: np.tanh(x)
klong('g::∇{log(x)+sqrt(x)+tanh(x)}')
value = 2.5
r = klong(f'g({value})')
expected = (1.0 / value) + (0.5 / np.sqrt(value)) + (1.0 / np.cosh(value) ** 2)
self.assertTrue(np.isclose(r, expected, atol=1e-9))

def test_scalar_grad_python_operators(self):
klong = KlongInterpreter()
klong('g::∇{(x*x*x) - 3*x + 5}')
value = 4.0
r = klong(f'g({value})')
expected = 3 * (value ** 2) - 3
self.assertTrue(np.isclose(r, expected, atol=1e-9))

def test_scalar_grad_numeric_fallback(self):
klong = KlongInterpreter()
klong['relu'] = lambda x: np.maximum(x, 0.0)
klong('g::∇{relu(x)}')
pos = klong('g(3.0)')
neg = klong('g(-3.0)')
self.assertTrue(np.isclose(pos, 1.0, atol=1e-6))
self.assertTrue(np.isclose(neg, 0.0, atol=1e-6))

@unittest.skipUnless(TORCH_AVAILABLE, "torch required")
def test_scalar_grad_torch(self):
klong = KlongInterpreter()
Expand Down
4 changes: 4 additions & 0 deletions tests/test_extra_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ def test_range_over_nested_arrays(self):
self.assert_eval_cmp('?[[[0 0] [0 0] [1 1]] [1 1] [1 1] 3 3]', '[[[0 0] [0 0] [1 1]] [1 1] 3]')
self.assert_eval_cmp('?[[0 0] [1 0] [2 0] [3 0] [4 1] [4 2] [4 3] [3 4] [2 4] [3 3] [4 3] [3 2] [2 2] [1 2]]', '[[0 0] [1 0] [2 0] [3 0] [4 1] [4 2] [4 3] [3 4] [2 4] [3 3] [3 2] [2 2] [1 2]]')

def test_range_distinguishes_types(self):
self.assert_eval_cmp('?[10 "10"]', '[10 "10"]')
self.assert_eval_cmp('?[:foo ":foo"]', '[:foo ":foo"]')

def test_sum_over_nested_arrays(self):
"""
sum over nested arrays should reduce
Expand Down