From b255da25263d1e6d32d075c9d9da998f86af588d Mon Sep 17 00:00:00 2001 From: Daniel Date: Sat, 17 Jan 2026 09:32:29 -0800 Subject: [PATCH 1/5] some progress --- TODO.md | 2 +- include/other.h | 3 + include/subexpr.h | 9 + python/atoms/prod_axis_zero.h | 33 ++ python/bindings.c | 3 + src/other/prod_axis_zero.c | 348 +++++++++++++++++++++ tests/all_tests.c | 8 + tests/forward_pass/test_prod_axis_zero.h | 33 ++ tests/jacobian_tests/test_prod_axis_zero.h | 45 +++ tests/wsum_hess/test_prod_axis_zero.h | 251 +++++++++++++++ 10 files changed, 734 insertions(+), 1 deletion(-) create mode 100644 python/atoms/prod_axis_zero.h create mode 100644 src/other/prod_axis_zero.c create mode 100644 tests/forward_pass/test_prod_axis_zero.h create mode 100644 tests/jacobian_tests/test_prod_axis_zero.h create mode 100644 tests/wsum_hess/test_prod_axis_zero.h diff --git a/TODO.md b/TODO.md index 8854918..f693841 100644 --- a/TODO.md +++ b/TODO.md @@ -1,6 +1,6 @@ 4. in the refactor, add consts 10. For performance reasons, is it useful to have a dense matmul with A and B as dense matrices? -11. right matmul, add broadcasting logic as in left matmul +11. right matmul, add broadcasting logic as in left matmul. Is this necessary? Going through all atoms to see that sparsity pattern is computed in initialization of jacobian: 2. trace - not ok diff --git a/include/other.h b/include/other.h index d44871a..8bf1d53 100644 --- a/include/other.h +++ b/include/other.h @@ -10,4 +10,7 @@ expr *new_quad_form(expr *child, CSR_Matrix *Q); /* product of all entries, without axis argument */ expr *new_prod(expr *child); +/* product of entries along axis=0 (columnwise products) */ +expr *new_prod_axis_zero(expr *child); + #endif /* OTHER_H */ diff --git a/include/subexpr.h b/include/subexpr.h index fd0a5b2..e96c40f 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -56,6 +56,15 @@ typedef struct prod_expr double prod_nonzero; /* product of non-zero elements */ } prod_expr; +/* Product of entries along axis=0 (columnwise products) */ +typedef struct prod_axis_zero_expr +{ + expr base; + int *num_of_zeros; /* num of zeros for each column */ + int *zero_index; /* stores idx of zero element per column */ + double *prod_nonzero; /* product of non-zero elements per column */ +} prod_axis_zero_expr; + /* Horizontal stack (concatenate) */ typedef struct hstack_expr { diff --git a/python/atoms/prod_axis_zero.h b/python/atoms/prod_axis_zero.h new file mode 100644 index 0000000..02b30f0 --- /dev/null +++ b/python/atoms/prod_axis_zero.h @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: Apache-2.0 +#ifndef ATOM_PROD_AXIS_ZERO_H +#define ATOM_PROD_AXIS_ZERO_H + +#include "common.h" +#include "other.h" + +static PyObject *py_make_prod_axis_zero(PyObject *self, PyObject *args) +{ + (void) self; + PyObject *child_capsule; + if (!PyArg_ParseTuple(args, "O", &child_capsule)) + { + return NULL; + } + expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); + if (!child) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + expr *node = new_prod_axis_zero(child); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create prod_axis_zero node"); + return NULL; + } + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_PROD_AXIS_ZERO_H */ diff --git a/python/bindings.c b/python/bindings.c index 55fbdc2..d83b271 100644 --- a/python/bindings.c +++ b/python/bindings.c @@ -23,6 +23,7 @@ #include "atoms/neg.h" #include "atoms/power.h" #include "atoms/prod.h" +#include "atoms/prod_axis_zero.h" #include "atoms/promote.h" #include "atoms/quad_form.h" #include "atoms/quad_over_lin.h" @@ -77,6 +78,8 @@ static PyMethodDef DNLPMethods[] = { "Create constant vector multiplication node (a ∘ f(x))"}, {"make_power", py_make_power, METH_VARARGS, "Create power node"}, {"make_prod", py_make_prod, METH_VARARGS, "Create prod node"}, + {"make_prod_axis_zero", py_make_prod_axis_zero, METH_VARARGS, + "Create prod_axis_zero node"}, {"make_sin", py_make_sin, METH_VARARGS, "Create sin node"}, {"make_cos", py_make_cos, METH_VARARGS, "Create cos node"}, {"make_tan", py_make_tan, METH_VARARGS, "Create tan node"}, diff --git a/src/other/prod_axis_zero.c b/src/other/prod_axis_zero.c new file mode 100644 index 0000000..8d1807b --- /dev/null +++ b/src/other/prod_axis_zero.c @@ -0,0 +1,348 @@ +#include "other.h" +#include +#include +#include +#include + +#define IS_ZERO(x) (fabs((x)) < 1e-8) + +static void forward(expr *node, const double *u) +{ + /* forward pass of child */ + expr *x = node->left; + x->forward(x, u); + + prod_axis_zero_expr *pnode = (prod_axis_zero_expr *) node; + int d1 = x->d1; + int d2 = x->d2; + + /* iterate over columns and compute columnwise products */ + for (int col = 0; col < d2; col++) + { + double prod_nonzero = 1.0; + int zeros = 0; + int zero_row = -1; + int start = col * d1; + int end = start + d1; + + /* iterate over rows of this column */ + for (int idx = start; idx < end; idx++) + { + if (IS_ZERO(x->value[idx])) + { + zeros++; + zero_row = idx - start; + } + else + { + prod_nonzero *= x->value[idx]; + } + } + + /* store results for this column */ + pnode->num_of_zeros[col] = zeros; + pnode->zero_index[col] = zero_row; + pnode->prod_nonzero[col] = prod_nonzero; + node->value[col] = (zeros > 0) ? 0.0 : prod_nonzero; + } +} + +static void jacobian_init(expr *node) +{ + expr *x = node->left; + + /* initialize child's jacobian */ + x->jacobian_init(x); + + /* if x is a variable */ + if (x->var_id != NOT_A_VARIABLE) + { + node->jacobian = new_csr_matrix(node->size, node->n_vars, x->size); + + /* set row pointers (each row has d1 nnzs) */ + for (int row = 0; row < x->d2; row++) + { + node->jacobian->p[row] = row * x->d1; + } + node->jacobian->p[x->d2] = x->size; + + /* set column indices */ + for (int col = 0; col < x->d2; col++) + { + int start = col * x->d1; + for (int i = 0; i < x->d1; i++) + { + node->jacobian->i[start + i] = x->var_id + start + i; + } + } + } + else + { + assert(false && "child must be a variable"); + } +} + +static void eval_jacobian(expr *node) +{ + expr *x = node->left; + prod_axis_zero_expr *pnode = (prod_axis_zero_expr *) node; + + double *J_vals = node->jacobian->x; + + /* if x is a variable */ + if (x->var_id != NOT_A_VARIABLE) + { + /* process each column */ + for (int col = 0; col < x->d2; col++) + { + int num_zeros = pnode->num_of_zeros[col]; + int start = col * x->d1; + + if (num_zeros == 0) + { + for (int i = 0; i < x->d1; i++) + { + J_vals[start + i] = node->value[col] / x->value[start + i]; + } + } + else if (num_zeros == 1) + { + memset(J_vals + start, 0, sizeof(double) * x->d1); + J_vals[start + pnode->zero_index[col]] = pnode->prod_nonzero[col]; + } + else + { + memset(J_vals + start, 0, sizeof(double) * x->d1); + } + } + } + else + { + assert(false && "child must be a variable"); + } +} + +static void wsum_hess_init(expr *node) +{ + expr *x = node->left; + + /* if x is a variable */ + if (x->var_id != NOT_A_VARIABLE) + { + /* Hessian has block diagonal structure: d2 blocks of size d1 x d1 */ + int nnz = x->d2 * x->d1 * x->d1; + node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, nnz); + CSR_Matrix *H = node->wsum_hess; + + /* fill row pointers for the variable's rows (block diagonal) */ + for (int i = 0; i < x->size; i++) + { + int block_idx = i / x->d1; + int row_in_block = i % x->d1; + H->p[x->var_id + i] = block_idx * x->d1 * x->d1 + row_in_block * x->d1; + } + + /* fill row pointers for rows after the variable */ + for (int i = x->var_id + x->size; i <= node->n_vars; i++) + { + H->p[i] = nnz; + } + + /* fill column indices for the d2 diagonal blocks */ + for (int block = 0; block < x->d2; block++) + { + int block_start_col = x->var_id + block * x->d1; + for (int i = 0; i < x->d1; i++) + { + for (int j = 0; j < x->d1; j++) + { + int idx = block * x->d1 * x->d1 + i * x->d1 + j; + H->i[idx] = block_start_col + j; + } + } + } + } + else + { + assert(false && "child must be a variable"); + } +} + +static inline void wsum_hess_column_no_zeros(expr *node, const double *w, int col, + int d1) +{ + expr *x = node->left; + double *H = node->wsum_hess->x; + int col_start = col * d1; + int block_start = col * d1 * d1; + double w_col = w[col]; + double f_col = node->value[col]; + + for (int i = 0; i < d1; i++) + { + for (int j = 0; j < d1; j++) + { + if (i == j) + { + H[block_start + i * d1 + j] = 0.0; + } + else + { + int idx_i = col_start + i; + int idx_j = col_start + j; + H[block_start + i * d1 + j] = + w_col * f_col / (x->value[idx_i] * x->value[idx_j]); + } + } + } +} + +static inline void wsum_hess_column_one_zero(expr *node, const double *w, int col, + int d1) +{ + expr *x = node->left; + prod_axis_zero_expr *pnode = (prod_axis_zero_expr *) node; + double *H = node->wsum_hess->x; + int col_start = col * d1; + int block_start = col * d1 * d1; + + /* clear this column's block */ + memset(&H[block_start], 0, sizeof(double) * d1 * d1); + + int p = pnode->zero_index[col]; + double prod_nonzero = pnode->prod_nonzero[col]; + double w_prod = w[col] * prod_nonzero; + + /* fill row p and column p of this block */ + for (int j = 0; j < d1; j++) + { + if (j == p) continue; + + int idx_j = col_start + j; + double hess_val = w_prod / x->value[idx_j]; + H[block_start + p * d1 + j] = hess_val; + H[block_start + j * d1 + p] = hess_val; + } +} + +static inline void wsum_hess_column_two_zeros(expr *node, const double *w, int col, + int d1) +{ + expr *x = node->left; + prod_axis_zero_expr *pnode = (prod_axis_zero_expr *) node; + double *H = node->wsum_hess->x; + int col_start = col * d1; + int block_start = col * d1 * d1; + + /* clear this column's block */ + memset(&H[block_start], 0, sizeof(double) * d1 * d1); + + /* find indices p and q where x[col_start + p] and x[col_start + q] are zero */ + int p = -1, q = -1; + for (int i = 0; i < d1; i++) + { + if (IS_ZERO(x->value[col_start + i])) + { + if (p == -1) + { + p = i; + } + else + { + q = i; + break; + } + } + } + assert(p != -1 && q != -1); + + double hess_val = w[col] * pnode->prod_nonzero[col]; + H[block_start + p * d1 + q] = hess_val; + H[block_start + q * d1 + p] = hess_val; +} + +static inline void wsum_hess_column_many_zeros(expr *node, const double *w, int col, + int d1) +{ + double *H = node->wsum_hess->x; + int block_start = col * d1 * d1; + + /* clear this column's block */ + memset(&H[block_start], 0, sizeof(double) * d1 * d1); + (void) w; +} + +static void eval_wsum_hess(expr *node, const double *w) +{ + expr *x = node->left; + prod_axis_zero_expr *pnode = (prod_axis_zero_expr *) node; + int d1 = x->d1; + int d2 = x->d2; + + /* if x is a variable */ + if (x->var_id != NOT_A_VARIABLE) + { + /* process each column */ + for (int col = 0; col < d2; col++) + { + int num_zeros = pnode->num_of_zeros[col]; + + if (num_zeros == 0) + { + wsum_hess_column_no_zeros(node, w, col, d1); + } + else if (num_zeros == 1) + { + wsum_hess_column_one_zero(node, w, col, d1); + } + else if (num_zeros == 2) + { + wsum_hess_column_two_zeros(node, w, col, d1); + } + else + { + wsum_hess_column_many_zeros(node, w, col, d1); + } + } + } + else + { + assert(false && "child must be a variable"); + } +} + +static bool is_affine(const expr *node) +{ + (void) node; + return false; +} + +static void free_type_data(expr *node) +{ + prod_axis_zero_expr *pnode = (prod_axis_zero_expr *) node; + free(pnode->num_of_zeros); + free(pnode->zero_index); + free(pnode->prod_nonzero); +} + +expr *new_prod_axis_zero(expr *child) +{ + prod_axis_zero_expr *pnode = + (prod_axis_zero_expr *) calloc(1, sizeof(prod_axis_zero_expr)); + expr *node = &pnode->base; + + /* output is always a row vector 1 x d2 - TODO: is that correct? */ + init_expr(node, 1, child->d2, child->n_vars, forward, jacobian_init, + eval_jacobian, is_affine, wsum_hess_init, eval_wsum_hess, + free_type_data); + + /* allocate arrays to store per-column statistics */ + pnode->num_of_zeros = (int *) calloc(child->d2, sizeof(int)); + pnode->zero_index = (int *) calloc(child->d2, sizeof(int)); + pnode->prod_nonzero = (double *) calloc(child->d2, sizeof(double)); + + node->left = child; + expr_retain(child); + + return node; +} diff --git a/tests/all_tests.c b/tests/all_tests.c index 0f8b9f6..164382b 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -14,6 +14,7 @@ #include "forward_pass/composite/test_composite.h" #include "forward_pass/elementwise/test_exp.h" #include "forward_pass/elementwise/test_log.h" +#include "forward_pass/test_prod_axis_zero.h" #include "jacobian_tests/test_broadcast.h" #include "jacobian_tests/test_composite.h" #include "jacobian_tests/test_const_scalar_mult.h" @@ -25,6 +26,7 @@ #include "jacobian_tests/test_log.h" #include "jacobian_tests/test_neg.h" #include "jacobian_tests/test_prod.h" +#include "jacobian_tests/test_prod_axis_zero.h" #include "jacobian_tests/test_promote.h" #include "jacobian_tests/test_quad_form.h" #include "jacobian_tests/test_quad_over_lin.h" @@ -53,6 +55,7 @@ #include "wsum_hess/test_left_matmul.h" #include "wsum_hess/test_multiply.h" #include "wsum_hess/test_prod.h" +#include "wsum_hess/test_prod_axis_zero.h" #include "wsum_hess/test_quad_form.h" #include "wsum_hess/test_quad_over_lin.h" #include "wsum_hess/test_rel_entr.h" @@ -86,6 +89,7 @@ int main(void) mu_run_test(test_broadcast_row, tests_run); mu_run_test(test_broadcast_col, tests_run); mu_run_test(test_broadcast_matrix, tests_run); + mu_run_test(test_forward_prod_axis_zero, tests_run); printf("\n--- Jacobian Tests ---\n"); mu_run_test(test_neg_jacobian, tests_run); @@ -118,6 +122,7 @@ int main(void) mu_run_test(test_jacobian_prod_no_zero, tests_run); mu_run_test(test_jacobian_prod_one_zero, tests_run); mu_run_test(test_jacobian_prod_two_zeros, tests_run); + mu_run_test(test_jacobian_prod_axis_zero, tests_run); mu_run_test(test_jacobian_sum_log, tests_run); mu_run_test(test_jacobian_sum_mult, tests_run); mu_run_test(test_jacobian_sum_log_axis_0, tests_run); @@ -168,6 +173,9 @@ int main(void) mu_run_test(test_wsum_hess_prod_one_zero, tests_run); mu_run_test(test_wsum_hess_prod_two_zeros, tests_run); mu_run_test(test_wsum_hess_prod_many_zeros, tests_run); + mu_run_test(test_wsum_hess_prod_axis_zero_no_zeros, tests_run); + mu_run_test(test_wsum_hess_prod_axis_zero_one_zero, tests_run); + mu_run_test(test_wsum_hess_prod_axis_zero_mixed_zeros, tests_run); mu_run_test(test_wsum_hess_rel_entr_1, tests_run); mu_run_test(test_wsum_hess_rel_entr_2, tests_run); mu_run_test(test_wsum_hess_rel_entr_matrix, tests_run); diff --git a/tests/forward_pass/test_prod_axis_zero.h b/tests/forward_pass/test_prod_axis_zero.h new file mode 100644 index 0000000..6504502 --- /dev/null +++ b/tests/forward_pass/test_prod_axis_zero.h @@ -0,0 +1,33 @@ +#include +#include +#include + +#include "affine.h" +#include "expr.h" +#include "minunit.h" +#include "other.h" +#include "test_helpers.h" + +const char *test_forward_prod_axis_zero() +{ + /* Create a 2x3 constant matrix stored column-wise: + [1, 3, 5] + [2, 4, 6] + Stored as: [1, 2, 3, 4, 5, 6] + */ + double values[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + expr *const_node = new_constant(2, 3, 0, values); + expr *prod_node = new_prod_axis_zero(const_node); + prod_node->forward(prod_node, NULL); + + /* Expected: columnwise products, result is 1x3 + [1*2, 3*4, 5*6] = [2, 12, 30] + */ + double expected[3] = {2.0, 12.0, 30.0}; + + mu_assert("prod_axis_zero forward test failed", + cmp_double_array(prod_node->value, expected, 3)); + + free_expr(prod_node); + return 0; +} diff --git a/tests/jacobian_tests/test_prod_axis_zero.h b/tests/jacobian_tests/test_prod_axis_zero.h new file mode 100644 index 0000000..7b3179e --- /dev/null +++ b/tests/jacobian_tests/test_prod_axis_zero.h @@ -0,0 +1,45 @@ +#include + +#include "affine.h" +#include "expr.h" +#include "minunit.h" +#include "other.h" +#include "test_helpers.h" + +const char *test_jacobian_prod_axis_zero() +{ + /* x is 2x3 variable, global index 1, total 8 vars + * x = [1, 2, 3, 4, 5, 6] (column-major order) + * [1, 3, 5] + * [2, 4, 6] + * + * f = prod_axis_zero(x) = [2, 12, 30] + * + * Jacobian is 3x8: + * Row 0 (output[0] = 2): df[0]/dx = [2, 1, 0, 0, 0, 0, 0, 0] + * df[0]/dx[0] = 2/1 = 2, df[0]/dx[1] = 2/2 = 1 + * Row 1 (output[1] = 12): df[1]/dx = [0, 0, 4, 3, 0, 0, 0, 0] + * df[1]/dx[2] = 12/3 = 4, df[1]/dx[3] = 12/4 = 3 + * Row 2 (output[2] = 30): df[2]/dx = [0, 0, 0, 0, 6, 5, 0, 0] + * df[2]/dx[4] = 30/5 = 6, df[2]/dx[5] = 30/6 = 5 + */ + double u_vals[8] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.0}; + expr *x = new_variable(2, 3, 1, 8); + expr *p = new_prod_axis_zero(x); + + p->forward(p, u_vals); + p->jacobian_init(p); + p->eval_jacobian(p); + + /* CSR format for 3x8 Jacobian with block diagonal structure */ + double expected_Ax[6] = {2.0, 1.0, 4.0, 3.0, 6.0, 5.0}; + int expected_Ap[4] = {0, 2, 4, 6}; + int expected_Ai[6] = {1, 2, 3, 4, 5, 6}; + + mu_assert("vals fail", cmp_double_array(p->jacobian->x, expected_Ax, 6)); + mu_assert("rows fail", cmp_int_array(p->jacobian->p, expected_Ap, 4)); + mu_assert("cols fail", cmp_int_array(p->jacobian->i, expected_Ai, 6)); + + free_expr(p); + return 0; +} diff --git a/tests/wsum_hess/test_prod_axis_zero.h b/tests/wsum_hess/test_prod_axis_zero.h new file mode 100644 index 0000000..e81f97d --- /dev/null +++ b/tests/wsum_hess/test_prod_axis_zero.h @@ -0,0 +1,251 @@ +#include +#include + +#include "expr.h" +#include "minunit.h" +#include "other.h" +#include "test_helpers.h" + +const char *test_wsum_hess_prod_axis_zero_no_zeros() +{ + /* x is 2x3 variable, global index 1, total 8 vars + * x = [1, 2, 3, 4, 5, 6] (column-major) + * [1, 3, 5] + * [2, 4, 6] + * f = prod_axis_zero(x) = [2, 12, 30] + * w = [1, 2, 3] + * + * Hessian is block diagonal with three 2x2 blocks: + * Block 0 (col 0): w[0] * f[0] = 1 * 2 = 2 + * H_00 = [0, 2/(1*2)] = [0, 1] + * H_01 = [2/(2*1), 0] = [1, 0] + * Block 1 (col 1): w[1] * f[1] = 2 * 12 = 24 + * H_00 = [0, 24/(3*4)] = [0, 2] + * H_01 = [24/(4*3), 0] = [2, 0] + * Block 2 (col 2): w[2] * f[2] = 3 * 30 = 90 + * H_00 = [0, 90/(5*6)] = [0, 3] + * H_01 = [90/(6*5), 0] = [3, 0] + */ + double u_vals[8] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.0}; + double w_vals[3] = {1.0, 2.0, 3.0}; + expr *x = new_variable(2, 3, 1, 8); + expr *p = new_prod_axis_zero(x); + + p->forward(p, u_vals); + p->wsum_hess_init(p); + p->eval_wsum_hess(p, w_vals); + + /* Block diagonal structure: 3 blocks of 2x2 = 6 nnz total + * Each block has 4 entries (2x2 dense) + */ + double expected_x[12] = {/* Block 0 */ + 0.0, 1.0, 1.0, 0.0, + /* Block 1 */ + 0.0, 2.0, 2.0, 0.0, + /* Block 2 */ + 0.0, 3.0, 3.0, 0.0}; + + /* Row pointers: variable is at global rows 1-6, so: + * p[0] = 0 (row 0: before variable) + * p[1-6] = block diagonal entries for the variable + * p[7-8] = 12 (rows 7-8: after variable) + */ + int expected_p[9] = {0, 0, 2, 4, 6, 8, 10, 12, 12}; + + /* Column indices: block diagonal structure + * Row 0: cols 1, 2 (global indices for block 0) + * Row 1: cols 1, 2 + * Row 2: cols 3, 4 (global indices for block 1) + * Row 3: cols 3, 4 + * Row 4: cols 5, 6 (global indices for block 2) + * Row 5: cols 5, 6 + */ + int expected_i[12] = {1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6}; + + mu_assert("vals fail", cmp_double_array(p->wsum_hess->x, expected_x, 12)); + mu_assert("rows fail", cmp_int_array(p->wsum_hess->p, expected_p, 9)); + mu_assert("cols fail", cmp_int_array(p->wsum_hess->i, expected_i, 12)); + + free_expr(p); + return 0; +} + +const char *test_wsum_hess_prod_axis_zero_mixed_zeros() +{ + /* x is 5x3 variable, global index 1, total 16 vars + * x = [1, 1, 1, 1, 1, 2, 0, 3, 4, 5, 1, 0, 0, 2, 3] (column-major) + * Matrix (column-major): + * [1, 2, 1] + * [1, 0, 0] + * [1, 3, 0] + * [1, 4, 2] + * [1, 5, 3] + * + * f = prod_axis_zero(x) = [1, 120, 6] + * Column 0: 5 nonzeros, prod = 1*1*1*1*1 = 1 + * Column 1: 1 zero at row 1, prod_nonzero = 2*3*4*5 = 120 + * Column 2: 2 zeros at rows 1,2, prod_nonzero = 1*2*3 = 6 + * + * w = [1, 2, 3] + * + * Block 0 (5x5): no zeros, diagonal=0, off-diagonal=w[0]*f[0]/(x[i]*x[j]) = + * 1/(1*1) = 1 + * Block 1 (5x5): 1 zero at row 1, nonzeros only in row/col 1 + * (1,0): 2*120/5 = 48, (1,2): 2*120/3 = 80, (1,3): 2*120/4 = 60, (1,4): + * 2*120/5 = 48 Block 2 (5x5): 2 zeros at rows 1,2, only (1,3) and (2,3) are + * nonzero (1,3): 3*6/2 = 9, (2,3): 3*6/3 = 6 + */ + double u_vals[16] = {0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 0.0, + 3.0, 4.0, 5.0, 1.0, 0.0, 0.0, 2.0, 3.0}; + /* This maps to: x[var_id+0..4]=column 0, x[var_id+5..9]=column 1, + * x[var_id+10..14]=column 2 Column 0: [1, 1, 1, 1, 1] - all nonzero Column 1: + * [2, 0, 3, 4, 5] - one zero at index 1 Column 2: [1, 0, 0, 2, 3] - two zeros at + * indices 1, 2 + */ + double w_vals[3] = {1.0, 2.0, 3.0}; + expr *x = new_variable(5, 3, 1, 16); + expr *p = new_prod_axis_zero(x); + + p->forward(p, u_vals); + p->wsum_hess_init(p); + p->eval_wsum_hess(p, w_vals); + + /* Block 0: 5x5 all off-diagonal = 1.0, total = 25 entries (indices 0-24) + * Block 1: 5x5 with 1 zero, total = 25 entries (indices 25-49) + * Block 2: 5x5 with 2 zeros, total = 25 entries (indices 50-74) + * Total nnz = 75 + */ + double expected_x[75]; + memset(expected_x, 0, sizeof(expected_x)); + + /* Block 0 (indices 0-24): all off-diagonal = 1.0 */ + for (int i = 0; i < 5; i++) + { + for (int j = 0; j < 5; j++) + { + if (i != j) + { + expected_x[i * 5 + j] = 1.0; + } + } + } + + /* Block 1 (indices 25-49): 1 zero at row 1 (index p=1), w[1]*prod_nonzero[1] = + * 2*120 = 240 Column 1 is x[5..9] = [2, 0, 3, 4, 5] (rows 0-4) Only row p=1 and + * column p=1 have nonzeros H[i,p] = H[p,i] = w_prod / x[i] for i != p H[0,1] = + * H[1,0] = 240/2 = 120 H[2,1] = H[1,2] = 240/3 = 80 H[3,1] = H[1,3] = 240/4 = 60 + * H[4,1] = H[1,4] = 240/5 = 48 + */ + /* Row 0 of block 1: only (0,1) is nonzero */ + expected_x[25 + 0 * 5 + 1] = 120.0; + /* Row 1 of block 1: (1,0), (1,2), (1,3), (1,4) are nonzero */ + expected_x[25 + 1 * 5 + 0] = 120.0; + expected_x[25 + 1 * 5 + 2] = 80.0; + expected_x[25 + 1 * 5 + 3] = 60.0; + expected_x[25 + 1 * 5 + 4] = 48.0; + /* Row 2 of block 1: only (2,1) is nonzero */ + expected_x[25 + 2 * 5 + 1] = 80.0; + /* Row 3 of block 1: only (3,1) is nonzero */ + expected_x[25 + 3 * 5 + 1] = 60.0; + /* Row 4 of block 1: only (4,1) is nonzero */ + expected_x[25 + 4 * 5 + 1] = 48.0; + + /* Block 2 (indices 50-74): 2 zeros at rows p=1, q=2, w[2]*prod_nonzero[2] = 3*6 + * = 18 Column 2 is x[10..14] = [1, 0, 0, 2, 3] (rows 0-4) With exactly 2 zeros, + * only H[p,q] and H[q,p] are nonzero H[1,2] = H[2,1] = 18 + */ + for (int i = 0; i < 5; i++) + for (int j = 0; j < 5; j++) expected_x[50 + i * 5 + j] = 0.0; + expected_x[50 + 1 * 5 + 2] = 18.0; /* H[1,2] */ + expected_x[50 + 2 * 5 + 1] = 18.0; /* H[2,1] */ + + int expected_p[17] = {0, 0, 5, 10, 15, 20, 25, 30, 35, + 40, 45, 50, 55, 60, 65, 70, 75}; + + /* Column indices: block diagonal structure + * Block 0 (rows 1-5): each row has cols [1, 2, 3, 4, 5] + * Block 1 (rows 6-10): each row has cols [6, 7, 8, 9, 10] + * Block 2 (rows 11-15): each row has cols [11, 12, 13, 14, 15] + */ + int expected_i[75]; + for (int block = 0; block < 3; block++) + { + int offset = block * 25; + int col_start = block * 5 + 1; + for (int row = 0; row < 5; row++) + { + for (int col = 0; col < 5; col++) + { + expected_i[offset + row * 5 + col] = col_start + col; + } + } + } + + mu_assert("rows fail", cmp_int_array(p->wsum_hess->p, expected_p, 17)); + mu_assert("cols fail", cmp_int_array(p->wsum_hess->i, expected_i, 75)); + mu_assert("vals fail", cmp_double_array(p->wsum_hess->x, expected_x, 75)); + + free_expr(p); + return 0; +} + +const char *test_wsum_hess_prod_axis_zero_one_zero() +{ + /* Test with a column that has exactly 1 zero + * x is 2x2 variable, global index 1, total 5 vars + * x = [1.0, 1.0, 2.0, 0.0] (column-major) + * Matrix (column-major): + * [1, 2] + * [1, 0] + * + * f = prod_axis_zero(x) = [1, 0] + * Column 0: no zeros, prod = 1 + * Column 1: 1 zero at row 1, prod_nonzero = 2 + * w = [1, 2] + */ + double u_vals[5] = {0.0, 1.0, 1.0, 2.0, 0.0}; + double w_vals[2] = {1.0, 2.0}; + expr *x = new_variable(2, 2, 1, 5); + expr *p = new_prod_axis_zero(x); + + p->forward(p, u_vals); + p->wsum_hess_init(p); + p->eval_wsum_hess(p, w_vals); + + /* Block 0 (no zeros): w[0]*f[0] = 1 + * (0,1) = 1/(1*1) = 1, (1,0) = 1 + * Block 1 (1 zero at row 1): w[1]*prod_nonzero[1] = 2*2 = 4 + * (1,0) = 4/2 = 2 (symmetric) + */ + double expected_x[8]; + memset(expected_x, 0, sizeof(expected_x)); + + /* Block 0 */ + expected_x[0 * 2 + 1] = 1.0; /* (0,1) */ + expected_x[1 * 2 + 0] = 1.0; /* (1,0) */ + + /* Block 1: one zero at row 1 */ + expected_x[4 + 1 * 2 + 0] = 2.0; /* (1,0) */ + expected_x[4 + 0 * 2 + 1] = 2.0; /* (0,1) symmetric */ + + /* Row pointers: variable is at global rows 1-2, so: + * p[0] = 0 (row 0: before variable) + * p[1-2] = block 0 (rows 1-2, each has 2 entries) + * p[3-4] = block 1 (rows 3-4, each has 2 entries) + * p[5] = 8 (after variable) + */ + int expected_p[6] = {0, 0, 2, 4, 6, 8}; + + /* Column indices: block diagonal structure + * Block 0: cols 1, 2 (global indices for block 0) + * Block 1: cols 3, 4 (global indices for block 1) + */ + int expected_i[8] = {1, 2, 1, 2, 3, 4, 3, 4}; + + mu_assert("vals fail", cmp_double_array(p->wsum_hess->x, expected_x, 8)); + mu_assert("rows fail", cmp_int_array(p->wsum_hess->p, expected_p, 6)); + mu_assert("cols fail", cmp_int_array(p->wsum_hess->i, expected_i, 8)); + + free_expr(p); + return 0; +} From 6921c074df41953d86f58004434d9c8dac95d699 Mon Sep 17 00:00:00 2001 From: Daniel Date: Sat, 17 Jan 2026 17:07:31 -0800 Subject: [PATCH 2/5] prod axis one some progress --- include/other.h | 3 + include/subexpr.h | 9 + src/other/prod_axis_one.c | 347 ++++++++++++++++++++++ tests/all_tests.c | 9 + tests/forward_pass/test_prod_axis_one.h | 35 +++ tests/jacobian_tests/test_prod_axis_one.h | 94 ++++++ tests/wsum_hess/test_prod_axis_one.h | 228 ++++++++++++++ 7 files changed, 725 insertions(+) create mode 100644 src/other/prod_axis_one.c create mode 100644 tests/forward_pass/test_prod_axis_one.h create mode 100644 tests/jacobian_tests/test_prod_axis_one.h create mode 100644 tests/wsum_hess/test_prod_axis_one.h diff --git a/include/other.h b/include/other.h index 8bf1d53..e523552 100644 --- a/include/other.h +++ b/include/other.h @@ -13,4 +13,7 @@ expr *new_prod(expr *child); /* product of entries along axis=0 (columnwise products) */ expr *new_prod_axis_zero(expr *child); +/* product of entries along axis=1 (rowwise products) */ +expr *new_prod_axis_one(expr *child); + #endif /* OTHER_H */ diff --git a/include/subexpr.h b/include/subexpr.h index e96c40f..01762f5 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -65,6 +65,15 @@ typedef struct prod_axis_zero_expr double *prod_nonzero; /* product of non-zero elements per column */ } prod_axis_zero_expr; +/* Product of entries along axis=1 (rowwise products) */ +typedef struct prod_axis_one_expr +{ + expr base; + int *num_of_zeros; /* num of zeros for each row */ + int *zero_index; /* stores idx of zero element per row */ + double *prod_nonzero; /* product of non-zero elements per row */ +} prod_axis_one_expr; + /* Horizontal stack (concatenate) */ typedef struct hstack_expr { diff --git a/src/other/prod_axis_one.c b/src/other/prod_axis_one.c new file mode 100644 index 0000000..9aaff4f --- /dev/null +++ b/src/other/prod_axis_one.c @@ -0,0 +1,347 @@ +#include "other.h" +#include +#include +#include +#include + +#define IS_ZERO(x) (fabs((x)) < 1e-8) + +static void forward(expr *node, const double *u) +{ + /* forward pass of child */ + expr *x = node->left; + x->forward(x, u); + + prod_axis_one_expr *pnode = (prod_axis_one_expr *) node; + int d1 = x->d1; + int d2 = x->d2; + + /* initialize per-row statistics */ + for (int i = 0; i < d1; i++) + { + pnode->num_of_zeros[i] = 0; + pnode->zero_index[i] = -1; + pnode->prod_nonzero[i] = 1.0; + } + + /* iterate over columns */ + for (int col = 0; col < d2; col++) + { + int start = col * d1; + int end = start + d1; + + for (int idx = start; idx < end; idx++) + { + int row = idx - start; + if (IS_ZERO(x->value[idx])) + { + pnode->num_of_zeros[row]++; + pnode->zero_index[row] = col; + } + else + { + pnode->prod_nonzero[row] *= x->value[idx]; + } + } + } + + /* compute output values */ + for (int i = 0; i < d1; i++) + { + node->value[i] = (pnode->num_of_zeros[i] > 0) ? 0.0 : pnode->prod_nonzero[i]; + } +} + +static void jacobian_init(expr *node) +{ + expr *x = node->left; + + /* initialize child's jacobian */ + x->jacobian_init(x); + + /* if x is a variable */ + if (x->var_id != NOT_A_VARIABLE) + { + node->jacobian = new_csr_matrix(node->size, node->n_vars, x->size); + + /* set row pointers (each row has d2 nnzs) */ + for (int row = 0; row < x->d1; row++) + { + node->jacobian->p[row] = row * x->d2; + } + node->jacobian->p[x->d1] = x->size; + + /* set column indices */ + for (int row = 0; row < x->d1; row++) + { + int start = row * x->d2; + for (int col = 0; col < x->d2; col++) + { + node->jacobian->i[start + col] = x->var_id + col * x->d1 + row; + } + } + } + else + { + assert(false && "child must be a variable"); + } +} + +static void eval_jacobian(expr *node) +{ + expr *x = node->left; + prod_axis_one_expr *pnode = (prod_axis_one_expr *) node; + + double *J_vals = node->jacobian->x; + + /* if x is a variable */ + if (x->var_id != NOT_A_VARIABLE) + { + /* process each row */ + for (int row = 0; row < x->d1; row++) + { + int num_zeros = pnode->num_of_zeros[row]; + int start = row * x->d2; + + if (num_zeros == 0) + { + for (int col = 0; col < x->d2; col++) + { + int idx = col * x->d1 + row; + J_vals[start + col] = node->value[row] / x->value[idx]; + } + } + else if (num_zeros == 1) + { + memset(J_vals + start, 0, sizeof(double) * x->d2); + J_vals[start + pnode->zero_index[row]] = pnode->prod_nonzero[row]; + } + else + { + memset(J_vals + start, 0, sizeof(double) * x->d2); + } + } + } + else + { + assert(false && "child must be a variable"); + } +} + +static void wsum_hess_init(expr *node) +{ + expr *x = node->left; + + /* if x is a variable */ + if (x->var_id != NOT_A_VARIABLE) + { + /* Hessian has block diagonal structure: d1 blocks of size d2 x d2 */ + int nnz = x->d1 * x->d2 * x->d2; + node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, nnz); + CSR_Matrix *H = node->wsum_hess; + + /* fill row pointers for the variable's rows (block diagonal) */ + for (int i = 0; i < x->size; i++) + { + int block_idx = i % x->d1; /* row of the matrix */ + int row_in_block = i / x->d1; /* column index within that row-block */ + H->p[x->var_id + i] = block_idx * x->d2 * x->d2 + row_in_block * x->d2; + } + + /* fill row pointers for rows after the variable */ + for (int i = x->var_id + x->size; i <= node->n_vars; i++) + { + H->p[i] = nnz; + } + + /* set column indices for each block */ + for (int block = 0; block < x->d1; block++) + { + int block_start = block * x->d2 * x->d2; + int col_start = block; /* global column offset for this row-block */ + for (int row = 0; row < x->d2; row++) + { + for (int col = 0; col < x->d2; col++) + { + /* variable indices: row-block fixed, column-major layout */ + int var_col = + col * x->d1 + + block; /* global variable index offset from var_id */ + H->i[block_start + row * x->d2 + col] = x->var_id + var_col; + } + } + } + } + else + { + assert(false && "child must be a variable"); + } +} + +static inline void wsum_hess_row_no_zeros(expr *node, const double *w, int row, + int d2) +{ + expr *x = node->left; + prod_axis_one_expr *pnode = (prod_axis_one_expr *) node; + double *H = node->wsum_hess->x; + + int block_start = row * d2 * d2; + double scale = w[row] * node->value[row]; + + /* fill block for this row */ + for (int r = 0; r < d2; r++) + { + int idx_r = r * node->left->d1 + row; /* column-major index into x */ + for (int c = 0; c < d2; c++) + { + int idx_c = c * node->left->d1 + row; + double val = + (r == c) ? 0.0 : scale / (x->value[idx_r] * x->value[idx_c]); + H[block_start + r * d2 + c] = val; + } + } + (void) pnode; /* suppress unused warning */ +} + +static inline void wsum_hess_row_one_zero(expr *node, const double *w, int row, + int d2) +{ + expr *x = node->left; + prod_axis_one_expr *pnode = (prod_axis_one_expr *) node; + double *H = node->wsum_hess->x; + + int block_start = row * d2 * d2; + int p = pnode->zero_index[row]; /* zero column */ + double w_prod = w[row] * pnode->prod_nonzero[row]; + + /* clear block */ + memset(&H[block_start], 0, sizeof(double) * d2 * d2); + + /* fill row p and column p */ + int col_offset = p * node->left->d1 + row; + for (int c = 0; c < d2; c++) + { + if (c == p) continue; + int idx_c = c * node->left->d1 + row; + double hess_val = w_prod / x->value[idx_c]; + H[block_start + p * d2 + c] = hess_val; + H[block_start + c * d2 + p] = hess_val; + } + (void) col_offset; /* suppress unused */ +} + +static inline void wsum_hess_row_two_zeros(expr *node, const double *w, int row, + int d2) +{ + expr *x = node->left; + prod_axis_one_expr *pnode = (prod_axis_one_expr *) node; + double *H = node->wsum_hess->x; + + int block_start = row * d2 * d2; + + /* clear block */ + memset(&H[block_start], 0, sizeof(double) * d2 * d2); + + /* find indices p and q where row has zeros */ + int p = -1, q = -1; + for (int c = 0; c < d2; c++) + { + int idx = c * node->left->d1 + row; + if (IS_ZERO(x->value[idx])) + { + if (p == -1) + { + p = c; + } + else + { + q = c; + break; + } + } + } + assert(p != -1 && q != -1); + + double hess_val = w[row] * pnode->prod_nonzero[row]; + H[block_start + p * d2 + q] = hess_val; + H[block_start + q * d2 + p] = hess_val; +} + +static inline void wsum_hess_row_many_zeros(expr *node, int row, int d2) +{ + double *H = node->wsum_hess->x; + int block_start = row * d2 * d2; + memset(&H[block_start], 0, sizeof(double) * d2 * d2); +} + +static void eval_wsum_hess(expr *node, const double *w) +{ + expr *x = node->left; + prod_axis_one_expr *pnode = (prod_axis_one_expr *) node; + + /* if x is a variable */ + if (x->var_id != NOT_A_VARIABLE) + { + for (int row = 0; row < x->d1; row++) + { + int num_zeros = pnode->num_of_zeros[row]; + + if (num_zeros == 0) + { + wsum_hess_row_no_zeros(node, w, row, x->d2); + } + else if (num_zeros == 1) + { + wsum_hess_row_one_zero(node, w, row, x->d2); + } + else if (num_zeros == 2) + { + wsum_hess_row_two_zeros(node, w, row, x->d2); + } + else + { + wsum_hess_row_many_zeros(node, row, x->d2); + } + } + } + else + { + assert(false && "child must be a variable"); + } +} + +static bool is_affine(const expr *node) +{ + (void) node; + return false; +} + +static void free_type_data(expr *node) +{ + prod_axis_one_expr *pnode = (prod_axis_one_expr *) node; + free(pnode->num_of_zeros); + free(pnode->zero_index); + free(pnode->prod_nonzero); +} + +expr *new_prod_axis_one(expr *child) +{ + prod_axis_one_expr *pnode = + (prod_axis_one_expr *) calloc(1, sizeof(prod_axis_one_expr)); + expr *node = &pnode->base; + + /* output is always a row vector 1 x d1 (one product per row) */ + init_expr(node, 1, child->d1, child->n_vars, forward, jacobian_init, + eval_jacobian, is_affine, wsum_hess_init, eval_wsum_hess, + free_type_data); + + /* allocate arrays to store per-row statistics */ + pnode->num_of_zeros = (int *) calloc(child->d1, sizeof(int)); + pnode->zero_index = (int *) calloc(child->d1, sizeof(int)); + pnode->prod_nonzero = (double *) calloc(child->d1, sizeof(double)); + + node->left = child; + expr_retain(child); + + return node; +} diff --git a/tests/all_tests.c b/tests/all_tests.c index 164382b..c5c0605 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -14,6 +14,7 @@ #include "forward_pass/composite/test_composite.h" #include "forward_pass/elementwise/test_exp.h" #include "forward_pass/elementwise/test_log.h" +#include "forward_pass/test_prod_axis_one.h" #include "forward_pass/test_prod_axis_zero.h" #include "jacobian_tests/test_broadcast.h" #include "jacobian_tests/test_composite.h" @@ -26,6 +27,7 @@ #include "jacobian_tests/test_log.h" #include "jacobian_tests/test_neg.h" #include "jacobian_tests/test_prod.h" +#include "jacobian_tests/test_prod_axis_one.h" #include "jacobian_tests/test_prod_axis_zero.h" #include "jacobian_tests/test_promote.h" #include "jacobian_tests/test_quad_form.h" @@ -55,6 +57,7 @@ #include "wsum_hess/test_left_matmul.h" #include "wsum_hess/test_multiply.h" #include "wsum_hess/test_prod.h" +#include "wsum_hess/test_prod_axis_one.h" #include "wsum_hess/test_prod_axis_zero.h" #include "wsum_hess/test_quad_form.h" #include "wsum_hess/test_quad_over_lin.h" @@ -90,6 +93,7 @@ int main(void) mu_run_test(test_broadcast_col, tests_run); mu_run_test(test_broadcast_matrix, tests_run); mu_run_test(test_forward_prod_axis_zero, tests_run); + mu_run_test(test_forward_prod_axis_one, tests_run); printf("\n--- Jacobian Tests ---\n"); mu_run_test(test_neg_jacobian, tests_run); @@ -123,6 +127,8 @@ int main(void) mu_run_test(test_jacobian_prod_one_zero, tests_run); mu_run_test(test_jacobian_prod_two_zeros, tests_run); mu_run_test(test_jacobian_prod_axis_zero, tests_run); + mu_run_test(test_jacobian_prod_axis_one, tests_run); + mu_run_test(test_jacobian_prod_axis_one_one_zero, tests_run); mu_run_test(test_jacobian_sum_log, tests_run); mu_run_test(test_jacobian_sum_mult, tests_run); mu_run_test(test_jacobian_sum_log_axis_0, tests_run); @@ -176,6 +182,9 @@ int main(void) mu_run_test(test_wsum_hess_prod_axis_zero_no_zeros, tests_run); mu_run_test(test_wsum_hess_prod_axis_zero_one_zero, tests_run); mu_run_test(test_wsum_hess_prod_axis_zero_mixed_zeros, tests_run); + mu_run_test(test_wsum_hess_prod_axis_one_no_zeros, tests_run); + mu_run_test(test_wsum_hess_prod_axis_one_one_zero, tests_run); + mu_run_test(test_wsum_hess_prod_axis_one_mixed_zeros, tests_run); mu_run_test(test_wsum_hess_rel_entr_1, tests_run); mu_run_test(test_wsum_hess_rel_entr_2, tests_run); mu_run_test(test_wsum_hess_rel_entr_matrix, tests_run); diff --git a/tests/forward_pass/test_prod_axis_one.h b/tests/forward_pass/test_prod_axis_one.h new file mode 100644 index 0000000..6f4b3bb --- /dev/null +++ b/tests/forward_pass/test_prod_axis_one.h @@ -0,0 +1,35 @@ +#include +#include +#include + +#include "affine.h" +#include "expr.h" +#include "minunit.h" +#include "other.h" +#include "test_helpers.h" + +const char *test_forward_prod_axis_one() +{ + /* Create a 2x3 constant matrix stored column-wise: + [1, 3, 5] + [2, 4, 6] + Stored as: [1, 2, 3, 4, 5, 6] + */ + double values[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + expr *const_node = new_constant(2, 3, 0, values); + expr *prod_node = new_prod_axis_one(const_node); + prod_node->forward(prod_node, NULL); + + /* Expected: rowwise products, result is 1x2 (row vector) + Row 0: 1*3*5 = 15 + Row 1: 2*4*6 = 48 + Result: [15, 48] + */ + double expected[2] = {15.0, 48.0}; + + mu_assert("prod_axis_one forward test failed", + cmp_double_array(prod_node->value, expected, 2)); + + free_expr(prod_node); + return 0; +} diff --git a/tests/jacobian_tests/test_prod_axis_one.h b/tests/jacobian_tests/test_prod_axis_one.h new file mode 100644 index 0000000..77993fb --- /dev/null +++ b/tests/jacobian_tests/test_prod_axis_one.h @@ -0,0 +1,94 @@ +#include + +#include "affine.h" +#include "expr.h" +#include "minunit.h" +#include "other.h" +#include "test_helpers.h" + +const char *test_jacobian_prod_axis_one() +{ + /* x is 3x3 variable, global index 1, total 10 vars + * x = [1, 2, 3, 4, 5, 6, 7, 8, 9] (column-major order) + * [1, 4, 7] + * [2, 5, 8] + * [3, 6, 9] + * + * f = prod_axis_one(x) = [28, 80, 162] (row vector) + * Row 0: 1*4*7 = 28 + * Row 1: 2*5*8 = 80 + * Row 2: 3*6*9 = 162 + * + * Jacobian is 3x10 (3 outputs, 10 total vars): + * Row 0 (output[0] = 28): df[0]/dx = [0, 28, 0, 0, 7, 0, 0, 4, 0, 0] + * df[0]/dx[0] = 28/1 = 28, df[0]/dx[3] = 28/4 = 7, df[0]/dx[6] = 28/7 = 4 + * Row 1 (output[1] = 80): df[1]/dx = [0, 0, 40, 0, 0, 16, 0, 0, 10, 0] + * df[1]/dx[1] = 80/2 = 40, df[1]/dx[4] = 80/5 = 16, df[1]/dx[7] = 80/8 = 10 + * Row 2 (output[2] = 162): df[2]/dx = [0, 0, 0, 54, 0, 0, 27, 0, 0, 18] + * df[2]/dx[2] = 162/3 = 54, df[2]/dx[5] = 162/6 = 27, df[2]/dx[8] = 162/9 = 18 + * + * (Note: global indices are offset by var_id=1) + */ + double u_vals[10] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; + expr *x = new_variable(3, 3, 1, 10); + expr *p = new_prod_axis_one(x); + + p->forward(p, u_vals); + p->jacobian_init(p); + p->eval_jacobian(p); + + /* CSR format for 3x10 Jacobian with row-strided structure */ + double expected_Ax[9] = {28.0, 7.0, 4.0, 40.0, 16.0, 10.0, 54.0, 27.0, 18.0}; + int expected_Ap[4] = {0, 3, 6, 9}; + int expected_Ai[9] = {1, 4, 7, 2, 5, 8, 3, 6, 9}; + + mu_assert("vals fail", cmp_double_array(p->jacobian->x, expected_Ax, 9)); + mu_assert("rows fail", cmp_int_array(p->jacobian->p, expected_Ap, 4)); + mu_assert("cols fail", cmp_int_array(p->jacobian->i, expected_Ai, 9)); + + free_expr(p); + return 0; +} + +const char *test_jacobian_prod_axis_one_one_zero() +{ + /* x is 3x3 variable, global index 1, total 10 vars + * x = [1, 2, 3, 4, 0, 6, 7, 8, 9] (column-major order) + * [1, 4, 7] + * [2, 0, 8] + * [3, 6, 9] + * + * f = prod_axis_one(x) = [28, 0, 162] (row vector) + * Row 0: 1*4*7 = 28 + * Row 1: 2*0*8 = 0 (one zero at column 1) + * Row 2: 3*6*9 = 162 + * + * Jacobian is 3x10: + * Row 0 (output[0] = 28, no zeros): df[0]/dx = [0, 28, 0, 0, 7, 0, 0, 4, 0, 0] + * df[0]/dx[0] = 28/1 = 28, df[0]/dx[3] = 28/4 = 7, df[0]/dx[6] = 28/7 = 4 + * Row 1 (output[1] = 0, one zero): df[1]/dx = [0, 0, 0, 0, 0, 16, 0, 0, 0, 0] + * Only derivative w.r.t. zero element: df[1]/dx[4] = prod_nonzero = 2*8 = 16 + * Row 2 (output[2] = 162, no zeros): df[2]/dx = [0, 0, 0, 54, 0, 0, 27, 0, 0, + * 18] df[2]/dx[2] = 162/3 = 54, df[2]/dx[5] = 162/6 = 27, df[2]/dx[8] = 162/9 = + * 18 + */ + double u_vals[10] = {0.0, 1.0, 2.0, 3.0, 4.0, 0.0, 6.0, 7.0, 8.0, 9.0}; + expr *x = new_variable(3, 3, 1, 10); + expr *p = new_prod_axis_one(x); + + p->forward(p, u_vals); + p->jacobian_init(p); + p->eval_jacobian(p); + + /* CSR format for 3x10 Jacobian with row-strided structure */ + double expected_Ax[9] = {28.0, 7.0, 4.0, 0.0, 16.0, 0.0, 54.0, 27.0, 18.0}; + int expected_Ap[4] = {0, 3, 6, 9}; + int expected_Ai[9] = {1, 4, 7, 2, 5, 8, 3, 6, 9}; + + mu_assert("vals fail", cmp_double_array(p->jacobian->x, expected_Ax, 9)); + mu_assert("rows fail", cmp_int_array(p->jacobian->p, expected_Ap, 4)); + mu_assert("cols fail", cmp_int_array(p->jacobian->i, expected_Ai, 9)); + + free_expr(p); + return 0; +} diff --git a/tests/wsum_hess/test_prod_axis_one.h b/tests/wsum_hess/test_prod_axis_one.h new file mode 100644 index 0000000..21e2d81 --- /dev/null +++ b/tests/wsum_hess/test_prod_axis_one.h @@ -0,0 +1,228 @@ +#include +#include + +#include "expr.h" +#include "minunit.h" +#include "other.h" +#include "test_helpers.h" + +const char *test_wsum_hess_prod_axis_one_no_zeros() +{ + /* x is 2x3 variable, global index 1, total 8 vars + * x = [1, 2, 3, 4, 5, 6] (column-major) + * [1, 3, 5] + * [2, 4, 6] + * f = prod_axis_one(x) = [15, 48] + * w = [1, 2] + * + * Hessian is block diagonal with two 3x3 blocks (one per row): + * Row 0 block (scale = 1 * 15 = 15): + * cols (c0=1, c1=3, c2=5) + * off-diagonals: (0,1)=5, (0,2)=3, (1,0)=5, (1,2)=1, (2,0)=3, (2,1)=1; diag=0 + * Row 1 block (scale = 2 * 48 = 96): + * cols (c0=2, c1=4, c2=6) + * off-diagonals: (0,1)=12, (0,2)=8, (1,0)=12, (1,2)=4, (2,0)=8, (2,1)=4; + * diag=0 + */ + double u_vals[8] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.0}; + double w_vals[2] = {1.0, 2.0}; + expr *x = new_variable(2, 3, 1, 8); + expr *p = new_prod_axis_one(x); + + p->forward(p, u_vals); + p->wsum_hess_init(p); + p->eval_wsum_hess(p, w_vals); + + double expected_x[18] = {/* Block 0 */ + 0.0, 5.0, 3.0, 5.0, 0.0, 1.0, 3.0, 1.0, 0.0, + /* Block 1 */ + 0.0, 12.0, 8.0, 12.0, 0.0, 4.0, 8.0, 4.0, 0.0}; + + /* Row pointers (per implementation ordering) */ + int expected_p[9] = {0, 0, 9, 3, 12, 6, 15, 18, 18}; + + /* Column indices (block diagonal, repeated column sets per block) */ + int expected_i[18] = {/* Block 0 */ + 1, 3, 5, 1, 3, 5, 1, 3, 5, + /* Block 1 */ + 2, 4, 6, 2, 4, 6, 2, 4, 6}; + + mu_assert("vals fail", cmp_double_array(p->wsum_hess->x, expected_x, 18)); + mu_assert("rows fail", cmp_int_array(p->wsum_hess->p, expected_p, 9)); + mu_assert("cols fail", cmp_int_array(p->wsum_hess->i, expected_i, 18)); + + free_expr(p); + return 0; +} + +const char *test_wsum_hess_prod_axis_one_one_zero() +{ + /* x is 3x3 variable, global index 1, total 10 vars + * x = [1, 2, 3, 4, 0, 6, 7, 8, 9] (column-major) + * [1, 4, 7] + * [2, 0, 8] + * [3, 6, 9] + * f = prod_axis_one(x) = [28, 0, 162] + * w = [1, 2, 3] + * + * Blocks are 3x3 (one per row): + * Row 0 (no zeros, scale=1*28=28): + * cols (1,4,7): off-diag (0,1)=7, (0,2)=4, (1,0)=7, (1,2)=2, (2,0)=4, (2,1)=2 + * Row 1 (one zero at col 1, prod_nonzero=16, scale=2*16=32): + * only row/col 1 nonzero: (1,0)=32/2=16, (1,2)=32/8=4 (symmetric) + * Row 2 (no zeros, scale=3*162=486): + * cols (3,6,9): off-diag (0,1)=27, (0,2)=18, (1,0)=27, (1,2)=9, (2,0)=18, + * (2,1)=9 + */ + double u_vals[10] = {0.0, 1.0, 2.0, 3.0, 4.0, 0.0, 6.0, 7.0, 8.0, 9.0}; + double w_vals[3] = {1.0, 2.0, 3.0}; + expr *x = new_variable(3, 3, 1, 10); + expr *p = new_prod_axis_one(x); + + p->forward(p, u_vals); + p->wsum_hess_init(p); + p->eval_wsum_hess(p, w_vals); + + double expected_x[27]; + memset(expected_x, 0, sizeof(expected_x)); + + /* Block 0 (row 0) */ + expected_x[0 * 3 + 1] = 7.0; /* (0,1) */ + expected_x[0 * 3 + 2] = 4.0; /* (0,2) */ + expected_x[1 * 3 + 0] = 7.0; /* (1,0) */ + expected_x[1 * 3 + 2] = 1.0; /* (1,2) = 28/(4*7) */ + expected_x[2 * 3 + 0] = 4.0; /* (2,0) */ + expected_x[2 * 3 + 1] = 1.0; /* (2,1) = 28/(7*4) */ + + /* Block 1 (row 1, one zero at col 1) */ + int block1 = 9; + expected_x[block1 + 1 * 3 + 0] = 16.0; /* (1,0) */ + expected_x[block1 + 0 * 3 + 1] = 16.0; /* (0,1) symmetric */ + expected_x[block1 + 1 * 3 + 2] = 4.0; /* (1,2) */ + expected_x[block1 + 2 * 3 + 1] = 4.0; /* (2,1) symmetric */ + + /* Block 2 (row 2) */ + int block2 = 18; + expected_x[block2 + 0 * 3 + 1] = 27.0; + expected_x[block2 + 0 * 3 + 2] = 18.0; + expected_x[block2 + 1 * 3 + 0] = 27.0; + expected_x[block2 + 1 * 3 + 2] = 9.0; + expected_x[block2 + 2 * 3 + 0] = 18.0; + expected_x[block2 + 2 * 3 + 1] = 9.0; + + /* Row pointers (per implementation ordering) */ + int expected_p[11] = {0, 0, 9, 18, 3, 12, 21, 6, 15, 24, 27}; + + /* Column indices (block diagonal) */ + int expected_i[27] = {/* Block 0 */ + 1, 4, 7, 1, 4, 7, 1, 4, 7, + /* Block 1 */ + 2, 5, 8, 2, 5, 8, 2, 5, 8, + /* Block 2 */ + 3, 6, 9, 3, 6, 9, 3, 6, 9}; + + mu_assert("vals fail", cmp_double_array(p->wsum_hess->x, expected_x, 27)); + mu_assert("rows fail", cmp_int_array(p->wsum_hess->p, expected_p, 11)); + mu_assert("cols fail", cmp_int_array(p->wsum_hess->i, expected_i, 27)); + + free_expr(p); + return 0; +} + +const char *test_wsum_hess_prod_axis_one_mixed_zeros() +{ + /* x is 5x3 variable, global index 1, total 16 vars + * Rows (axis=1 products): + * r0: [1, 2, 1] -> no zeros, prod=2 + * r1: [1, 0, 0] -> two zeros (cols 1,2), prod_nonzero=1 + * r2: [1, 3, 0] -> one zero (col 2), prod_nonzero=3 + * r3: [1, 4, 2] -> no zeros, prod=8 + * r4: [1, 5, 3] -> no zeros, prod=15 + * w = [1, 2, 3, 4, 5] + * Blocks are 3x3 (one per row, block diagonal): + */ + double u_vals[16] = {0.0, + /* col 0 (rows 0-4): 1,1,1,1,1 */ + 1.0, 1.0, 1.0, 1.0, 1.0, + /* col 1: 2,0,3,4,5 */ + 2.0, 0.0, 3.0, 4.0, 5.0, + /* col 2: 1,0,0,2,3 */ + 1.0, 0.0, 0.0, 2.0, 3.0}; + /* Actually store column-major: + * col0: [1,1,1,1,1] + * col1: [2,0,3,4,5] + * col2: [1,0,0,2,3] + */ + double w_vals[5] = {1.0, 2.0, 3.0, 4.0, 5.0}; + expr *x = new_variable(5, 3, 1, 16); + expr *p = new_prod_axis_one(x); + + p->forward(p, u_vals); + p->wsum_hess_init(p); + p->eval_wsum_hess(p, w_vals); + + double expected_x[45]; + memset(expected_x, 0, sizeof(expected_x)); + + /* Block offsets */ + int b0 = 0, b1 = 9, b2 = 18, b3 = 27, b4 = 36; + + /* Row 0 block (no zeros, scale = 1 * 2 = 2), x = [1,2,1] */ + expected_x[b0 + 0 * 3 + 1] = 1.0; /* 2/(1*2) */ + expected_x[b0 + 0 * 3 + 2] = 2.0; /* 2/(1*1) */ + expected_x[b0 + 1 * 3 + 0] = 1.0; /* symmetric */ + expected_x[b0 + 1 * 3 + 2] = 1.0; /* 2/(2*1) */ + expected_x[b0 + 2 * 3 + 0] = 2.0; /* symmetric */ + expected_x[b0 + 2 * 3 + 1] = 1.0; /* symmetric */ + + /* Row 1 block (two zeros at cols 1,2), hess = w*prod_nonzero = 2*1 = 2 */ + expected_x[b1 + 1 * 3 + 2] = 2.0; + expected_x[b1 + 2 * 3 + 1] = 2.0; + + /* Row 2 block (one zero at col 2), w_prod = 3 * 3 = 9, x = [1,3,0] */ + expected_x[b2 + 2 * 3 + 0] = 9.0; /* 9/1 */ + expected_x[b2 + 0 * 3 + 2] = 9.0; /* symmetric */ + expected_x[b2 + 2 * 3 + 1] = 3.0; /* 9/3 */ + expected_x[b2 + 1 * 3 + 2] = 3.0; /* symmetric */ + + /* Row 3 block (no zeros, scale = 4 * 8 = 32), x = [1,4,2] */ + expected_x[b3 + 0 * 3 + 1] = 8.0; /* 32/(1*4) */ + expected_x[b3 + 0 * 3 + 2] = 16.0; /* 32/(1*2) */ + expected_x[b3 + 1 * 3 + 0] = 8.0; /* symmetric */ + expected_x[b3 + 1 * 3 + 2] = 4.0; /* 32/(4*2) */ + expected_x[b3 + 2 * 3 + 0] = 16.0; /* symmetric */ + expected_x[b3 + 2 * 3 + 1] = 4.0; /* symmetric */ + + /* Row 4 block (no zeros, scale = 5 * 15 = 75), x = [1,5,3] */ + expected_x[b4 + 0 * 3 + 1] = 15.0; /* 75/(1*5) */ + expected_x[b4 + 0 * 3 + 2] = 25.0; /* 75/(1*3) */ + expected_x[b4 + 1 * 3 + 0] = 15.0; /* symmetric */ + expected_x[b4 + 1 * 3 + 2] = 5.0; /* 75/(5*3) */ + expected_x[b4 + 2 * 3 + 0] = 25.0; /* symmetric */ + expected_x[b4 + 2 * 3 + 1] = 5.0; /* symmetric */ + + /* Row pointers (per implementation ordering) */ + int expected_p[17] = {0, 0, 9, 18, 27, 36, 3, 12, 21, + 30, 39, 6, 15, 24, 33, 42, 45}; + + /* Column indices (block diagonal) */ + int expected_i[45]; + for (int block = 0; block < 5; block++) + { + int offset = block * 9; + int base = 1 + block; /* var_id + row index */ + for (int r = 0; r < 3; r++) + { + expected_i[offset + r * 3 + 0] = base; + expected_i[offset + r * 3 + 1] = base + 5; + expected_i[offset + r * 3 + 2] = base + 10; + } + } + + mu_assert("vals fail", cmp_double_array(p->wsum_hess->x, expected_x, 45)); + mu_assert("rows fail", cmp_int_array(p->wsum_hess->p, expected_p, 17)); + mu_assert("cols fail", cmp_int_array(p->wsum_hess->i, expected_i, 45)); + + free_expr(p); + return 0; +} From a6c513b22fc23c388c4fa46250f8e015b2443e91 Mon Sep 17 00:00:00 2001 From: Daniel Date: Sun, 18 Jan 2026 06:08:57 -0800 Subject: [PATCH 3/5] more progress --- python/atoms/prod_axis_one.h | 33 +++ python/bindings.c | 3 + src/other/prod_axis_one.c | 183 ++++++++++----- src/problem.c | 6 + tests/all_tests.c | 1 + tests/wsum_hess/test_prod_axis_one.h | 336 ++++++++++++++++++++------- 6 files changed, 417 insertions(+), 145 deletions(-) create mode 100644 python/atoms/prod_axis_one.h diff --git a/python/atoms/prod_axis_one.h b/python/atoms/prod_axis_one.h new file mode 100644 index 0000000..db4d2ea --- /dev/null +++ b/python/atoms/prod_axis_one.h @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: Apache-2.0 +#ifndef ATOM_PROD_AXIS_ONE_H +#define ATOM_PROD_AXIS_ONE_H + +#include "common.h" +#include "other.h" + +static PyObject *py_make_prod_axis_one(PyObject *self, PyObject *args) +{ + (void) self; + PyObject *child_capsule; + if (!PyArg_ParseTuple(args, "O", &child_capsule)) + { + return NULL; + } + expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); + if (!child) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + expr *node = new_prod_axis_one(child); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create prod_axis_one node"); + return NULL; + } + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_PROD_AXIS_ONE_H */ diff --git a/python/bindings.c b/python/bindings.c index d83b271..0cd703e 100644 --- a/python/bindings.c +++ b/python/bindings.c @@ -23,6 +23,7 @@ #include "atoms/neg.h" #include "atoms/power.h" #include "atoms/prod.h" +#include "atoms/prod_axis_one.h" #include "atoms/prod_axis_zero.h" #include "atoms/promote.h" #include "atoms/quad_form.h" @@ -80,6 +81,8 @@ static PyMethodDef DNLPMethods[] = { {"make_prod", py_make_prod, METH_VARARGS, "Create prod node"}, {"make_prod_axis_zero", py_make_prod_axis_zero, METH_VARARGS, "Create prod_axis_zero node"}, + {"make_prod_axis_one", py_make_prod_axis_one, METH_VARARGS, + "Create prod_axis_one node"}, {"make_sin", py_make_sin, METH_VARARGS, "Create sin node"}, {"make_cos", py_make_cos, METH_VARARGS, "Create cos node"}, {"make_tan", py_make_tan, METH_VARARGS, "Create tan node"}, diff --git a/src/other/prod_axis_one.c b/src/other/prod_axis_one.c index 9aaff4f..05889de 100644 --- a/src/other/prod_axis_one.c +++ b/src/other/prod_axis_one.c @@ -1,6 +1,7 @@ #include "other.h" #include #include +#include #include #include @@ -135,17 +136,31 @@ static void wsum_hess_init(expr *node) /* if x is a variable */ if (x->var_id != NOT_A_VARIABLE) { - /* Hessian has block diagonal structure: d1 blocks of size d2 x d2 */ + /* Hessian structure: for each row i of the input matrix, + * the Hessian of the product (w[i] * prod(row i)) has entries + * at the column indices corresponding to the columns in that row. + * Each row i has d2 non-zero columns, and we have d2 x d2 interactions. + * Total nnz = d1 * d2 * d2 + */ int nnz = x->d1 * x->d2 * x->d2; node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, nnz); CSR_Matrix *H = node->wsum_hess; - /* fill row pointers for the variable's rows (block diagonal) */ + /* Fill row pointers and column indices */ + int nnz_per_row = x->d2; for (int i = 0; i < x->size; i++) { - int block_idx = i % x->d1; /* row of the matrix */ - int row_in_block = i / x->d1; /* column index within that row-block */ - H->p[x->var_id + i] = block_idx * x->d2 * x->d2 + row_in_block * x->d2; + int matrix_row = i % x->d1; + H->p[x->var_id + i] = i * nnz_per_row; + + /* For this variable (at matrix position [matrix_row, ...]), + * the Hessian has entries with all columns in the same matrix_row */ + for (int j = 0; j < x->d2; j++) + { + /* Global column index for column j of this row */ + int global_col = x->var_id + matrix_row + j * x->d1; + H->i[i * nnz_per_row + j] = global_col; + } } /* fill row pointers for rows after the variable */ @@ -153,29 +168,25 @@ static void wsum_hess_init(expr *node) { H->p[i] = nnz; } - - /* set column indices for each block */ - for (int block = 0; block < x->d1; block++) - { - int block_start = block * x->d2 * x->d2; - int col_start = block; /* global column offset for this row-block */ - for (int row = 0; row < x->d2; row++) - { - for (int col = 0; col < x->d2; col++) - { - /* variable indices: row-block fixed, column-major layout */ - int var_col = - col * x->d1 + - block; /* global variable index offset from var_id */ - H->i[block_start + row * x->d2 + col] = x->var_id + var_col; - } - } - } } else { assert(false && "child must be a variable"); } + + /* print row pointers */ + // for (int i = 0; i <= node->n_vars; i++) + //{ + // printf("H->p[%d] = %d\n", i, node->wsum_hess->p[i]); + // } + // + ///* print column indices */ + // for (int i = 0; i < node->wsum_hess->nnz; i++) + //{ + // printf("H->i[%d] = %d\n", i, node->wsum_hess->i[i]); + // } + + printf("done wsum_hess_init prod_axis_one\n"); } static inline void wsum_hess_row_no_zeros(expr *node, const double *w, int row, @@ -183,21 +194,33 @@ static inline void wsum_hess_row_no_zeros(expr *node, const double *w, int row, { expr *x = node->left; prod_axis_one_expr *pnode = (prod_axis_one_expr *) node; - double *H = node->wsum_hess->x; + CSR_Matrix *H = node->wsum_hess; + double *H_vals = H->x; - int block_start = row * d2 * d2; double scale = w[row] * node->value[row]; + int idx_i = row; /* Start with the row offset */ - /* fill block for this row */ - for (int r = 0; r < d2; r++) + /* For each variable in this row, fill in Hessian entries */ + for (int i = 0; i < d2; i++) { - int idx_r = r * node->left->d1 + row; /* column-major index into x */ - for (int c = 0; c < d2; c++) + int var_i = x->var_id + row + i * x->d1; + int row_start = H->p[var_i]; + idx_i = i * x->d1 + row; /* Element at matrix [row, i] */ + + /* Row var_i has nnz entries with columns from the same matrix row */ + for (int j = 0; j < d2; j++) { - int idx_c = c * node->left->d1 + row; - double val = - (r == c) ? 0.0 : scale / (x->value[idx_r] * x->value[idx_c]); - H[block_start + r * d2 + c] = val; + if (i == j) + { + /* Diagonal entries are 0 */ + H_vals[row_start + j] = 0.0; + } + else + { + /* Off-diagonal: scale / (x[i] * x[j]) */ + int idx_j = j * x->d1 + row; /* Element at matrix [row, j] */ + H_vals[row_start + j] = scale / (x->value[idx_i] * x->value[idx_j]); + } } } (void) pnode; /* suppress unused warning */ @@ -208,26 +231,40 @@ static inline void wsum_hess_row_one_zero(expr *node, const double *w, int row, { expr *x = node->left; prod_axis_one_expr *pnode = (prod_axis_one_expr *) node; - double *H = node->wsum_hess->x; + CSR_Matrix *H = node->wsum_hess; + double *H_vals = H->x; - int block_start = row * d2 * d2; - int p = pnode->zero_index[row]; /* zero column */ + int p = pnode->zero_index[row]; /* zero column index */ double w_prod = w[row] * pnode->prod_nonzero[row]; - /* clear block */ - memset(&H[block_start], 0, sizeof(double) * d2 * d2); - - /* fill row p and column p */ - int col_offset = p * node->left->d1 + row; - for (int c = 0; c < d2; c++) + /* For each variable in this row */ + for (int i = 0; i < d2; i++) { - if (c == p) continue; - int idx_c = c * node->left->d1 + row; - double hess_val = w_prod / x->value[idx_c]; - H[block_start + p * d2 + c] = hess_val; - H[block_start + c * d2 + p] = hess_val; + int var_i = x->var_id + row + i * x->d1; + int row_start = H->p[var_i]; + + /* Row var_i has nnz entries with columns from the same matrix row */ + for (int j = 0; j < d2; j++) + { + if (i == p && j != p) + { + /* Row p (zero row), column j (nonzero) */ + int idx_j = j * x->d1 + row; + H_vals[row_start + j] = w_prod / x->value[idx_j]; + } + else if (j == p && i != p) + { + /* Row i (nonzero), column p (zero) */ + int idx_i = i * x->d1 + row; + H_vals[row_start + j] = w_prod / x->value[idx_i]; + } + else + { + /* All other entries are zero */ + H_vals[row_start + j] = 0.0; + } + } } - (void) col_offset; /* suppress unused */ } static inline void wsum_hess_row_two_zeros(expr *node, const double *w, int row, @@ -235,12 +272,8 @@ static inline void wsum_hess_row_two_zeros(expr *node, const double *w, int row, { expr *x = node->left; prod_axis_one_expr *pnode = (prod_axis_one_expr *) node; - double *H = node->wsum_hess->x; - - int block_start = row * d2 * d2; - - /* clear block */ - memset(&H[block_start], 0, sizeof(double) * d2 * d2); + CSR_Matrix *H = node->wsum_hess; + double *H_vals = H->x; /* find indices p and q where row has zeros */ int p = -1, q = -1; @@ -263,15 +296,47 @@ static inline void wsum_hess_row_two_zeros(expr *node, const double *w, int row, assert(p != -1 && q != -1); double hess_val = w[row] * pnode->prod_nonzero[row]; - H[block_start + p * d2 + q] = hess_val; - H[block_start + q * d2 + p] = hess_val; + + /* For each variable in this row */ + for (int i = 0; i < d2; i++) + { + int var_i = x->var_id + row + i * x->d1; + int row_start = H->p[var_i]; + + /* Row var_i has nnz entries with columns from the same matrix row */ + for (int j = 0; j < d2; j++) + { + /* Only (p,q) and (q,p) are nonzero */ + if ((i == p && j == q) || (i == q && j == p)) + { + H_vals[row_start + j] = hess_val; + } + else + { + H_vals[row_start + j] = 0.0; + } + } + } } static inline void wsum_hess_row_many_zeros(expr *node, int row, int d2) { - double *H = node->wsum_hess->x; - int block_start = row * d2 * d2; - memset(&H[block_start], 0, sizeof(double) * d2 * d2); + CSR_Matrix *H = node->wsum_hess; + double *H_vals = H->x; + expr *x = node->left; + + /* For each variable in this row, zero out all entries */ + for (int i = 0; i < d2; i++) + { + int var_i = x->var_id + row + i * x->d1; + int row_start = H->p[var_i]; + + /* Each variable has d2 entries */ + for (int j = 0; j < d2; j++) + { + H_vals[row_start + j] = 0.0; + } + } } static void eval_wsum_hess(expr *node, const double *w) diff --git a/src/problem.c b/src/problem.c index e5b4764..d6e752b 100644 --- a/src/problem.c +++ b/src/problem.c @@ -56,6 +56,7 @@ void problem_init_derivatives(problem *prob) // ------------------------------------------------------------------------------- // Jacobian structure // ------------------------------------------------------------------------------- + printf("Initializing Jacobian structure...\n"); prob->objective->jacobian_init(prob->objective); int nnz = 0; for (int i = 0; i < prob->n_constraints; i++) @@ -66,6 +67,7 @@ void problem_init_derivatives(problem *prob) } prob->jacobian = new_csr_matrix(prob->total_constraint_size, prob->n_vars, nnz); + printf("jacobian structure initialized\n"); /* set sparsity pattern of jacobian */ CSR_Matrix *H = prob->jacobian; @@ -90,6 +92,7 @@ void problem_init_derivatives(problem *prob) // ------------------------------------------------------------------------------- // Lagrange Hessian structure // ------------------------------------------------------------------------------- + printf("Initializing Lagrange Hessian structure...\n"); prob->objective->wsum_hess_init(prob->objective); nnz = prob->objective->wsum_hess->nnz; @@ -103,11 +106,14 @@ void problem_init_derivatives(problem *prob) memset(prob->lagrange_hessian->x, 0, nnz * sizeof(double)); /* affine shortcut */ prob->hess_idx_map = (int *) malloc(nnz * sizeof(int)); int *iwork = (int *) malloc(MAX(nnz, prob->n_vars) * sizeof(int)); + printf("Lagrange Hessian structure initialized\n"); problem_lagrange_hess_fill_sparsity(prob, iwork); + printf("Lagrange Hessian sparsity filled\n"); free(iwork); clock_gettime(CLOCK_MONOTONIC, &timer.end); prob->stats.time_init_derivatives += GET_ELAPSED_SECONDS(timer); + printf("Derivative initialization complete.\n"); } static void problem_lagrange_hess_fill_sparsity(problem *prob, int *iwork) diff --git a/tests/all_tests.c b/tests/all_tests.c index c5c0605..5de1b00 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -185,6 +185,7 @@ int main(void) mu_run_test(test_wsum_hess_prod_axis_one_no_zeros, tests_run); mu_run_test(test_wsum_hess_prod_axis_one_one_zero, tests_run); mu_run_test(test_wsum_hess_prod_axis_one_mixed_zeros, tests_run); + mu_run_test(test_wsum_hess_prod_axis_one_2x2, tests_run); mu_run_test(test_wsum_hess_rel_entr_1, tests_run); mu_run_test(test_wsum_hess_rel_entr_2, tests_run); mu_run_test(test_wsum_hess_rel_entr_matrix, tests_run); diff --git a/tests/wsum_hess/test_prod_axis_one.h b/tests/wsum_hess/test_prod_axis_one.h index 21e2d81..7acbec1 100644 --- a/tests/wsum_hess/test_prod_axis_one.h +++ b/tests/wsum_hess/test_prod_axis_one.h @@ -33,19 +33,35 @@ const char *test_wsum_hess_prod_axis_one_no_zeros() p->wsum_hess_init(p); p->eval_wsum_hess(p, w_vals); - double expected_x[18] = {/* Block 0 */ - 0.0, 5.0, 3.0, 5.0, 0.0, 1.0, 3.0, 1.0, 0.0, - /* Block 1 */ - 0.0, 12.0, 8.0, 12.0, 0.0, 4.0, 8.0, 4.0, 0.0}; + double expected_x[18] = {/* Var 1 (row 0 of matrix): [0, 5, 3] */ + 0.0, 5.0, 3.0, + /* Var 2 (row 1 of matrix): [0, 12, 8] */ + 0.0, 12.0, 8.0, + /* Var 3 (row 0 of matrix): [5, 0, 1] */ + 5.0, 0.0, 1.0, + /* Var 4 (row 1 of matrix): [12, 0, 4] */ + 12.0, 0.0, 4.0, + /* Var 5 (row 0 of matrix): [3, 1, 0] */ + 3.0, 1.0, 0.0, + /* Var 6 (row 1 of matrix): [8, 4, 0] */ + 8.0, 4.0, 0.0}; - /* Row pointers (per implementation ordering) */ - int expected_p[9] = {0, 0, 9, 3, 12, 6, 15, 18, 18}; + /* Row pointers (monotonically increasing for valid CSR format) */ + int expected_p[9] = {0, 0, 3, 6, 9, 12, 15, 18, 18}; - /* Column indices (block diagonal, repeated column sets per block) */ - int expected_i[18] = {/* Block 0 */ - 1, 3, 5, 1, 3, 5, 1, 3, 5, - /* Block 1 */ - 2, 4, 6, 2, 4, 6, 2, 4, 6}; + /* Column indices (each row of the matrix interacts with its own columns) */ + int expected_i[18] = {/* Var 1 (row 0): cols 1,3,5 */ + 1, 3, 5, + /* Var 2 (row 1): cols 2,4,6 */ + 2, 4, 6, + /* Var 3 (row 0): cols 1,3,5 */ + 1, 3, 5, + /* Var 4 (row 1): cols 2,4,6 */ + 2, 4, 6, + /* Var 5 (row 0): cols 1,3,5 */ + 1, 3, 5, + /* Var 6 (row 1): cols 2,4,6 */ + 2, 4, 6}; mu_assert("vals fail", cmp_double_array(p->wsum_hess->x, expected_x, 18)); mu_assert("rows fail", cmp_int_array(p->wsum_hess->p, expected_p, 9)); @@ -86,40 +102,73 @@ const char *test_wsum_hess_prod_axis_one_one_zero() double expected_x[27]; memset(expected_x, 0, sizeof(expected_x)); - /* Block 0 (row 0) */ - expected_x[0 * 3 + 1] = 7.0; /* (0,1) */ - expected_x[0 * 3 + 2] = 4.0; /* (0,2) */ - expected_x[1 * 3 + 0] = 7.0; /* (1,0) */ - expected_x[1 * 3 + 2] = 1.0; /* (1,2) = 28/(4*7) */ - expected_x[2 * 3 + 0] = 4.0; /* (2,0) */ - expected_x[2 * 3 + 1] = 1.0; /* (2,1) = 28/(7*4) */ - - /* Block 1 (row 1, one zero at col 1) */ - int block1 = 9; - expected_x[block1 + 1 * 3 + 0] = 16.0; /* (1,0) */ - expected_x[block1 + 0 * 3 + 1] = 16.0; /* (0,1) symmetric */ - expected_x[block1 + 1 * 3 + 2] = 4.0; /* (1,2) */ - expected_x[block1 + 2 * 3 + 1] = 4.0; /* (2,1) symmetric */ - - /* Block 2 (row 2) */ - int block2 = 18; - expected_x[block2 + 0 * 3 + 1] = 27.0; - expected_x[block2 + 0 * 3 + 2] = 18.0; - expected_x[block2 + 1 * 3 + 0] = 27.0; - expected_x[block2 + 1 * 3 + 2] = 9.0; - expected_x[block2 + 2 * 3 + 0] = 18.0; - expected_x[block2 + 2 * 3 + 1] = 9.0; - - /* Row pointers (per implementation ordering) */ - int expected_p[11] = {0, 0, 9, 18, 3, 12, 21, 6, 15, 24, 27}; - - /* Column indices (block diagonal) */ - int expected_i[27] = {/* Block 0 */ - 1, 4, 7, 1, 4, 7, 1, 4, 7, - /* Block 1 */ - 2, 5, 8, 2, 5, 8, 2, 5, 8, - /* Block 2 */ - 3, 6, 9, 3, 6, 9, 3, 6, 9}; + /* Var 1 (row 0): [0, 7, 4] */ + expected_x[0] = 0.0; + expected_x[1] = 7.0; + expected_x[2] = 4.0; + + /* Var 2 (row 1, one zero at col 1): [0, 16, 0] */ + expected_x[3] = 0.0; + expected_x[4] = 16.0; + expected_x[5] = 0.0; + + /* Var 3 (row 2): [0, 27, 18] */ + expected_x[6] = 0.0; + expected_x[7] = 27.0; + expected_x[8] = 18.0; + + /* Var 4 (row 0): [7, 0, 1] */ + expected_x[9] = 7.0; + expected_x[10] = 0.0; + expected_x[11] = 1.0; + + /* Var 5 (row 1, one zero at col 1): [16, 0, 4] */ + expected_x[12] = 16.0; + expected_x[13] = 0.0; + expected_x[14] = 4.0; + + /* Var 6 (row 2): [27, 0, 9] */ + expected_x[15] = 27.0; + expected_x[16] = 0.0; + expected_x[17] = 9.0; + + /* Var 7 (row 0): [4, 1, 0] */ + expected_x[18] = 4.0; + expected_x[19] = 1.0; + expected_x[20] = 0.0; + + /* Var 8 (row 1, one zero at col 1): [0, 4, 0] */ + expected_x[21] = 0.0; + expected_x[22] = 4.0; + expected_x[23] = 0.0; + + /* Var 9 (row 2): [18, 9, 0] */ + expected_x[24] = 18.0; + expected_x[25] = 9.0; + expected_x[26] = 0.0; + + /* Row pointers (monotonically increasing for valid CSR format) */ + int expected_p[11] = {0, 0, 3, 6, 9, 12, 15, 18, 21, 24, 27}; + + /* Column indices (each row of the matrix interacts with its own columns) */ + int expected_i[27] = {/* Var 1 (row 0): cols 1,4,7 */ + 1, 4, 7, + /* Var 2 (row 1): cols 2,5,8 */ + 2, 5, 8, + /* Var 3 (row 2): cols 3,6,9 */ + 3, 6, 9, + /* Var 4 (row 0): cols 1,4,7 */ + 1, 4, 7, + /* Var 5 (row 1): cols 2,5,8 */ + 2, 5, 8, + /* Var 6 (row 2): cols 3,6,9 */ + 3, 6, 9, + /* Var 7 (row 0): cols 1,4,7 */ + 1, 4, 7, + /* Var 8 (row 1): cols 2,5,8 */ + 2, 5, 8, + /* Var 9 (row 2): cols 3,6,9 */ + 3, 6, 9}; mu_assert("vals fail", cmp_double_array(p->wsum_hess->x, expected_x, 27)); mu_assert("rows fail", cmp_int_array(p->wsum_hess->p, expected_p, 11)); @@ -164,59 +213,120 @@ const char *test_wsum_hess_prod_axis_one_mixed_zeros() double expected_x[45]; memset(expected_x, 0, sizeof(expected_x)); - /* Block offsets */ - int b0 = 0, b1 = 9, b2 = 18, b3 = 27, b4 = 36; + /* For a 5x3 matrix with var_id=1: + * CSR row pointers: p[i] = (i-1)*3 for i in [1,15] + * Variables are indexed sequentially: + * Var 1 (matrix [0,0]): p[1]=0 + * Var 2 (matrix [1,0]): p[2]=3 + * Var 3 (matrix [2,0]): p[3]=6 + * Var 4 (matrix [3,0]): p[4]=9 + * Var 5 (matrix [4,0]): p[5]=12 + * Var 6 (matrix [0,1]): p[6]=15 + * Var 7 (matrix [1,1]): p[7]=18 + * Var 8 (matrix [2,1]): p[8]=21 + * Var 9 (matrix [3,1]): p[9]=24 + * Var 10 (matrix [4,1]): p[10]=27 + * Var 11 (matrix [0,2]): p[11]=30 + * Var 12 (matrix [1,2]): p[12]=33 + * Var 13 (matrix [2,2]): p[13]=36 + * Var 14 (matrix [3,2]): p[14]=39 + * Var 15 (matrix [4,2]): p[15]=42 + */ /* Row 0 block (no zeros, scale = 1 * 2 = 2), x = [1,2,1] */ - expected_x[b0 + 0 * 3 + 1] = 1.0; /* 2/(1*2) */ - expected_x[b0 + 0 * 3 + 2] = 2.0; /* 2/(1*1) */ - expected_x[b0 + 1 * 3 + 0] = 1.0; /* symmetric */ - expected_x[b0 + 1 * 3 + 2] = 1.0; /* 2/(2*1) */ - expected_x[b0 + 2 * 3 + 0] = 2.0; /* symmetric */ - expected_x[b0 + 2 * 3 + 1] = 1.0; /* symmetric */ + /* Var 1 (matrix [0,0]): [0, 1, 2] */ + expected_x[0] = 0.0; + expected_x[1] = 1.0; /* 2/(1*2) */ + expected_x[2] = 2.0; /* 2/(1*1) */ + + /* Var 6 (matrix [0,1]): [1, 0, 1] */ + expected_x[15] = 1.0; /* 2/(2*1) */ + expected_x[16] = 0.0; + expected_x[17] = 1.0; /* 2/(2*1) */ + + /* Var 11 (matrix [0,2]): [2, 1, 0] */ + expected_x[30] = 2.0; /* 2/(1*1) */ + expected_x[31] = 1.0; /* 2/(1*2) */ + expected_x[32] = 0.0; /* Row 1 block (two zeros at cols 1,2), hess = w*prod_nonzero = 2*1 = 2 */ - expected_x[b1 + 1 * 3 + 2] = 2.0; - expected_x[b1 + 2 * 3 + 1] = 2.0; + /* Var 2 (matrix [1,0]): [0, 0, 0] */ + expected_x[3] = 0.0; + expected_x[4] = 0.0; + expected_x[5] = 0.0; + + /* Var 7 (matrix [1,1]): [0, 0, 2] */ + expected_x[18] = 0.0; + expected_x[19] = 0.0; + expected_x[20] = 2.0; /* (1,2) */ + + /* Var 12 (matrix [1,2]): [0, 2, 0] */ + expected_x[33] = 0.0; + expected_x[34] = 2.0; /* (2,1) */ + expected_x[35] = 0.0; /* Row 2 block (one zero at col 2), w_prod = 3 * 3 = 9, x = [1,3,0] */ - expected_x[b2 + 2 * 3 + 0] = 9.0; /* 9/1 */ - expected_x[b2 + 0 * 3 + 2] = 9.0; /* symmetric */ - expected_x[b2 + 2 * 3 + 1] = 3.0; /* 9/3 */ - expected_x[b2 + 1 * 3 + 2] = 3.0; /* symmetric */ + /* Var 3 (matrix [2,0]): [0, 0, 9] */ + expected_x[6] = 0.0; + expected_x[7] = 0.0; + expected_x[8] = 9.0; /* (0,2): w_prod/x[0] */ + + /* Var 8 (matrix [2,1]): [0, 0, 3] */ + expected_x[21] = 0.0; + expected_x[22] = 0.0; + expected_x[23] = 3.0; /* (1,2): w_prod/x[1] */ + + /* Var 13 (matrix [2,2]): [9, 3, 0] */ + expected_x[36] = 9.0; /* (2,0): w_prod/x[0] */ + expected_x[37] = 3.0; /* (2,1): w_prod/x[1] */ + expected_x[38] = 0.0; /* Row 3 block (no zeros, scale = 4 * 8 = 32), x = [1,4,2] */ - expected_x[b3 + 0 * 3 + 1] = 8.0; /* 32/(1*4) */ - expected_x[b3 + 0 * 3 + 2] = 16.0; /* 32/(1*2) */ - expected_x[b3 + 1 * 3 + 0] = 8.0; /* symmetric */ - expected_x[b3 + 1 * 3 + 2] = 4.0; /* 32/(4*2) */ - expected_x[b3 + 2 * 3 + 0] = 16.0; /* symmetric */ - expected_x[b3 + 2 * 3 + 1] = 4.0; /* symmetric */ + /* Var 4 (matrix [3,0]): [0, 8, 16] */ + expected_x[9] = 0.0; + expected_x[10] = 8.0; /* 32/(1*4) */ + expected_x[11] = 16.0; /* 32/(1*2) */ + + /* Var 9 (matrix [3,1]): [8, 0, 4] */ + expected_x[24] = 8.0; /* 32/(4*1) */ + expected_x[25] = 0.0; + expected_x[26] = 4.0; /* 32/(4*2) */ + + /* Var 14 (matrix [3,2]): [16, 4, 0] */ + expected_x[39] = 16.0; /* 32/(2*1) */ + expected_x[40] = 4.0; /* 32/(2*4) */ + expected_x[41] = 0.0; /* Row 4 block (no zeros, scale = 5 * 15 = 75), x = [1,5,3] */ - expected_x[b4 + 0 * 3 + 1] = 15.0; /* 75/(1*5) */ - expected_x[b4 + 0 * 3 + 2] = 25.0; /* 75/(1*3) */ - expected_x[b4 + 1 * 3 + 0] = 15.0; /* symmetric */ - expected_x[b4 + 1 * 3 + 2] = 5.0; /* 75/(5*3) */ - expected_x[b4 + 2 * 3 + 0] = 25.0; /* symmetric */ - expected_x[b4 + 2 * 3 + 1] = 5.0; /* symmetric */ - - /* Row pointers (per implementation ordering) */ - int expected_p[17] = {0, 0, 9, 18, 27, 36, 3, 12, 21, - 30, 39, 6, 15, 24, 33, 42, 45}; - - /* Column indices (block diagonal) */ + /* Var 5 (matrix [4,0]): [0, 15, 25] */ + expected_x[12] = 0.0; + expected_x[13] = 15.0; /* 75/(1*5) */ + expected_x[14] = 25.0; /* 75/(1*3) */ + + /* Var 10 (matrix [4,1]): [15, 0, 5] */ + expected_x[27] = 15.0; /* 75/(5*1) */ + expected_x[28] = 0.0; + expected_x[29] = 5.0; /* 75/(5*3) */ + + /* Var 15 (matrix [4,2]): [25, 5, 0] */ + expected_x[42] = 25.0; /* 75/(3*1) */ + expected_x[43] = 5.0; /* 75/(3*5) */ + expected_x[44] = 0.0; + + /* Row pointers (monotonically increasing for valid CSR format) */ + int expected_p[17] = {0, 0, 3, 6, 9, 12, 15, 18, 21, + 24, 27, 30, 33, 36, 39, 42, 45}; + + /* Column indices (each row of the matrix interacts with its own columns) */ int expected_i[45]; - for (int block = 0; block < 5; block++) + for (int var_idx = 0; var_idx < 15; var_idx++) { - int offset = block * 9; - int base = 1 + block; /* var_id + row index */ - for (int r = 0; r < 3; r++) - { - expected_i[offset + r * 3 + 0] = base; - expected_i[offset + r * 3 + 1] = base + 5; - expected_i[offset + r * 3 + 2] = base + 10; - } + int matrix_row = var_idx % 5; /* which row of the 5x3 matrix */ + int nnz_start = var_idx * 3; + /* All columns from matrix_row */ + expected_i[nnz_start + 0] = 1 + matrix_row + 0 * 5; + expected_i[nnz_start + 1] = 1 + matrix_row + 1 * 5; + expected_i[nnz_start + 2] = 1 + matrix_row + 2 * 5; } mu_assert("vals fail", cmp_double_array(p->wsum_hess->x, expected_x, 45)); @@ -226,3 +336,57 @@ const char *test_wsum_hess_prod_axis_one_mixed_zeros() free_expr(p); return 0; } +const char *test_wsum_hess_prod_axis_one_2x2() +{ + /* x is 2x2 variable, global index 0, total 4 vars + * x = [2, 1, 3, 2] (column-major) + * [2, 3] + * [1, 2] + * f = prod_axis_one(x) = [2*3, 1*2] = [6, 2] + * w = [1, 1] + * + * For a 2x2 matrix with var_id=0: + * Var 0 (matrix [0,0]): p[0]=0, stores columns [0,1] + * Var 1 (matrix [1,0]): p[1]=2, stores columns [0,1] + * Var 2 (matrix [0,1]): p[2]=4, stores columns [0,1] + * Var 3 (matrix [1,1]): p[3]=6, stores columns [0,1] + * + * Row 0 Hessian (no zeros, x=[2,3], scale=1*6=6): + * (0,0)=0, (0,1)=1 -> stored at Var 0: [0, 1] + * (1,0)=1, (1,1)=0 -> stored at Var 2: [1, 0] + * + * Row 1 Hessian (no zeros, x=[1,2], scale=1*2=2): + * (0,0)=0, (0,1)=1 -> stored at Var 1: [0, 1] + * (1,0)=1, (1,1)=0 -> stored at Var 3: [1, 0] + */ + double u_vals[4] = {2.0, 1.0, 3.0, 2.0}; + double w_vals[2] = {1.0, 1.0}; + expr *x = new_variable(2, 2, 0, 4); + expr *p = new_prod_axis_one(x); + + p->forward(p, u_vals); + p->wsum_hess_init(p); + p->eval_wsum_hess(p, w_vals); + + /* Expected sparse structure (nnz = 8) */ + double expected_x[8] = {0.0, 1.0, /* Var 0 */ + 0.0, 1.0, /* Var 1 */ + 1.0, 0.0, /* Var 2 */ + 1.0, 0.0}; /* Var 3 */ + + /* Row pointers (each row has 2 nnz) */ + int expected_p[5] = {0, 2, 4, 6, 8}; + + /* Column indices (each variable stores columns for its matrix row) */ + int expected_i[8] = {0, 2, /* Var 0 (row 0) */ + 1, 3, /* Var 1 (row 1) */ + 0, 2, /* Var 2 (row 0) */ + 1, 3}; /* Var 3 (row 1) */ + + mu_assert("vals fail", cmp_double_array(p->wsum_hess->x, expected_x, 8)); + mu_assert("rows fail", cmp_int_array(p->wsum_hess->p, expected_p, 5)); + mu_assert("cols fail", cmp_int_array(p->wsum_hess->i, expected_i, 8)); + + free_expr(p); + return 0; +} From e211fcb23c90549486c364446dc6dd041a143a5e Mon Sep 17 00:00:00 2001 From: Daniel Date: Sun, 18 Jan 2026 09:00:06 -0800 Subject: [PATCH 4/5] exclude structural zero --- include/subexpr.h | 22 +- python/atoms/prod_axis_one.h | 1 - src/other/prod_axis_one.c | 166 +++++------- src/other/prod_axis_zero.c | 21 +- src/problem.c | 6 - tests/wsum_hess/test_prod_axis_one.h | 378 +++++++++++++-------------- 6 files changed, 271 insertions(+), 323 deletions(-) diff --git a/include/subexpr.h b/include/subexpr.h index 01762f5..d8e56a5 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -56,23 +56,15 @@ typedef struct prod_expr double prod_nonzero; /* product of non-zero elements */ } prod_expr; -/* Product of entries along axis=0 (columnwise products) */ -typedef struct prod_axis_zero_expr +/* Product of entries along axis=0 (columnwise products) or axis = 1 (rowwise + * products) */ +typedef struct prod_axis { expr base; - int *num_of_zeros; /* num of zeros for each column */ - int *zero_index; /* stores idx of zero element per column */ - double *prod_nonzero; /* product of non-zero elements per column */ -} prod_axis_zero_expr; - -/* Product of entries along axis=1 (rowwise products) */ -typedef struct prod_axis_one_expr -{ - expr base; - int *num_of_zeros; /* num of zeros for each row */ - int *zero_index; /* stores idx of zero element per row */ - double *prod_nonzero; /* product of non-zero elements per row */ -} prod_axis_one_expr; + int *num_of_zeros; /* num of zeros for each column / row depending on the axis*/ + int *zero_index; /* stores idx of zero element per column / row */ + double *prod_nonzero; /* product of non-zero elements per column / row */ +} prod_axis; /* Horizontal stack (concatenate) */ typedef struct hstack_expr diff --git a/python/atoms/prod_axis_one.h b/python/atoms/prod_axis_one.h index db4d2ea..d9c3b8a 100644 --- a/python/atoms/prod_axis_one.h +++ b/python/atoms/prod_axis_one.h @@ -1,4 +1,3 @@ -// SPDX-License-Identifier: Apache-2.0 #ifndef ATOM_PROD_AXIS_ONE_H #define ATOM_PROD_AXIS_ONE_H diff --git a/src/other/prod_axis_one.c b/src/other/prod_axis_one.c index 05889de..600a20a 100644 --- a/src/other/prod_axis_one.c +++ b/src/other/prod_axis_one.c @@ -13,7 +13,7 @@ static void forward(expr *node, const double *u) expr *x = node->left; x->forward(x, u); - prod_axis_one_expr *pnode = (prod_axis_one_expr *) node; + prod_axis *pnode = (prod_axis *) node; int d1 = x->d1; int d2 = x->d2; @@ -91,7 +91,7 @@ static void jacobian_init(expr *node) static void eval_jacobian(expr *node) { expr *x = node->left; - prod_axis_one_expr *pnode = (prod_axis_one_expr *) node; + prod_axis *pnode = (prod_axis *) node; double *J_vals = node->jacobian->x; @@ -136,30 +136,32 @@ static void wsum_hess_init(expr *node) /* if x is a variable */ if (x->var_id != NOT_A_VARIABLE) { - /* Hessian structure: for each row i of the input matrix, - * the Hessian of the product (w[i] * prod(row i)) has entries - * at the column indices corresponding to the columns in that row. - * Each row i has d2 non-zero columns, and we have d2 x d2 interactions. - * Total nnz = d1 * d2 * d2 - */ - int nnz = x->d1 * x->d2 * x->d2; + /* each row i has d2-1 non-zero entries, with column indices corresponding to + the columns in that row (except the diagonal element). */ + int nnz = x->d1 * x->d2 * (x->d2 - 1); node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, nnz); CSR_Matrix *H = node->wsum_hess; - /* Fill row pointers and column indices */ - int nnz_per_row = x->d2; + /* fill sparsity pattern */ + int nnz_per_row = x->d2 - 1; for (int i = 0; i < x->size; i++) { - int matrix_row = i % x->d1; - H->p[x->var_id + i] = i * nnz_per_row; - - /* For this variable (at matrix position [matrix_row, ...]), - * the Hessian has entries with all columns in the same matrix_row */ + int row = i % x->d1; + int col = i / x->d1; + int start = i * nnz_per_row; + H->p[x->var_id + i] = start; + + /* for this variable (at X[row, col]), the Hessian has entries with all + * columns in the same row, excluding the diagonal (col itself) */ + int offset = 0; + int col_offset = x->var_id + row; for (int j = 0; j < x->d2; j++) { - /* Global column index for column j of this row */ - int global_col = x->var_id + matrix_row + j * x->d1; - H->i[i * nnz_per_row + j] = global_col; + if (j != col) + { + H->i[start + offset] = col_offset + j * x->d1; + offset++; + } } } @@ -173,67 +175,41 @@ static void wsum_hess_init(expr *node) { assert(false && "child must be a variable"); } - - /* print row pointers */ - // for (int i = 0; i <= node->n_vars; i++) - //{ - // printf("H->p[%d] = %d\n", i, node->wsum_hess->p[i]); - // } - // - ///* print column indices */ - // for (int i = 0; i < node->wsum_hess->nnz; i++) - //{ - // printf("H->i[%d] = %d\n", i, node->wsum_hess->i[i]); - // } - - printf("done wsum_hess_init prod_axis_one\n"); } static inline void wsum_hess_row_no_zeros(expr *node, const double *w, int row, int d2) { expr *x = node->left; - prod_axis_one_expr *pnode = (prod_axis_one_expr *) node; CSR_Matrix *H = node->wsum_hess; - double *H_vals = H->x; - double scale = w[row] * node->value[row]; - int idx_i = row; /* Start with the row offset */ - /* For each variable in this row, fill in Hessian entries */ + /* for each variable xk in this row, fill in Hessian entries + // TODO: more cache-friendly solution is possible? */ for (int i = 0; i < d2; i++) { - int var_i = x->var_id + row + i * x->d1; - int row_start = H->p[var_i]; - idx_i = i * x->d1 + row; /* Element at matrix [row, i] */ - - /* Row var_i has nnz entries with columns from the same matrix row */ + int k = x->var_id + row + i * x->d1; + int idx_k = i * x->d1 + row; + int offset = 0; for (int j = 0; j < d2; j++) { - if (i == j) + if (j != i) { - /* Diagonal entries are 0 */ - H_vals[row_start + j] = 0.0; - } - else - { - /* Off-diagonal: scale / (x[i] * x[j]) */ - int idx_j = j * x->d1 + row; /* Element at matrix [row, j] */ - H_vals[row_start + j] = scale / (x->value[idx_i] * x->value[idx_j]); + int idx_j = j * x->d1 + row; + H->x[H->p[k] + offset] = scale / (x->value[idx_k] * x->value[idx_j]); + offset++; } } } - (void) pnode; /* suppress unused warning */ } static inline void wsum_hess_row_one_zero(expr *node, const double *w, int row, int d2) { expr *x = node->left; - prod_axis_one_expr *pnode = (prod_axis_one_expr *) node; + prod_axis *pnode = (prod_axis *) node; CSR_Matrix *H = node->wsum_hess; double *H_vals = H->x; - int p = pnode->zero_index[row]; /* zero column index */ double w_prod = w[row] * pnode->prod_nonzero[row]; @@ -243,25 +219,30 @@ static inline void wsum_hess_row_one_zero(expr *node, const double *w, int row, int var_i = x->var_id + row + i * x->d1; int row_start = H->p[var_i]; - /* Row var_i has nnz entries with columns from the same matrix row */ + /* Row var_i has d2-1 entries (excluding diagonal) */ + int offset = 0; for (int j = 0; j < d2; j++) { - if (i == p && j != p) - { - /* Row p (zero row), column j (nonzero) */ - int idx_j = j * x->d1 + row; - H_vals[row_start + j] = w_prod / x->value[idx_j]; - } - else if (j == p && i != p) - { - /* Row i (nonzero), column p (zero) */ - int idx_i = i * x->d1 + row; - H_vals[row_start + j] = w_prod / x->value[idx_i]; - } - else + if (j != i) { - /* All other entries are zero */ - H_vals[row_start + j] = 0.0; + if (i == p && j != p) + { + /* row p (zero row), column j (nonzero) */ + int idx_j = j * x->d1 + row; + H_vals[row_start + offset] = w_prod / x->value[idx_j]; + } + else if (j == p && i != p) + { + /* row i (nonzero), column p (zero) */ + int idx_i = i * x->d1 + row; + H_vals[row_start + offset] = w_prod / x->value[idx_i]; + } + else + { + /* all other entries are zero */ + H_vals[row_start + offset] = 0.0; + } + offset++; } } } @@ -271,7 +252,7 @@ static inline void wsum_hess_row_two_zeros(expr *node, const double *w, int row, int d2) { expr *x = node->left; - prod_axis_one_expr *pnode = (prod_axis_one_expr *) node; + prod_axis *pnode = (prod_axis *) node; CSR_Matrix *H = node->wsum_hess; double *H_vals = H->x; @@ -279,8 +260,7 @@ static inline void wsum_hess_row_two_zeros(expr *node, const double *w, int row, int p = -1, q = -1; for (int c = 0; c < d2; c++) { - int idx = c * node->left->d1 + row; - if (IS_ZERO(x->value[idx])) + if (IS_ZERO(x->value[c * node->left->d1 + row])) { if (p == -1) { @@ -303,17 +283,22 @@ static inline void wsum_hess_row_two_zeros(expr *node, const double *w, int row, int var_i = x->var_id + row + i * x->d1; int row_start = H->p[var_i]; - /* Row var_i has nnz entries with columns from the same matrix row */ + /* row var_i has d2-1 entries (excluding diagonal) */ + int offset = 0; for (int j = 0; j < d2; j++) { - /* Only (p,q) and (q,p) are nonzero */ - if ((i == p && j == q) || (i == q && j == p)) - { - H_vals[row_start + j] = hess_val; - } - else + if (j != i) { - H_vals[row_start + j] = 0.0; + /* only (p,q) and (q,p) are nonzero */ + if ((i == p && j == q) || (i == q && j == p)) + { + H_vals[row_start + offset] = hess_val; + } + else + { + H_vals[row_start + offset] = 0.0; + } + offset++; } } } @@ -325,24 +310,18 @@ static inline void wsum_hess_row_many_zeros(expr *node, int row, int d2) double *H_vals = H->x; expr *x = node->left; - /* For each variable in this row, zero out all entries */ + /* for each variable in this row, zero out all entries */ for (int i = 0; i < d2; i++) { int var_i = x->var_id + row + i * x->d1; - int row_start = H->p[var_i]; - - /* Each variable has d2 entries */ - for (int j = 0; j < d2; j++) - { - H_vals[row_start + j] = 0.0; - } + memset(H_vals + H->p[var_i], 0, sizeof(double) * (d2 - 1)); } } static void eval_wsum_hess(expr *node, const double *w) { expr *x = node->left; - prod_axis_one_expr *pnode = (prod_axis_one_expr *) node; + prod_axis *pnode = (prod_axis *) node; /* if x is a variable */ if (x->var_id != NOT_A_VARIABLE) @@ -383,7 +362,7 @@ static bool is_affine(const expr *node) static void free_type_data(expr *node) { - prod_axis_one_expr *pnode = (prod_axis_one_expr *) node; + prod_axis *pnode = (prod_axis *) node; free(pnode->num_of_zeros); free(pnode->zero_index); free(pnode->prod_nonzero); @@ -391,8 +370,7 @@ static void free_type_data(expr *node) expr *new_prod_axis_one(expr *child) { - prod_axis_one_expr *pnode = - (prod_axis_one_expr *) calloc(1, sizeof(prod_axis_one_expr)); + prod_axis *pnode = (prod_axis *) calloc(1, sizeof(prod_axis)); expr *node = &pnode->base; /* output is always a row vector 1 x d1 (one product per row) */ diff --git a/src/other/prod_axis_zero.c b/src/other/prod_axis_zero.c index 8d1807b..191f6ff 100644 --- a/src/other/prod_axis_zero.c +++ b/src/other/prod_axis_zero.c @@ -12,7 +12,7 @@ static void forward(expr *node, const double *u) expr *x = node->left; x->forward(x, u); - prod_axis_zero_expr *pnode = (prod_axis_zero_expr *) node; + prod_axis *pnode = (prod_axis *) node; int d1 = x->d1; int d2 = x->d2; @@ -85,7 +85,7 @@ static void jacobian_init(expr *node) static void eval_jacobian(expr *node) { expr *x = node->left; - prod_axis_zero_expr *pnode = (prod_axis_zero_expr *) node; + prod_axis *pnode = (prod_axis *) node; double *J_vals = node->jacobian->x; @@ -175,8 +175,7 @@ static inline void wsum_hess_column_no_zeros(expr *node, const double *w, int co double *H = node->wsum_hess->x; int col_start = col * d1; int block_start = col * d1 * d1; - double w_col = w[col]; - double f_col = node->value[col]; + double scale = w[col] * node->value[col]; for (int i = 0; i < d1; i++) { @@ -191,7 +190,7 @@ static inline void wsum_hess_column_no_zeros(expr *node, const double *w, int co int idx_i = col_start + i; int idx_j = col_start + j; H[block_start + i * d1 + j] = - w_col * f_col / (x->value[idx_i] * x->value[idx_j]); + scale / (x->value[idx_i] * x->value[idx_j]); } } } @@ -201,7 +200,7 @@ static inline void wsum_hess_column_one_zero(expr *node, const double *w, int co int d1) { expr *x = node->left; - prod_axis_zero_expr *pnode = (prod_axis_zero_expr *) node; + prod_axis *pnode = (prod_axis *) node; double *H = node->wsum_hess->x; int col_start = col * d1; int block_start = col * d1 * d1; @@ -229,7 +228,7 @@ static inline void wsum_hess_column_two_zeros(expr *node, const double *w, int c int d1) { expr *x = node->left; - prod_axis_zero_expr *pnode = (prod_axis_zero_expr *) node; + prod_axis *pnode = (prod_axis *) node; double *H = node->wsum_hess->x; int col_start = col * d1; int block_start = col * d1 * d1; @@ -275,7 +274,7 @@ static inline void wsum_hess_column_many_zeros(expr *node, const double *w, int static void eval_wsum_hess(expr *node, const double *w) { expr *x = node->left; - prod_axis_zero_expr *pnode = (prod_axis_zero_expr *) node; + prod_axis *pnode = (prod_axis *) node; int d1 = x->d1; int d2 = x->d2; @@ -319,16 +318,16 @@ static bool is_affine(const expr *node) static void free_type_data(expr *node) { - prod_axis_zero_expr *pnode = (prod_axis_zero_expr *) node; + prod_axis *pnode = (prod_axis *) node; free(pnode->num_of_zeros); free(pnode->zero_index); free(pnode->prod_nonzero); } +/* TODO: refactor to remove diagonal entry as nonzero since it's always zero */ expr *new_prod_axis_zero(expr *child) { - prod_axis_zero_expr *pnode = - (prod_axis_zero_expr *) calloc(1, sizeof(prod_axis_zero_expr)); + prod_axis *pnode = (prod_axis *) calloc(1, sizeof(prod_axis)); expr *node = &pnode->base; /* output is always a row vector 1 x d2 - TODO: is that correct? */ diff --git a/src/problem.c b/src/problem.c index d6e752b..e5b4764 100644 --- a/src/problem.c +++ b/src/problem.c @@ -56,7 +56,6 @@ void problem_init_derivatives(problem *prob) // ------------------------------------------------------------------------------- // Jacobian structure // ------------------------------------------------------------------------------- - printf("Initializing Jacobian structure...\n"); prob->objective->jacobian_init(prob->objective); int nnz = 0; for (int i = 0; i < prob->n_constraints; i++) @@ -67,7 +66,6 @@ void problem_init_derivatives(problem *prob) } prob->jacobian = new_csr_matrix(prob->total_constraint_size, prob->n_vars, nnz); - printf("jacobian structure initialized\n"); /* set sparsity pattern of jacobian */ CSR_Matrix *H = prob->jacobian; @@ -92,7 +90,6 @@ void problem_init_derivatives(problem *prob) // ------------------------------------------------------------------------------- // Lagrange Hessian structure // ------------------------------------------------------------------------------- - printf("Initializing Lagrange Hessian structure...\n"); prob->objective->wsum_hess_init(prob->objective); nnz = prob->objective->wsum_hess->nnz; @@ -106,14 +103,11 @@ void problem_init_derivatives(problem *prob) memset(prob->lagrange_hessian->x, 0, nnz * sizeof(double)); /* affine shortcut */ prob->hess_idx_map = (int *) malloc(nnz * sizeof(int)); int *iwork = (int *) malloc(MAX(nnz, prob->n_vars) * sizeof(int)); - printf("Lagrange Hessian structure initialized\n"); problem_lagrange_hess_fill_sparsity(prob, iwork); - printf("Lagrange Hessian sparsity filled\n"); free(iwork); clock_gettime(CLOCK_MONOTONIC, &timer.end); prob->stats.time_init_derivatives += GET_ELAPSED_SECONDS(timer); - printf("Derivative initialization complete.\n"); } static void problem_lagrange_hess_fill_sparsity(problem *prob, int *iwork) diff --git a/tests/wsum_hess/test_prod_axis_one.h b/tests/wsum_hess/test_prod_axis_one.h index 7acbec1..18524e9 100644 --- a/tests/wsum_hess/test_prod_axis_one.h +++ b/tests/wsum_hess/test_prod_axis_one.h @@ -33,39 +33,40 @@ const char *test_wsum_hess_prod_axis_one_no_zeros() p->wsum_hess_init(p); p->eval_wsum_hess(p, w_vals); - double expected_x[18] = {/* Var 1 (row 0 of matrix): [0, 5, 3] */ - 0.0, 5.0, 3.0, - /* Var 2 (row 1 of matrix): [0, 12, 8] */ - 0.0, 12.0, 8.0, - /* Var 3 (row 0 of matrix): [5, 0, 1] */ - 5.0, 0.0, 1.0, - /* Var 4 (row 1 of matrix): [12, 0, 4] */ - 12.0, 0.0, 4.0, - /* Var 5 (row 0 of matrix): [3, 1, 0] */ - 3.0, 1.0, 0.0, - /* Var 6 (row 1 of matrix): [8, 4, 0] */ - 8.0, 4.0, 0.0}; + double expected_x[12] = {/* Var 1 (row 0, col 0): [5, 3] (excludes col 0) */ + 5.0, 3.0, + /* Var 2 (row 1, col 0): [12, 8] (excludes col 0) */ + 12.0, 8.0, + /* Var 3 (row 0, col 1): [5, 1] (excludes col 1) */ + 5.0, 1.0, + /* Var 4 (row 1, col 1): [12, 4] (excludes col 1) */ + 12.0, 4.0, + /* Var 5 (row 0, col 2): [3, 1] (excludes col 2) */ + 3.0, 1.0, + /* Var 6 (row 1, col 2): [8, 4] (excludes col 2) */ + 8.0, 4.0}; /* Row pointers (monotonically increasing for valid CSR format) */ - int expected_p[9] = {0, 0, 3, 6, 9, 12, 15, 18, 18}; - - /* Column indices (each row of the matrix interacts with its own columns) */ - int expected_i[18] = {/* Var 1 (row 0): cols 1,3,5 */ - 1, 3, 5, - /* Var 2 (row 1): cols 2,4,6 */ - 2, 4, 6, - /* Var 3 (row 0): cols 1,3,5 */ - 1, 3, 5, - /* Var 4 (row 1): cols 2,4,6 */ - 2, 4, 6, - /* Var 5 (row 0): cols 1,3,5 */ - 1, 3, 5, - /* Var 6 (row 1): cols 2,4,6 */ - 2, 4, 6}; - - mu_assert("vals fail", cmp_double_array(p->wsum_hess->x, expected_x, 18)); + int expected_p[9] = {0, 0, 2, 4, 6, 8, 10, 12, 12}; + + /* Column indices (each row of the matrix interacts with its own columns, + * excluding diagonal) */ + int expected_i[12] = {/* Var 1 (row 0, col 0): cols 3,5 (excludes 1) */ + 3, 5, + /* Var 2 (row 1, col 0): cols 4,6 (excludes 2) */ + 4, 6, + /* Var 3 (row 0, col 1): cols 1,5 (excludes 3) */ + 1, 5, + /* Var 4 (row 1, col 1): cols 2,6 (excludes 4) */ + 2, 6, + /* Var 5 (row 0, col 2): cols 1,3 (excludes 5) */ + 1, 3, + /* Var 6 (row 1, col 2): cols 2,4 (excludes 6) */ + 2, 4}; + + mu_assert("vals fail", cmp_double_array(p->wsum_hess->x, expected_x, 12)); mu_assert("rows fail", cmp_int_array(p->wsum_hess->p, expected_p, 9)); - mu_assert("cols fail", cmp_int_array(p->wsum_hess->i, expected_i, 18)); + mu_assert("cols fail", cmp_int_array(p->wsum_hess->i, expected_i, 12)); free_expr(p); return 0; @@ -99,80 +100,72 @@ const char *test_wsum_hess_prod_axis_one_one_zero() p->wsum_hess_init(p); p->eval_wsum_hess(p, w_vals); - double expected_x[27]; + double expected_x[18]; memset(expected_x, 0, sizeof(expected_x)); - /* Var 1 (row 0): [0, 7, 4] */ - expected_x[0] = 0.0; - expected_x[1] = 7.0; - expected_x[2] = 4.0; + /* Var 1 (row 0, col 0): [7, 4] (excludes col 0) */ + expected_x[0] = 7.0; + expected_x[1] = 4.0; - /* Var 2 (row 1, one zero at col 1): [0, 16, 0] */ + /* Var 2 (row 1, col 0): [16, 0] (excludes col 0, one zero at col 1) */ + expected_x[2] = 16.0; expected_x[3] = 0.0; - expected_x[4] = 16.0; - expected_x[5] = 0.0; - - /* Var 3 (row 2): [0, 27, 18] */ - expected_x[6] = 0.0; - expected_x[7] = 27.0; - expected_x[8] = 18.0; - - /* Var 4 (row 0): [7, 0, 1] */ - expected_x[9] = 7.0; - expected_x[10] = 0.0; - expected_x[11] = 1.0; - - /* Var 5 (row 1, one zero at col 1): [16, 0, 4] */ - expected_x[12] = 16.0; - expected_x[13] = 0.0; - expected_x[14] = 4.0; - - /* Var 6 (row 2): [27, 0, 9] */ - expected_x[15] = 27.0; - expected_x[16] = 0.0; - expected_x[17] = 9.0; - /* Var 7 (row 0): [4, 1, 0] */ - expected_x[18] = 4.0; - expected_x[19] = 1.0; - expected_x[20] = 0.0; + /* Var 3 (row 2, col 0): [27, 18] (excludes col 0) */ + expected_x[4] = 27.0; + expected_x[5] = 18.0; + + /* Var 4 (row 0, col 1): [7, 1] (excludes col 1) */ + expected_x[6] = 7.0; + expected_x[7] = 1.0; - /* Var 8 (row 1, one zero at col 1): [0, 4, 0] */ - expected_x[21] = 0.0; - expected_x[22] = 4.0; - expected_x[23] = 0.0; + /* Var 5 (row 1, col 1): [16, 4] (excludes col 1, one zero at col 1) */ + expected_x[8] = 16.0; + expected_x[9] = 4.0; - /* Var 9 (row 2): [18, 9, 0] */ - expected_x[24] = 18.0; - expected_x[25] = 9.0; - expected_x[26] = 0.0; + /* Var 6 (row 2, col 1): [27, 9] (excludes col 1) */ + expected_x[10] = 27.0; + expected_x[11] = 9.0; + + /* Var 7 (row 0, col 2): [4, 1] (excludes col 2) */ + expected_x[12] = 4.0; + expected_x[13] = 1.0; + + /* Var 8 (row 1, col 2): [0, 4] (excludes col 2, one zero at col 1) */ + expected_x[14] = 0.0; + expected_x[15] = 4.0; + + /* Var 9 (row 2, col 2): [18, 9] (excludes col 2) */ + expected_x[16] = 18.0; + expected_x[17] = 9.0; /* Row pointers (monotonically increasing for valid CSR format) */ - int expected_p[11] = {0, 0, 3, 6, 9, 12, 15, 18, 21, 24, 27}; - - /* Column indices (each row of the matrix interacts with its own columns) */ - int expected_i[27] = {/* Var 1 (row 0): cols 1,4,7 */ - 1, 4, 7, - /* Var 2 (row 1): cols 2,5,8 */ - 2, 5, 8, - /* Var 3 (row 2): cols 3,6,9 */ - 3, 6, 9, - /* Var 4 (row 0): cols 1,4,7 */ - 1, 4, 7, - /* Var 5 (row 1): cols 2,5,8 */ - 2, 5, 8, - /* Var 6 (row 2): cols 3,6,9 */ - 3, 6, 9, - /* Var 7 (row 0): cols 1,4,7 */ - 1, 4, 7, - /* Var 8 (row 1): cols 2,5,8 */ - 2, 5, 8, - /* Var 9 (row 2): cols 3,6,9 */ - 3, 6, 9}; - - mu_assert("vals fail", cmp_double_array(p->wsum_hess->x, expected_x, 27)); + int expected_p[11] = {0, 0, 2, 4, 6, 8, 10, 12, 14, 16, 18}; + + /* Column indices (each row of the matrix interacts with its own columns, + * excluding diagonal) */ + int expected_i[18] = {/* Var 1 (row 0, col 0): cols 4,7 (excludes 1) */ + 4, 7, + /* Var 2 (row 1, col 0): cols 5,8 (excludes 2) */ + 5, 8, + /* Var 3 (row 2, col 0): cols 6,9 (excludes 3) */ + 6, 9, + /* Var 4 (row 0, col 1): cols 1,7 (excludes 4) */ + 1, 7, + /* Var 5 (row 1, col 1): cols 2,8 (excludes 5) */ + 2, 8, + /* Var 6 (row 2, col 1): cols 3,9 (excludes 6) */ + 3, 9, + /* Var 7 (row 0, col 2): cols 1,4 (excludes 7) */ + 1, 4, + /* Var 8 (row 1, col 2): cols 2,5 (excludes 8) */ + 2, 5, + /* Var 9 (row 2, col 2): cols 3,6 (excludes 9) */ + 3, 6}; + + mu_assert("vals fail", cmp_double_array(p->wsum_hess->x, expected_x, 18)); mu_assert("rows fail", cmp_int_array(p->wsum_hess->p, expected_p, 11)); - mu_assert("cols fail", cmp_int_array(p->wsum_hess->i, expected_i, 27)); + mu_assert("cols fail", cmp_int_array(p->wsum_hess->i, expected_i, 18)); free_expr(p); return 0; @@ -210,128 +203,120 @@ const char *test_wsum_hess_prod_axis_one_mixed_zeros() p->wsum_hess_init(p); p->eval_wsum_hess(p, w_vals); - double expected_x[45]; + double expected_x[30]; memset(expected_x, 0, sizeof(expected_x)); - /* For a 5x3 matrix with var_id=1: - * CSR row pointers: p[i] = (i-1)*3 for i in [1,15] - * Variables are indexed sequentially: + /* For a 5x3 matrix with var_id=1, each row has 2 nnz (d2-1): + * CSR row pointers: p[i] = (i-1)*2 for i in [1,15] * Var 1 (matrix [0,0]): p[1]=0 - * Var 2 (matrix [1,0]): p[2]=3 - * Var 3 (matrix [2,0]): p[3]=6 - * Var 4 (matrix [3,0]): p[4]=9 - * Var 5 (matrix [4,0]): p[5]=12 - * Var 6 (matrix [0,1]): p[6]=15 - * Var 7 (matrix [1,1]): p[7]=18 - * Var 8 (matrix [2,1]): p[8]=21 - * Var 9 (matrix [3,1]): p[9]=24 - * Var 10 (matrix [4,1]): p[10]=27 - * Var 11 (matrix [0,2]): p[11]=30 - * Var 12 (matrix [1,2]): p[12]=33 - * Var 13 (matrix [2,2]): p[13]=36 - * Var 14 (matrix [3,2]): p[14]=39 - * Var 15 (matrix [4,2]): p[15]=42 + * Var 2 (matrix [1,0]): p[2]=2 + * Var 3 (matrix [2,0]): p[3]=4 + * Var 4 (matrix [3,0]): p[4]=6 + * Var 5 (matrix [4,0]): p[5]=8 + * Var 6 (matrix [0,1]): p[6]=10 + * Var 7 (matrix [1,1]): p[7]=12 + * Var 8 (matrix [2,1]): p[8]=14 + * Var 9 (matrix [3,1]): p[9]=16 + * Var 10 (matrix [4,1]): p[10]=18 + * Var 11 (matrix [0,2]): p[11]=20 + * Var 12 (matrix [1,2]): p[12]=22 + * Var 13 (matrix [2,2]): p[13]=24 + * Var 14 (matrix [3,2]): p[14]=26 + * Var 15 (matrix [4,2]): p[15]=28 */ /* Row 0 block (no zeros, scale = 1 * 2 = 2), x = [1,2,1] */ - /* Var 1 (matrix [0,0]): [0, 1, 2] */ - expected_x[0] = 0.0; - expected_x[1] = 1.0; /* 2/(1*2) */ - expected_x[2] = 2.0; /* 2/(1*1) */ + /* Var 1 (matrix [0,0], col 0): [1, 2] (excludes col 0) */ + expected_x[0] = 1.0; /* 2/(1*2) */ + expected_x[1] = 2.0; /* 2/(1*1) */ - /* Var 6 (matrix [0,1]): [1, 0, 1] */ - expected_x[15] = 1.0; /* 2/(2*1) */ - expected_x[16] = 0.0; - expected_x[17] = 1.0; /* 2/(2*1) */ + /* Var 6 (matrix [0,1], col 1): [1, 1] (excludes col 1) */ + expected_x[10] = 1.0; /* 2/(2*1) */ + expected_x[11] = 1.0; /* 2/(2*1) */ - /* Var 11 (matrix [0,2]): [2, 1, 0] */ - expected_x[30] = 2.0; /* 2/(1*1) */ - expected_x[31] = 1.0; /* 2/(1*2) */ - expected_x[32] = 0.0; + /* Var 11 (matrix [0,2], col 2): [2, 1] (excludes col 2) */ + expected_x[20] = 2.0; /* 2/(1*1) */ + expected_x[21] = 1.0; /* 2/(1*2) */ /* Row 1 block (two zeros at cols 1,2), hess = w*prod_nonzero = 2*1 = 2 */ - /* Var 2 (matrix [1,0]): [0, 0, 0] */ + /* Var 2 (matrix [1,0], col 0): [0, 0] (excludes col 0) */ + expected_x[2] = 0.0; expected_x[3] = 0.0; - expected_x[4] = 0.0; - expected_x[5] = 0.0; - /* Var 7 (matrix [1,1]): [0, 0, 2] */ - expected_x[18] = 0.0; - expected_x[19] = 0.0; - expected_x[20] = 2.0; /* (1,2) */ + /* Var 7 (matrix [1,1], col 1): [0, 2] (excludes col 1) */ + expected_x[12] = 0.0; + expected_x[13] = 2.0; /* (1,2) */ - /* Var 12 (matrix [1,2]): [0, 2, 0] */ - expected_x[33] = 0.0; - expected_x[34] = 2.0; /* (2,1) */ - expected_x[35] = 0.0; + /* Var 12 (matrix [1,2], col 2): [0, 2] (excludes col 2) */ + expected_x[22] = 0.0; + expected_x[23] = 2.0; /* (2,1) */ /* Row 2 block (one zero at col 2), w_prod = 3 * 3 = 9, x = [1,3,0] */ - /* Var 3 (matrix [2,0]): [0, 0, 9] */ - expected_x[6] = 0.0; - expected_x[7] = 0.0; - expected_x[8] = 9.0; /* (0,2): w_prod/x[0] */ + /* Var 3 (matrix [2,0], col 0): [0, 9] (excludes col 0) */ + expected_x[4] = 0.0; + expected_x[5] = 9.0; /* (0,2): w_prod/x[0] */ - /* Var 8 (matrix [2,1]): [0, 0, 3] */ - expected_x[21] = 0.0; - expected_x[22] = 0.0; - expected_x[23] = 3.0; /* (1,2): w_prod/x[1] */ + /* Var 8 (matrix [2,1], col 1): [0, 3] (excludes col 1) */ + expected_x[14] = 0.0; + expected_x[15] = 3.0; /* (1,2): w_prod/x[1] */ - /* Var 13 (matrix [2,2]): [9, 3, 0] */ - expected_x[36] = 9.0; /* (2,0): w_prod/x[0] */ - expected_x[37] = 3.0; /* (2,1): w_prod/x[1] */ - expected_x[38] = 0.0; + /* Var 13 (matrix [2,2], col 2): [9, 3] (excludes col 2) */ + expected_x[24] = 9.0; /* (2,0): w_prod/x[0] */ + expected_x[25] = 3.0; /* (2,1): w_prod/x[1] */ /* Row 3 block (no zeros, scale = 4 * 8 = 32), x = [1,4,2] */ - /* Var 4 (matrix [3,0]): [0, 8, 16] */ - expected_x[9] = 0.0; - expected_x[10] = 8.0; /* 32/(1*4) */ - expected_x[11] = 16.0; /* 32/(1*2) */ + /* Var 4 (matrix [3,0], col 0): [8, 16] (excludes col 0) */ + expected_x[6] = 8.0; /* 32/(1*4) */ + expected_x[7] = 16.0; /* 32/(1*2) */ - /* Var 9 (matrix [3,1]): [8, 0, 4] */ - expected_x[24] = 8.0; /* 32/(4*1) */ - expected_x[25] = 0.0; - expected_x[26] = 4.0; /* 32/(4*2) */ + /* Var 9 (matrix [3,1], col 1): [8, 4] (excludes col 1) */ + expected_x[16] = 8.0; /* 32/(4*1) */ + expected_x[17] = 4.0; /* 32/(4*2) */ - /* Var 14 (matrix [3,2]): [16, 4, 0] */ - expected_x[39] = 16.0; /* 32/(2*1) */ - expected_x[40] = 4.0; /* 32/(2*4) */ - expected_x[41] = 0.0; + /* Var 14 (matrix [3,2], col 2): [16, 4] (excludes col 2) */ + expected_x[26] = 16.0; /* 32/(2*1) */ + expected_x[27] = 4.0; /* 32/(2*4) */ /* Row 4 block (no zeros, scale = 5 * 15 = 75), x = [1,5,3] */ - /* Var 5 (matrix [4,0]): [0, 15, 25] */ - expected_x[12] = 0.0; - expected_x[13] = 15.0; /* 75/(1*5) */ - expected_x[14] = 25.0; /* 75/(1*3) */ + /* Var 5 (matrix [4,0], col 0): [15, 25] (excludes col 0) */ + expected_x[8] = 15.0; /* 75/(1*5) */ + expected_x[9] = 25.0; /* 75/(1*3) */ - /* Var 10 (matrix [4,1]): [15, 0, 5] */ - expected_x[27] = 15.0; /* 75/(5*1) */ - expected_x[28] = 0.0; - expected_x[29] = 5.0; /* 75/(5*3) */ + /* Var 10 (matrix [4,1], col 1): [15, 5] (excludes col 1) */ + expected_x[18] = 15.0; /* 75/(5*1) */ + expected_x[19] = 5.0; /* 75/(5*3) */ - /* Var 15 (matrix [4,2]): [25, 5, 0] */ - expected_x[42] = 25.0; /* 75/(3*1) */ - expected_x[43] = 5.0; /* 75/(3*5) */ - expected_x[44] = 0.0; + /* Var 15 (matrix [4,2], col 2): [25, 5] (excludes col 2) */ + expected_x[28] = 25.0; /* 75/(3*1) */ + expected_x[29] = 5.0; /* 75/(3*5) */ /* Row pointers (monotonically increasing for valid CSR format) */ - int expected_p[17] = {0, 0, 3, 6, 9, 12, 15, 18, 21, - 24, 27, 30, 33, 36, 39, 42, 45}; + int expected_p[17] = {0, 0, 2, 4, 6, 8, 10, 12, 14, + 16, 18, 20, 22, 24, 26, 28, 30}; - /* Column indices (each row of the matrix interacts with its own columns) */ - int expected_i[45]; + /* Column indices (each row of the matrix interacts with its own columns, + * excluding diagonal) */ + int expected_i[30]; for (int var_idx = 0; var_idx < 15; var_idx++) { int matrix_row = var_idx % 5; /* which row of the 5x3 matrix */ - int nnz_start = var_idx * 3; - /* All columns from matrix_row */ - expected_i[nnz_start + 0] = 1 + matrix_row + 0 * 5; - expected_i[nnz_start + 1] = 1 + matrix_row + 1 * 5; - expected_i[nnz_start + 2] = 1 + matrix_row + 2 * 5; + int col_i = var_idx / 5; /* which column this variable is in */ + int nnz_start = var_idx * 2; + /* All columns from matrix_row, excluding col_i */ + int offset = 0; + for (int j = 0; j < 3; j++) + { + if (j != col_i) + { + expected_i[nnz_start + offset] = 1 + matrix_row + j * 5; + offset++; + } + } } - mu_assert("vals fail", cmp_double_array(p->wsum_hess->x, expected_x, 45)); + mu_assert("vals fail", cmp_double_array(p->wsum_hess->x, expected_x, 30)); mu_assert("rows fail", cmp_int_array(p->wsum_hess->p, expected_p, 17)); - mu_assert("cols fail", cmp_int_array(p->wsum_hess->i, expected_i, 45)); + mu_assert("cols fail", cmp_int_array(p->wsum_hess->i, expected_i, 30)); free_expr(p); return 0; @@ -368,24 +353,25 @@ const char *test_wsum_hess_prod_axis_one_2x2() p->wsum_hess_init(p); p->eval_wsum_hess(p, w_vals); - /* Expected sparse structure (nnz = 8) */ - double expected_x[8] = {0.0, 1.0, /* Var 0 */ - 0.0, 1.0, /* Var 1 */ - 1.0, 0.0, /* Var 2 */ - 1.0, 0.0}; /* Var 3 */ + /* Expected sparse structure (nnz = 4, each row has 1 nnz) */ + double expected_x[4] = {1.0, /* Var 0 (excludes col 0) */ + 1.0, /* Var 1 (excludes col 0) */ + 1.0, /* Var 2 (excludes col 1) */ + 1.0}; /* Var 3 (excludes col 1) */ - /* Row pointers (each row has 2 nnz) */ - int expected_p[5] = {0, 2, 4, 6, 8}; + /* Row pointers (each row has 1 nnz) */ + int expected_p[5] = {0, 1, 2, 3, 4}; - /* Column indices (each variable stores columns for its matrix row) */ - int expected_i[8] = {0, 2, /* Var 0 (row 0) */ - 1, 3, /* Var 1 (row 1) */ - 0, 2, /* Var 2 (row 0) */ - 1, 3}; /* Var 3 (row 1) */ + /* Column indices (each variable stores columns for its matrix row, excluding + * diagonal) */ + int expected_i[4] = {2, /* Var 0 (row 0, col 0): only col 1 */ + 3, /* Var 1 (row 1, col 0): only col 1 */ + 0, /* Var 2 (row 0, col 1): only col 0 */ + 1}; /* Var 3 (row 1, col 1): only col 0 */ - mu_assert("vals fail", cmp_double_array(p->wsum_hess->x, expected_x, 8)); + mu_assert("vals fail", cmp_double_array(p->wsum_hess->x, expected_x, 4)); mu_assert("rows fail", cmp_int_array(p->wsum_hess->p, expected_p, 5)); - mu_assert("cols fail", cmp_int_array(p->wsum_hess->i, expected_i, 8)); + mu_assert("cols fail", cmp_int_array(p->wsum_hess->i, expected_i, 4)); free_expr(p); return 0; From 6e5aed184fe9fbe6fcdbf2a307c0607bc79f166b Mon Sep 17 00:00:00 2001 From: Daniel Date: Sun, 18 Jan 2026 10:07:13 -0800 Subject: [PATCH 5/5] error handling --- src/bivariate/quad_over_lin.c | 10 ++++++++-- src/other/prod.c | 19 +++++++++++++++---- src/other/prod_axis_one.c | 16 ++++++++++++---- src/other/prod_axis_zero.c | 19 +++++++++++++++---- 4 files changed, 50 insertions(+), 14 deletions(-) diff --git a/src/bivariate/quad_over_lin.c b/src/bivariate/quad_over_lin.c index 35639ff..d3d5920 100644 --- a/src/bivariate/quad_over_lin.c +++ b/src/bivariate/quad_over_lin.c @@ -3,7 +3,9 @@ #include "utils/CSC_Matrix.h" #include #include +#include #include +#include // ------------------------------------------------------------------------------ // Implementation of quad-over-lin. The second argument will always be a variable @@ -231,7 +233,9 @@ static void wsum_hess_init(expr *node) else { /* TODO: implement */ - assert(false && "not implemented"); + fprintf(stderr, "Error in quad_over_lin wsum_hess_init: non-variable child " + "not implemented\n"); + exit(1); } } @@ -288,7 +292,9 @@ static void eval_wsum_hess(expr *node, const double *w) else { /* TODO: implement */ - assert(false && "not implemented"); + fprintf(stderr, "Error in quad_over_lin eval_wsum_hess: non-variable child " + "not implemented\n"); + exit(1); } } diff --git a/src/other/prod.c b/src/other/prod.c index 2d62672..0dfbfbd 100644 --- a/src/other/prod.c +++ b/src/other/prod.c @@ -1,6 +1,7 @@ #include "other.h" #include #include +#include #include #include @@ -62,7 +63,9 @@ static void jacobian_init(expr *node) } else { - assert(false && "not implemented"); + fprintf(stderr, + "Error in prod jacobian_init: non-variable child not implemented\n"); + exit(1); } } @@ -94,7 +97,9 @@ static void eval_jacobian(expr *node) } else { - assert(false && "not implemented"); + fprintf(stderr, + "Error in prod eval_jacobian: non-variable child not implemented\n"); + exit(1); } } @@ -133,7 +138,10 @@ static void wsum_hess_init(expr *node) } else { - assert(false && "not implemented"); + fprintf( + stderr, + "Error in prod wsum_hess_init: non-variable child not implemented\n"); + exit(1); } } @@ -164,7 +172,10 @@ static void eval_wsum_hess(expr *node, const double *w) } else { - assert(false && "not implemented"); + fprintf( + stderr, + "Error in prod eval_wsum_hess: non-variable child not implemented\n"); + exit(1); } } diff --git a/src/other/prod_axis_one.c b/src/other/prod_axis_one.c index 600a20a..cd43229 100644 --- a/src/other/prod_axis_one.c +++ b/src/other/prod_axis_one.c @@ -84,7 +84,9 @@ static void jacobian_init(expr *node) } else { - assert(false && "child must be a variable"); + fprintf(stderr, + "Error in prod_axis_one jacobian_init: child must be a variable\n"); + exit(1); } } @@ -125,7 +127,9 @@ static void eval_jacobian(expr *node) } else { - assert(false && "child must be a variable"); + fprintf(stderr, + "Error in prod_axis_one eval_jacobian: child must be a variable\n"); + exit(1); } } @@ -173,7 +177,9 @@ static void wsum_hess_init(expr *node) } else { - assert(false && "child must be a variable"); + fprintf(stderr, + "Error in prod_axis_one wsum_hess_init: child must be a variable\n"); + exit(1); } } @@ -350,7 +356,9 @@ static void eval_wsum_hess(expr *node, const double *w) } else { - assert(false && "child must be a variable"); + fprintf(stderr, + "Error in prod_axis_one eval_wsum_hess: child must be a variable\n"); + exit(1); } } diff --git a/src/other/prod_axis_zero.c b/src/other/prod_axis_zero.c index 191f6ff..ad85b6e 100644 --- a/src/other/prod_axis_zero.c +++ b/src/other/prod_axis_zero.c @@ -1,6 +1,7 @@ #include "other.h" #include #include +#include #include #include @@ -78,7 +79,9 @@ static void jacobian_init(expr *node) } else { - assert(false && "child must be a variable"); + fprintf(stderr, + "Error in prod_axis_zero jacobian_init: child must be a variable\n"); + exit(1); } } @@ -118,7 +121,9 @@ static void eval_jacobian(expr *node) } else { - assert(false && "child must be a variable"); + fprintf(stderr, + "Error in prod_axis_zero eval_jacobian: child must be a variable\n"); + exit(1); } } @@ -164,7 +169,10 @@ static void wsum_hess_init(expr *node) } else { - assert(false && "child must be a variable"); + fprintf( + stderr, + "Error in prod_axis_zero wsum_hess_init: child must be a variable\n"); + exit(1); } } @@ -306,7 +314,10 @@ static void eval_wsum_hess(expr *node, const double *w) } else { - assert(false && "child must be a variable"); + fprintf( + stderr, + "Error in prod_axis_zero eval_wsum_hess: child must be a variable\n"); + exit(1); } }