diff --git a/include/other.h b/include/other.h index 8bf1d53..e523552 100644 --- a/include/other.h +++ b/include/other.h @@ -13,4 +13,7 @@ expr *new_prod(expr *child); /* product of entries along axis=0 (columnwise products) */ expr *new_prod_axis_zero(expr *child); +/* product of entries along axis=1 (rowwise products) */ +expr *new_prod_axis_one(expr *child); + #endif /* OTHER_H */ diff --git a/include/subexpr.h b/include/subexpr.h index e96c40f..d8e56a5 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -56,14 +56,15 @@ typedef struct prod_expr double prod_nonzero; /* product of non-zero elements */ } prod_expr; -/* Product of entries along axis=0 (columnwise products) */ -typedef struct prod_axis_zero_expr +/* Product of entries along axis=0 (columnwise products) or axis = 1 (rowwise + * products) */ +typedef struct prod_axis { expr base; - int *num_of_zeros; /* num of zeros for each column */ - int *zero_index; /* stores idx of zero element per column */ - double *prod_nonzero; /* product of non-zero elements per column */ -} prod_axis_zero_expr; + int *num_of_zeros; /* num of zeros for each column / row depending on the axis*/ + int *zero_index; /* stores idx of zero element per column / row */ + double *prod_nonzero; /* product of non-zero elements per column / row */ +} prod_axis; /* Horizontal stack (concatenate) */ typedef struct hstack_expr diff --git a/python/atoms/prod_axis_one.h b/python/atoms/prod_axis_one.h new file mode 100644 index 0000000..d9c3b8a --- /dev/null +++ b/python/atoms/prod_axis_one.h @@ -0,0 +1,32 @@ +#ifndef ATOM_PROD_AXIS_ONE_H +#define ATOM_PROD_AXIS_ONE_H + +#include "common.h" +#include "other.h" + +static PyObject *py_make_prod_axis_one(PyObject *self, PyObject *args) +{ + (void) self; + PyObject *child_capsule; + if (!PyArg_ParseTuple(args, "O", &child_capsule)) + { + return NULL; + } + expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); + if (!child) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + expr *node = new_prod_axis_one(child); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create prod_axis_one node"); + return NULL; + } + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_PROD_AXIS_ONE_H */ diff --git a/python/bindings.c b/python/bindings.c index d83b271..0cd703e 100644 --- a/python/bindings.c +++ b/python/bindings.c @@ -23,6 +23,7 @@ #include "atoms/neg.h" #include "atoms/power.h" #include "atoms/prod.h" +#include "atoms/prod_axis_one.h" #include "atoms/prod_axis_zero.h" #include "atoms/promote.h" #include "atoms/quad_form.h" @@ -80,6 +81,8 @@ static PyMethodDef DNLPMethods[] = { {"make_prod", py_make_prod, METH_VARARGS, "Create prod node"}, {"make_prod_axis_zero", py_make_prod_axis_zero, METH_VARARGS, "Create prod_axis_zero node"}, + {"make_prod_axis_one", py_make_prod_axis_one, METH_VARARGS, + "Create prod_axis_one node"}, {"make_sin", py_make_sin, METH_VARARGS, "Create sin node"}, {"make_cos", py_make_cos, METH_VARARGS, "Create cos node"}, {"make_tan", py_make_tan, METH_VARARGS, "Create tan node"}, diff --git a/src/other/prod_axis_one.c b/src/other/prod_axis_one.c new file mode 100644 index 0000000..600a20a --- /dev/null +++ b/src/other/prod_axis_one.c @@ -0,0 +1,390 @@ +#include "other.h" +#include +#include +#include +#include +#include + +#define IS_ZERO(x) (fabs((x)) < 1e-8) + +static void forward(expr *node, const double *u) +{ + /* forward pass of child */ + expr *x = node->left; + x->forward(x, u); + + prod_axis *pnode = (prod_axis *) node; + int d1 = x->d1; + int d2 = x->d2; + + /* initialize per-row statistics */ + for (int i = 0; i < d1; i++) + { + pnode->num_of_zeros[i] = 0; + pnode->zero_index[i] = -1; + pnode->prod_nonzero[i] = 1.0; + } + + /* iterate over columns */ + for (int col = 0; col < d2; col++) + { + int start = col * d1; + int end = start + d1; + + for (int idx = start; idx < end; idx++) + { + int row = idx - start; + if (IS_ZERO(x->value[idx])) + { + pnode->num_of_zeros[row]++; + pnode->zero_index[row] = col; + } + else + { + pnode->prod_nonzero[row] *= x->value[idx]; + } + } + } + + /* compute output values */ + for (int i = 0; i < d1; i++) + { + node->value[i] = (pnode->num_of_zeros[i] > 0) ? 0.0 : pnode->prod_nonzero[i]; + } +} + +static void jacobian_init(expr *node) +{ + expr *x = node->left; + + /* initialize child's jacobian */ + x->jacobian_init(x); + + /* if x is a variable */ + if (x->var_id != NOT_A_VARIABLE) + { + node->jacobian = new_csr_matrix(node->size, node->n_vars, x->size); + + /* set row pointers (each row has d2 nnzs) */ + for (int row = 0; row < x->d1; row++) + { + node->jacobian->p[row] = row * x->d2; + } + node->jacobian->p[x->d1] = x->size; + + /* set column indices */ + for (int row = 0; row < x->d1; row++) + { + int start = row * x->d2; + for (int col = 0; col < x->d2; col++) + { + node->jacobian->i[start + col] = x->var_id + col * x->d1 + row; + } + } + } + else + { + assert(false && "child must be a variable"); + } +} + +static void eval_jacobian(expr *node) +{ + expr *x = node->left; + prod_axis *pnode = (prod_axis *) node; + + double *J_vals = node->jacobian->x; + + /* if x is a variable */ + if (x->var_id != NOT_A_VARIABLE) + { + /* process each row */ + for (int row = 0; row < x->d1; row++) + { + int num_zeros = pnode->num_of_zeros[row]; + int start = row * x->d2; + + if (num_zeros == 0) + { + for (int col = 0; col < x->d2; col++) + { + int idx = col * x->d1 + row; + J_vals[start + col] = node->value[row] / x->value[idx]; + } + } + else if (num_zeros == 1) + { + memset(J_vals + start, 0, sizeof(double) * x->d2); + J_vals[start + pnode->zero_index[row]] = pnode->prod_nonzero[row]; + } + else + { + memset(J_vals + start, 0, sizeof(double) * x->d2); + } + } + } + else + { + assert(false && "child must be a variable"); + } +} + +static void wsum_hess_init(expr *node) +{ + expr *x = node->left; + + /* if x is a variable */ + if (x->var_id != NOT_A_VARIABLE) + { + /* each row i has d2-1 non-zero entries, with column indices corresponding to + the columns in that row (except the diagonal element). */ + int nnz = x->d1 * x->d2 * (x->d2 - 1); + node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, nnz); + CSR_Matrix *H = node->wsum_hess; + + /* fill sparsity pattern */ + int nnz_per_row = x->d2 - 1; + for (int i = 0; i < x->size; i++) + { + int row = i % x->d1; + int col = i / x->d1; + int start = i * nnz_per_row; + H->p[x->var_id + i] = start; + + /* for this variable (at X[row, col]), the Hessian has entries with all + * columns in the same row, excluding the diagonal (col itself) */ + int offset = 0; + int col_offset = x->var_id + row; + for (int j = 0; j < x->d2; j++) + { + if (j != col) + { + H->i[start + offset] = col_offset + j * x->d1; + offset++; + } + } + } + + /* fill row pointers for rows after the variable */ + for (int i = x->var_id + x->size; i <= node->n_vars; i++) + { + H->p[i] = nnz; + } + } + else + { + assert(false && "child must be a variable"); + } +} + +static inline void wsum_hess_row_no_zeros(expr *node, const double *w, int row, + int d2) +{ + expr *x = node->left; + CSR_Matrix *H = node->wsum_hess; + double scale = w[row] * node->value[row]; + + /* for each variable xk in this row, fill in Hessian entries + // TODO: more cache-friendly solution is possible? */ + for (int i = 0; i < d2; i++) + { + int k = x->var_id + row + i * x->d1; + int idx_k = i * x->d1 + row; + int offset = 0; + for (int j = 0; j < d2; j++) + { + if (j != i) + { + int idx_j = j * x->d1 + row; + H->x[H->p[k] + offset] = scale / (x->value[idx_k] * x->value[idx_j]); + offset++; + } + } + } +} + +static inline void wsum_hess_row_one_zero(expr *node, const double *w, int row, + int d2) +{ + expr *x = node->left; + prod_axis *pnode = (prod_axis *) node; + CSR_Matrix *H = node->wsum_hess; + double *H_vals = H->x; + int p = pnode->zero_index[row]; /* zero column index */ + double w_prod = w[row] * pnode->prod_nonzero[row]; + + /* For each variable in this row */ + for (int i = 0; i < d2; i++) + { + int var_i = x->var_id + row + i * x->d1; + int row_start = H->p[var_i]; + + /* Row var_i has d2-1 entries (excluding diagonal) */ + int offset = 0; + for (int j = 0; j < d2; j++) + { + if (j != i) + { + if (i == p && j != p) + { + /* row p (zero row), column j (nonzero) */ + int idx_j = j * x->d1 + row; + H_vals[row_start + offset] = w_prod / x->value[idx_j]; + } + else if (j == p && i != p) + { + /* row i (nonzero), column p (zero) */ + int idx_i = i * x->d1 + row; + H_vals[row_start + offset] = w_prod / x->value[idx_i]; + } + else + { + /* all other entries are zero */ + H_vals[row_start + offset] = 0.0; + } + offset++; + } + } + } +} + +static inline void wsum_hess_row_two_zeros(expr *node, const double *w, int row, + int d2) +{ + expr *x = node->left; + prod_axis *pnode = (prod_axis *) node; + CSR_Matrix *H = node->wsum_hess; + double *H_vals = H->x; + + /* find indices p and q where row has zeros */ + int p = -1, q = -1; + for (int c = 0; c < d2; c++) + { + if (IS_ZERO(x->value[c * node->left->d1 + row])) + { + if (p == -1) + { + p = c; + } + else + { + q = c; + break; + } + } + } + assert(p != -1 && q != -1); + + double hess_val = w[row] * pnode->prod_nonzero[row]; + + /* For each variable in this row */ + for (int i = 0; i < d2; i++) + { + int var_i = x->var_id + row + i * x->d1; + int row_start = H->p[var_i]; + + /* row var_i has d2-1 entries (excluding diagonal) */ + int offset = 0; + for (int j = 0; j < d2; j++) + { + if (j != i) + { + /* only (p,q) and (q,p) are nonzero */ + if ((i == p && j == q) || (i == q && j == p)) + { + H_vals[row_start + offset] = hess_val; + } + else + { + H_vals[row_start + offset] = 0.0; + } + offset++; + } + } + } +} + +static inline void wsum_hess_row_many_zeros(expr *node, int row, int d2) +{ + CSR_Matrix *H = node->wsum_hess; + double *H_vals = H->x; + expr *x = node->left; + + /* for each variable in this row, zero out all entries */ + for (int i = 0; i < d2; i++) + { + int var_i = x->var_id + row + i * x->d1; + memset(H_vals + H->p[var_i], 0, sizeof(double) * (d2 - 1)); + } +} + +static void eval_wsum_hess(expr *node, const double *w) +{ + expr *x = node->left; + prod_axis *pnode = (prod_axis *) node; + + /* if x is a variable */ + if (x->var_id != NOT_A_VARIABLE) + { + for (int row = 0; row < x->d1; row++) + { + int num_zeros = pnode->num_of_zeros[row]; + + if (num_zeros == 0) + { + wsum_hess_row_no_zeros(node, w, row, x->d2); + } + else if (num_zeros == 1) + { + wsum_hess_row_one_zero(node, w, row, x->d2); + } + else if (num_zeros == 2) + { + wsum_hess_row_two_zeros(node, w, row, x->d2); + } + else + { + wsum_hess_row_many_zeros(node, row, x->d2); + } + } + } + else + { + assert(false && "child must be a variable"); + } +} + +static bool is_affine(const expr *node) +{ + (void) node; + return false; +} + +static void free_type_data(expr *node) +{ + prod_axis *pnode = (prod_axis *) node; + free(pnode->num_of_zeros); + free(pnode->zero_index); + free(pnode->prod_nonzero); +} + +expr *new_prod_axis_one(expr *child) +{ + prod_axis *pnode = (prod_axis *) calloc(1, sizeof(prod_axis)); + expr *node = &pnode->base; + + /* output is always a row vector 1 x d1 (one product per row) */ + init_expr(node, 1, child->d1, child->n_vars, forward, jacobian_init, + eval_jacobian, is_affine, wsum_hess_init, eval_wsum_hess, + free_type_data); + + /* allocate arrays to store per-row statistics */ + pnode->num_of_zeros = (int *) calloc(child->d1, sizeof(int)); + pnode->zero_index = (int *) calloc(child->d1, sizeof(int)); + pnode->prod_nonzero = (double *) calloc(child->d1, sizeof(double)); + + node->left = child; + expr_retain(child); + + return node; +} diff --git a/src/other/prod_axis_zero.c b/src/other/prod_axis_zero.c index 8d1807b..191f6ff 100644 --- a/src/other/prod_axis_zero.c +++ b/src/other/prod_axis_zero.c @@ -12,7 +12,7 @@ static void forward(expr *node, const double *u) expr *x = node->left; x->forward(x, u); - prod_axis_zero_expr *pnode = (prod_axis_zero_expr *) node; + prod_axis *pnode = (prod_axis *) node; int d1 = x->d1; int d2 = x->d2; @@ -85,7 +85,7 @@ static void jacobian_init(expr *node) static void eval_jacobian(expr *node) { expr *x = node->left; - prod_axis_zero_expr *pnode = (prod_axis_zero_expr *) node; + prod_axis *pnode = (prod_axis *) node; double *J_vals = node->jacobian->x; @@ -175,8 +175,7 @@ static inline void wsum_hess_column_no_zeros(expr *node, const double *w, int co double *H = node->wsum_hess->x; int col_start = col * d1; int block_start = col * d1 * d1; - double w_col = w[col]; - double f_col = node->value[col]; + double scale = w[col] * node->value[col]; for (int i = 0; i < d1; i++) { @@ -191,7 +190,7 @@ static inline void wsum_hess_column_no_zeros(expr *node, const double *w, int co int idx_i = col_start + i; int idx_j = col_start + j; H[block_start + i * d1 + j] = - w_col * f_col / (x->value[idx_i] * x->value[idx_j]); + scale / (x->value[idx_i] * x->value[idx_j]); } } } @@ -201,7 +200,7 @@ static inline void wsum_hess_column_one_zero(expr *node, const double *w, int co int d1) { expr *x = node->left; - prod_axis_zero_expr *pnode = (prod_axis_zero_expr *) node; + prod_axis *pnode = (prod_axis *) node; double *H = node->wsum_hess->x; int col_start = col * d1; int block_start = col * d1 * d1; @@ -229,7 +228,7 @@ static inline void wsum_hess_column_two_zeros(expr *node, const double *w, int c int d1) { expr *x = node->left; - prod_axis_zero_expr *pnode = (prod_axis_zero_expr *) node; + prod_axis *pnode = (prod_axis *) node; double *H = node->wsum_hess->x; int col_start = col * d1; int block_start = col * d1 * d1; @@ -275,7 +274,7 @@ static inline void wsum_hess_column_many_zeros(expr *node, const double *w, int static void eval_wsum_hess(expr *node, const double *w) { expr *x = node->left; - prod_axis_zero_expr *pnode = (prod_axis_zero_expr *) node; + prod_axis *pnode = (prod_axis *) node; int d1 = x->d1; int d2 = x->d2; @@ -319,16 +318,16 @@ static bool is_affine(const expr *node) static void free_type_data(expr *node) { - prod_axis_zero_expr *pnode = (prod_axis_zero_expr *) node; + prod_axis *pnode = (prod_axis *) node; free(pnode->num_of_zeros); free(pnode->zero_index); free(pnode->prod_nonzero); } +/* TODO: refactor to remove diagonal entry as nonzero since it's always zero */ expr *new_prod_axis_zero(expr *child) { - prod_axis_zero_expr *pnode = - (prod_axis_zero_expr *) calloc(1, sizeof(prod_axis_zero_expr)); + prod_axis *pnode = (prod_axis *) calloc(1, sizeof(prod_axis)); expr *node = &pnode->base; /* output is always a row vector 1 x d2 - TODO: is that correct? */ diff --git a/tests/all_tests.c b/tests/all_tests.c index 164382b..5de1b00 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -14,6 +14,7 @@ #include "forward_pass/composite/test_composite.h" #include "forward_pass/elementwise/test_exp.h" #include "forward_pass/elementwise/test_log.h" +#include "forward_pass/test_prod_axis_one.h" #include "forward_pass/test_prod_axis_zero.h" #include "jacobian_tests/test_broadcast.h" #include "jacobian_tests/test_composite.h" @@ -26,6 +27,7 @@ #include "jacobian_tests/test_log.h" #include "jacobian_tests/test_neg.h" #include "jacobian_tests/test_prod.h" +#include "jacobian_tests/test_prod_axis_one.h" #include "jacobian_tests/test_prod_axis_zero.h" #include "jacobian_tests/test_promote.h" #include "jacobian_tests/test_quad_form.h" @@ -55,6 +57,7 @@ #include "wsum_hess/test_left_matmul.h" #include "wsum_hess/test_multiply.h" #include "wsum_hess/test_prod.h" +#include "wsum_hess/test_prod_axis_one.h" #include "wsum_hess/test_prod_axis_zero.h" #include "wsum_hess/test_quad_form.h" #include "wsum_hess/test_quad_over_lin.h" @@ -90,6 +93,7 @@ int main(void) mu_run_test(test_broadcast_col, tests_run); mu_run_test(test_broadcast_matrix, tests_run); mu_run_test(test_forward_prod_axis_zero, tests_run); + mu_run_test(test_forward_prod_axis_one, tests_run); printf("\n--- Jacobian Tests ---\n"); mu_run_test(test_neg_jacobian, tests_run); @@ -123,6 +127,8 @@ int main(void) mu_run_test(test_jacobian_prod_one_zero, tests_run); mu_run_test(test_jacobian_prod_two_zeros, tests_run); mu_run_test(test_jacobian_prod_axis_zero, tests_run); + mu_run_test(test_jacobian_prod_axis_one, tests_run); + mu_run_test(test_jacobian_prod_axis_one_one_zero, tests_run); mu_run_test(test_jacobian_sum_log, tests_run); mu_run_test(test_jacobian_sum_mult, tests_run); mu_run_test(test_jacobian_sum_log_axis_0, tests_run); @@ -176,6 +182,10 @@ int main(void) mu_run_test(test_wsum_hess_prod_axis_zero_no_zeros, tests_run); mu_run_test(test_wsum_hess_prod_axis_zero_one_zero, tests_run); mu_run_test(test_wsum_hess_prod_axis_zero_mixed_zeros, tests_run); + mu_run_test(test_wsum_hess_prod_axis_one_no_zeros, tests_run); + mu_run_test(test_wsum_hess_prod_axis_one_one_zero, tests_run); + mu_run_test(test_wsum_hess_prod_axis_one_mixed_zeros, tests_run); + mu_run_test(test_wsum_hess_prod_axis_one_2x2, tests_run); mu_run_test(test_wsum_hess_rel_entr_1, tests_run); mu_run_test(test_wsum_hess_rel_entr_2, tests_run); mu_run_test(test_wsum_hess_rel_entr_matrix, tests_run); diff --git a/tests/forward_pass/test_prod_axis_one.h b/tests/forward_pass/test_prod_axis_one.h new file mode 100644 index 0000000..6f4b3bb --- /dev/null +++ b/tests/forward_pass/test_prod_axis_one.h @@ -0,0 +1,35 @@ +#include +#include +#include + +#include "affine.h" +#include "expr.h" +#include "minunit.h" +#include "other.h" +#include "test_helpers.h" + +const char *test_forward_prod_axis_one() +{ + /* Create a 2x3 constant matrix stored column-wise: + [1, 3, 5] + [2, 4, 6] + Stored as: [1, 2, 3, 4, 5, 6] + */ + double values[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + expr *const_node = new_constant(2, 3, 0, values); + expr *prod_node = new_prod_axis_one(const_node); + prod_node->forward(prod_node, NULL); + + /* Expected: rowwise products, result is 1x2 (row vector) + Row 0: 1*3*5 = 15 + Row 1: 2*4*6 = 48 + Result: [15, 48] + */ + double expected[2] = {15.0, 48.0}; + + mu_assert("prod_axis_one forward test failed", + cmp_double_array(prod_node->value, expected, 2)); + + free_expr(prod_node); + return 0; +} diff --git a/tests/jacobian_tests/test_prod_axis_one.h b/tests/jacobian_tests/test_prod_axis_one.h new file mode 100644 index 0000000..77993fb --- /dev/null +++ b/tests/jacobian_tests/test_prod_axis_one.h @@ -0,0 +1,94 @@ +#include + +#include "affine.h" +#include "expr.h" +#include "minunit.h" +#include "other.h" +#include "test_helpers.h" + +const char *test_jacobian_prod_axis_one() +{ + /* x is 3x3 variable, global index 1, total 10 vars + * x = [1, 2, 3, 4, 5, 6, 7, 8, 9] (column-major order) + * [1, 4, 7] + * [2, 5, 8] + * [3, 6, 9] + * + * f = prod_axis_one(x) = [28, 80, 162] (row vector) + * Row 0: 1*4*7 = 28 + * Row 1: 2*5*8 = 80 + * Row 2: 3*6*9 = 162 + * + * Jacobian is 3x10 (3 outputs, 10 total vars): + * Row 0 (output[0] = 28): df[0]/dx = [0, 28, 0, 0, 7, 0, 0, 4, 0, 0] + * df[0]/dx[0] = 28/1 = 28, df[0]/dx[3] = 28/4 = 7, df[0]/dx[6] = 28/7 = 4 + * Row 1 (output[1] = 80): df[1]/dx = [0, 0, 40, 0, 0, 16, 0, 0, 10, 0] + * df[1]/dx[1] = 80/2 = 40, df[1]/dx[4] = 80/5 = 16, df[1]/dx[7] = 80/8 = 10 + * Row 2 (output[2] = 162): df[2]/dx = [0, 0, 0, 54, 0, 0, 27, 0, 0, 18] + * df[2]/dx[2] = 162/3 = 54, df[2]/dx[5] = 162/6 = 27, df[2]/dx[8] = 162/9 = 18 + * + * (Note: global indices are offset by var_id=1) + */ + double u_vals[10] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; + expr *x = new_variable(3, 3, 1, 10); + expr *p = new_prod_axis_one(x); + + p->forward(p, u_vals); + p->jacobian_init(p); + p->eval_jacobian(p); + + /* CSR format for 3x10 Jacobian with row-strided structure */ + double expected_Ax[9] = {28.0, 7.0, 4.0, 40.0, 16.0, 10.0, 54.0, 27.0, 18.0}; + int expected_Ap[4] = {0, 3, 6, 9}; + int expected_Ai[9] = {1, 4, 7, 2, 5, 8, 3, 6, 9}; + + mu_assert("vals fail", cmp_double_array(p->jacobian->x, expected_Ax, 9)); + mu_assert("rows fail", cmp_int_array(p->jacobian->p, expected_Ap, 4)); + mu_assert("cols fail", cmp_int_array(p->jacobian->i, expected_Ai, 9)); + + free_expr(p); + return 0; +} + +const char *test_jacobian_prod_axis_one_one_zero() +{ + /* x is 3x3 variable, global index 1, total 10 vars + * x = [1, 2, 3, 4, 0, 6, 7, 8, 9] (column-major order) + * [1, 4, 7] + * [2, 0, 8] + * [3, 6, 9] + * + * f = prod_axis_one(x) = [28, 0, 162] (row vector) + * Row 0: 1*4*7 = 28 + * Row 1: 2*0*8 = 0 (one zero at column 1) + * Row 2: 3*6*9 = 162 + * + * Jacobian is 3x10: + * Row 0 (output[0] = 28, no zeros): df[0]/dx = [0, 28, 0, 0, 7, 0, 0, 4, 0, 0] + * df[0]/dx[0] = 28/1 = 28, df[0]/dx[3] = 28/4 = 7, df[0]/dx[6] = 28/7 = 4 + * Row 1 (output[1] = 0, one zero): df[1]/dx = [0, 0, 0, 0, 0, 16, 0, 0, 0, 0] + * Only derivative w.r.t. zero element: df[1]/dx[4] = prod_nonzero = 2*8 = 16 + * Row 2 (output[2] = 162, no zeros): df[2]/dx = [0, 0, 0, 54, 0, 0, 27, 0, 0, + * 18] df[2]/dx[2] = 162/3 = 54, df[2]/dx[5] = 162/6 = 27, df[2]/dx[8] = 162/9 = + * 18 + */ + double u_vals[10] = {0.0, 1.0, 2.0, 3.0, 4.0, 0.0, 6.0, 7.0, 8.0, 9.0}; + expr *x = new_variable(3, 3, 1, 10); + expr *p = new_prod_axis_one(x); + + p->forward(p, u_vals); + p->jacobian_init(p); + p->eval_jacobian(p); + + /* CSR format for 3x10 Jacobian with row-strided structure */ + double expected_Ax[9] = {28.0, 7.0, 4.0, 0.0, 16.0, 0.0, 54.0, 27.0, 18.0}; + int expected_Ap[4] = {0, 3, 6, 9}; + int expected_Ai[9] = {1, 4, 7, 2, 5, 8, 3, 6, 9}; + + mu_assert("vals fail", cmp_double_array(p->jacobian->x, expected_Ax, 9)); + mu_assert("rows fail", cmp_int_array(p->jacobian->p, expected_Ap, 4)); + mu_assert("cols fail", cmp_int_array(p->jacobian->i, expected_Ai, 9)); + + free_expr(p); + return 0; +} diff --git a/tests/wsum_hess/test_prod_axis_one.h b/tests/wsum_hess/test_prod_axis_one.h new file mode 100644 index 0000000..18524e9 --- /dev/null +++ b/tests/wsum_hess/test_prod_axis_one.h @@ -0,0 +1,378 @@ +#include +#include + +#include "expr.h" +#include "minunit.h" +#include "other.h" +#include "test_helpers.h" + +const char *test_wsum_hess_prod_axis_one_no_zeros() +{ + /* x is 2x3 variable, global index 1, total 8 vars + * x = [1, 2, 3, 4, 5, 6] (column-major) + * [1, 3, 5] + * [2, 4, 6] + * f = prod_axis_one(x) = [15, 48] + * w = [1, 2] + * + * Hessian is block diagonal with two 3x3 blocks (one per row): + * Row 0 block (scale = 1 * 15 = 15): + * cols (c0=1, c1=3, c2=5) + * off-diagonals: (0,1)=5, (0,2)=3, (1,0)=5, (1,2)=1, (2,0)=3, (2,1)=1; diag=0 + * Row 1 block (scale = 2 * 48 = 96): + * cols (c0=2, c1=4, c2=6) + * off-diagonals: (0,1)=12, (0,2)=8, (1,0)=12, (1,2)=4, (2,0)=8, (2,1)=4; + * diag=0 + */ + double u_vals[8] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.0}; + double w_vals[2] = {1.0, 2.0}; + expr *x = new_variable(2, 3, 1, 8); + expr *p = new_prod_axis_one(x); + + p->forward(p, u_vals); + p->wsum_hess_init(p); + p->eval_wsum_hess(p, w_vals); + + double expected_x[12] = {/* Var 1 (row 0, col 0): [5, 3] (excludes col 0) */ + 5.0, 3.0, + /* Var 2 (row 1, col 0): [12, 8] (excludes col 0) */ + 12.0, 8.0, + /* Var 3 (row 0, col 1): [5, 1] (excludes col 1) */ + 5.0, 1.0, + /* Var 4 (row 1, col 1): [12, 4] (excludes col 1) */ + 12.0, 4.0, + /* Var 5 (row 0, col 2): [3, 1] (excludes col 2) */ + 3.0, 1.0, + /* Var 6 (row 1, col 2): [8, 4] (excludes col 2) */ + 8.0, 4.0}; + + /* Row pointers (monotonically increasing for valid CSR format) */ + int expected_p[9] = {0, 0, 2, 4, 6, 8, 10, 12, 12}; + + /* Column indices (each row of the matrix interacts with its own columns, + * excluding diagonal) */ + int expected_i[12] = {/* Var 1 (row 0, col 0): cols 3,5 (excludes 1) */ + 3, 5, + /* Var 2 (row 1, col 0): cols 4,6 (excludes 2) */ + 4, 6, + /* Var 3 (row 0, col 1): cols 1,5 (excludes 3) */ + 1, 5, + /* Var 4 (row 1, col 1): cols 2,6 (excludes 4) */ + 2, 6, + /* Var 5 (row 0, col 2): cols 1,3 (excludes 5) */ + 1, 3, + /* Var 6 (row 1, col 2): cols 2,4 (excludes 6) */ + 2, 4}; + + mu_assert("vals fail", cmp_double_array(p->wsum_hess->x, expected_x, 12)); + mu_assert("rows fail", cmp_int_array(p->wsum_hess->p, expected_p, 9)); + mu_assert("cols fail", cmp_int_array(p->wsum_hess->i, expected_i, 12)); + + free_expr(p); + return 0; +} + +const char *test_wsum_hess_prod_axis_one_one_zero() +{ + /* x is 3x3 variable, global index 1, total 10 vars + * x = [1, 2, 3, 4, 0, 6, 7, 8, 9] (column-major) + * [1, 4, 7] + * [2, 0, 8] + * [3, 6, 9] + * f = prod_axis_one(x) = [28, 0, 162] + * w = [1, 2, 3] + * + * Blocks are 3x3 (one per row): + * Row 0 (no zeros, scale=1*28=28): + * cols (1,4,7): off-diag (0,1)=7, (0,2)=4, (1,0)=7, (1,2)=2, (2,0)=4, (2,1)=2 + * Row 1 (one zero at col 1, prod_nonzero=16, scale=2*16=32): + * only row/col 1 nonzero: (1,0)=32/2=16, (1,2)=32/8=4 (symmetric) + * Row 2 (no zeros, scale=3*162=486): + * cols (3,6,9): off-diag (0,1)=27, (0,2)=18, (1,0)=27, (1,2)=9, (2,0)=18, + * (2,1)=9 + */ + double u_vals[10] = {0.0, 1.0, 2.0, 3.0, 4.0, 0.0, 6.0, 7.0, 8.0, 9.0}; + double w_vals[3] = {1.0, 2.0, 3.0}; + expr *x = new_variable(3, 3, 1, 10); + expr *p = new_prod_axis_one(x); + + p->forward(p, u_vals); + p->wsum_hess_init(p); + p->eval_wsum_hess(p, w_vals); + + double expected_x[18]; + memset(expected_x, 0, sizeof(expected_x)); + + /* Var 1 (row 0, col 0): [7, 4] (excludes col 0) */ + expected_x[0] = 7.0; + expected_x[1] = 4.0; + + /* Var 2 (row 1, col 0): [16, 0] (excludes col 0, one zero at col 1) */ + expected_x[2] = 16.0; + expected_x[3] = 0.0; + + /* Var 3 (row 2, col 0): [27, 18] (excludes col 0) */ + expected_x[4] = 27.0; + expected_x[5] = 18.0; + + /* Var 4 (row 0, col 1): [7, 1] (excludes col 1) */ + expected_x[6] = 7.0; + expected_x[7] = 1.0; + + /* Var 5 (row 1, col 1): [16, 4] (excludes col 1, one zero at col 1) */ + expected_x[8] = 16.0; + expected_x[9] = 4.0; + + /* Var 6 (row 2, col 1): [27, 9] (excludes col 1) */ + expected_x[10] = 27.0; + expected_x[11] = 9.0; + + /* Var 7 (row 0, col 2): [4, 1] (excludes col 2) */ + expected_x[12] = 4.0; + expected_x[13] = 1.0; + + /* Var 8 (row 1, col 2): [0, 4] (excludes col 2, one zero at col 1) */ + expected_x[14] = 0.0; + expected_x[15] = 4.0; + + /* Var 9 (row 2, col 2): [18, 9] (excludes col 2) */ + expected_x[16] = 18.0; + expected_x[17] = 9.0; + + /* Row pointers (monotonically increasing for valid CSR format) */ + int expected_p[11] = {0, 0, 2, 4, 6, 8, 10, 12, 14, 16, 18}; + + /* Column indices (each row of the matrix interacts with its own columns, + * excluding diagonal) */ + int expected_i[18] = {/* Var 1 (row 0, col 0): cols 4,7 (excludes 1) */ + 4, 7, + /* Var 2 (row 1, col 0): cols 5,8 (excludes 2) */ + 5, 8, + /* Var 3 (row 2, col 0): cols 6,9 (excludes 3) */ + 6, 9, + /* Var 4 (row 0, col 1): cols 1,7 (excludes 4) */ + 1, 7, + /* Var 5 (row 1, col 1): cols 2,8 (excludes 5) */ + 2, 8, + /* Var 6 (row 2, col 1): cols 3,9 (excludes 6) */ + 3, 9, + /* Var 7 (row 0, col 2): cols 1,4 (excludes 7) */ + 1, 4, + /* Var 8 (row 1, col 2): cols 2,5 (excludes 8) */ + 2, 5, + /* Var 9 (row 2, col 2): cols 3,6 (excludes 9) */ + 3, 6}; + + mu_assert("vals fail", cmp_double_array(p->wsum_hess->x, expected_x, 18)); + mu_assert("rows fail", cmp_int_array(p->wsum_hess->p, expected_p, 11)); + mu_assert("cols fail", cmp_int_array(p->wsum_hess->i, expected_i, 18)); + + free_expr(p); + return 0; +} + +const char *test_wsum_hess_prod_axis_one_mixed_zeros() +{ + /* x is 5x3 variable, global index 1, total 16 vars + * Rows (axis=1 products): + * r0: [1, 2, 1] -> no zeros, prod=2 + * r1: [1, 0, 0] -> two zeros (cols 1,2), prod_nonzero=1 + * r2: [1, 3, 0] -> one zero (col 2), prod_nonzero=3 + * r3: [1, 4, 2] -> no zeros, prod=8 + * r4: [1, 5, 3] -> no zeros, prod=15 + * w = [1, 2, 3, 4, 5] + * Blocks are 3x3 (one per row, block diagonal): + */ + double u_vals[16] = {0.0, + /* col 0 (rows 0-4): 1,1,1,1,1 */ + 1.0, 1.0, 1.0, 1.0, 1.0, + /* col 1: 2,0,3,4,5 */ + 2.0, 0.0, 3.0, 4.0, 5.0, + /* col 2: 1,0,0,2,3 */ + 1.0, 0.0, 0.0, 2.0, 3.0}; + /* Actually store column-major: + * col0: [1,1,1,1,1] + * col1: [2,0,3,4,5] + * col2: [1,0,0,2,3] + */ + double w_vals[5] = {1.0, 2.0, 3.0, 4.0, 5.0}; + expr *x = new_variable(5, 3, 1, 16); + expr *p = new_prod_axis_one(x); + + p->forward(p, u_vals); + p->wsum_hess_init(p); + p->eval_wsum_hess(p, w_vals); + + double expected_x[30]; + memset(expected_x, 0, sizeof(expected_x)); + + /* For a 5x3 matrix with var_id=1, each row has 2 nnz (d2-1): + * CSR row pointers: p[i] = (i-1)*2 for i in [1,15] + * Var 1 (matrix [0,0]): p[1]=0 + * Var 2 (matrix [1,0]): p[2]=2 + * Var 3 (matrix [2,0]): p[3]=4 + * Var 4 (matrix [3,0]): p[4]=6 + * Var 5 (matrix [4,0]): p[5]=8 + * Var 6 (matrix [0,1]): p[6]=10 + * Var 7 (matrix [1,1]): p[7]=12 + * Var 8 (matrix [2,1]): p[8]=14 + * Var 9 (matrix [3,1]): p[9]=16 + * Var 10 (matrix [4,1]): p[10]=18 + * Var 11 (matrix [0,2]): p[11]=20 + * Var 12 (matrix [1,2]): p[12]=22 + * Var 13 (matrix [2,2]): p[13]=24 + * Var 14 (matrix [3,2]): p[14]=26 + * Var 15 (matrix [4,2]): p[15]=28 + */ + + /* Row 0 block (no zeros, scale = 1 * 2 = 2), x = [1,2,1] */ + /* Var 1 (matrix [0,0], col 0): [1, 2] (excludes col 0) */ + expected_x[0] = 1.0; /* 2/(1*2) */ + expected_x[1] = 2.0; /* 2/(1*1) */ + + /* Var 6 (matrix [0,1], col 1): [1, 1] (excludes col 1) */ + expected_x[10] = 1.0; /* 2/(2*1) */ + expected_x[11] = 1.0; /* 2/(2*1) */ + + /* Var 11 (matrix [0,2], col 2): [2, 1] (excludes col 2) */ + expected_x[20] = 2.0; /* 2/(1*1) */ + expected_x[21] = 1.0; /* 2/(1*2) */ + + /* Row 1 block (two zeros at cols 1,2), hess = w*prod_nonzero = 2*1 = 2 */ + /* Var 2 (matrix [1,0], col 0): [0, 0] (excludes col 0) */ + expected_x[2] = 0.0; + expected_x[3] = 0.0; + + /* Var 7 (matrix [1,1], col 1): [0, 2] (excludes col 1) */ + expected_x[12] = 0.0; + expected_x[13] = 2.0; /* (1,2) */ + + /* Var 12 (matrix [1,2], col 2): [0, 2] (excludes col 2) */ + expected_x[22] = 0.0; + expected_x[23] = 2.0; /* (2,1) */ + + /* Row 2 block (one zero at col 2), w_prod = 3 * 3 = 9, x = [1,3,0] */ + /* Var 3 (matrix [2,0], col 0): [0, 9] (excludes col 0) */ + expected_x[4] = 0.0; + expected_x[5] = 9.0; /* (0,2): w_prod/x[0] */ + + /* Var 8 (matrix [2,1], col 1): [0, 3] (excludes col 1) */ + expected_x[14] = 0.0; + expected_x[15] = 3.0; /* (1,2): w_prod/x[1] */ + + /* Var 13 (matrix [2,2], col 2): [9, 3] (excludes col 2) */ + expected_x[24] = 9.0; /* (2,0): w_prod/x[0] */ + expected_x[25] = 3.0; /* (2,1): w_prod/x[1] */ + + /* Row 3 block (no zeros, scale = 4 * 8 = 32), x = [1,4,2] */ + /* Var 4 (matrix [3,0], col 0): [8, 16] (excludes col 0) */ + expected_x[6] = 8.0; /* 32/(1*4) */ + expected_x[7] = 16.0; /* 32/(1*2) */ + + /* Var 9 (matrix [3,1], col 1): [8, 4] (excludes col 1) */ + expected_x[16] = 8.0; /* 32/(4*1) */ + expected_x[17] = 4.0; /* 32/(4*2) */ + + /* Var 14 (matrix [3,2], col 2): [16, 4] (excludes col 2) */ + expected_x[26] = 16.0; /* 32/(2*1) */ + expected_x[27] = 4.0; /* 32/(2*4) */ + + /* Row 4 block (no zeros, scale = 5 * 15 = 75), x = [1,5,3] */ + /* Var 5 (matrix [4,0], col 0): [15, 25] (excludes col 0) */ + expected_x[8] = 15.0; /* 75/(1*5) */ + expected_x[9] = 25.0; /* 75/(1*3) */ + + /* Var 10 (matrix [4,1], col 1): [15, 5] (excludes col 1) */ + expected_x[18] = 15.0; /* 75/(5*1) */ + expected_x[19] = 5.0; /* 75/(5*3) */ + + /* Var 15 (matrix [4,2], col 2): [25, 5] (excludes col 2) */ + expected_x[28] = 25.0; /* 75/(3*1) */ + expected_x[29] = 5.0; /* 75/(3*5) */ + + /* Row pointers (monotonically increasing for valid CSR format) */ + int expected_p[17] = {0, 0, 2, 4, 6, 8, 10, 12, 14, + 16, 18, 20, 22, 24, 26, 28, 30}; + + /* Column indices (each row of the matrix interacts with its own columns, + * excluding diagonal) */ + int expected_i[30]; + for (int var_idx = 0; var_idx < 15; var_idx++) + { + int matrix_row = var_idx % 5; /* which row of the 5x3 matrix */ + int col_i = var_idx / 5; /* which column this variable is in */ + int nnz_start = var_idx * 2; + /* All columns from matrix_row, excluding col_i */ + int offset = 0; + for (int j = 0; j < 3; j++) + { + if (j != col_i) + { + expected_i[nnz_start + offset] = 1 + matrix_row + j * 5; + offset++; + } + } + } + + mu_assert("vals fail", cmp_double_array(p->wsum_hess->x, expected_x, 30)); + mu_assert("rows fail", cmp_int_array(p->wsum_hess->p, expected_p, 17)); + mu_assert("cols fail", cmp_int_array(p->wsum_hess->i, expected_i, 30)); + + free_expr(p); + return 0; +} +const char *test_wsum_hess_prod_axis_one_2x2() +{ + /* x is 2x2 variable, global index 0, total 4 vars + * x = [2, 1, 3, 2] (column-major) + * [2, 3] + * [1, 2] + * f = prod_axis_one(x) = [2*3, 1*2] = [6, 2] + * w = [1, 1] + * + * For a 2x2 matrix with var_id=0: + * Var 0 (matrix [0,0]): p[0]=0, stores columns [0,1] + * Var 1 (matrix [1,0]): p[1]=2, stores columns [0,1] + * Var 2 (matrix [0,1]): p[2]=4, stores columns [0,1] + * Var 3 (matrix [1,1]): p[3]=6, stores columns [0,1] + * + * Row 0 Hessian (no zeros, x=[2,3], scale=1*6=6): + * (0,0)=0, (0,1)=1 -> stored at Var 0: [0, 1] + * (1,0)=1, (1,1)=0 -> stored at Var 2: [1, 0] + * + * Row 1 Hessian (no zeros, x=[1,2], scale=1*2=2): + * (0,0)=0, (0,1)=1 -> stored at Var 1: [0, 1] + * (1,0)=1, (1,1)=0 -> stored at Var 3: [1, 0] + */ + double u_vals[4] = {2.0, 1.0, 3.0, 2.0}; + double w_vals[2] = {1.0, 1.0}; + expr *x = new_variable(2, 2, 0, 4); + expr *p = new_prod_axis_one(x); + + p->forward(p, u_vals); + p->wsum_hess_init(p); + p->eval_wsum_hess(p, w_vals); + + /* Expected sparse structure (nnz = 4, each row has 1 nnz) */ + double expected_x[4] = {1.0, /* Var 0 (excludes col 0) */ + 1.0, /* Var 1 (excludes col 0) */ + 1.0, /* Var 2 (excludes col 1) */ + 1.0}; /* Var 3 (excludes col 1) */ + + /* Row pointers (each row has 1 nnz) */ + int expected_p[5] = {0, 1, 2, 3, 4}; + + /* Column indices (each variable stores columns for its matrix row, excluding + * diagonal) */ + int expected_i[4] = {2, /* Var 0 (row 0, col 0): only col 1 */ + 3, /* Var 1 (row 1, col 0): only col 1 */ + 0, /* Var 2 (row 0, col 1): only col 0 */ + 1}; /* Var 3 (row 1, col 1): only col 0 */ + + mu_assert("vals fail", cmp_double_array(p->wsum_hess->x, expected_x, 4)); + mu_assert("rows fail", cmp_int_array(p->wsum_hess->p, expected_p, 5)); + mu_assert("cols fail", cmp_int_array(p->wsum_hess->i, expected_i, 4)); + + free_expr(p); + return 0; +}