From af1f9a31dc889206d506c0b2312bc35a00a90c7e Mon Sep 17 00:00:00 2001 From: Daniel Date: Sat, 3 Jan 2026 23:31:03 -0800 Subject: [PATCH 1/5] relative entropy hessian --- src/bivariate/rel_entr.c | 96 +++++++++++++++++++ tests/all_tests.c | 19 ++-- tests/wsum_hess/{ => elementwise}/test_entr.h | 0 tests/wsum_hess/{ => elementwise}/test_exp.h | 0 .../{ => elementwise}/test_hyperbolic.h | 0 tests/wsum_hess/{ => elementwise}/test_log.h | 0 .../{ => elementwise}/test_logistic.h | 0 .../wsum_hess/{ => elementwise}/test_power.h | 0 tests/wsum_hess/{ => elementwise}/test_trig.h | 0 tests/wsum_hess/{ => elementwise}/test_xexp.h | 0 tests/wsum_hess/test_rel_entr.h | 85 ++++++++++++++++ 11 files changed, 192 insertions(+), 8 deletions(-) rename tests/wsum_hess/{ => elementwise}/test_entr.h (100%) rename tests/wsum_hess/{ => elementwise}/test_exp.h (100%) rename tests/wsum_hess/{ => elementwise}/test_hyperbolic.h (100%) rename tests/wsum_hess/{ => elementwise}/test_log.h (100%) rename tests/wsum_hess/{ => elementwise}/test_logistic.h (100%) rename tests/wsum_hess/{ => elementwise}/test_power.h (100%) rename tests/wsum_hess/{ => elementwise}/test_trig.h (100%) rename tests/wsum_hess/{ => elementwise}/test_xexp.h (100%) create mode 100644 tests/wsum_hess/test_rel_entr.h diff --git a/src/bivariate/rel_entr.c b/src/bivariate/rel_entr.c index 70ae0f1..bddf777 100644 --- a/src/bivariate/rel_entr.c +++ b/src/bivariate/rel_entr.c @@ -1,4 +1,5 @@ #include "bivariate.h" +#include #include #include @@ -28,6 +29,8 @@ static void jacobian_init_vectors_args(expr *node) expr *x = node->left; expr *y = node->right; + assert(x->var_id != NOT_A_VARIABLE && y->var_id != NOT_A_VARIABLE); + assert(x->var_id != y->var_id); /* if x has lower variable idx than y it should appear first */ if (x->var_id < y->var_id) @@ -76,6 +79,97 @@ static void eval_jacobian_vector_args(expr *node) } } +static void wsum_hess_init_vector_args(expr *node) +{ + node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, 4 * node->d1); + expr *x = node->left; + expr *y = node->right; + + int i, var1_id, var2_id; + + if (x->var_id < y->var_id) + { + var1_id = x->var_id; + var2_id = y->var_id; + } + else + { + var1_id = y->var_id; + var2_id = x->var_id; + } + + /* var1 rows of Hessian */ + for (i = 0; i < node->d1; i++) + { + node->wsum_hess->p[var1_id + i] = 2 * i; + node->wsum_hess->i[2 * i] = var1_id + i; + node->wsum_hess->i[2 * i + 1] = var2_id + i; + } + + int nnz = 2 * node->d1; + + /* rows between var1 and var2 */ + for (i = var1_id + node->d1; i < var2_id; i++) + { + node->wsum_hess->p[i] = nnz; + } + + /* var2 rows of Hessian */ + for (i = 0; i < node->d1; i++) + { + node->wsum_hess->p[var2_id + i] = nnz + 2 * i; + } + memcpy(node->wsum_hess->i + nnz, node->wsum_hess->i, nnz * sizeof(int)); + + /* remaining rows */ + for (i = var2_id + node->d1; i <= node->n_vars; i++) + { + node->wsum_hess->p[i] = 4 * node->d1; + } +} + +static void eval_wsum_hess_vector_args(expr *node, const double *w) +{ + + int i, x_id, y_id; + double *x = node->left->value; + double *y = node->right->value; + double *hess = node->wsum_hess->x; + + if (node->left->var_id < node->right->var_id) + { + for (i = 0; i < node->d1; i++) + { + hess[2 * i] = w[i] / x[i]; + hess[2 * i + 1] = -w[i] / y[i]; + } + + hess += 2 * node->d1; + + for (i = 0; i < node->d1; i++) + { + hess[2 * i] = -w[i] / y[i]; + hess[2 * i + 1] = w[i] * x[i] / (y[i] * y[i]); + } + } + else + { + for (i = 0; i < node->d1; i++) + { + hess[2 * i] = w[i] * x[i] / (y[i] * y[i]); + hess[2 * i + 1] = -w[i] / y[i]; + } + + hess += 2 * node->d1; + + for (i = 0; i < node->d1; i++) + { + hess[2 * i] = -w[i] / y[i]; + hess[2 * i + 1] = w[i] / x[i]; + } + } +} + expr *new_rel_entr_vector_args(expr *left, expr *right) { expr *node = new_expr(left->d1, 1, left->n_vars); @@ -86,6 +180,8 @@ expr *new_rel_entr_vector_args(expr *left, expr *right) node->forward = forward_vector_args; node->jacobian_init = jacobian_init_vectors_args; node->eval_jacobian = eval_jacobian_vector_args; + node->wsum_hess_init = wsum_hess_init_vector_args; + node->eval_wsum_hess = eval_wsum_hess_vector_args; // node->is_affine = is_affine_elementwise; // node->local_jacobian = local_jacobian; return node; diff --git a/tests/all_tests.c b/tests/all_tests.c index 856aa77..d01eb5f 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -21,15 +21,16 @@ #include "jacobian_tests/test_sum.h" #include "utils/test_csc_matrix.h" #include "utils/test_csr_matrix.h" -#include "wsum_hess/test_entr.h" -#include "wsum_hess/test_exp.h" -#include "wsum_hess/test_hyperbolic.h" -#include "wsum_hess/test_log.h" -#include "wsum_hess/test_logistic.h" -#include "wsum_hess/test_power.h" +#include "wsum_hess/elementwise/test_entr.h" +#include "wsum_hess/elementwise/test_exp.h" +#include "wsum_hess/elementwise/test_hyperbolic.h" +#include "wsum_hess/elementwise/test_log.h" +#include "wsum_hess/elementwise/test_logistic.h" +#include "wsum_hess/elementwise/test_power.h" +#include "wsum_hess/elementwise/test_trig.h" +#include "wsum_hess/elementwise/test_xexp.h" +#include "wsum_hess/test_rel_entr.h" #include "wsum_hess/test_sum.h" -#include "wsum_hess/test_trig.h" -#include "wsum_hess/test_xexp.h" int main(void) { @@ -95,6 +96,8 @@ int main(void) mu_run_test(test_wsum_hess_sum_log_linear, tests_run); mu_run_test(test_wsum_hess_sum_log_axis0, tests_run); mu_run_test(test_wsum_hess_sum_log_axis1, tests_run); + mu_run_test(test_wsum_hess_rel_entr_1, tests_run); + mu_run_test(test_wsum_hess_rel_entr_2, tests_run); printf("\n--- Utility Tests ---\n"); mu_run_test(test_diag_csr_mult, tests_run); diff --git a/tests/wsum_hess/test_entr.h b/tests/wsum_hess/elementwise/test_entr.h similarity index 100% rename from tests/wsum_hess/test_entr.h rename to tests/wsum_hess/elementwise/test_entr.h diff --git a/tests/wsum_hess/test_exp.h b/tests/wsum_hess/elementwise/test_exp.h similarity index 100% rename from tests/wsum_hess/test_exp.h rename to tests/wsum_hess/elementwise/test_exp.h diff --git a/tests/wsum_hess/test_hyperbolic.h b/tests/wsum_hess/elementwise/test_hyperbolic.h similarity index 100% rename from tests/wsum_hess/test_hyperbolic.h rename to tests/wsum_hess/elementwise/test_hyperbolic.h diff --git a/tests/wsum_hess/test_log.h b/tests/wsum_hess/elementwise/test_log.h similarity index 100% rename from tests/wsum_hess/test_log.h rename to tests/wsum_hess/elementwise/test_log.h diff --git a/tests/wsum_hess/test_logistic.h b/tests/wsum_hess/elementwise/test_logistic.h similarity index 100% rename from tests/wsum_hess/test_logistic.h rename to tests/wsum_hess/elementwise/test_logistic.h diff --git a/tests/wsum_hess/test_power.h b/tests/wsum_hess/elementwise/test_power.h similarity index 100% rename from tests/wsum_hess/test_power.h rename to tests/wsum_hess/elementwise/test_power.h diff --git a/tests/wsum_hess/test_trig.h b/tests/wsum_hess/elementwise/test_trig.h similarity index 100% rename from tests/wsum_hess/test_trig.h rename to tests/wsum_hess/elementwise/test_trig.h diff --git a/tests/wsum_hess/test_xexp.h b/tests/wsum_hess/elementwise/test_xexp.h similarity index 100% rename from tests/wsum_hess/test_xexp.h rename to tests/wsum_hess/elementwise/test_xexp.h diff --git a/tests/wsum_hess/test_rel_entr.h b/tests/wsum_hess/test_rel_entr.h new file mode 100644 index 0000000..4934806 --- /dev/null +++ b/tests/wsum_hess/test_rel_entr.h @@ -0,0 +1,85 @@ +#include "affine.h" +#include "bivariate.h" +#include "expr.h" +#include "minunit.h" +#include "test_helpers.h" +#include +#include + +const char *test_wsum_hess_rel_entr_1() +{ + // Total 10 variables: [?, x0, x1, x2, ?, ?, y0, y1, y2, ?] + // x has var_id = 1, y has var_id = 6 + // x = [1, 2, 3], y = [4, 5, 6] + // w = [1, 2, 3] + + double u_vals[10] = {0, 1.0, 2.0, 3.0, 0, 0, 4.0, 5.0, 6.0, 0}; + double w[3] = {1.0, 2.0, 3.0}; + + expr *x = new_variable(3, 1, 1, 10); + expr *y = new_variable(3, 1, 6, 10); + expr *node = new_rel_entr_vector_args(x, y); + + node->forward(node, u_vals); + node->wsum_hess_init(node); + node->eval_wsum_hess(node, w); + + int expected_p[11] = {0, 0, 2, 4, 6, 6, 6, 8, 10, 12, 12}; + int expected_i[12] = {1, 6, 2, 7, 3, 8, 1, 6, 2, 7, 3, 8}; + double expected_x[12] = { + 1.0, -0.25, // row 1 + 1.0, -0.4, // row 2 + 1.0, -0.5, // row 3 + -0.25, 0.0625, // row 6 + -0.4, 0.16, // row 7 + -0.5, 0.25 // row 8 + }; + + mu_assert("p array fails", cmp_int_array(node->wsum_hess->p, expected_p, 11)); + mu_assert("i array fails", cmp_int_array(node->wsum_hess->i, expected_i, 12)); + mu_assert("x array fails", cmp_double_array(node->wsum_hess->x, expected_x, 12)); + + free_expr(node); + free_expr(x); + free_expr(y); + return 0; +} + +const char *test_wsum_hess_rel_entr_2() +{ + // Total 10 variables: [?, y0, y1, y2, ?, ?, x0, x1, x2, ?] + // y has var_id = 1, x has var_id = 6 + // x = [1, 2, 3], y = [4, 5, 6] + // w = [1, 2, 3] + + double u_vals[10] = {0, 4.0, 5.0, 6.0, 0, 0, 1.0, 2.0, 3.0, 0}; + double w[3] = {1.0, 2.0, 3.0}; + + expr *x = new_variable(3, 1, 6, 10); + expr *y = new_variable(3, 1, 1, 10); + expr *node = new_rel_entr_vector_args(x, y); + + node->forward(node, u_vals); + node->wsum_hess_init(node); + node->eval_wsum_hess(node, w); + + int expected_p[11] = {0, 0, 2, 4, 6, 6, 6, 8, 10, 12, 12}; + int expected_i[12] = {1, 6, 2, 7, 3, 8, 1, 6, 2, 7, 3, 8}; + double expected_x[12] = { + 0.0625, -0.25, // row 1 + 0.16, -0.4, // row 2 + 0.25, -0.5, // row 3 + -0.25, 1.0, // row 6 + -0.4, 1.0, // row 7 + -0.5, 1.0 // row 8 + }; + + mu_assert("p array fails", cmp_int_array(node->wsum_hess->p, expected_p, 11)); + mu_assert("i array fails", cmp_int_array(node->wsum_hess->i, expected_i, 12)); + mu_assert("x array fails", cmp_double_array(node->wsum_hess->x, expected_x, 12)); + + free_expr(node); + free_expr(x); + free_expr(y); + return 0; +} From 62327cd12d643a658f1ab60a9505e7da667792a2 Mon Sep 17 00:00:00 2001 From: Daniel Date: Sat, 3 Jan 2026 23:35:35 -0800 Subject: [PATCH 2/5] removed unused variable --- src/bivariate/rel_entr.c | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/bivariate/rel_entr.c b/src/bivariate/rel_entr.c index bddf777..9a5b667 100644 --- a/src/bivariate/rel_entr.c +++ b/src/bivariate/rel_entr.c @@ -130,15 +130,13 @@ static void wsum_hess_init_vector_args(expr *node) static void eval_wsum_hess_vector_args(expr *node, const double *w) { - - int i, x_id, y_id; double *x = node->left->value; double *y = node->right->value; double *hess = node->wsum_hess->x; if (node->left->var_id < node->right->var_id) { - for (i = 0; i < node->d1; i++) + for (int i = 0; i < node->d1; i++) { hess[2 * i] = w[i] / x[i]; hess[2 * i + 1] = -w[i] / y[i]; @@ -146,7 +144,7 @@ static void eval_wsum_hess_vector_args(expr *node, const double *w) hess += 2 * node->d1; - for (i = 0; i < node->d1; i++) + for (int i = 0; i < node->d1; i++) { hess[2 * i] = -w[i] / y[i]; hess[2 * i + 1] = w[i] * x[i] / (y[i] * y[i]); @@ -154,7 +152,7 @@ static void eval_wsum_hess_vector_args(expr *node, const double *w) } else { - for (i = 0; i < node->d1; i++) + for (int i = 0; i < node->d1; i++) { hess[2 * i] = w[i] * x[i] / (y[i] * y[i]); hess[2 * i + 1] = -w[i] / y[i]; @@ -162,7 +160,7 @@ static void eval_wsum_hess_vector_args(expr *node, const double *w) hess += 2 * node->d1; - for (i = 0; i < node->d1; i++) + for (int i = 0; i < node->d1; i++) { hess[2 * i] = -w[i] / y[i]; hess[2 * i + 1] = w[i] / x[i]; From 315c0a1dc2af26dc537c33a1cc80377b189f74aa Mon Sep 17 00:00:00 2001 From: Daniel Date: Sat, 3 Jan 2026 23:38:19 -0800 Subject: [PATCH 3/5] minor --- src/bivariate/multiply.c | 2 +- src/bivariate/rel_entr.c | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/bivariate/multiply.c b/src/bivariate/multiply.c index eae64ed..a7f81c1 100644 --- a/src/bivariate/multiply.c +++ b/src/bivariate/multiply.c @@ -27,7 +27,7 @@ static void jacobian_init(expr *node) { /* if a child is a variable we initialize its jacobian for a short chain rule implementation */ - if (node->left->var_id != -1) + if (node->left->var_id != NOT_A_VARIABLE) { node->left->jacobian_init(node->left); } diff --git a/src/bivariate/rel_entr.c b/src/bivariate/rel_entr.c index 9a5b667..59f385d 100644 --- a/src/bivariate/rel_entr.c +++ b/src/bivariate/rel_entr.c @@ -2,6 +2,7 @@ #include #include #include +#include // -------------------------------------------------------------------- // Implementation of relative entropy when both arguments are vectors. From 2cd8f348cff841556fe32b9e204416f3f67c6751 Mon Sep 17 00:00:00 2001 From: Daniel Date: Sun, 4 Jan 2026 00:22:38 -0800 Subject: [PATCH 4/5] added hessian for hstack --- include/subexpr.h | 1 + include/utils/CSR_Matrix.h | 3 +- src/affine/hstack.c | 38 +++++- src/utils/CSR_Matrix.c | 11 +- tests/all_tests.c | 3 + tests/wsum_hess/test_hstack.h | 225 ++++++++++++++++++++++++++++++++++ 6 files changed, 275 insertions(+), 6 deletions(-) create mode 100644 tests/wsum_hess/test_hstack.h diff --git a/include/subexpr.h b/include/subexpr.h index 4f56ea4..033fd7f 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -46,6 +46,7 @@ typedef struct hstack_expr expr base; expr **args; int n_args; + CSR_Matrix *CSR_work; /* for summing Hessians of children */ } hstack_expr; #endif /* SUBEXPR_H */ diff --git a/include/utils/CSR_Matrix.h b/include/utils/CSR_Matrix.h index beeb7bd..89f83f2 100644 --- a/include/utils/CSR_Matrix.h +++ b/include/utils/CSR_Matrix.h @@ -55,7 +55,8 @@ void diag_csr_mult(const double *d, const CSR_Matrix *A, CSR_Matrix *C); /* Compute C = A + B where A, B, C are CSR matrices * A and B must have same dimensions - * C must be pre-allocated with sufficient nnz capacity */ + * C must be pre-allocated with sufficient nnz capacity. + * C must be different from A and B */ void sum_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C); /* Compute C = diag(d1) * A + diag(d2) * B where A, B, C are CSR matrices */ diff --git a/src/affine/hstack.c b/src/affine/hstack.c index a00863c..548806a 100644 --- a/src/affine/hstack.c +++ b/src/affine/hstack.c @@ -37,7 +37,7 @@ static void jacobian_init(expr *node) node->jacobian = new_csr_matrix(node->size, node->n_vars, nnz); - /* precompute sparsity pattern of this jacobian's node */ + /* precompute sparsity pattern of this node's jacobian */ int row_offset = 0; CSR_Matrix *A = node->jacobian; A->nnz = 0; @@ -80,6 +80,40 @@ static void eval_jacobian(expr *node) } } +static void wsum_hess_init(expr *node) +{ + /* initialize children's hessians */ + hstack_expr *hnode = (hstack_expr *) node; + int nnz = 0; + for (int i = 0; i < hnode->n_args; i++) + { + hnode->args[i]->wsum_hess_init(hnode->args[i]); + nnz += hnode->args[i]->wsum_hess->nnz; + } + + /* worst-case scenario the nnz of node->wsum_hess is the sum of children's + nnz */ + node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, nnz); + hnode->CSR_work = new_csr_matrix(node->n_vars, node->n_vars, nnz); +} + +static void wsum_hess_eval(expr *node, const double *w) +{ + hstack_expr *hnode = (hstack_expr *) node; + CSR_Matrix *H = node->wsum_hess; + int row_offset = 0; + H->nnz = 0; + + for (int i = 0; i < hnode->n_args; i++) + { + expr *child = hnode->args[i]; + child->eval_wsum_hess(child, w + row_offset); + copy_csr_matrix(H, hnode->CSR_work); + sum_csr_matrices(hnode->CSR_work, child->wsum_hess, H); + row_offset += child->size; + } +} + static bool is_affine(const expr *node) { const hstack_expr *hnode = (const hstack_expr *) node; @@ -121,6 +155,8 @@ expr *new_hstack(expr **args, int n_args, int n_vars) /* Set type-specific fields */ hnode->args = args; hnode->n_args = n_args; + node->wsum_hess_init = wsum_hess_init; + node->eval_wsum_hess = wsum_hess_eval; for (int i = 0; i < n_args; i++) { diff --git a/src/utils/CSR_Matrix.c b/src/utils/CSR_Matrix.c index cbf7ce9..620180f 100644 --- a/src/utils/CSR_Matrix.c +++ b/src/utils/CSR_Matrix.c @@ -99,6 +99,9 @@ void diag_csr_mult(const double *d, const CSR_Matrix *A, CSR_Matrix *C) void sum_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C) { + /* A and B must be different from C */ + assert(A != C && B != C); + C->nnz = 0; for (int row = 0; row < A->m; row++) @@ -138,8 +141,8 @@ void sum_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C) if (a_ptr < a_end) { int a_remaining = a_end - a_ptr; - memmove(C->i + C->nnz, A->i + a_ptr, a_remaining * sizeof(int)); - memmove(C->x + C->nnz, A->x + a_ptr, a_remaining * sizeof(double)); + memcpy(C->i + C->nnz, A->i + a_ptr, a_remaining * sizeof(int)); + memcpy(C->x + C->nnz, A->x + a_ptr, a_remaining * sizeof(double)); C->nnz += a_remaining; } @@ -147,8 +150,8 @@ void sum_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C) if (b_ptr < b_end) { int b_remaining = b_end - b_ptr; - memmove(C->i + C->nnz, B->i + b_ptr, b_remaining * sizeof(int)); - memmove(C->x + C->nnz, B->x + b_ptr, b_remaining * sizeof(double)); + memcpy(C->i + C->nnz, B->i + b_ptr, b_remaining * sizeof(int)); + memcpy(C->x + C->nnz, B->x + b_ptr, b_remaining * sizeof(double)); C->nnz += b_remaining; } } diff --git a/tests/all_tests.c b/tests/all_tests.c index d01eb5f..d57c183 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -29,6 +29,7 @@ #include "wsum_hess/elementwise/test_power.h" #include "wsum_hess/elementwise/test_trig.h" #include "wsum_hess/elementwise/test_xexp.h" +#include "wsum_hess/test_hstack.h" #include "wsum_hess/test_rel_entr.h" #include "wsum_hess/test_sum.h" @@ -98,6 +99,8 @@ int main(void) mu_run_test(test_wsum_hess_sum_log_axis1, tests_run); mu_run_test(test_wsum_hess_rel_entr_1, tests_run); mu_run_test(test_wsum_hess_rel_entr_2, tests_run); + mu_run_test(test_wsum_hess_hstack, tests_run); + mu_run_test(test_wsum_hess_hstack_matrix, tests_run); printf("\n--- Utility Tests ---\n"); mu_run_test(test_diag_csr_mult, tests_run); diff --git a/tests/wsum_hess/test_hstack.h b/tests/wsum_hess/test_hstack.h new file mode 100644 index 0000000..b819fef --- /dev/null +++ b/tests/wsum_hess/test_hstack.h @@ -0,0 +1,225 @@ +#include "affine.h" +#include "elementwise_univariate.h" +#include "expr.h" +#include "minunit.h" +#include "test_helpers.h" +#include +#include + +const char *test_wsum_hess_hstack() +{ + /* Test: hstack([log(x), log(z), exp(x), sin(y)]) + * Variables: x at idx 0, z at idx 3, y at idx 6 + * x = [1, 2, 3], z = [4, 5, 6], y = [7, 8, 9] + * Total 9 variables + * hStacked vectorized output is 12x1: [log(x), + * log(z), + * exp(x), + * sin(y)] + * w = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] + */ + + double u_vals[9] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; + double w[12] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0}; + + expr *x = new_variable(3, 1, 0, 9); + expr *z = new_variable(3, 1, 3, 9); + expr *y = new_variable(3, 1, 6, 9); + + expr *log_x = new_log(x); + expr *log_z = new_log(z); + expr *exp_x = new_exp(x); + expr *sin_y = new_sin(y); + + expr *args[4] = {log_x, log_z, exp_x, sin_y}; + expr *hstack_node = new_hstack(args, 4, 9); + + hstack_node->forward(hstack_node, u_vals); + hstack_node->wsum_hess_init(hstack_node); + hstack_node->eval_wsum_hess(hstack_node, w); + + /* Expected Hessian: + * log(x): d²/dx² = -1/x² + * w[0] * (-1/1²) = 1 * (-1) = -1 at (0,0) + * w[1] * (-1/2²) = 2 * (-0.25) = -0.5 at (1,1) + * w[2] * (-1/3²) = 3 * (-1/9) = -1/3 at (2,2) + * + * log(z): d²/dz² = -1/z² + * w[3] * (-1/4²) = 4 * (-1/16) = -0.25 at (3,3) + * w[4] * (-1/5²) = 5 * (-1/25) = -0.2 at (4,4) + * w[5] * (-1/6²) = 6 * (-1/36) = -1/6 at (5,5) + * + * exp(x): d²/dx² = exp(x) + * w[6] * exp(1) = 7 * e at (0,0) + * w[7] * exp(2) = 8 * e² at (1,1) + * w[8] * exp(3) = 9 * e³ at (2,2) + * + * sin(y): d²/dy² = -sin(y) + * w[9] * (-sin(7)) at (6,6) + * w[10] * (-sin(8)) at (7,7) + * w[11] * (-sin(9)) at (8,8) + * + * Accumulated: + * (0,0): -1 + 7*e + * (1,1): -0.5 + 8*e² + * (2,2): -1/3 + 9*e³ + * (3,3): -0.25 + * (4,4): -0.2 + * (5,5): -1/6 + * (6,6): -10*sin(7) + * (7,7): -11*sin(8) + * (8,8): -12*sin(9) + */ + + double e = exp(1.0); + double e2 = exp(2.0); + double e3 = exp(3.0); + + double expected_x[9] = {-1.0 + 7.0 * e, + -0.5 + 8.0 * e2, + -1.0 / 3.0 + 9.0 * e3, + -0.25, + -0.2, + -1.0 / 6.0, + -10.0 * sin(7.0), + -11.0 * sin(8.0), + -12.0 * sin(9.0)}; + + int expected_p[10] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + int expected_i[9] = {0, 1, 2, 3, 4, 5, 6, 7, 8}; + + mu_assert("vals incorrect", + cmp_double_array(hstack_node->wsum_hess->x, expected_x, 9)); + mu_assert("rows incorrect", + cmp_int_array(hstack_node->wsum_hess->p, expected_p, 10)); + mu_assert("cols incorrect", + cmp_int_array(hstack_node->wsum_hess->i, expected_i, 9)); + + free_expr(hstack_node); + free_expr(sin_y); + free_expr(exp_x); + free_expr(log_z); + free_expr(log_x); + free_expr(y); + free_expr(z); + free_expr(x); + + return 0; +} + +const char *test_wsum_hess_hstack_matrix() +{ + /* Test: hstack([log(x), log(z), exp(x), sin(y)]) with matrix variables + * Variables: x at idx 0, z at idx 6, y at idx 12 + * Each is 3x2, so 6 elements per variable + * x = [1 4] z = [7 10] y = [13 16] + * [2 5] [8 11] [14 17] + * [3 6] [9 12] [15 18] + * Vectorized column-wise: x = [1,2,3,4,5,6], z = [7,8,9,10,11,12], y = + * [13,14,15,16,17,18] Total 18 variables Stacked output is 24x1: [log(x), + * log(z), exp(x), sin(y)] each 6x1 w = [1, 2, 3, ..., 24] + */ + + double u_vals[18] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, + 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0}; + double w[24]; + for (int i = 0; i < 24; i++) + { + w[i] = i + 1.0; + } + + expr *x = new_variable(3, 2, 0, 18); + expr *z = new_variable(3, 2, 6, 18); + expr *y = new_variable(3, 2, 12, 18); + + expr *log_x = new_log(x); + expr *log_z = new_log(z); + expr *exp_x = new_exp(x); + expr *sin_y = new_sin(y); + + expr *args[4] = {log_x, log_z, exp_x, sin_y}; + expr *hstack_node = new_hstack(args, 4, 18); + + hstack_node->forward(hstack_node, u_vals); + hstack_node->wsum_hess_init(hstack_node); + hstack_node->eval_wsum_hess(hstack_node, w); + + /* Expected Hessian (diagonal): + * log(x): w[0:5] * (-1/x[0:5]²) at indices 0-5 + * log(z): w[6:11] * (-1/z[0:5]²) at indices 6-11 + * exp(x): w[12:17] * exp(x[0:5]) at indices 0-5 (accumulates with log(x)) + * sin(y): w[18:23] * (-sin(y[0:5])) at indices 12-17 + * + * For x indices (0-5): + * i=0: -1/1² + 13*e¹ = -1 + 13*e + * i=1: -2/2² + 14*e² = -0.5 + 14*e² + * i=2: -3/3² + 15*e³ = -1/3 + 15*e³ + * i=3: -4/4² + 16*e⁴ = -0.25 + 16*e⁴ + * i=4: -5/5² + 17*e⁵ = -0.2 + 17*e⁵ + * i=5: -6/6² + 18*e⁶ = -1/6 + 18*e⁶ + * + * For z indices (6-11): + * i=0: -7/7² = -1/7 + * i=1: -8/8² = -1/8 + * i=2: -9/9² = -1/9 + * i=3: -10/10² = -0.1 + * i=4: -11/11² = -11/121 + * i=5: -12/12² = -1/12 + * + * For y indices (12-17): + * i=0: -19*sin(13) + * i=1: -20*sin(14) + * i=2: -21*sin(15) + * i=3: -22*sin(16) + * i=4: -23*sin(17) + * i=5: -24*sin(18) + */ + + double expected_x[18]; + // x indices (0-5) - accumulation of log and exp + expected_x[0] = -1.0 + 13.0 * exp(1.0); + expected_x[1] = -0.5 + 14.0 * exp(2.0); + expected_x[2] = -1.0 / 3.0 + 15.0 * exp(3.0); + expected_x[3] = -0.25 + 16.0 * exp(4.0); + expected_x[4] = -0.2 + 17.0 * exp(5.0); + expected_x[5] = -1.0 / 6.0 + 18.0 * exp(6.0); + + // z indices (6-11) - only log + expected_x[6] = -1.0 / 7.0; + expected_x[7] = -1.0 / 8.0; + expected_x[8] = -1.0 / 9.0; + expected_x[9] = -0.1; + expected_x[10] = -11.0 / 121.0; + expected_x[11] = -1.0 / 12.0; + + // y indices (12-17) - only sin + expected_x[12] = -19.0 * sin(13.0); + expected_x[13] = -20.0 * sin(14.0); + expected_x[14] = -21.0 * sin(15.0); + expected_x[15] = -22.0 * sin(16.0); + expected_x[16] = -23.0 * sin(17.0); + expected_x[17] = -24.0 * sin(18.0); + + int expected_p[19] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18}; + int expected_i[18] = {0, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, 17}; + + mu_assert("vals incorrect", + cmp_double_array(hstack_node->wsum_hess->x, expected_x, 18)); + mu_assert("rows incorrect", + cmp_int_array(hstack_node->wsum_hess->p, expected_p, 19)); + mu_assert("cols incorrect", + cmp_int_array(hstack_node->wsum_hess->i, expected_i, 18)); + + free_expr(hstack_node); + free_expr(sin_y); + free_expr(exp_x); + free_expr(log_z); + free_expr(log_x); + free_expr(y); + free_expr(z); + free_expr(x); + + return 0; +} From c16d66381240ed5a8f2651ce0a2fc4793763ae92 Mon Sep 17 00:00:00 2001 From: Daniel Date: Sun, 4 Jan 2026 00:24:37 -0800 Subject: [PATCH 5/5] fixed memory leak --- src/affine/hstack.c | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/affine/hstack.c b/src/affine/hstack.c index 548806a..8300282 100644 --- a/src/affine/hstack.c +++ b/src/affine/hstack.c @@ -134,7 +134,11 @@ static void free_type_data(expr *node) for (int i = 0; i < hnode->n_args; i++) { free_expr(hnode->args[i]); + hnode->args[i] = NULL; } + + free_csr_matrix(hnode->CSR_work); + hnode->CSR_work = NULL; } expr *new_hstack(expr **args, int n_args, int n_vars)