From 61fb579786fc3f0aa41b1c7fce7e8fbde8a7f583 Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 7 Jan 2026 01:34:11 -0800 Subject: [PATCH 1/3] jacobian and hessian for prod --- include/other.h | 3 + include/subexpr.h | 9 + src/other/prod.c | 279 +++++++++++++++++++++++++++++++ tests/all_tests.c | 9 + tests/jacobian_tests/test_prod.h | 81 +++++++++ tests/wsum_hess/test_prod.h | 123 ++++++++++++++ 6 files changed, 504 insertions(+) create mode 100644 src/other/prod.c create mode 100644 tests/jacobian_tests/test_prod.h create mode 100644 tests/wsum_hess/test_prod.h diff --git a/include/other.h b/include/other.h index 95fb035..d44871a 100644 --- a/include/other.h +++ b/include/other.h @@ -7,4 +7,7 @@ expr *new_quad_form(expr *child, CSR_Matrix *Q); +/* product of all entries, without axis argument */ +expr *new_prod(expr *child); + #endif /* OTHER_H */ diff --git a/include/subexpr.h b/include/subexpr.h index 8c5661b..e6efbc3 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -40,6 +40,15 @@ typedef struct sum_expr struct int_double_pair *int_double_pairs; /* for sorting jacobian entries */ } sum_expr; +/* Product of all entries */ +typedef struct prod_expr +{ + expr base; + int num_of_zeros; + int zero_index; /* index of zero element when num_of_zeros == 1 */ + double prod_nonzero; /* product of non-zero elements when num_of_zeros == 1 */ +} prod_expr; + /* Horizontal stack (concatenate) */ typedef struct hstack_expr { diff --git a/src/other/prod.c b/src/other/prod.c new file mode 100644 index 0000000..52910a7 --- /dev/null +++ b/src/other/prod.c @@ -0,0 +1,279 @@ +#include "other.h" +#include +#include +#include +#include + +#define IS_ZERO(x) (fabs((x)) < 1e-8) + +static inline void wsum_hess_no_zeros(expr *node, const double *w); +static inline void wsum_hess_one_zero(expr *node, const double *w); +static inline void wsum_hess_two_zeros(expr *node, const double *w); +static inline void wsum_hess_many_zeros(expr *node, const double *w); + +static void forward(expr *node, const double *u) +{ + expr *x = node->left; + + /* forward pass of child */ + x->forward(x, u); + + /* local forward pass and count zeros */ + double prod_nonzero = 1.0; + int zeros = 0; + int zero_idx = -1; + for (int i = 0; i < x->size; i++) + { + if (IS_ZERO(x->value[i])) + { + zeros++; + zero_idx = i; + } + else + { + prod_nonzero *= x->value[i]; + } + } + + node->value[0] = (zeros > 0) ? 0.0 : prod_nonzero; + prod_expr *pnode = (prod_expr *) node; + pnode->num_of_zeros = zeros; + pnode->zero_index = zero_idx; + pnode->prod_nonzero = prod_nonzero; +} + +static void jacobian_init(expr *node) +{ + expr *x = node->left; + + /* initialize child's jacobian */ + x->jacobian_init(x); + + /* if x is a variable */ + if (x->var_id != NOT_A_VARIABLE) + { + node->jacobian = new_csr_matrix(1, node->n_vars, x->size); + node->jacobian->p[0] = 0; + node->jacobian->p[1] = x->size; + for (int j = 0; j < x->size; j++) + { + node->jacobian->i[j] = x->var_id + j; + } + } + else + { + assert(false && "not implemented"); + } +} + +static void eval_jacobian(expr *node) +{ + expr *x = node->left; + prod_expr *pnode = (prod_expr *) node; + int num_of_zeros = pnode->num_of_zeros; + + /* if x is a variable */ + if (x->var_id != NOT_A_VARIABLE) + { + if (num_of_zeros == 0) + { + for (int j = 0; j < x->size; j++) + { + node->jacobian->x[j] = node->value[0] / x->value[j]; + } + } + else if (num_of_zeros == 1) + { + memset(node->jacobian->x, 0, sizeof(double) * x->size); + node->jacobian->x[pnode->zero_index] = pnode->prod_nonzero; + } + else + { + memset(node->jacobian->x, 0, sizeof(double) * x->size); + } + } + else + { + assert(false && "not implemented"); + } +} + +static void wsum_hess_init(expr *node) +{ + expr *x = node->left; + + /* if x is a variable */ + if (x->var_id != NOT_A_VARIABLE) + { + /* allocate n_vars x n_vars CSR matrix with dense block */ + int block_size = x->size; + int nnz = block_size * block_size; + node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, nnz); + + /* fill row pointers for the dense block */ + for (int i = 0; i < block_size; i++) + { + node->wsum_hess->p[x->var_id + i] = i * block_size; + } + + /* fill row pointers for rows after the block */ + for (int i = x->var_id + block_size; i <= node->n_vars; i++) + { + node->wsum_hess->p[i] = nnz; + } + + /* fill column indices for the dense block */ + for (int i = 0; i < block_size; i++) + { + for (int j = 0; j < block_size; j++) + { + node->wsum_hess->i[i * block_size + j] = x->var_id + j; + } + } + } + else + { + /* not implemented for non-variable case */ + } +} + +static void eval_wsum_hess(expr *node, const double *w) +{ + expr *x = node->left; + int num_of_zeros = ((prod_expr *) node)->num_of_zeros; + + /* if x is a variable */ + if (x->var_id != NOT_A_VARIABLE) + { + if (num_of_zeros == 0) + { + wsum_hess_no_zeros(node, w); + } + else if (num_of_zeros == 1) + { + wsum_hess_one_zero(node, w); + } + else if (num_of_zeros == 2) + { + wsum_hess_two_zeros(node, w); + } + else + { + wsum_hess_many_zeros(node, w); + } + } + else + { + assert(false && "not implemented"); + } +} + +static bool is_affine(const expr *node) +{ + (void) node; + return false; +} + +static void free_type_data(expr *node) +{ + (void) node; +} + +expr *new_prod(expr *child) +{ + /* Output is scalar: 1 x 1 */ + prod_expr *pnode = (prod_expr *) calloc(1, sizeof(prod_expr)); + expr *node = &pnode->base; + init_expr(node, 1, 1, child->n_vars, forward, jacobian_init, eval_jacobian, + is_affine, free_type_data); + node->wsum_hess_init = wsum_hess_init; + node->eval_wsum_hess = eval_wsum_hess; + node->left = child; + expr_retain(child); + + pnode->num_of_zeros = 0; + + return node; +} + +// --------------------------------------------------------------------------------------- +// Helper functions for Hessian evaluation +// --------------------------------------------------------------------------------------- +static inline void wsum_hess_no_zeros(expr *node, const double *w) +{ + double *x = node->left->value; + int n = node->left->size; + double wf = w[0] * node->value[0]; + + for (int i = 0; i < n; i++) + { + for (int j = 0; j < n; j++) + { + if (i == j) + { + node->wsum_hess->x[i * n + j] = 0.0; + } + else + { + node->wsum_hess->x[i * n + j] = wf / (x[i] * x[j]); + } + } + } +} + +static inline void wsum_hess_one_zero(expr *node, const double *w) +{ + expr *x = node->left; + double *H = node->wsum_hess->x; + memset(H, 0, sizeof(double) * (x->size * x->size)); + int p = ((prod_expr *) node)->zero_index; + double prod_nonzero = ((prod_expr *) node)->prod_nonzero; + double w_prod = w[0] * prod_nonzero; + + /* fill row p and column p */ + for (int j = 0; j < x->size; j++) + { + if (j == p) continue; + + double hess_val = w_prod / x->value[j]; + H[p * x->size + j] = hess_val; + H[j * x->size + p] = hess_val; + } +} + +static inline void wsum_hess_two_zeros(expr *node, const double *w) +{ + expr *x = node->left; + int n = x->size; + memset(node->wsum_hess->x, 0, sizeof(double) * (n * n)); + + /* find indices p and q where x[p] = x[q] = 0 */ + int p = -1, q = -1; + for (int i = 0; i < n; i++) + { + if (IS_ZERO(x->value[i])) + { + if (p == -1) + { + p = i; + } + else + { + q = i; + break; + } + } + } + assert(p != -1 && q != -1); + + double hess_val = w[0] * ((prod_expr *) node)->prod_nonzero; + node->wsum_hess->x[p * n + q] = hess_val; + node->wsum_hess->x[q * n + p] = hess_val; +} + +static inline void wsum_hess_many_zeros(expr *node, const double *w) +{ + expr *x = node->left; + memset(node->wsum_hess->x, 0, sizeof(double) * (x->size * x->size)); + (void) w; +} diff --git a/tests/all_tests.c b/tests/all_tests.c index 16bb1ec..2801647 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -15,6 +15,7 @@ #include "jacobian_tests/test_elementwise_mult.h" #include "jacobian_tests/test_hstack.h" #include "jacobian_tests/test_log.h" +#include "jacobian_tests/test_prod.h" #include "jacobian_tests/test_quad_form.h" #include "jacobian_tests/test_quad_over_lin.h" #include "jacobian_tests/test_rel_entr.h" @@ -31,6 +32,7 @@ #include "wsum_hess/elementwise/test_xexp.h" #include "wsum_hess/test_hstack.h" #include "wsum_hess/test_multiply.h" +#include "wsum_hess/test_prod.h" #include "wsum_hess/test_quad_form.h" #include "wsum_hess/test_quad_over_lin.h" #include "wsum_hess/test_rel_entr.h" @@ -75,6 +77,9 @@ int main(void) mu_run_test(test_quad_form, tests_run); /* commented out - see test_quad_form.h */ // mu_run_test(test_quad_form2, tests_run); + mu_run_test(test_jacobian_prod_no_zero, tests_run); + mu_run_test(test_jacobian_prod_one_zero, tests_run); + mu_run_test(test_jacobian_prod_two_zeros, tests_run); mu_run_test(test_jacobian_sum_log, tests_run); mu_run_test(test_jacobian_sum_mult, tests_run); mu_run_test(test_jacobian_sum_log_axis_0, tests_run); @@ -103,6 +108,10 @@ int main(void) mu_run_test(test_wsum_hess_sum_log_linear, tests_run); mu_run_test(test_wsum_hess_sum_log_axis0, tests_run); mu_run_test(test_wsum_hess_sum_log_axis1, tests_run); + mu_run_test(test_wsum_hess_prod_no_zero, tests_run); + mu_run_test(test_wsum_hess_prod_one_zero, tests_run); + mu_run_test(test_wsum_hess_prod_two_zeros, tests_run); + mu_run_test(test_wsum_hess_prod_many_zeros, tests_run); mu_run_test(test_wsum_hess_rel_entr_1, tests_run); mu_run_test(test_wsum_hess_rel_entr_2, tests_run); mu_run_test(test_wsum_hess_hstack, tests_run); diff --git a/tests/jacobian_tests/test_prod.h b/tests/jacobian_tests/test_prod.h new file mode 100644 index 0000000..8fbfe99 --- /dev/null +++ b/tests/jacobian_tests/test_prod.h @@ -0,0 +1,81 @@ +#include + +#include "expr.h" +#include "minunit.h" +#include "other.h" +#include "test_helpers.h" + +const char *test_jacobian_prod_no_zero() +{ + /* x is 4x1 variable, global index 2, total 8 vars + * x = [1, 2, 3, 4] + * f = prod(x) = 24 + * df/dx = [24, 12, 8, 6] + */ + double u_vals[8] = {0.0, 0.0, 1.0, 2.0, 3.0, 4.0, 0.0, 0.0}; + expr *x = new_variable(4, 1, 2, 8); + expr *p = new_prod(x); + + p->forward(p, u_vals); + p->jacobian_init(p); + p->eval_jacobian(p); + + double expected_Ax[4] = {24.0, 12.0, 8.0, 6.0}; + int expected_Ap[2] = {0, 4}; + int expected_Ai[4] = {2, 3, 4, 5}; + + mu_assert("vals fail", cmp_double_array(p->jacobian->x, expected_Ax, 4)); + mu_assert("rows fail", cmp_int_array(p->jacobian->p, expected_Ap, 2)); + mu_assert("cols fail", cmp_int_array(p->jacobian->i, expected_Ai, 4)); + + free_expr(p); + return 0; +} + +const char *test_jacobian_prod_one_zero() +{ + /* x = [1, 0, 3, 4], zero at index 1 + * df/dx = [0, prod_nonzero, 0, 0] = [0, 12, 0, 0] + */ + double u_vals[8] = {0.0, 0.0, 1.0, 0.0, 3.0, 4.0, 0.0, 0.0}; + expr *x = new_variable(4, 1, 2, 8); + expr *p = new_prod(x); + + p->forward(p, u_vals); + p->jacobian_init(p); + p->eval_jacobian(p); + + double expected_Ax[4] = {0.0, 12.0, 0.0, 0.0}; + int expected_Ap[2] = {0, 4}; + int expected_Ai[4] = {2, 3, 4, 5}; + + mu_assert("vals fail", cmp_double_array(p->jacobian->x, expected_Ax, 4)); + mu_assert("rows fail", cmp_int_array(p->jacobian->p, expected_Ap, 2)); + mu_assert("cols fail", cmp_int_array(p->jacobian->i, expected_Ai, 4)); + + free_expr(p); + return 0; +} + +const char *test_jacobian_prod_two_zeros() +{ + /* x = [1, 0, 0, 4], two zeros -> Jacobian all zeros */ + double u_vals[8] = {0.0, 0.0, 1.0, 0.0, 0.0, 4.0, 0.0, 0.0}; + expr *x = new_variable(4, 1, 2, 8); + expr *p = new_prod(x); + + p->forward(p, u_vals); + p->jacobian_init(p); + p->eval_jacobian(p); + + double expected_Ax[4] = {0.0, 0.0, 0.0, 0.0}; + int expected_Ap[2] = {0, 4}; + int expected_Ai[4] = {2, 3, 4, 5}; + + mu_assert("vals fail", cmp_double_array(p->jacobian->x, expected_Ax, 4)); + mu_assert("rows fail", cmp_int_array(p->jacobian->p, expected_Ap, 2)); + mu_assert("cols fail", cmp_int_array(p->jacobian->i, expected_Ai, 4)); + + free_expr(p); + return 0; +} diff --git a/tests/wsum_hess/test_prod.h b/tests/wsum_hess/test_prod.h new file mode 100644 index 0000000..33a41b0 --- /dev/null +++ b/tests/wsum_hess/test_prod.h @@ -0,0 +1,123 @@ +#include +#include + +#include "expr.h" +#include "minunit.h" +#include "other.h" +#include "test_helpers.h" + +/* Common setup: x is 4x1 variable, global index 2, total 8 vars */ + +const char *test_wsum_hess_prod_no_zero() +{ + /* x = [1, 2, 3, 4], f = 24 */ + double u_vals[8] = {0.0, 0.0, 1.0, 2.0, 3.0, 4.0, 0.0, 0.0}; + double w = 1.0; + expr *x = new_variable(4, 1, 2, 8); + expr *p = new_prod(x); + + p->forward(p, u_vals); + p->wsum_hess_init(p); + p->eval_wsum_hess(p, &w); + + /* Row-major over dense 4x4 block */ + double expected_x[16] = {0.0, 12.0, 8.0, 6.0, 12.0, 0.0, 4.0, 3.0, + 8.0, 4.0, 0.0, 2.0, 6.0, 3.0, 2.0, 0.0}; + + int expected_p[9] = {0, 0, 0, 4, 8, 12, 16, 16, 16}; + int expected_i[16] = {2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5}; + + mu_assert("vals fail", cmp_double_array(p->wsum_hess->x, expected_x, 16)); + mu_assert("rows fail", cmp_int_array(p->wsum_hess->p, expected_p, 9)); + mu_assert("cols fail", cmp_int_array(p->wsum_hess->i, expected_i, 16)); + + free_expr(p); + return 0; +} + +const char *test_wsum_hess_prod_one_zero() +{ + /* x = [1, 0, 3, 4], zero at index 1, prod_nonzero = 12 */ + double u_vals[8] = {0.0, 0.0, 1.0, 0.0, 3.0, 4.0, 0.0, 0.0}; + double w = 1.0; + expr *x = new_variable(4, 1, 2, 8); + expr *p = new_prod(x); + + p->forward(p, u_vals); + p->wsum_hess_init(p); + p->eval_wsum_hess(p, &w); + + double expected_x[16]; + memset(expected_x, 0, sizeof(expected_x)); + /* Row 1 / col 1 nonzeros */ + expected_x[1 * 4 + 0] = 12.0; /* (1,0) */ + expected_x[1 * 4 + 2] = 4.0; /* (1,2) */ + expected_x[1 * 4 + 3] = 3.0; /* (1,3) */ + expected_x[0 * 4 + 1] = 12.0; /* symmetric */ + expected_x[2 * 4 + 1] = 4.0; + expected_x[3 * 4 + 1] = 3.0; + + int expected_p[9] = {0, 0, 0, 4, 8, 12, 16, 16, 16}; + int expected_i[16] = {2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5}; + + mu_assert("vals fail", cmp_double_array(p->wsum_hess->x, expected_x, 16)); + mu_assert("rows fail", cmp_int_array(p->wsum_hess->p, expected_p, 9)); + mu_assert("cols fail", cmp_int_array(p->wsum_hess->i, expected_i, 16)); + + free_expr(p); + return 0; +} + +const char *test_wsum_hess_prod_two_zeros() +{ + /* x = [1, 0, 0, 4], zeros at 1 and 2, prod over others = 4 */ + double u_vals[8] = {0.0, 0.0, 1.0, 0.0, 0.0, 4.0, 0.0, 0.0}; + double w = 1.0; + expr *x = new_variable(4, 1, 2, 8); + expr *p = new_prod(x); + + p->forward(p, u_vals); + p->wsum_hess_init(p); + p->eval_wsum_hess(p, &w); + + double expected_x[16]; + memset(expected_x, 0, sizeof(expected_x)); + expected_x[1 * 4 + 2] = 4.0; /* (1,2) */ + expected_x[2 * 4 + 1] = 4.0; /* (2,1) */ + + int expected_p[9] = {0, 0, 0, 4, 8, 12, 16, 16, 16}; + int expected_i[16] = {2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5}; + + mu_assert("vals fail", cmp_double_array(p->wsum_hess->x, expected_x, 16)); + mu_assert("rows fail", cmp_int_array(p->wsum_hess->p, expected_p, 9)); + mu_assert("cols fail", cmp_int_array(p->wsum_hess->i, expected_i, 16)); + + free_expr(p); + return 0; +} + +const char *test_wsum_hess_prod_many_zeros() +{ + /* x = [0, 0, 0, 4], three zeros => Hessian all zeros */ + double u_vals[8] = {0.0, 0.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0}; + double w = 1.0; + expr *x = new_variable(4, 1, 2, 8); + expr *p = new_prod(x); + + p->forward(p, u_vals); + p->wsum_hess_init(p); + p->eval_wsum_hess(p, &w); + + double expected_x[16]; + memset(expected_x, 0, sizeof(expected_x)); + + int expected_p[9] = {0, 0, 0, 4, 8, 12, 16, 16, 16}; + int expected_i[16] = {2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5}; + + mu_assert("vals fail", cmp_double_array(p->wsum_hess->x, expected_x, 16)); + mu_assert("rows fail", cmp_int_array(p->wsum_hess->p, expected_p, 9)); + mu_assert("cols fail", cmp_int_array(p->wsum_hess->i, expected_i, 16)); + + free_expr(p); + return 0; +} From 08da38b7106cbf80220e9f992e60ae31dd481cbe Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 7 Jan 2026 01:36:07 -0800 Subject: [PATCH 2/3] added not implemented assertion --- src/other/prod.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/other/prod.c b/src/other/prod.c index 52910a7..c065c42 100644 --- a/src/other/prod.c +++ b/src/other/prod.c @@ -133,7 +133,7 @@ static void wsum_hess_init(expr *node) } else { - /* not implemented for non-variable case */ + assert(false && "not implemented"); } } From d824b83d907c6bef14843727d7c86b512e4c4647 Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 7 Jan 2026 01:41:32 -0800 Subject: [PATCH 3/3] minor --- include/subexpr.h | 2 +- src/other/prod.c | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/include/subexpr.h b/include/subexpr.h index e6efbc3..4b29416 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -46,7 +46,7 @@ typedef struct prod_expr expr base; int num_of_zeros; int zero_index; /* index of zero element when num_of_zeros == 1 */ - double prod_nonzero; /* product of non-zero elements when num_of_zeros == 1 */ + double prod_nonzero; /* product of non-zero elements */ } prod_expr; /* Horizontal stack (concatenate) */ diff --git a/src/other/prod.c b/src/other/prod.c index c065c42..92fbe13 100644 --- a/src/other/prod.c +++ b/src/other/prod.c @@ -190,9 +190,6 @@ expr *new_prod(expr *child) node->eval_wsum_hess = eval_wsum_hess; node->left = child; expr_retain(child); - - pnode->num_of_zeros = 0; - return node; }