diff --git a/CLAUDE.md b/CLAUDE.md index 544eccb..43b2bfd 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -11,17 +11,26 @@ DNLP-diff-engine is a C library with Python bindings that provides automatic dif ### Python Package (Recommended) ```bash -# Install in development mode (builds C library + Python bindings) -pip install -e . +# Install in development mode with uv (recommended) +uv pip install -e ".[test]" -# Run all Python tests +# Or with pip +pip install -e ".[test]" + +# Run all Python tests (tests are in python/tests/) pytest # Run specific test file -pytest tests/python/test_unconstrained.py +pytest python/tests/test_unconstrained.py # Run specific test -pytest tests/python/test_unconstrained.py::test_sum_log +pytest python/tests/test_unconstrained.py::test_sum_log + +# Lint with ruff +ruff check src/ + +# Auto-fix lint issues +ruff check --fix src/ ``` ### Standalone C Library @@ -35,16 +44,6 @@ cmake --build build ./build/all_tests ``` -### Legacy Python Build (without pip) - -```bash -# Build Python bindings manually (from project root) -cd python && cmake -B build -S . && cmake --build build && cd .. - -# Run tests from python/ directory -cd python && python tests/test_problem_native.py -``` - ## Architecture ### Expression Tree System @@ -59,10 +58,10 @@ The core abstraction is the `expr` struct (in `include/expr.h`) representing a n Atoms are organized by mathematical properties in `src/`: -- **`affine/`** - Linear operations: `variable`, `constant`, `add`, `neg`, `sum`, `promote`, `hstack`, `trace`, `linear_op` -- **`elementwise_univariate/`** - Scalar functions applied elementwise: `log`, `exp`, `entr`, `power`, `logistic`, trigonometric, hyperbolic -- **`bivariate/`** - Two-argument operations: `multiply`, `quad_over_lin`, `rel_entr` -- **`other/`** - Special atoms not fitting above categories +- **`affine/`** - Linear operations: `variable`, `constant`, `add`, `neg`, `sum`, `promote`, `hstack`, `trace`, `linear_op`, `index` +- **`elementwise_univariate/`** - Scalar functions applied elementwise: `log`, `exp`, `entr`, `power`, `logistic`, `xexp`, trigonometric (`sin`, `cos`, `tan`), hyperbolic (`sinh`, `tanh`, `asinh`, `atanh`) +- **`bivariate/`** - Two-argument operations: `multiply`, `quad_over_lin`, `rel_entr`, `const_scalar_mult`, `const_vector_mult`, `left_matmul`, `right_matmul` +- **`other/`** - Special atoms: `quad_form`, `prod` Each atom implements its own `forward`, `jacobian_init`, `eval_jacobian`, and `eval_wsum_hess` functions following a consistent pattern. @@ -87,7 +86,8 @@ The Python package `dnlp_diff_engine` (in `src/dnlp_diff_engine/`) provides: **High-level API** (`__init__.py`): - `C_problem` class wraps the C problem struct - `convert_problem()` builds expression trees from CVXPY Problem objects -- Atoms are mapped via `ATOM_CONVERTERS` dictionary +- Atoms are mapped via `ATOM_CONVERTERS` dictionary (maps CVXPY atom names → converter functions) +- Special converters handle: matrix multiplication (`_convert_matmul`), multiply with constants (`_convert_multiply`), indexing, reshape (Fortran order only) **Low-level C extension** (`_core` module, built from `python/bindings.c`): - Atom constructors: `make_variable`, `make_constant`, `make_log`, `make_exp`, `make_add`, etc. @@ -108,14 +108,15 @@ Hessian computes weighted sum: `obj_w * H_obj + sum(lambda_i * H_constraint_i)` ## Key Directories -- `include/` - Header files defining public API -- `src/` - C implementation files -- `src/dnlp_diff_engine/` - Python package (installed via pip) -- `python/` - Python bindings C code and binding headers +- `include/` - Header files defining public API (`expr.h`, `problem.h`, atom headers) +- `src/` - C implementation files organized by atom category +- `src/dnlp_diff_engine/` - Python package with high-level API +- `python/` - Python bindings C code (`bindings.c`) - `python/atoms/` - Python binding headers for each atom type - `python/problem/` - Python binding headers for problem interface +- `python/tests/` - Python integration tests (run via pytest) - `tests/` - C tests using minunit framework -- `tests/python/` - Python tests (run via pytest) +- `tests/forward_pass/` - Forward evaluation tests (C) - `tests/jacobian_tests/` - Jacobian correctness tests (C) - `tests/wsum_hess/` - Hessian correctness tests (C) diff --git a/include/affine.h b/include/affine.h index 654557b..c1dd13b 100644 --- a/include/affine.h +++ b/include/affine.h @@ -18,4 +18,6 @@ expr *new_trace(expr *child); expr *new_constant(int d1, int d2, int n_vars, const double *values); expr *new_variable(int d1, int d2, int var_id, int n_vars); +expr *new_index(expr *child, const int *indices, int n_idxs); + #endif /* AFFINE_H */ diff --git a/include/subexpr.h b/include/subexpr.h index 8239a7d..97a31bf 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -109,4 +109,13 @@ typedef struct const_vector_mult_expr double *a; /* length equals node->size */ } const_vector_mult_expr; +/* Index/slicing: y = child[indices] where indices is a list of flat positions */ +typedef struct index_expr +{ + expr base; + int *indices; /* Flattened indices to select (owned, copied) */ + int n_idxs; /* Number of selected elements */ + bool has_duplicates; /* True if indices have duplicates (affects Hessian path) */ +} index_expr; + #endif /* SUBEXPR_H */ diff --git a/python/atoms/index.h b/python/atoms/index.h new file mode 100644 index 0000000..035ac56 --- /dev/null +++ b/python/atoms/index.h @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: Apache-2.0 + +#ifndef ATOM_INDEX_H +#define ATOM_INDEX_H + +#include "affine.h" +#include "common.h" + +/* Index/slicing: y = child[indices] where indices is a list of flattened positions + */ +static PyObject *py_make_index(PyObject *self, PyObject *args) +{ + PyObject *child_capsule; + PyObject *indices_obj; + + if (!PyArg_ParseTuple(args, "OO", &child_capsule, &indices_obj)) + { + return NULL; + } + + expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); + if (!child) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + /* Convert indices array to int32 */ + PyArrayObject *indices_array = (PyArrayObject *) PyArray_FROM_OTF( + indices_obj, NPY_INT32, NPY_ARRAY_IN_ARRAY); + + if (!indices_array) + { + return NULL; + } + + int n_idxs = (int) PyArray_SIZE(indices_array); + int *indices_data = (int *) PyArray_DATA(indices_array); + + expr *node = new_index(child, indices_data, n_idxs); + + Py_DECREF(indices_array); + + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create index node"); + return NULL; + } + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_INDEX_H */ diff --git a/python/atoms/quad_over_lin.h b/python/atoms/quad_over_lin.h new file mode 100644 index 0000000..4cb3b91 --- /dev/null +++ b/python/atoms/quad_over_lin.h @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: Apache-2.0 + +#ifndef ATOM_QUAD_OVER_LIN_H +#define ATOM_QUAD_OVER_LIN_H + +#include "bivariate.h" +#include "common.h" + +/* quad_over_lin: y = sum(x^2) / z where x is left, z is right (scalar) */ +static PyObject *py_make_quad_over_lin(PyObject *self, PyObject *args) +{ + (void)self; + PyObject *left_capsule, *right_capsule; + if (!PyArg_ParseTuple(args, "OO", &left_capsule, &right_capsule)) + { + return NULL; + } + expr *left = (expr *)PyCapsule_GetPointer(left_capsule, EXPR_CAPSULE_NAME); + expr *right = (expr *)PyCapsule_GetPointer(right_capsule, EXPR_CAPSULE_NAME); + if (!left || !right) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + expr *node = new_quad_over_lin(left, right); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create quad_over_lin node"); + return NULL; + } + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_QUAD_OVER_LIN_H */ diff --git a/python/atoms/rel_entr.h b/python/atoms/rel_entr.h new file mode 100644 index 0000000..e233d17 --- /dev/null +++ b/python/atoms/rel_entr.h @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: Apache-2.0 + +#ifndef ATOM_REL_ENTR_H +#define ATOM_REL_ENTR_H + +#include "bivariate.h" +#include "common.h" + +/* rel_entr: rel_entr(x, y) = x * log(x/y) elementwise */ +static PyObject *py_make_rel_entr(PyObject *self, PyObject *args) +{ + (void)self; + PyObject *left_capsule, *right_capsule; + if (!PyArg_ParseTuple(args, "OO", &left_capsule, &right_capsule)) + { + return NULL; + } + expr *left = (expr *)PyCapsule_GetPointer(left_capsule, EXPR_CAPSULE_NAME); + expr *right = (expr *)PyCapsule_GetPointer(right_capsule, EXPR_CAPSULE_NAME); + if (!left || !right) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + expr *node = new_rel_entr_vector_args(left, right); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create rel_entr node"); + return NULL; + } + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_REL_ENTR_H */ diff --git a/python/bindings.c b/python/bindings.c index 3045e2b..1042701 100644 --- a/python/bindings.c +++ b/python/bindings.c @@ -12,6 +12,7 @@ #include "atoms/cos.h" #include "atoms/entr.h" #include "atoms/exp.h" +#include "atoms/index.h" #include "atoms/left_matmul.h" #include "atoms/linear.h" #include "atoms/log.h" @@ -21,6 +22,8 @@ #include "atoms/power.h" #include "atoms/promote.h" #include "atoms/quad_form.h" +#include "atoms/quad_over_lin.h" +#include "atoms/rel_entr.h" #include "atoms/right_matmul.h" #include "atoms/sin.h" #include "atoms/sinh.h" @@ -55,6 +58,7 @@ static PyMethodDef DNLPMethods[] = { {"make_linear", py_make_linear, METH_VARARGS, "Create linear op node"}, {"make_log", py_make_log, METH_VARARGS, "Create log node"}, {"make_exp", py_make_exp, METH_VARARGS, "Create exp node"}, + {"make_index", py_make_index, METH_VARARGS, "Create index node"}, {"make_add", py_make_add, METH_VARARGS, "Create add node"}, {"make_sum", py_make_sum, METH_VARARGS, "Create sum node"}, {"make_neg", py_make_neg, METH_VARARGS, "Create neg node"}, @@ -82,6 +86,10 @@ static PyMethodDef DNLPMethods[] = { "Create right matmul node (f(x) @ A)"}, {"make_quad_form", py_make_quad_form, METH_VARARGS, "Create quadratic form node (x' * Q * x)"}, + {"make_quad_over_lin", py_make_quad_over_lin, METH_VARARGS, + "Create quad_over_lin node (sum(x^2) / y)"}, + {"make_rel_entr", py_make_rel_entr, METH_VARARGS, + "Create rel_entr node: x * log(x/y) elementwise"}, {"make_problem", py_make_problem, METH_VARARGS, "Create problem from objective and constraints"}, {"problem_init_derivatives", py_problem_init_derivatives, METH_VARARGS, diff --git a/src/affine/index.c b/src/affine/index.c new file mode 100644 index 0000000..c407daa --- /dev/null +++ b/src/affine/index.c @@ -0,0 +1,173 @@ +#include "affine.h" +#include "subexpr.h" +#include +#include + +/* Index/slicing: y = child[indices] where indices is a list of flat positions */ + +/* Check if indices array contains duplicates using a bitmap. + * Returns true if duplicates exist, false otherwise. */ +static bool check_for_duplicates(const int *indices, int n_idxs, int max_idx) +{ + bool *seen = (bool *) calloc(max_idx, sizeof(bool)); + bool has_dup = false; + for (int i = 0; i < n_idxs && !has_dup; i++) + { + if (seen[indices[i]]) + { + has_dup = true; + } + seen[indices[i]] = true; + } + free(seen); + return has_dup; +} + +static void forward(expr *node, const double *u) +{ + expr *x = node->left; + index_expr *idx = (index_expr *) node; + + /* child's forward pass */ + x->forward(x, u); + + /* gather selected elements */ + for (int i = 0; i < idx->n_idxs; i++) + { + node->value[i] = x->value[idx->indices[i]]; + } +} + +static void jacobian_init(expr *node) +{ + expr *x = node->left; + index_expr *idx = (index_expr *) node; + x->jacobian_init(x); + + CSR_Matrix *Jx = x->jacobian; + CSR_Matrix *J = new_csr_matrix(node->size, node->n_vars, Jx->nnz); + + /* set sparsity pattern */ + J->p[0] = 0; + for (int i = 0; i < idx->n_idxs; i++) + { + int row = idx->indices[i]; + int len = Jx->p[row + 1] - Jx->p[row]; + memcpy(J->i + J->p[i], Jx->i + Jx->p[row], len * sizeof(int)); + J->p[i + 1] = J->p[i] + len; + } + + node->jacobian = J; +} + +static void eval_jacobian(expr *node) +{ + expr *x = node->left; + index_expr *idx = (index_expr *) node; + x->eval_jacobian(x); + + CSR_Matrix *J = node->jacobian; + CSR_Matrix *Jx = x->jacobian; + + for (int i = 0; i < idx->n_idxs; i++) + { + int row = idx->indices[i]; + int len = J->p[i + 1] - J->p[i]; + memcpy(J->x + J->p[i], Jx->x + Jx->p[row], len * sizeof(double)); + } +} + +static void wsum_hess_init(expr *node) +{ + expr *x = node->left; + + /* initialize child's wsum_hess */ + x->wsum_hess_init(x); + + /* for setting weight vector to evaluate hessian of child */ + node->dwork = (double *) calloc(x->size, sizeof(double)); + + /* in the implementation of eval_wsum_hess we evaluate the + child's hessian with a weight vector that has w[i] = 0 + if i is not included in idx->indices. This can lead to + many numerical zeros in child->wsum_hess that are actually + structural zeros, but we do not try to exploit that sparsity + right now. */ + CSR_Matrix *H_child = x->wsum_hess; + node->wsum_hess = new_csr_matrix(H_child->m, H_child->n, H_child->nnz); + memcpy(node->wsum_hess->p, H_child->p, (H_child->m + 1) * sizeof(int)); + memcpy(node->wsum_hess->i, H_child->i, H_child->nnz * sizeof(int)); +} + +static void eval_wsum_hess(expr *node, const double *w) +{ + expr *child = node->left; + index_expr *idx = (index_expr *) node; + + if (idx->has_duplicates) + { + /* zero and accumulate for repeated indices */ + memset(node->dwork, 0, child->size * sizeof(double)); + for (int i = 0; i < idx->n_idxs; i++) + { + node->dwork[idx->indices[i]] += w[i]; + } + } + else + { + /* direct write (no memset needed, no accumulation) */ + for (int i = 0; i < idx->n_idxs; i++) + { + node->dwork[idx->indices[i]] = w[i]; + } + } + + /* delegate to child */ + child->eval_wsum_hess(child, node->dwork); + + /* copy values from child */ + memcpy(node->wsum_hess->x, child->wsum_hess->x, + child->wsum_hess->nnz * sizeof(double)); +} + +static bool is_affine(const expr *node) +{ + return node->left->is_affine(node->left); +} + +static void free_type_data(expr *node) +{ + index_expr *idx = (index_expr *) node; + if (idx->indices) + { + free(idx->indices); + idx->indices = NULL; + } +} + +expr *new_index(expr *child, const int *indices, int n_idxs) +{ + /* allocate type-specific struct */ + index_expr *idx = (index_expr *) calloc(1, sizeof(index_expr)); + expr *node = &idx->base; + + /* output shape is (n_idxs, 1) - flattened */ + init_expr(node, n_idxs, 1, child->n_vars, forward, jacobian_init, eval_jacobian, + is_affine, free_type_data); + + node->left = child; + expr_retain(child); + + node->wsum_hess_init = wsum_hess_init; + node->eval_wsum_hess = eval_wsum_hess; + + /* copy indices */ + idx->indices = (int *) malloc(n_idxs * sizeof(int)); + memcpy(idx->indices, indices, n_idxs * sizeof(int)); + idx->n_idxs = n_idxs; + + /* detect duplicates for Hessian optimization */ + idx->has_duplicates = check_for_duplicates(indices, n_idxs, child->size); + + return node; +} diff --git a/src/dnlp_diff_engine/__init__.py b/src/dnlp_diff_engine/__init__.py index 3a10c67..a8cda90 100644 --- a/src/dnlp_diff_engine/__init__.py +++ b/src/dnlp_diff_engine/__init__.py @@ -100,6 +100,50 @@ def _convert_multiply(expr, children): return _diffengine.make_multiply(children[0], children[1]) +def _extract_flat_indices_from_index(expr): + """Extract flattened indices from CVXPY index expression.""" + parent_shape = expr.args[0].shape + indices_per_dim = [np.arange(s.start, s.stop, s.step) for s in expr.key] + + if len(indices_per_dim) == 1: + return indices_per_dim[0].astype(np.int32) + elif len(indices_per_dim) == 2: + # Fortran order: idx = row + col * n_rows + return np.add.outer( + indices_per_dim[0], indices_per_dim[1] * parent_shape[0] + ).flatten(order="F").astype(np.int32) + else: + raise NotImplementedError("index with >2 dimensions not supported") + + +def _extract_flat_indices_from_special_index(expr): + """Extract flattened indices from CVXPY special_index expression.""" + return np.reshape(expr._select_mat, expr._select_mat.size, order="F").astype(np.int32) + + +def _convert_reshape(expr, children): + """Convert reshape - only Fortran order is supported. + + For Fortran order, reshape is a no-op since the underlying data layout + is unchanged. We just pass through the child expression. + + Note: Only order='F' (Fortran/column-major) is supported. This is the + default and most common case in CVXPY. C order would require permutation. + """ + if expr.order != 'F': + raise NotImplementedError( + f"reshape with order='{expr.order}' not supported. " + "Only order='F' (Fortran) is currently supported." + ) + # Pass through - data layout is unchanged with Fortran order + return children[0] + + +def _convert_rel_entr(_expr, children): + """Convert rel_entr(x, y) = x * log(x/y) elementwise.""" + return _diffengine.make_rel_entr(children[0], children[1]) + + def _convert_quad_form(expr, children): """Convert quadratic form x.T @ P @ x.""" @@ -145,6 +189,10 @@ def _convert_quad_form(expr, children): # Bivariate "multiply": _convert_multiply, "QuadForm": _convert_quad_form, + "quad_over_lin": lambda _expr, children: _diffengine.make_quad_over_lin( + children[0], children[1] + ), + "rel_entr": _convert_rel_entr, # Matrix multiplication "MulExpression": _convert_matmul, # Elementwise univariate with parameter @@ -162,6 +210,15 @@ def _convert_quad_form(expr, children): "entr": lambda _expr, children: _diffengine.make_entr(children[0]), "logistic": lambda _expr, children: _diffengine.make_logistic(children[0]), "xexp": lambda _expr, children: _diffengine.make_xexp(children[0]), + # Indexing/slicing + "index": lambda expr, children: _diffengine.make_index( + children[0], _extract_flat_indices_from_index(expr) + ), + "special_index": lambda expr, children: _diffengine.make_index( + children[0], _extract_flat_indices_from_special_index(expr) + ), + # Reshape (Fortran order only - pass-through since data layout unchanged) + "reshape": _convert_reshape, } diff --git a/src/other/quad_form.c b/src/other/quad_form.c index 04badb8..1fb68a3 100644 --- a/src/other/quad_form.c +++ b/src/other/quad_form.c @@ -27,7 +27,6 @@ static void forward(expr *node, const double *u) static void jacobian_init(expr *node) { - CSR_Matrix *Q = ((quad_form_expr *) node)->Q; assert(node->left->var_id != NOT_A_VARIABLE); assert(node->left->d2 == 1); expr *x = node->left; diff --git a/tests/all_tests.c b/tests/all_tests.c index 5d65653..9c7692d 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -18,6 +18,7 @@ #include "jacobian_tests/test_const_vector_mult.h" #include "jacobian_tests/test_elementwise_mult.h" #include "jacobian_tests/test_hstack.h" +#include "jacobian_tests/test_index.h" #include "jacobian_tests/test_left_matmul.h" #include "jacobian_tests/test_log.h" #include "jacobian_tests/test_neg.h" @@ -43,6 +44,7 @@ #include "wsum_hess/test_const_scalar_mult.h" #include "wsum_hess/test_const_vector_mult.h" #include "wsum_hess/test_hstack.h" +#include "wsum_hess/test_index.h" #include "wsum_hess/test_left_matmul.h" #include "wsum_hess/test_multiply.h" #include "wsum_hess/test_prod.h" @@ -111,6 +113,12 @@ int main(void) mu_run_test(test_jacobian_sum_log_axis_1, tests_run); mu_run_test(test_jacobian_hstack_vectors, tests_run); mu_run_test(test_jacobian_hstack_matrix, tests_run); + mu_run_test(test_index_forward_simple, tests_run); + mu_run_test(test_index_forward_repeated, tests_run); + mu_run_test(test_index_jacobian_of_variable, tests_run); + mu_run_test(test_index_jacobian_of_log, tests_run); + mu_run_test(test_index_jacobian_repeated, tests_run); + mu_run_test(test_sum_of_index, tests_run); mu_run_test(test_promote_scalar_jacobian, tests_run); mu_run_test(test_promote_scalar_to_matrix_jacobian, tests_run); mu_run_test(test_wsum_hess_multiply_1, tests_run); @@ -150,6 +158,9 @@ int main(void) mu_run_test(test_wsum_hess_rel_entr_matrix, tests_run); mu_run_test(test_wsum_hess_hstack, tests_run); mu_run_test(test_wsum_hess_hstack_matrix, tests_run); + mu_run_test(test_wsum_hess_index_log, tests_run); + mu_run_test(test_wsum_hess_index_repeated, tests_run); + mu_run_test(test_wsum_hess_sum_index_log, tests_run); mu_run_test(test_wsum_hess_quad_over_lin_xy, tests_run); mu_run_test(test_wsum_hess_quad_over_lin_yx, tests_run); mu_run_test(test_wsum_hess_quad_form, tests_run); diff --git a/tests/jacobian_tests/test_index.h b/tests/jacobian_tests/test_index.h new file mode 100644 index 0000000..7f3c905 --- /dev/null +++ b/tests/jacobian_tests/test_index.h @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "affine.h" +#include "elementwise_univariate.h" +#include "expr.h" +#include "minunit.h" +#include "test_helpers.h" + +const char *test_index_forward_simple(void) +{ + /* x = [1, 2, 3], select indices [0, 2] -> [1, 3] */ + double u[3] = {1.0, 2.0, 3.0}; + int indices[2] = {0, 2}; + expr *var = new_variable(3, 1, 0, 3); + expr *idx = new_index(var, indices, 2); + idx->forward(idx, u); + + double expected[2] = {1.0, 3.0}; + mu_assert("index forward simple", cmp_double_array(idx->value, expected, 2)); + + free_expr(idx); + return 0; +} + +const char *test_index_forward_repeated(void) +{ + /* x = [1, 2, 3], select [0, 0, 2] -> [1, 1, 3] */ + double u[3] = {1.0, 2.0, 3.0}; + int indices[3] = {0, 0, 2}; + expr *var = new_variable(3, 1, 0, 3); + expr *idx = new_index(var, indices, 3); + idx->forward(idx, u); + + double expected[3] = {1.0, 1.0, 3.0}; + mu_assert("index forward repeated", cmp_double_array(idx->value, expected, 3)); + + free_expr(idx); + return 0; +} + +const char *test_index_jacobian_of_variable(void) +{ + /* x[0, 2] where x is variable of size 3 + * Jacobian should select rows 0 and 2 of identity */ + double u[3] = {1.0, 2.0, 3.0}; + int indices[2] = {0, 2}; + expr *var = new_variable(3, 1, 0, 3); + expr *idx = new_index(var, indices, 2); + idx->forward(idx, u); + idx->jacobian_init(idx); + idx->eval_jacobian(idx); + + /* Jacobian is 2x3 with pattern: row 0 selects col 0, row 1 selects col 2 */ + double expected_x[2] = {1.0, 1.0}; + int expected_p[3] = {0, 1, 2}; /* CSR row ptrs */ + int expected_i[2] = {0, 2}; /* column indices */ + + mu_assert("index jac vals", cmp_double_array(idx->jacobian->x, expected_x, 2)); + mu_assert("index jac p", cmp_int_array(idx->jacobian->p, expected_p, 3)); + mu_assert("index jac i", cmp_int_array(idx->jacobian->i, expected_i, 2)); + + free_expr(idx); + return 0; +} + +const char *test_index_jacobian_of_log(void) +{ + /* log(x)[0, 2] - jacobian should be selected rows of diag(1/x) */ + double u[3] = {1.0, 2.0, 4.0}; + int indices[2] = {0, 2}; + expr *var = new_variable(3, 1, 0, 3); + expr *log_node = new_log(var); + expr *idx = new_index(log_node, indices, 2); + idx->forward(idx, u); + idx->jacobian_init(idx); + idx->eval_jacobian(idx); + + /* d/dx log(x) = diag(1/x), then select rows 0 and 2 + * Row 0: 1/1 = 1.0 at col 0 + * Row 1: 1/4 = 0.25 at col 2 */ + double expected_x[2] = {1.0, 0.25}; + int expected_i[2] = {0, 2}; + + mu_assert("index of log jac vals", + cmp_double_array(idx->jacobian->x, expected_x, 2)); + mu_assert("index of log jac cols", + cmp_int_array(idx->jacobian->i, expected_i, 2)); + + free_expr(idx); + return 0; +} + +const char *test_index_jacobian_repeated(void) +{ + /* x[0, 0] where x is size 3 - both outputs depend on x[0] */ + double u[3] = {1.0, 2.0, 3.0}; + int indices[2] = {0, 0}; + expr *var = new_variable(3, 1, 0, 3); + expr *idx = new_index(var, indices, 2); + idx->forward(idx, u); + idx->jacobian_init(idx); + idx->eval_jacobian(idx); + + /* Both rows select column 0 */ + double expected_x[2] = {1.0, 1.0}; + int expected_p[3] = {0, 1, 2}; + int expected_i[2] = {0, 0}; /* Both reference col 0 */ + + mu_assert("index repeated jac vals", + cmp_double_array(idx->jacobian->x, expected_x, 2)); + mu_assert("index repeated row ptr", + cmp_int_array(idx->jacobian->p, expected_p, 4)); + mu_assert("index repeated jac i", + cmp_int_array(idx->jacobian->i, expected_i, 2)); + + free_expr(idx); + return 0; +} + +const char *test_sum_of_index(void) +{ + /* sum(x[0, 2]) = x[0] + x[2] + * Gradient should be sparse with entries at cols 0 and 2 */ + double u[3] = {1.0, 2.0, 3.0}; + int indices[2] = {0, 2}; + + expr *var = new_variable(3, 1, 0, 3); + expr *idx = new_index(var, indices, 2); + expr *s = new_sum(idx, -1); /* sum all */ + + s->forward(s, u); + s->jacobian_init(s); + s->eval_jacobian(s); + + /* Gradient: [1, 0, 1] in sparse form */ + double expected_x[2] = {1.0, 1.0}; + int expected_i[2] = {0, 2}; + + mu_assert("sum of index vals", cmp_double_array(s->jacobian->x, expected_x, 2)); + mu_assert("sum of index cols", cmp_int_array(s->jacobian->i, expected_i, 2)); + + free_expr(s); + return 0; +} diff --git a/tests/wsum_hess/test_index.h b/tests/wsum_hess/test_index.h new file mode 100644 index 0000000..0a2fb9c --- /dev/null +++ b/tests/wsum_hess/test_index.h @@ -0,0 +1,116 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include +#include + +#include "affine.h" +#include "elementwise_univariate.h" +#include "expr.h" +#include "minunit.h" +#include "test_helpers.h" + +const char *test_wsum_hess_index_log(void) +{ + /* log(x)[0, 2] with w = [1, 2] + * Hessian of log is diag(-1/x^2) + * After scatter: parent_w = [1, 0, 2] + * Hessian: diag(-1/x^2) weighted by [1, 0, 2] = diag([-1, 0, -0.125]) */ + double u[3] = {1.0, 2.0, 4.0}; + int indices[2] = {0, 2}; + double w[2] = {1.0, 2.0}; + + expr *var = new_variable(3, 1, 0, 3); + expr *log_node = new_log(var); + expr *idx = new_index(log_node, indices, 2); + + idx->forward(idx, u); + idx->jacobian_init(idx); + idx->wsum_hess_init(idx); + idx->eval_wsum_hess(idx, w); + + /* Expected diagonal values: + * H[0,0] = -1 * 1/1^2 = -1.0 + * H[1,1] = 0 (no weight) + * H[2,2] = -2 * 1/16 = -0.125 */ + double expected_x[3] = {-1.0, 0.0, -0.125}; + int expected_p[4] = {0, 1, 2, 3}; + int expected_i[3] = {0, 1, 2}; + + mu_assert("index log hess vals", cmp_double_array(idx->wsum_hess->x, expected_x, 3)); + mu_assert("index log hess p", cmp_int_array(idx->wsum_hess->p, expected_p, 4)); + mu_assert("index log hess i", cmp_int_array(idx->wsum_hess->i, expected_i, 3)); + + free_expr(idx); + return 0; +} + +const char *test_wsum_hess_index_repeated(void) +{ + /* log(x)[0, 0] with w = [1, 2] - weights should accumulate + * parent_w = [3, 0, 0] (1+2 at position 0) */ + double u[3] = {2.0, 3.0, 4.0}; + int indices[2] = {0, 0}; + double w[2] = {1.0, 2.0}; + + expr *var = new_variable(3, 1, 0, 3); + expr *log_node = new_log(var); + expr *idx = new_index(log_node, indices, 2); + + idx->forward(idx, u); + idx->jacobian_init(idx); + idx->wsum_hess_init(idx); + idx->eval_wsum_hess(idx, w); + + /* Hessian of log at x=2 is -1/4 + * weighted by 3 (accumulated) -> -3/4 = -0.75 + * Other positions have 0 weight */ + double expected_x[3] = {-0.75, 0.0, 0.0}; + int expected_p[4] = {0, 1, 2, 3}; + int expected_i[3] = {0, 1, 2}; + + mu_assert("index repeated hess vals", cmp_double_array(idx->wsum_hess->x, expected_x, 3)); + mu_assert("index repeated hess p", cmp_int_array(idx->wsum_hess->p, expected_p, 4)); + mu_assert("index repeated hess i", cmp_int_array(idx->wsum_hess->i, expected_i, 3)); + + free_expr(idx); + return 0; +} + +const char *test_wsum_hess_sum_index_log(void) +{ + /* sum(log(x)[0, 2]) with w = 1.0 + * This tests the chain: sum -> index -> log -> variable + * Gradient: [1/x[0], 0, 1/x[2]] + * Hessian: diag([-1/x[0]^2, 0, -1/x[2]^2]) */ + double u[3] = {1.0, 2.0, 4.0}; + int indices[2] = {0, 2}; + double w = 1.0; + + expr *var = new_variable(3, 1, 0, 3); + expr *log_node = new_log(var); + expr *idx = new_index(log_node, indices, 2); + expr *sum_node = new_sum(idx, -1); + + sum_node->forward(sum_node, u); + sum_node->jacobian_init(sum_node); + sum_node->wsum_hess_init(sum_node); + sum_node->eval_wsum_hess(sum_node, &w); + + /* Expected diagonal values: + * H[0,0] = -1 * 1/1^2 = -1.0 + * H[1,1] = 0 (index 1 not selected) + * H[2,2] = -1 * 1/16 = -0.0625 */ + double expected_x[3] = {-1.0, 0.0, -0.0625}; + int expected_p[4] = {0, 1, 2, 3}; + int expected_i[3] = {0, 1, 2}; + + mu_assert("sum index log hess vals", + cmp_double_array(sum_node->wsum_hess->x, expected_x, 3)); + mu_assert("sum index log hess p", cmp_int_array(sum_node->wsum_hess->p, expected_p, 4)); + mu_assert("sum index log hess i", cmp_int_array(sum_node->wsum_hess->i, expected_i, 3)); + + free_expr(sum_node); + return 0; +}