diff --git a/src/affine/add.c b/src/affine/add.c index d7994b3..397c0db 100644 --- a/src/affine/add.c +++ b/src/affine/add.c @@ -49,7 +49,7 @@ static void wsum_hess_init(expr *node) int nnz_max = node->left->wsum_hess->nnz + node->right->wsum_hess->nnz; node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, nnz_max); - /* TODO: we should fill sparsity pattern here for consistency */ + /* fill sparsity pattern of hessian */ sum_csr_matrices_fill_sparsity(node->left->wsum_hess, node->right->wsum_hess, node->wsum_hess); } diff --git a/src/affine/sum.c b/src/affine/sum.c index 7114124..ac40679 100644 --- a/src/affine/sum.c +++ b/src/affine/sum.c @@ -140,7 +140,7 @@ static void eval_wsum_hess(expr *node, const double *w) x->eval_wsum_hess(x, node->dwork); - /* copy values (TODO: is this necessary or can we just change pointers?)*/ + /* copy values */ memcpy(node->wsum_hess->x, x->wsum_hess->x, x->wsum_hess->nnz * sizeof(double)); } diff --git a/src/bivariate/rel_entr.c b/src/bivariate/rel_entr.c index 1360a1a..61be542 100644 --- a/src/bivariate/rel_entr.c +++ b/src/bivariate/rel_entr.c @@ -24,11 +24,9 @@ static void forward_vector_args(expr *node, const double *u) } } -/* TODO: this probably doesn't work for matrices because we use node->d1 - instead of node->size */ static void jacobian_init_vectors_args(expr *node) { - node->jacobian = new_csr_matrix(node->d1, node->n_vars, 2 * node->d1); + node->jacobian = new_csr_matrix(node->size, node->n_vars, 2 * node->size); expr *x = node->left; expr *y = node->right; @@ -38,7 +36,7 @@ static void jacobian_init_vectors_args(expr *node) /* if x has lower variable idx than y it should appear first */ if (x->var_id < y->var_id) { - for (int j = 0; j < node->d1; j++) + for (int j = 0; j < node->size; j++) { node->jacobian->i[2 * j] = j + x->var_id; node->jacobian->i[2 * j + 1] = j + y->var_id; @@ -47,7 +45,7 @@ static void jacobian_init_vectors_args(expr *node) } else { - for (int j = 0; j < node->d1; j++) + for (int j = 0; j < node->size; j++) { node->jacobian->i[2 * j] = j + y->var_id; node->jacobian->i[2 * j + 1] = j + x->var_id; @@ -55,7 +53,7 @@ static void jacobian_init_vectors_args(expr *node) } } - node->jacobian->p[node->d1] = 2 * node->d1; + node->jacobian->p[node->size] = 2 * node->size; } static void eval_jacobian_vector_args(expr *node) @@ -82,11 +80,9 @@ static void eval_jacobian_vector_args(expr *node) } } -/* TODO: this probably doesn't work for matrices because we use node->d1 - instead of node->size */ static void wsum_hess_init_vector_args(expr *node) { - node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, 4 * node->d1); + node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, 4 * node->size); expr *x = node->left; expr *y = node->right; @@ -104,32 +100,32 @@ static void wsum_hess_init_vector_args(expr *node) } /* var1 rows of Hessian */ - for (i = 0; i < node->d1; i++) + for (i = 0; i < node->size; i++) { node->wsum_hess->p[var1_id + i] = 2 * i; node->wsum_hess->i[2 * i] = var1_id + i; node->wsum_hess->i[2 * i + 1] = var2_id + i; } - int nnz = 2 * node->d1; + int nnz = 2 * node->size; /* rows between var1 and var2 */ - for (i = var1_id + node->d1; i < var2_id; i++) + for (i = var1_id + node->size; i < var2_id; i++) { node->wsum_hess->p[i] = nnz; } /* var2 rows of Hessian */ - for (i = 0; i < node->d1; i++) + for (i = 0; i < node->size; i++) { node->wsum_hess->p[var2_id + i] = nnz + 2 * i; } memcpy(node->wsum_hess->i + nnz, node->wsum_hess->i, nnz * sizeof(int)); /* remaining rows */ - for (i = var2_id + node->d1; i <= node->n_vars; i++) + for (i = var2_id + node->size; i <= node->n_vars; i++) { - node->wsum_hess->p[i] = 4 * node->d1; + node->wsum_hess->p[i] = 4 * node->size; } } @@ -141,15 +137,15 @@ static void eval_wsum_hess_vector_args(expr *node, const double *w) if (node->left->var_id < node->right->var_id) { - for (int i = 0; i < node->d1; i++) + for (int i = 0; i < node->size; i++) { hess[2 * i] = w[i] / x[i]; hess[2 * i + 1] = -w[i] / y[i]; } - hess += 2 * node->d1; + hess += 2 * node->size; - for (int i = 0; i < node->d1; i++) + for (int i = 0; i < node->size; i++) { hess[2 * i] = -w[i] / y[i]; hess[2 * i + 1] = w[i] * x[i] / (y[i] * y[i]); @@ -157,15 +153,15 @@ static void eval_wsum_hess_vector_args(expr *node, const double *w) } else { - for (int i = 0; i < node->d1; i++) + for (int i = 0; i < node->size; i++) { hess[2 * i] = w[i] * x[i] / (y[i] * y[i]); hess[2 * i + 1] = -w[i] / y[i]; } - hess += 2 * node->d1; + hess += 2 * node->size; - for (int i = 0; i < node->d1; i++) + for (int i = 0; i < node->size; i++) { hess[2 * i] = -w[i] / y[i]; hess[2 * i + 1] = w[i] / x[i]; @@ -175,7 +171,7 @@ static void eval_wsum_hess_vector_args(expr *node, const double *w) expr *new_rel_entr_vector_args(expr *left, expr *right) { - expr *node = new_expr(left->d1, 1, left->n_vars); + expr *node = new_expr(left->d1, left->d2, left->n_vars); node->left = left; node->right = right; expr_retain(left); @@ -185,8 +181,6 @@ expr *new_rel_entr_vector_args(expr *left, expr *right) node->eval_jacobian = eval_jacobian_vector_args; node->wsum_hess_init = wsum_hess_init_vector_args; node->eval_wsum_hess = eval_wsum_hess_vector_args; - // node->is_affine = is_affine_elementwise; - // node->local_jacobian = local_jacobian; return node; } diff --git a/tests/all_tests.c b/tests/all_tests.c index 502116e..ccfbd56 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -15,17 +15,17 @@ #include "forward_pass/elementwise/test_log.h" #include "jacobian_tests/test_composite.h" #include "jacobian_tests/test_elementwise_mult.h" -#include "jacobian_tests/test_neg.h" #include "jacobian_tests/test_hstack.h" #include "jacobian_tests/test_log.h" -#include "jacobian_tests/test_promote.h" +#include "jacobian_tests/test_neg.h" #include "jacobian_tests/test_prod.h" +#include "jacobian_tests/test_promote.h" #include "jacobian_tests/test_quad_form.h" #include "jacobian_tests/test_quad_over_lin.h" #include "jacobian_tests/test_rel_entr.h" #include "jacobian_tests/test_sum.h" -#include "problem/test_problem.h" #include "jacobian_tests/test_trace.h" +#include "problem/test_problem.h" #include "utils/test_csc_matrix.h" #include "utils/test_csr_matrix.h" #include "wsum_hess/elementwise/test_entr.h" @@ -76,6 +76,7 @@ int main(void) mu_run_test(test_jacobian_composite_log_add, tests_run); mu_run_test(test_jacobian_rel_entr_vector_args_1, tests_run); mu_run_test(test_jacobian_rel_entr_vector_args_2, tests_run); + mu_run_test(test_jacobian_rel_entr_matrix_args, tests_run); mu_run_test(test_jacobian_elementwise_mult_1, tests_run); mu_run_test(test_jacobian_elementwise_mult_2, tests_run); mu_run_test(test_jacobian_elementwise_mult_3, tests_run); @@ -129,6 +130,7 @@ int main(void) mu_run_test(test_wsum_hess_prod_many_zeros, tests_run); mu_run_test(test_wsum_hess_rel_entr_1, tests_run); mu_run_test(test_wsum_hess_rel_entr_2, tests_run); + mu_run_test(test_wsum_hess_rel_entr_matrix, tests_run); mu_run_test(test_wsum_hess_hstack, tests_run); mu_run_test(test_wsum_hess_hstack_matrix, tests_run); mu_run_test(test_wsum_hess_quad_over_lin_xy, tests_run); diff --git a/tests/jacobian_tests/test_rel_entr.h b/tests/jacobian_tests/test_rel_entr.h index 1e4b089..6cd5b9f 100644 --- a/tests/jacobian_tests/test_rel_entr.h +++ b/tests/jacobian_tests/test_rel_entr.h @@ -67,3 +67,46 @@ const char *test_jacobian_rel_entr_vector_args_2() free_expr(node); return 0; } + +const char *test_jacobian_rel_entr_matrix_args() +{ + // x, y are 3 x 2 matrices (vectorized columnwise) + // x has global variable index 0, y has global variable index 6 + // we compute jacobian of x log(x) - x log(y) + + double u_vals[12] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0}; + + expr *x = new_variable(3, 2, 0, 12); + expr *y = new_variable(3, 2, 6, 12); + expr *node = new_rel_entr_vector_args(x, y); + + node->forward(node, u_vals); + node->jacobian_init(node); + node->eval_jacobian(node); + + double dx0 = log(1.0 / 6.0) + 1.0; + double dx1 = log(2.0 / 5.0) + 1.0; + double dx2 = log(3.0 / 4.0) + 1.0; + double dx3 = log(4.0 / 3.0) + 1.0; + double dx4 = log(5.0 / 2.0) + 1.0; + double dx5 = log(6.0 / 1.0) + 1.0; + + double dy0 = -1.0 / 6.0; + double dy1 = -2.0 / 5.0; + double dy2 = -3.0 / 4.0; + double dy3 = -4.0 / 3.0; + double dy4 = -5.0 / 2.0; + double dy5 = -6.0 / 1.0; + + double expected_Ax[12] = {dx0, dy0, dx1, dy1, dx2, dy2, + dx3, dy3, dx4, dy4, dx5, dy5}; + int expected_Ap[7] = {0, 2, 4, 6, 8, 10, 12}; + int expected_Ai[12] = {0, 6, 1, 7, 2, 8, 3, 9, 4, 10, 5, 11}; + + mu_assert("vals fail", cmp_double_array(node->jacobian->x, expected_Ax, 12)); + mu_assert("rows fail", cmp_int_array(node->jacobian->p, expected_Ap, 7)); + mu_assert("cols fail", cmp_int_array(node->jacobian->i, expected_Ai, 12)); + + free_expr(node); + return 0; +} diff --git a/tests/wsum_hess/test_rel_entr.h b/tests/wsum_hess/test_rel_entr.h index 9bfafa4..3516882 100644 --- a/tests/wsum_hess/test_rel_entr.h +++ b/tests/wsum_hess/test_rel_entr.h @@ -26,14 +26,8 @@ const char *test_wsum_hess_rel_entr_1() int expected_p[11] = {0, 0, 2, 4, 6, 6, 6, 8, 10, 12, 12}; int expected_i[12] = {1, 6, 2, 7, 3, 8, 1, 6, 2, 7, 3, 8}; - double expected_x[12] = { - 1.0, -0.25, // row 1 - 1.0, -0.4, // row 2 - 1.0, -0.5, // row 3 - -0.25, 0.0625, // row 6 - -0.4, 0.16, // row 7 - -0.5, 0.25 // row 8 - }; + double expected_x[12] = {1.0, -0.25, 1.0, -0.4, 1.0, -0.5, + -0.25, 0.0625, -0.4, 0.16, -0.5, 0.25}; mu_assert("p array fails", cmp_int_array(node->wsum_hess->p, expected_p, 11)); mu_assert("i array fails", cmp_int_array(node->wsum_hess->i, expected_i, 12)); @@ -63,14 +57,8 @@ const char *test_wsum_hess_rel_entr_2() int expected_p[11] = {0, 0, 2, 4, 6, 6, 6, 8, 10, 12, 12}; int expected_i[12] = {1, 6, 2, 7, 3, 8, 1, 6, 2, 7, 3, 8}; - double expected_x[12] = { - 0.0625, -0.25, // row 1 - 0.16, -0.4, // row 2 - 0.25, -0.5, // row 3 - -0.25, 1.0, // row 6 - -0.4, 1.0, // row 7 - -0.5, 1.0 // row 8 - }; + double expected_x[12] = {0.0625, -0.25, 0.16, -0.4, 0.25, -0.5, + -0.25, 1.0, -0.4, 1.0, -0.5, 1.0}; mu_assert("p array fails", cmp_int_array(node->wsum_hess->p, expected_p, 11)); mu_assert("i array fails", cmp_int_array(node->wsum_hess->i, expected_i, 12)); @@ -79,3 +67,38 @@ const char *test_wsum_hess_rel_entr_2() free_expr(node); return 0; } + +const char *test_wsum_hess_rel_entr_matrix() +{ + // x, y are 3 x 2 matrices (vectorized columnwise) + // x has var_id = 0, y has var_id = 6 + // x = [1 2 3 4 5 6], y = [6 5 4 3 2 1] + // w = [1 2 3 4 5 6] + + double u_vals[12] = {1, 2, 3, 4, 5, 6, 6, 5, 4, 3, 2, 1}; + double w[6] = {1, 2, 3, 4, 5, 6}; + + expr *x = new_variable(3, 2, 0, 12); + expr *y = new_variable(3, 2, 6, 12); + expr *node = new_rel_entr_vector_args(x, y); + + node->forward(node, u_vals); + node->wsum_hess_init(node); + node->eval_wsum_hess(node, w); + + int expected_p[13] = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24}; + int expected_i[24] = {0, 6, 1, 7, 2, 8, 3, 9, 4, 10, 5, 11, + 0, 6, 1, 7, 2, 8, 3, 9, 4, 10, 5, 11}; + double expected_x[24] = { + 1.0, -1.0 / 6.0, 1.0, -0.4, 1.0, -0.75, + 1.0, -4.0 / 3.0, 1.0, -2.5, 1.0, -6.0, + -1.0 / 6.0, 1.0 / 36.0, -0.4, 0.16, -0.75, 0.5625, + -4.0 / 3.0, 1.7777777777777777, -2.5, 6.25, -6.0, 36.0}; + + mu_assert("p array fails", cmp_int_array(node->wsum_hess->p, expected_p, 13)); + mu_assert("i array fails", cmp_int_array(node->wsum_hess->i, expected_i, 24)); + mu_assert("x array fails", cmp_double_array(node->wsum_hess->x, expected_x, 24)); + + free_expr(node); + return 0; +}