From 32268fdc4f5b6ad2ec46481247fda6be92d6140e Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 22 Jan 2026 09:20:37 -0800 Subject: [PATCH] trace implementation --- include/affine.h | 1 + python/atoms/trace.h | 1 - python/atoms/transpose.h | 32 +++++++ python/bindings.c | 5 +- src/affine/transpose.c | 123 ++++++++++++++++++++++++++ tests/all_tests.c | 4 + tests/jacobian_tests/test_transpose.h | 47 ++++++++++ tests/wsum_hess/test_transpose.h | 40 +++++++++ 8 files changed, 250 insertions(+), 3 deletions(-) create mode 100644 python/atoms/transpose.h create mode 100644 src/affine/transpose.c create mode 100644 tests/jacobian_tests/test_transpose.h create mode 100644 tests/wsum_hess/test_transpose.h diff --git a/include/affine.h b/include/affine.h index 5905e61..88b1f7b 100644 --- a/include/affine.h +++ b/include/affine.h @@ -21,5 +21,6 @@ expr *new_variable(int d1, int d2, int var_id, int n_vars); expr *new_index(expr *child, int d1, int d2, const int *indices, int n_idxs); expr *new_reshape(expr *child, int d1, int d2); expr *new_broadcast(expr *child, int target_d1, int target_d2); +expr *new_transpose(expr *child); #endif /* AFFINE_H */ diff --git a/python/atoms/trace.h b/python/atoms/trace.h index 9007a9a..e130aa2 100644 --- a/python/atoms/trace.h +++ b/python/atoms/trace.h @@ -1,5 +1,4 @@ #ifndef ATOM_TRACE_H - #define ATOM_TRACE_H #include "common.h" diff --git a/python/atoms/transpose.h b/python/atoms/transpose.h new file mode 100644 index 0000000..85c1dcc --- /dev/null +++ b/python/atoms/transpose.h @@ -0,0 +1,32 @@ +#ifndef PYTHON_ATOMS_TRANSPOSE_H +#define PYTHON_ATOMS_TRANSPOSE_H + +#include "common.h" + +// Python binding for the transpose atom +static PyObject *py_make_transpose(PyObject *self, PyObject *args) +{ + PyObject *child_capsule; + if (!PyArg_ParseTuple(args, "O", &child_capsule)) + { + return NULL; + } + expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); + + if (!child) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + expr *node = new_transpose(child); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create trace node"); + return NULL; + } + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif // PYTHON_ATOMS_TRANSPOSE_H diff --git a/python/bindings.c b/python/bindings.c index 10f263d..71d702a 100644 --- a/python/bindings.c +++ b/python/bindings.c @@ -41,6 +41,7 @@ #include "atoms/tan.h" #include "atoms/tanh.h" #include "atoms/trace.h" +#include "atoms/transpose.h" #include "atoms/variable.h" #include "atoms/xexp.h" @@ -71,8 +72,8 @@ static PyMethodDef DNLPMethods[] = { {"make_exp", py_make_exp, METH_VARARGS, "Create exp node"}, {"make_index", py_make_index, METH_VARARGS, "Create index node"}, {"make_add", py_make_add, METH_VARARGS, "Create add node"}, - {"make_trace", py_make_trace, METH_VARARGS, - "Create trace node from an expr capsule (make_trace(child))"}, + {"make_trace", py_make_trace, METH_VARARGS, "Create trace node"}, + {"make_transpose", py_make_transpose, METH_VARARGS, "Create transpose node"}, {"make_hstack", py_make_hstack, METH_VARARGS, "Create hstack node from list of expr capsules and n_vars (make_hstack([e1, " "e2, ...], n_vars))"}, diff --git a/src/affine/transpose.c b/src/affine/transpose.c new file mode 100644 index 0000000..7933631 --- /dev/null +++ b/src/affine/transpose.c @@ -0,0 +1,123 @@ +#include "affine.h" +#include +#include +#include + +// Forward pass for transpose atom +static void forward(expr *node, const double *u) +{ + /* forward pass for child */ + node->left->forward(node->left, u); + + /* local forward pass */ + int d1 = node->d1; + int d2 = node->d2; + double *X = node->left->value; + double *XT = node->value; + for (int i = 0; i < d1; ++i) + { + for (int j = 0; j < d2; ++j) + { + XT[j * d1 + i] = X[i * d2 + j]; + } + } +} + +static void jacobian_init(expr *node) +{ + expr *child = node->left; + child->jacobian_init(child); + CSR_Matrix *Jc = child->jacobian; + node->jacobian = new_csr_matrix(node->size, node->n_vars, Jc->nnz); + + /* fill sparsity */ + CSR_Matrix *J = node->jacobian; + int d1 = node->d1; + int d2 = node->d2; + int nnz = 0; + J->p[0] = 0; + + /* 'k' is the old row that gets swapped to 'row'*/ + int k, len; + for (int row = 0; row < J->m; ++row) + { + k = (row / d1) + (row % d1) * d2; + len = Jc->p[k + 1] - Jc->p[k]; + memcpy(J->i + nnz, Jc->i + Jc->p[k], len * sizeof(int)); + nnz += len; + J->p[row + 1] = nnz; + } +} + +static void eval_jacobian(expr *node) +{ + expr *child = node->left; + child->eval_jacobian(child); + CSR_Matrix *Jc = child->jacobian; + CSR_Matrix *J = node->jacobian; + + int d1 = node->d1; + int d2 = node->d2; + int nnz = 0; + for (int row = 0; row < J->m; ++row) + { + int k = (row / d1) + (row % d1) * d2; + int len = Jc->p[k + 1] - Jc->p[k]; + memcpy(J->x + nnz, Jc->x + Jc->p[k], len * sizeof(double)); + nnz += len; + } +} + +static void wsum_hess_init(expr *node) +{ + /* initialize child */ + expr *x = node->left; + x->wsum_hess_init(x); + + /* same sparsity pattern as child */ + CSR_Matrix *H = node->wsum_hess; + H = new_csr_matrix(x->wsum_hess->m, node->n_vars, x->wsum_hess->nnz); + memcpy(H->p, x->wsum_hess->p, (H->m + 1) * sizeof(int)); + memcpy(H->i, x->wsum_hess->i, H->nnz * sizeof(int)); + node->wsum_hess = H; + + /* for computing Kw where K is the commutation matrix */ + node->dwork = (double *) malloc(node->size * sizeof(double)); +} +static void eval_wsum_hess(expr *node, const double *w) +{ + int d2 = node->d2; + int d1 = node->d1; + // TODO: meaybe more efficient to do this with memcpy first + + /* evaluate hessian of child at Kw */ + for (int i = 0; i < d2; ++i) + { + for (int j = 0; j < d1; ++j) + { + node->dwork[j * d2 + i] = w[i * d1 + j]; + } + } + + node->left->eval_wsum_hess(node->left, node->dwork); + + /* copy to this node's hessian */ + memcpy(node->wsum_hess->x, node->left->wsum_hess->x, + node->wsum_hess->nnz * sizeof(double)); +} + +static bool is_affine(const expr *node) +{ + return node->left->is_affine(node->left); +} + +expr *new_transpose(expr *child) +{ + expr *node = (expr *) calloc(1, sizeof(expr)); + init_expr(node, child->d2, child->d1, child->n_vars, forward, jacobian_init, + eval_jacobian, is_affine, wsum_hess_init, eval_wsum_hess, NULL); + node->left = child; + expr_retain(child); + + return node; +} diff --git a/tests/all_tests.c b/tests/all_tests.c index 55aef4f..4e7571a 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -40,6 +40,7 @@ #include "jacobian_tests/test_right_matmul.h" #include "jacobian_tests/test_sum.h" #include "jacobian_tests/test_trace.h" +#include "jacobian_tests/test_transpose.h" #include "problem/test_problem.h" #include "utils/test_csc_matrix.h" #include "utils/test_csr_matrix.h" @@ -70,6 +71,7 @@ #include "wsum_hess/test_right_matmul.h" #include "wsum_hess/test_sum.h" #include "wsum_hess/test_trace.h" +#include "wsum_hess/test_transpose.h" int main(void) { @@ -161,6 +163,7 @@ int main(void) mu_run_test(test_jacobian_right_matmul_log, tests_run); mu_run_test(test_jacobian_right_matmul_log_vector, tests_run); mu_run_test(test_jacobian_matmul, tests_run); + mu_run_test(test_jacobian_transpose, tests_run); printf("\n--- Weighted Sum of Hessian Tests ---\n"); mu_run_test(test_wsum_hess_log, tests_run); @@ -225,6 +228,7 @@ int main(void) mu_run_test(test_wsum_hess_trace_variable, tests_run); mu_run_test(test_wsum_hess_trace_log_variable, tests_run); mu_run_test(test_wsum_hess_trace_composite, tests_run); + mu_run_test(test_wsum_hess_transpose, tests_run); printf("\n--- Utility Tests ---\n"); mu_run_test(test_diag_csr_mult, tests_run); diff --git a/tests/jacobian_tests/test_transpose.h b/tests/jacobian_tests/test_transpose.h new file mode 100644 index 0000000..8d71316 --- /dev/null +++ b/tests/jacobian_tests/test_transpose.h @@ -0,0 +1,47 @@ + +#ifndef TEST_TRANSPOSE_H +#define TEST_TRANSPOSE_H + +#include "affine.h" +#include "minunit.h" +#include "test_helpers.h" +#include +#include + +const char *test_jacobian_transpose() +{ + // A = [1 2; 3 4] + CSR_Matrix *A = new_csr_matrix(2, 2, 4); + int A_p[3] = {0, 2, 4}; + int A_i[4] = {0, 1, 0, 1}; + double A_x[4] = {1, 2, 3, 4}; + memcpy(A->p, A_p, 3 * sizeof(int)); + memcpy(A->i, A_i, 4 * sizeof(int)); + memcpy(A->x, A_x, 4 * sizeof(double)); + + // X = [1 2; 3 4] (columnwise: x = [1 3 2 4]) + expr *X = new_variable(2, 2, 0, 4); + expr *AX = new_left_matmul(X, A); + expr *transpose_AX = new_transpose(AX); + double u[4] = {1, 3, 2, 4}; + transpose_AX->forward(transpose_AX, u); + transpose_AX->jacobian_init(transpose_AX); + transpose_AX->eval_jacobian(transpose_AX); + + // Jacobian of transpose_AX + double expected_x[8] = {1, 2, 1, 2, 3, 4, 3, 4}; + int expected_p[5] = {0, 2, 4, 6, 8}; + int expected_i[8] = {0, 1, 2, 3, 0, 1, 2, 3}; + + mu_assert("jacobian values fail", + cmp_double_array(transpose_AX->jacobian->x, expected_x, 8)); + mu_assert("jacobian row ptr fail", + cmp_int_array(transpose_AX->jacobian->p, expected_p, 5)); + mu_assert("jacobian col idx fail", + cmp_int_array(transpose_AX->jacobian->i, expected_i, 8)); + free_expr(AX); + free_csr_matrix(A); + return 0; +} + +#endif // TEST_TRANSPOSE_H diff --git a/tests/wsum_hess/test_transpose.h b/tests/wsum_hess/test_transpose.h new file mode 100644 index 0000000..4e113f6 --- /dev/null +++ b/tests/wsum_hess/test_transpose.h @@ -0,0 +1,40 @@ +#ifndef TEST_WSUM_HESS_TRANSPOSE_H +#define TEST_WSUM_HESS_TRANSPOSE_H + +#include "affine.h" +#include "minunit.h" +#include "test_helpers.h" +#include +#include + +const char *test_wsum_hess_transpose() +{ + + expr *X = new_variable(2, 2, 0, 8); + expr *Y = new_variable(2, 2, 4, 8); + + expr *XY = new_matmul(X, Y); + expr *XYT = new_transpose(XY); + + double u[8] = {1, 3, 2, 4, 5, 7, 6, 8}; + XYT->forward(XYT, u); + XYT->wsum_hess_init(XYT); + double w[4] = {1, 2, 3, 4}; + XYT->eval_wsum_hess(XYT, w); + + double expected_x[16] = {1, 2, 3, 4, 1, 2, 3, 4, 1, 3, 1, 3, 2, 4, 2, 4}; + int expected_p[9] = {0, 2, 4, 6, 8, 10, 12, 14, 16}; + int expected_i[16] = {4, 6, 4, 6, 5, 7, 5, 7, 0, 1, 2, 3, 0, 1, 2, 3}; + + mu_assert("hess values fail", + cmp_double_array(XYT->wsum_hess->x, expected_x, 8)); + mu_assert("jacobian row ptr fail", + cmp_int_array(XYT->wsum_hess->p, expected_p, 5)); + mu_assert("jacobian col idx fail", + cmp_int_array(XYT->wsum_hess->i, expected_i, 8)); + free_expr(XYT); + + return 0; +} + +#endif // TEST_WSUM_HESS_TRANSPOSE_H