From b255da25263d1e6d32d075c9d9da998f86af588d Mon Sep 17 00:00:00 2001 From: Daniel Date: Sat, 17 Jan 2026 09:32:29 -0800 Subject: [PATCH] some progress --- TODO.md | 2 +- include/other.h | 3 + include/subexpr.h | 9 + python/atoms/prod_axis_zero.h | 33 ++ python/bindings.c | 3 + src/other/prod_axis_zero.c | 348 +++++++++++++++++++++ tests/all_tests.c | 8 + tests/forward_pass/test_prod_axis_zero.h | 33 ++ tests/jacobian_tests/test_prod_axis_zero.h | 45 +++ tests/wsum_hess/test_prod_axis_zero.h | 251 +++++++++++++++ 10 files changed, 734 insertions(+), 1 deletion(-) create mode 100644 python/atoms/prod_axis_zero.h create mode 100644 src/other/prod_axis_zero.c create mode 100644 tests/forward_pass/test_prod_axis_zero.h create mode 100644 tests/jacobian_tests/test_prod_axis_zero.h create mode 100644 tests/wsum_hess/test_prod_axis_zero.h diff --git a/TODO.md b/TODO.md index 8854918..f693841 100644 --- a/TODO.md +++ b/TODO.md @@ -1,6 +1,6 @@ 4. in the refactor, add consts 10. For performance reasons, is it useful to have a dense matmul with A and B as dense matrices? -11. right matmul, add broadcasting logic as in left matmul +11. right matmul, add broadcasting logic as in left matmul. Is this necessary? Going through all atoms to see that sparsity pattern is computed in initialization of jacobian: 2. trace - not ok diff --git a/include/other.h b/include/other.h index d44871a..8bf1d53 100644 --- a/include/other.h +++ b/include/other.h @@ -10,4 +10,7 @@ expr *new_quad_form(expr *child, CSR_Matrix *Q); /* product of all entries, without axis argument */ expr *new_prod(expr *child); +/* product of entries along axis=0 (columnwise products) */ +expr *new_prod_axis_zero(expr *child); + #endif /* OTHER_H */ diff --git a/include/subexpr.h b/include/subexpr.h index fd0a5b2..e96c40f 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -56,6 +56,15 @@ typedef struct prod_expr double prod_nonzero; /* product of non-zero elements */ } prod_expr; +/* Product of entries along axis=0 (columnwise products) */ +typedef struct prod_axis_zero_expr +{ + expr base; + int *num_of_zeros; /* num of zeros for each column */ + int *zero_index; /* stores idx of zero element per column */ + double *prod_nonzero; /* product of non-zero elements per column */ +} prod_axis_zero_expr; + /* Horizontal stack (concatenate) */ typedef struct hstack_expr { diff --git a/python/atoms/prod_axis_zero.h b/python/atoms/prod_axis_zero.h new file mode 100644 index 0000000..02b30f0 --- /dev/null +++ b/python/atoms/prod_axis_zero.h @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: Apache-2.0 +#ifndef ATOM_PROD_AXIS_ZERO_H +#define ATOM_PROD_AXIS_ZERO_H + +#include "common.h" +#include "other.h" + +static PyObject *py_make_prod_axis_zero(PyObject *self, PyObject *args) +{ + (void) self; + PyObject *child_capsule; + if (!PyArg_ParseTuple(args, "O", &child_capsule)) + { + return NULL; + } + expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); + if (!child) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + expr *node = new_prod_axis_zero(child); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create prod_axis_zero node"); + return NULL; + } + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_PROD_AXIS_ZERO_H */ diff --git a/python/bindings.c b/python/bindings.c index 55fbdc2..d83b271 100644 --- a/python/bindings.c +++ b/python/bindings.c @@ -23,6 +23,7 @@ #include "atoms/neg.h" #include "atoms/power.h" #include "atoms/prod.h" +#include "atoms/prod_axis_zero.h" #include "atoms/promote.h" #include "atoms/quad_form.h" #include "atoms/quad_over_lin.h" @@ -77,6 +78,8 @@ static PyMethodDef DNLPMethods[] = { "Create constant vector multiplication node (a ∘ f(x))"}, {"make_power", py_make_power, METH_VARARGS, "Create power node"}, {"make_prod", py_make_prod, METH_VARARGS, "Create prod node"}, + {"make_prod_axis_zero", py_make_prod_axis_zero, METH_VARARGS, + "Create prod_axis_zero node"}, {"make_sin", py_make_sin, METH_VARARGS, "Create sin node"}, {"make_cos", py_make_cos, METH_VARARGS, "Create cos node"}, {"make_tan", py_make_tan, METH_VARARGS, "Create tan node"}, diff --git a/src/other/prod_axis_zero.c b/src/other/prod_axis_zero.c new file mode 100644 index 0000000..8d1807b --- /dev/null +++ b/src/other/prod_axis_zero.c @@ -0,0 +1,348 @@ +#include "other.h" +#include +#include +#include +#include + +#define IS_ZERO(x) (fabs((x)) < 1e-8) + +static void forward(expr *node, const double *u) +{ + /* forward pass of child */ + expr *x = node->left; + x->forward(x, u); + + prod_axis_zero_expr *pnode = (prod_axis_zero_expr *) node; + int d1 = x->d1; + int d2 = x->d2; + + /* iterate over columns and compute columnwise products */ + for (int col = 0; col < d2; col++) + { + double prod_nonzero = 1.0; + int zeros = 0; + int zero_row = -1; + int start = col * d1; + int end = start + d1; + + /* iterate over rows of this column */ + for (int idx = start; idx < end; idx++) + { + if (IS_ZERO(x->value[idx])) + { + zeros++; + zero_row = idx - start; + } + else + { + prod_nonzero *= x->value[idx]; + } + } + + /* store results for this column */ + pnode->num_of_zeros[col] = zeros; + pnode->zero_index[col] = zero_row; + pnode->prod_nonzero[col] = prod_nonzero; + node->value[col] = (zeros > 0) ? 0.0 : prod_nonzero; + } +} + +static void jacobian_init(expr *node) +{ + expr *x = node->left; + + /* initialize child's jacobian */ + x->jacobian_init(x); + + /* if x is a variable */ + if (x->var_id != NOT_A_VARIABLE) + { + node->jacobian = new_csr_matrix(node->size, node->n_vars, x->size); + + /* set row pointers (each row has d1 nnzs) */ + for (int row = 0; row < x->d2; row++) + { + node->jacobian->p[row] = row * x->d1; + } + node->jacobian->p[x->d2] = x->size; + + /* set column indices */ + for (int col = 0; col < x->d2; col++) + { + int start = col * x->d1; + for (int i = 0; i < x->d1; i++) + { + node->jacobian->i[start + i] = x->var_id + start + i; + } + } + } + else + { + assert(false && "child must be a variable"); + } +} + +static void eval_jacobian(expr *node) +{ + expr *x = node->left; + prod_axis_zero_expr *pnode = (prod_axis_zero_expr *) node; + + double *J_vals = node->jacobian->x; + + /* if x is a variable */ + if (x->var_id != NOT_A_VARIABLE) + { + /* process each column */ + for (int col = 0; col < x->d2; col++) + { + int num_zeros = pnode->num_of_zeros[col]; + int start = col * x->d1; + + if (num_zeros == 0) + { + for (int i = 0; i < x->d1; i++) + { + J_vals[start + i] = node->value[col] / x->value[start + i]; + } + } + else if (num_zeros == 1) + { + memset(J_vals + start, 0, sizeof(double) * x->d1); + J_vals[start + pnode->zero_index[col]] = pnode->prod_nonzero[col]; + } + else + { + memset(J_vals + start, 0, sizeof(double) * x->d1); + } + } + } + else + { + assert(false && "child must be a variable"); + } +} + +static void wsum_hess_init(expr *node) +{ + expr *x = node->left; + + /* if x is a variable */ + if (x->var_id != NOT_A_VARIABLE) + { + /* Hessian has block diagonal structure: d2 blocks of size d1 x d1 */ + int nnz = x->d2 * x->d1 * x->d1; + node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, nnz); + CSR_Matrix *H = node->wsum_hess; + + /* fill row pointers for the variable's rows (block diagonal) */ + for (int i = 0; i < x->size; i++) + { + int block_idx = i / x->d1; + int row_in_block = i % x->d1; + H->p[x->var_id + i] = block_idx * x->d1 * x->d1 + row_in_block * x->d1; + } + + /* fill row pointers for rows after the variable */ + for (int i = x->var_id + x->size; i <= node->n_vars; i++) + { + H->p[i] = nnz; + } + + /* fill column indices for the d2 diagonal blocks */ + for (int block = 0; block < x->d2; block++) + { + int block_start_col = x->var_id + block * x->d1; + for (int i = 0; i < x->d1; i++) + { + for (int j = 0; j < x->d1; j++) + { + int idx = block * x->d1 * x->d1 + i * x->d1 + j; + H->i[idx] = block_start_col + j; + } + } + } + } + else + { + assert(false && "child must be a variable"); + } +} + +static inline void wsum_hess_column_no_zeros(expr *node, const double *w, int col, + int d1) +{ + expr *x = node->left; + double *H = node->wsum_hess->x; + int col_start = col * d1; + int block_start = col * d1 * d1; + double w_col = w[col]; + double f_col = node->value[col]; + + for (int i = 0; i < d1; i++) + { + for (int j = 0; j < d1; j++) + { + if (i == j) + { + H[block_start + i * d1 + j] = 0.0; + } + else + { + int idx_i = col_start + i; + int idx_j = col_start + j; + H[block_start + i * d1 + j] = + w_col * f_col / (x->value[idx_i] * x->value[idx_j]); + } + } + } +} + +static inline void wsum_hess_column_one_zero(expr *node, const double *w, int col, + int d1) +{ + expr *x = node->left; + prod_axis_zero_expr *pnode = (prod_axis_zero_expr *) node; + double *H = node->wsum_hess->x; + int col_start = col * d1; + int block_start = col * d1 * d1; + + /* clear this column's block */ + memset(&H[block_start], 0, sizeof(double) * d1 * d1); + + int p = pnode->zero_index[col]; + double prod_nonzero = pnode->prod_nonzero[col]; + double w_prod = w[col] * prod_nonzero; + + /* fill row p and column p of this block */ + for (int j = 0; j < d1; j++) + { + if (j == p) continue; + + int idx_j = col_start + j; + double hess_val = w_prod / x->value[idx_j]; + H[block_start + p * d1 + j] = hess_val; + H[block_start + j * d1 + p] = hess_val; + } +} + +static inline void wsum_hess_column_two_zeros(expr *node, const double *w, int col, + int d1) +{ + expr *x = node->left; + prod_axis_zero_expr *pnode = (prod_axis_zero_expr *) node; + double *H = node->wsum_hess->x; + int col_start = col * d1; + int block_start = col * d1 * d1; + + /* clear this column's block */ + memset(&H[block_start], 0, sizeof(double) * d1 * d1); + + /* find indices p and q where x[col_start + p] and x[col_start + q] are zero */ + int p = -1, q = -1; + for (int i = 0; i < d1; i++) + { + if (IS_ZERO(x->value[col_start + i])) + { + if (p == -1) + { + p = i; + } + else + { + q = i; + break; + } + } + } + assert(p != -1 && q != -1); + + double hess_val = w[col] * pnode->prod_nonzero[col]; + H[block_start + p * d1 + q] = hess_val; + H[block_start + q * d1 + p] = hess_val; +} + +static inline void wsum_hess_column_many_zeros(expr *node, const double *w, int col, + int d1) +{ + double *H = node->wsum_hess->x; + int block_start = col * d1 * d1; + + /* clear this column's block */ + memset(&H[block_start], 0, sizeof(double) * d1 * d1); + (void) w; +} + +static void eval_wsum_hess(expr *node, const double *w) +{ + expr *x = node->left; + prod_axis_zero_expr *pnode = (prod_axis_zero_expr *) node; + int d1 = x->d1; + int d2 = x->d2; + + /* if x is a variable */ + if (x->var_id != NOT_A_VARIABLE) + { + /* process each column */ + for (int col = 0; col < d2; col++) + { + int num_zeros = pnode->num_of_zeros[col]; + + if (num_zeros == 0) + { + wsum_hess_column_no_zeros(node, w, col, d1); + } + else if (num_zeros == 1) + { + wsum_hess_column_one_zero(node, w, col, d1); + } + else if (num_zeros == 2) + { + wsum_hess_column_two_zeros(node, w, col, d1); + } + else + { + wsum_hess_column_many_zeros(node, w, col, d1); + } + } + } + else + { + assert(false && "child must be a variable"); + } +} + +static bool is_affine(const expr *node) +{ + (void) node; + return false; +} + +static void free_type_data(expr *node) +{ + prod_axis_zero_expr *pnode = (prod_axis_zero_expr *) node; + free(pnode->num_of_zeros); + free(pnode->zero_index); + free(pnode->prod_nonzero); +} + +expr *new_prod_axis_zero(expr *child) +{ + prod_axis_zero_expr *pnode = + (prod_axis_zero_expr *) calloc(1, sizeof(prod_axis_zero_expr)); + expr *node = &pnode->base; + + /* output is always a row vector 1 x d2 - TODO: is that correct? */ + init_expr(node, 1, child->d2, child->n_vars, forward, jacobian_init, + eval_jacobian, is_affine, wsum_hess_init, eval_wsum_hess, + free_type_data); + + /* allocate arrays to store per-column statistics */ + pnode->num_of_zeros = (int *) calloc(child->d2, sizeof(int)); + pnode->zero_index = (int *) calloc(child->d2, sizeof(int)); + pnode->prod_nonzero = (double *) calloc(child->d2, sizeof(double)); + + node->left = child; + expr_retain(child); + + return node; +} diff --git a/tests/all_tests.c b/tests/all_tests.c index 0f8b9f6..164382b 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -14,6 +14,7 @@ #include "forward_pass/composite/test_composite.h" #include "forward_pass/elementwise/test_exp.h" #include "forward_pass/elementwise/test_log.h" +#include "forward_pass/test_prod_axis_zero.h" #include "jacobian_tests/test_broadcast.h" #include "jacobian_tests/test_composite.h" #include "jacobian_tests/test_const_scalar_mult.h" @@ -25,6 +26,7 @@ #include "jacobian_tests/test_log.h" #include "jacobian_tests/test_neg.h" #include "jacobian_tests/test_prod.h" +#include "jacobian_tests/test_prod_axis_zero.h" #include "jacobian_tests/test_promote.h" #include "jacobian_tests/test_quad_form.h" #include "jacobian_tests/test_quad_over_lin.h" @@ -53,6 +55,7 @@ #include "wsum_hess/test_left_matmul.h" #include "wsum_hess/test_multiply.h" #include "wsum_hess/test_prod.h" +#include "wsum_hess/test_prod_axis_zero.h" #include "wsum_hess/test_quad_form.h" #include "wsum_hess/test_quad_over_lin.h" #include "wsum_hess/test_rel_entr.h" @@ -86,6 +89,7 @@ int main(void) mu_run_test(test_broadcast_row, tests_run); mu_run_test(test_broadcast_col, tests_run); mu_run_test(test_broadcast_matrix, tests_run); + mu_run_test(test_forward_prod_axis_zero, tests_run); printf("\n--- Jacobian Tests ---\n"); mu_run_test(test_neg_jacobian, tests_run); @@ -118,6 +122,7 @@ int main(void) mu_run_test(test_jacobian_prod_no_zero, tests_run); mu_run_test(test_jacobian_prod_one_zero, tests_run); mu_run_test(test_jacobian_prod_two_zeros, tests_run); + mu_run_test(test_jacobian_prod_axis_zero, tests_run); mu_run_test(test_jacobian_sum_log, tests_run); mu_run_test(test_jacobian_sum_mult, tests_run); mu_run_test(test_jacobian_sum_log_axis_0, tests_run); @@ -168,6 +173,9 @@ int main(void) mu_run_test(test_wsum_hess_prod_one_zero, tests_run); mu_run_test(test_wsum_hess_prod_two_zeros, tests_run); mu_run_test(test_wsum_hess_prod_many_zeros, tests_run); + mu_run_test(test_wsum_hess_prod_axis_zero_no_zeros, tests_run); + mu_run_test(test_wsum_hess_prod_axis_zero_one_zero, tests_run); + mu_run_test(test_wsum_hess_prod_axis_zero_mixed_zeros, tests_run); mu_run_test(test_wsum_hess_rel_entr_1, tests_run); mu_run_test(test_wsum_hess_rel_entr_2, tests_run); mu_run_test(test_wsum_hess_rel_entr_matrix, tests_run); diff --git a/tests/forward_pass/test_prod_axis_zero.h b/tests/forward_pass/test_prod_axis_zero.h new file mode 100644 index 0000000..6504502 --- /dev/null +++ b/tests/forward_pass/test_prod_axis_zero.h @@ -0,0 +1,33 @@ +#include +#include +#include + +#include "affine.h" +#include "expr.h" +#include "minunit.h" +#include "other.h" +#include "test_helpers.h" + +const char *test_forward_prod_axis_zero() +{ + /* Create a 2x3 constant matrix stored column-wise: + [1, 3, 5] + [2, 4, 6] + Stored as: [1, 2, 3, 4, 5, 6] + */ + double values[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + expr *const_node = new_constant(2, 3, 0, values); + expr *prod_node = new_prod_axis_zero(const_node); + prod_node->forward(prod_node, NULL); + + /* Expected: columnwise products, result is 1x3 + [1*2, 3*4, 5*6] = [2, 12, 30] + */ + double expected[3] = {2.0, 12.0, 30.0}; + + mu_assert("prod_axis_zero forward test failed", + cmp_double_array(prod_node->value, expected, 3)); + + free_expr(prod_node); + return 0; +} diff --git a/tests/jacobian_tests/test_prod_axis_zero.h b/tests/jacobian_tests/test_prod_axis_zero.h new file mode 100644 index 0000000..7b3179e --- /dev/null +++ b/tests/jacobian_tests/test_prod_axis_zero.h @@ -0,0 +1,45 @@ +#include + +#include "affine.h" +#include "expr.h" +#include "minunit.h" +#include "other.h" +#include "test_helpers.h" + +const char *test_jacobian_prod_axis_zero() +{ + /* x is 2x3 variable, global index 1, total 8 vars + * x = [1, 2, 3, 4, 5, 6] (column-major order) + * [1, 3, 5] + * [2, 4, 6] + * + * f = prod_axis_zero(x) = [2, 12, 30] + * + * Jacobian is 3x8: + * Row 0 (output[0] = 2): df[0]/dx = [2, 1, 0, 0, 0, 0, 0, 0] + * df[0]/dx[0] = 2/1 = 2, df[0]/dx[1] = 2/2 = 1 + * Row 1 (output[1] = 12): df[1]/dx = [0, 0, 4, 3, 0, 0, 0, 0] + * df[1]/dx[2] = 12/3 = 4, df[1]/dx[3] = 12/4 = 3 + * Row 2 (output[2] = 30): df[2]/dx = [0, 0, 0, 0, 6, 5, 0, 0] + * df[2]/dx[4] = 30/5 = 6, df[2]/dx[5] = 30/6 = 5 + */ + double u_vals[8] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.0}; + expr *x = new_variable(2, 3, 1, 8); + expr *p = new_prod_axis_zero(x); + + p->forward(p, u_vals); + p->jacobian_init(p); + p->eval_jacobian(p); + + /* CSR format for 3x8 Jacobian with block diagonal structure */ + double expected_Ax[6] = {2.0, 1.0, 4.0, 3.0, 6.0, 5.0}; + int expected_Ap[4] = {0, 2, 4, 6}; + int expected_Ai[6] = {1, 2, 3, 4, 5, 6}; + + mu_assert("vals fail", cmp_double_array(p->jacobian->x, expected_Ax, 6)); + mu_assert("rows fail", cmp_int_array(p->jacobian->p, expected_Ap, 4)); + mu_assert("cols fail", cmp_int_array(p->jacobian->i, expected_Ai, 6)); + + free_expr(p); + return 0; +} diff --git a/tests/wsum_hess/test_prod_axis_zero.h b/tests/wsum_hess/test_prod_axis_zero.h new file mode 100644 index 0000000..e81f97d --- /dev/null +++ b/tests/wsum_hess/test_prod_axis_zero.h @@ -0,0 +1,251 @@ +#include +#include + +#include "expr.h" +#include "minunit.h" +#include "other.h" +#include "test_helpers.h" + +const char *test_wsum_hess_prod_axis_zero_no_zeros() +{ + /* x is 2x3 variable, global index 1, total 8 vars + * x = [1, 2, 3, 4, 5, 6] (column-major) + * [1, 3, 5] + * [2, 4, 6] + * f = prod_axis_zero(x) = [2, 12, 30] + * w = [1, 2, 3] + * + * Hessian is block diagonal with three 2x2 blocks: + * Block 0 (col 0): w[0] * f[0] = 1 * 2 = 2 + * H_00 = [0, 2/(1*2)] = [0, 1] + * H_01 = [2/(2*1), 0] = [1, 0] + * Block 1 (col 1): w[1] * f[1] = 2 * 12 = 24 + * H_00 = [0, 24/(3*4)] = [0, 2] + * H_01 = [24/(4*3), 0] = [2, 0] + * Block 2 (col 2): w[2] * f[2] = 3 * 30 = 90 + * H_00 = [0, 90/(5*6)] = [0, 3] + * H_01 = [90/(6*5), 0] = [3, 0] + */ + double u_vals[8] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.0}; + double w_vals[3] = {1.0, 2.0, 3.0}; + expr *x = new_variable(2, 3, 1, 8); + expr *p = new_prod_axis_zero(x); + + p->forward(p, u_vals); + p->wsum_hess_init(p); + p->eval_wsum_hess(p, w_vals); + + /* Block diagonal structure: 3 blocks of 2x2 = 6 nnz total + * Each block has 4 entries (2x2 dense) + */ + double expected_x[12] = {/* Block 0 */ + 0.0, 1.0, 1.0, 0.0, + /* Block 1 */ + 0.0, 2.0, 2.0, 0.0, + /* Block 2 */ + 0.0, 3.0, 3.0, 0.0}; + + /* Row pointers: variable is at global rows 1-6, so: + * p[0] = 0 (row 0: before variable) + * p[1-6] = block diagonal entries for the variable + * p[7-8] = 12 (rows 7-8: after variable) + */ + int expected_p[9] = {0, 0, 2, 4, 6, 8, 10, 12, 12}; + + /* Column indices: block diagonal structure + * Row 0: cols 1, 2 (global indices for block 0) + * Row 1: cols 1, 2 + * Row 2: cols 3, 4 (global indices for block 1) + * Row 3: cols 3, 4 + * Row 4: cols 5, 6 (global indices for block 2) + * Row 5: cols 5, 6 + */ + int expected_i[12] = {1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6}; + + mu_assert("vals fail", cmp_double_array(p->wsum_hess->x, expected_x, 12)); + mu_assert("rows fail", cmp_int_array(p->wsum_hess->p, expected_p, 9)); + mu_assert("cols fail", cmp_int_array(p->wsum_hess->i, expected_i, 12)); + + free_expr(p); + return 0; +} + +const char *test_wsum_hess_prod_axis_zero_mixed_zeros() +{ + /* x is 5x3 variable, global index 1, total 16 vars + * x = [1, 1, 1, 1, 1, 2, 0, 3, 4, 5, 1, 0, 0, 2, 3] (column-major) + * Matrix (column-major): + * [1, 2, 1] + * [1, 0, 0] + * [1, 3, 0] + * [1, 4, 2] + * [1, 5, 3] + * + * f = prod_axis_zero(x) = [1, 120, 6] + * Column 0: 5 nonzeros, prod = 1*1*1*1*1 = 1 + * Column 1: 1 zero at row 1, prod_nonzero = 2*3*4*5 = 120 + * Column 2: 2 zeros at rows 1,2, prod_nonzero = 1*2*3 = 6 + * + * w = [1, 2, 3] + * + * Block 0 (5x5): no zeros, diagonal=0, off-diagonal=w[0]*f[0]/(x[i]*x[j]) = + * 1/(1*1) = 1 + * Block 1 (5x5): 1 zero at row 1, nonzeros only in row/col 1 + * (1,0): 2*120/5 = 48, (1,2): 2*120/3 = 80, (1,3): 2*120/4 = 60, (1,4): + * 2*120/5 = 48 Block 2 (5x5): 2 zeros at rows 1,2, only (1,3) and (2,3) are + * nonzero (1,3): 3*6/2 = 9, (2,3): 3*6/3 = 6 + */ + double u_vals[16] = {0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 0.0, + 3.0, 4.0, 5.0, 1.0, 0.0, 0.0, 2.0, 3.0}; + /* This maps to: x[var_id+0..4]=column 0, x[var_id+5..9]=column 1, + * x[var_id+10..14]=column 2 Column 0: [1, 1, 1, 1, 1] - all nonzero Column 1: + * [2, 0, 3, 4, 5] - one zero at index 1 Column 2: [1, 0, 0, 2, 3] - two zeros at + * indices 1, 2 + */ + double w_vals[3] = {1.0, 2.0, 3.0}; + expr *x = new_variable(5, 3, 1, 16); + expr *p = new_prod_axis_zero(x); + + p->forward(p, u_vals); + p->wsum_hess_init(p); + p->eval_wsum_hess(p, w_vals); + + /* Block 0: 5x5 all off-diagonal = 1.0, total = 25 entries (indices 0-24) + * Block 1: 5x5 with 1 zero, total = 25 entries (indices 25-49) + * Block 2: 5x5 with 2 zeros, total = 25 entries (indices 50-74) + * Total nnz = 75 + */ + double expected_x[75]; + memset(expected_x, 0, sizeof(expected_x)); + + /* Block 0 (indices 0-24): all off-diagonal = 1.0 */ + for (int i = 0; i < 5; i++) + { + for (int j = 0; j < 5; j++) + { + if (i != j) + { + expected_x[i * 5 + j] = 1.0; + } + } + } + + /* Block 1 (indices 25-49): 1 zero at row 1 (index p=1), w[1]*prod_nonzero[1] = + * 2*120 = 240 Column 1 is x[5..9] = [2, 0, 3, 4, 5] (rows 0-4) Only row p=1 and + * column p=1 have nonzeros H[i,p] = H[p,i] = w_prod / x[i] for i != p H[0,1] = + * H[1,0] = 240/2 = 120 H[2,1] = H[1,2] = 240/3 = 80 H[3,1] = H[1,3] = 240/4 = 60 + * H[4,1] = H[1,4] = 240/5 = 48 + */ + /* Row 0 of block 1: only (0,1) is nonzero */ + expected_x[25 + 0 * 5 + 1] = 120.0; + /* Row 1 of block 1: (1,0), (1,2), (1,3), (1,4) are nonzero */ + expected_x[25 + 1 * 5 + 0] = 120.0; + expected_x[25 + 1 * 5 + 2] = 80.0; + expected_x[25 + 1 * 5 + 3] = 60.0; + expected_x[25 + 1 * 5 + 4] = 48.0; + /* Row 2 of block 1: only (2,1) is nonzero */ + expected_x[25 + 2 * 5 + 1] = 80.0; + /* Row 3 of block 1: only (3,1) is nonzero */ + expected_x[25 + 3 * 5 + 1] = 60.0; + /* Row 4 of block 1: only (4,1) is nonzero */ + expected_x[25 + 4 * 5 + 1] = 48.0; + + /* Block 2 (indices 50-74): 2 zeros at rows p=1, q=2, w[2]*prod_nonzero[2] = 3*6 + * = 18 Column 2 is x[10..14] = [1, 0, 0, 2, 3] (rows 0-4) With exactly 2 zeros, + * only H[p,q] and H[q,p] are nonzero H[1,2] = H[2,1] = 18 + */ + for (int i = 0; i < 5; i++) + for (int j = 0; j < 5; j++) expected_x[50 + i * 5 + j] = 0.0; + expected_x[50 + 1 * 5 + 2] = 18.0; /* H[1,2] */ + expected_x[50 + 2 * 5 + 1] = 18.0; /* H[2,1] */ + + int expected_p[17] = {0, 0, 5, 10, 15, 20, 25, 30, 35, + 40, 45, 50, 55, 60, 65, 70, 75}; + + /* Column indices: block diagonal structure + * Block 0 (rows 1-5): each row has cols [1, 2, 3, 4, 5] + * Block 1 (rows 6-10): each row has cols [6, 7, 8, 9, 10] + * Block 2 (rows 11-15): each row has cols [11, 12, 13, 14, 15] + */ + int expected_i[75]; + for (int block = 0; block < 3; block++) + { + int offset = block * 25; + int col_start = block * 5 + 1; + for (int row = 0; row < 5; row++) + { + for (int col = 0; col < 5; col++) + { + expected_i[offset + row * 5 + col] = col_start + col; + } + } + } + + mu_assert("rows fail", cmp_int_array(p->wsum_hess->p, expected_p, 17)); + mu_assert("cols fail", cmp_int_array(p->wsum_hess->i, expected_i, 75)); + mu_assert("vals fail", cmp_double_array(p->wsum_hess->x, expected_x, 75)); + + free_expr(p); + return 0; +} + +const char *test_wsum_hess_prod_axis_zero_one_zero() +{ + /* Test with a column that has exactly 1 zero + * x is 2x2 variable, global index 1, total 5 vars + * x = [1.0, 1.0, 2.0, 0.0] (column-major) + * Matrix (column-major): + * [1, 2] + * [1, 0] + * + * f = prod_axis_zero(x) = [1, 0] + * Column 0: no zeros, prod = 1 + * Column 1: 1 zero at row 1, prod_nonzero = 2 + * w = [1, 2] + */ + double u_vals[5] = {0.0, 1.0, 1.0, 2.0, 0.0}; + double w_vals[2] = {1.0, 2.0}; + expr *x = new_variable(2, 2, 1, 5); + expr *p = new_prod_axis_zero(x); + + p->forward(p, u_vals); + p->wsum_hess_init(p); + p->eval_wsum_hess(p, w_vals); + + /* Block 0 (no zeros): w[0]*f[0] = 1 + * (0,1) = 1/(1*1) = 1, (1,0) = 1 + * Block 1 (1 zero at row 1): w[1]*prod_nonzero[1] = 2*2 = 4 + * (1,0) = 4/2 = 2 (symmetric) + */ + double expected_x[8]; + memset(expected_x, 0, sizeof(expected_x)); + + /* Block 0 */ + expected_x[0 * 2 + 1] = 1.0; /* (0,1) */ + expected_x[1 * 2 + 0] = 1.0; /* (1,0) */ + + /* Block 1: one zero at row 1 */ + expected_x[4 + 1 * 2 + 0] = 2.0; /* (1,0) */ + expected_x[4 + 0 * 2 + 1] = 2.0; /* (0,1) symmetric */ + + /* Row pointers: variable is at global rows 1-2, so: + * p[0] = 0 (row 0: before variable) + * p[1-2] = block 0 (rows 1-2, each has 2 entries) + * p[3-4] = block 1 (rows 3-4, each has 2 entries) + * p[5] = 8 (after variable) + */ + int expected_p[6] = {0, 0, 2, 4, 6, 8}; + + /* Column indices: block diagonal structure + * Block 0: cols 1, 2 (global indices for block 0) + * Block 1: cols 3, 4 (global indices for block 1) + */ + int expected_i[8] = {1, 2, 1, 2, 3, 4, 3, 4}; + + mu_assert("vals fail", cmp_double_array(p->wsum_hess->x, expected_x, 8)); + mu_assert("rows fail", cmp_int_array(p->wsum_hess->p, expected_p, 6)); + mu_assert("cols fail", cmp_int_array(p->wsum_hess->i, expected_i, 8)); + + free_expr(p); + return 0; +}