Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/affine/add.c
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ static void wsum_hess_init(expr *node)
int nnz_max = node->left->wsum_hess->nnz + node->right->wsum_hess->nnz;
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, nnz_max);

/* TODO: we should fill sparsity pattern here for consistency */
/* fill sparsity pattern of hessian */
sum_csr_matrices_fill_sparsity(node->left->wsum_hess, node->right->wsum_hess,
node->wsum_hess);
}
Expand Down
2 changes: 1 addition & 1 deletion src/affine/sum.c
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ static void eval_wsum_hess(expr *node, const double *w)

x->eval_wsum_hess(x, node->dwork);

/* copy values (TODO: is this necessary or can we just change pointers?)*/
/* copy values */
memcpy(node->wsum_hess->x, x->wsum_hess->x, x->wsum_hess->nnz * sizeof(double));
}

Expand Down
42 changes: 18 additions & 24 deletions src/bivariate/rel_entr.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,9 @@ static void forward_vector_args(expr *node, const double *u)
}
}

/* TODO: this probably doesn't work for matrices because we use node->d1
instead of node->size */
static void jacobian_init_vectors_args(expr *node)
{
node->jacobian = new_csr_matrix(node->d1, node->n_vars, 2 * node->d1);
node->jacobian = new_csr_matrix(node->size, node->n_vars, 2 * node->size);

expr *x = node->left;
expr *y = node->right;
Expand All @@ -38,7 +36,7 @@ static void jacobian_init_vectors_args(expr *node)
/* if x has lower variable idx than y it should appear first */
if (x->var_id < y->var_id)
{
for (int j = 0; j < node->d1; j++)
for (int j = 0; j < node->size; j++)
{
node->jacobian->i[2 * j] = j + x->var_id;
node->jacobian->i[2 * j + 1] = j + y->var_id;
Expand All @@ -47,15 +45,15 @@ static void jacobian_init_vectors_args(expr *node)
}
else
{
for (int j = 0; j < node->d1; j++)
for (int j = 0; j < node->size; j++)
{
node->jacobian->i[2 * j] = j + y->var_id;
node->jacobian->i[2 * j + 1] = j + x->var_id;
node->jacobian->p[j] = 2 * j;
}
}

node->jacobian->p[node->d1] = 2 * node->d1;
node->jacobian->p[node->size] = 2 * node->size;
}

static void eval_jacobian_vector_args(expr *node)
Expand All @@ -82,11 +80,9 @@ static void eval_jacobian_vector_args(expr *node)
}
}

/* TODO: this probably doesn't work for matrices because we use node->d1
instead of node->size */
static void wsum_hess_init_vector_args(expr *node)
{
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, 4 * node->d1);
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, 4 * node->size);
expr *x = node->left;
expr *y = node->right;

Expand All @@ -104,32 +100,32 @@ static void wsum_hess_init_vector_args(expr *node)
}

/* var1 rows of Hessian */
for (i = 0; i < node->d1; i++)
for (i = 0; i < node->size; i++)
{
node->wsum_hess->p[var1_id + i] = 2 * i;
node->wsum_hess->i[2 * i] = var1_id + i;
node->wsum_hess->i[2 * i + 1] = var2_id + i;
}

int nnz = 2 * node->d1;
int nnz = 2 * node->size;

/* rows between var1 and var2 */
for (i = var1_id + node->d1; i < var2_id; i++)
for (i = var1_id + node->size; i < var2_id; i++)
{
node->wsum_hess->p[i] = nnz;
}

/* var2 rows of Hessian */
for (i = 0; i < node->d1; i++)
for (i = 0; i < node->size; i++)
{
node->wsum_hess->p[var2_id + i] = nnz + 2 * i;
}
memcpy(node->wsum_hess->i + nnz, node->wsum_hess->i, nnz * sizeof(int));

/* remaining rows */
for (i = var2_id + node->d1; i <= node->n_vars; i++)
for (i = var2_id + node->size; i <= node->n_vars; i++)
{
node->wsum_hess->p[i] = 4 * node->d1;
node->wsum_hess->p[i] = 4 * node->size;
}
}

Expand All @@ -141,31 +137,31 @@ static void eval_wsum_hess_vector_args(expr *node, const double *w)

if (node->left->var_id < node->right->var_id)
{
for (int i = 0; i < node->d1; i++)
for (int i = 0; i < node->size; i++)
{
hess[2 * i] = w[i] / x[i];
hess[2 * i + 1] = -w[i] / y[i];
}

hess += 2 * node->d1;
hess += 2 * node->size;

for (int i = 0; i < node->d1; i++)
for (int i = 0; i < node->size; i++)
{
hess[2 * i] = -w[i] / y[i];
hess[2 * i + 1] = w[i] * x[i] / (y[i] * y[i]);
}
}
else
{
for (int i = 0; i < node->d1; i++)
for (int i = 0; i < node->size; i++)
{
hess[2 * i] = w[i] * x[i] / (y[i] * y[i]);
hess[2 * i + 1] = -w[i] / y[i];
}

hess += 2 * node->d1;
hess += 2 * node->size;

for (int i = 0; i < node->d1; i++)
for (int i = 0; i < node->size; i++)
{
hess[2 * i] = -w[i] / y[i];
hess[2 * i + 1] = w[i] / x[i];
Expand All @@ -175,7 +171,7 @@ static void eval_wsum_hess_vector_args(expr *node, const double *w)

expr *new_rel_entr_vector_args(expr *left, expr *right)
{
expr *node = new_expr(left->d1, 1, left->n_vars);
expr *node = new_expr(left->d1, left->d2, left->n_vars);
node->left = left;
node->right = right;
expr_retain(left);
Expand All @@ -185,8 +181,6 @@ expr *new_rel_entr_vector_args(expr *left, expr *right)
node->eval_jacobian = eval_jacobian_vector_args;
node->wsum_hess_init = wsum_hess_init_vector_args;
node->eval_wsum_hess = eval_wsum_hess_vector_args;
// node->is_affine = is_affine_elementwise;
// node->local_jacobian = local_jacobian;
return node;
}

Expand Down
8 changes: 5 additions & 3 deletions tests/all_tests.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@
#include "forward_pass/elementwise/test_log.h"
#include "jacobian_tests/test_composite.h"
#include "jacobian_tests/test_elementwise_mult.h"
#include "jacobian_tests/test_neg.h"
#include "jacobian_tests/test_hstack.h"
#include "jacobian_tests/test_log.h"
#include "jacobian_tests/test_promote.h"
#include "jacobian_tests/test_neg.h"
#include "jacobian_tests/test_prod.h"
#include "jacobian_tests/test_promote.h"
#include "jacobian_tests/test_quad_form.h"
#include "jacobian_tests/test_quad_over_lin.h"
#include "jacobian_tests/test_rel_entr.h"
#include "jacobian_tests/test_sum.h"
#include "problem/test_problem.h"
#include "jacobian_tests/test_trace.h"
#include "problem/test_problem.h"
#include "utils/test_csc_matrix.h"
#include "utils/test_csr_matrix.h"
#include "wsum_hess/elementwise/test_entr.h"
Expand Down Expand Up @@ -76,6 +76,7 @@ int main(void)
mu_run_test(test_jacobian_composite_log_add, tests_run);
mu_run_test(test_jacobian_rel_entr_vector_args_1, tests_run);
mu_run_test(test_jacobian_rel_entr_vector_args_2, tests_run);
mu_run_test(test_jacobian_rel_entr_matrix_args, tests_run);
mu_run_test(test_jacobian_elementwise_mult_1, tests_run);
mu_run_test(test_jacobian_elementwise_mult_2, tests_run);
mu_run_test(test_jacobian_elementwise_mult_3, tests_run);
Expand Down Expand Up @@ -129,6 +130,7 @@ int main(void)
mu_run_test(test_wsum_hess_prod_many_zeros, tests_run);
mu_run_test(test_wsum_hess_rel_entr_1, tests_run);
mu_run_test(test_wsum_hess_rel_entr_2, tests_run);
mu_run_test(test_wsum_hess_rel_entr_matrix, tests_run);
mu_run_test(test_wsum_hess_hstack, tests_run);
mu_run_test(test_wsum_hess_hstack_matrix, tests_run);
mu_run_test(test_wsum_hess_quad_over_lin_xy, tests_run);
Expand Down
43 changes: 43 additions & 0 deletions tests/jacobian_tests/test_rel_entr.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,46 @@ const char *test_jacobian_rel_entr_vector_args_2()
free_expr(node);
return 0;
}

const char *test_jacobian_rel_entr_matrix_args()
{
// x, y are 3 x 2 matrices (vectorized columnwise)
// x has global variable index 0, y has global variable index 6
// we compute jacobian of x log(x) - x log(y)

double u_vals[12] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0};

expr *x = new_variable(3, 2, 0, 12);
expr *y = new_variable(3, 2, 6, 12);
expr *node = new_rel_entr_vector_args(x, y);

node->forward(node, u_vals);
node->jacobian_init(node);
node->eval_jacobian(node);

double dx0 = log(1.0 / 6.0) + 1.0;
double dx1 = log(2.0 / 5.0) + 1.0;
double dx2 = log(3.0 / 4.0) + 1.0;
double dx3 = log(4.0 / 3.0) + 1.0;
double dx4 = log(5.0 / 2.0) + 1.0;
double dx5 = log(6.0 / 1.0) + 1.0;

double dy0 = -1.0 / 6.0;
double dy1 = -2.0 / 5.0;
double dy2 = -3.0 / 4.0;
double dy3 = -4.0 / 3.0;
double dy4 = -5.0 / 2.0;
double dy5 = -6.0 / 1.0;

double expected_Ax[12] = {dx0, dy0, dx1, dy1, dx2, dy2,
dx3, dy3, dx4, dy4, dx5, dy5};
int expected_Ap[7] = {0, 2, 4, 6, 8, 10, 12};
int expected_Ai[12] = {0, 6, 1, 7, 2, 8, 3, 9, 4, 10, 5, 11};

mu_assert("vals fail", cmp_double_array(node->jacobian->x, expected_Ax, 12));
mu_assert("rows fail", cmp_int_array(node->jacobian->p, expected_Ap, 7));
mu_assert("cols fail", cmp_int_array(node->jacobian->i, expected_Ai, 12));

free_expr(node);
return 0;
}
55 changes: 39 additions & 16 deletions tests/wsum_hess/test_rel_entr.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,8 @@ const char *test_wsum_hess_rel_entr_1()

int expected_p[11] = {0, 0, 2, 4, 6, 6, 6, 8, 10, 12, 12};
int expected_i[12] = {1, 6, 2, 7, 3, 8, 1, 6, 2, 7, 3, 8};
double expected_x[12] = {
1.0, -0.25, // row 1
1.0, -0.4, // row 2
1.0, -0.5, // row 3
-0.25, 0.0625, // row 6
-0.4, 0.16, // row 7
-0.5, 0.25 // row 8
};
double expected_x[12] = {1.0, -0.25, 1.0, -0.4, 1.0, -0.5,
-0.25, 0.0625, -0.4, 0.16, -0.5, 0.25};

mu_assert("p array fails", cmp_int_array(node->wsum_hess->p, expected_p, 11));
mu_assert("i array fails", cmp_int_array(node->wsum_hess->i, expected_i, 12));
Expand Down Expand Up @@ -63,14 +57,8 @@ const char *test_wsum_hess_rel_entr_2()

int expected_p[11] = {0, 0, 2, 4, 6, 6, 6, 8, 10, 12, 12};
int expected_i[12] = {1, 6, 2, 7, 3, 8, 1, 6, 2, 7, 3, 8};
double expected_x[12] = {
0.0625, -0.25, // row 1
0.16, -0.4, // row 2
0.25, -0.5, // row 3
-0.25, 1.0, // row 6
-0.4, 1.0, // row 7
-0.5, 1.0 // row 8
};
double expected_x[12] = {0.0625, -0.25, 0.16, -0.4, 0.25, -0.5,
-0.25, 1.0, -0.4, 1.0, -0.5, 1.0};

mu_assert("p array fails", cmp_int_array(node->wsum_hess->p, expected_p, 11));
mu_assert("i array fails", cmp_int_array(node->wsum_hess->i, expected_i, 12));
Expand All @@ -79,3 +67,38 @@ const char *test_wsum_hess_rel_entr_2()
free_expr(node);
return 0;
}

const char *test_wsum_hess_rel_entr_matrix()
{
// x, y are 3 x 2 matrices (vectorized columnwise)
// x has var_id = 0, y has var_id = 6
// x = [1 2 3 4 5 6], y = [6 5 4 3 2 1]
// w = [1 2 3 4 5 6]

double u_vals[12] = {1, 2, 3, 4, 5, 6, 6, 5, 4, 3, 2, 1};
double w[6] = {1, 2, 3, 4, 5, 6};

expr *x = new_variable(3, 2, 0, 12);
expr *y = new_variable(3, 2, 6, 12);
expr *node = new_rel_entr_vector_args(x, y);

node->forward(node, u_vals);
node->wsum_hess_init(node);
node->eval_wsum_hess(node, w);

int expected_p[13] = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24};
int expected_i[24] = {0, 6, 1, 7, 2, 8, 3, 9, 4, 10, 5, 11,
0, 6, 1, 7, 2, 8, 3, 9, 4, 10, 5, 11};
double expected_x[24] = {
1.0, -1.0 / 6.0, 1.0, -0.4, 1.0, -0.75,
1.0, -4.0 / 3.0, 1.0, -2.5, 1.0, -6.0,
-1.0 / 6.0, 1.0 / 36.0, -0.4, 0.16, -0.75, 0.5625,
-4.0 / 3.0, 1.7777777777777777, -2.5, 6.25, -6.0, 36.0};

mu_assert("p array fails", cmp_int_array(node->wsum_hess->p, expected_p, 13));
mu_assert("i array fails", cmp_int_array(node->wsum_hess->i, expected_i, 24));
mu_assert("x array fails", cmp_double_array(node->wsum_hess->x, expected_x, 24));

free_expr(node);
return 0;
}