Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 145 additions & 2 deletions src/bivariate/quad_over_lin.c
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,153 @@ static void eval_jacobian(expr *node)
csc_matvec_fill_values(A_csc, node->dwork, node->jacobian);

/* insert derivative wrt y at right place (for correctness this assumes
that y does not appear in the denominator, but this will always be
the case since y is a new variable for the numerator) */
that y does not appear in the numerator, but this will always be
the case since y is a new variable for the denominator */
node->jacobian->x[node->iwork[0]] = -node->value[0] / y->value[0];
}
}

static void wsum_hess_init(expr *node)
{
expr *x = node->left;
expr *y = node->right;
int var_id_x = x->var_id;
int var_id_y = y->var_id;

/* if left node is a variable */
if (x->var_id != NOT_A_VARIABLE)
{
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, 3 * x->d1 + 1);
CSR_Matrix *H = node->wsum_hess;

/* if x has lower idx than y*/
if (var_id_x < var_id_y)
{
/* x rows: each row has 2 entries (diagonal element + element for y) */
for (int i = 0; i < x->d1; i++)
{
H->p[var_id_x + i] = 2 * i;
H->i[2 * i] = var_id_x + i;
H->i[2 * i + 1] = var_id_y;
}

/* rows between x and y are empty, all point to same offset */
int offset = 2 * x->d1;
for (int i = var_id_x + x->d1; i <= var_id_y; i++)
{
H->p[i] = offset;
}

/* y row has d1 + 1 entries */
H->p[var_id_y + 1] = offset + x->d1 + 1;
for (int i = 0; i < x->d1; i++)
{
H->i[offset + i] = var_id_x + i;
}
H->i[offset + x->d1] = var_id_y;

/* remaining rows are empty */
for (int i = var_id_y + 1; i <= node->n_vars; i++)
{
H->p[i] = 3 * x->d1 + 1;
}
}
else /* y has lower idx than x */
{
/* y row has d1 + 1 entries */
H->p[var_id_y + 1] = x->d1 + 1;
H->i[0] = var_id_y;
for (int i = 0; i < x->d1; i++)
{
H->i[i + 1] = var_id_x + i;
}

/* rows between y and x are empty, all point to same offset */
int offset = x->d1 + 1;
for (int i = var_id_y + 1; i <= var_id_x; i++)
{
H->p[i] = offset;
}

/* x rows: each row has 2 entries */
for (int i = 0; i < x->d1; i++)
{
H->p[var_id_x + i] = offset + 2 * i;
H->i[offset + 2 * i] = var_id_y;
H->i[offset + 2 * i + 1] = var_id_x + i;
}

/* remaining rows are empty */
for (int i = var_id_x + x->d1; i <= node->n_vars; i++)
{
H->p[i] = 3 * x->d1 + 1;
}
}
}
else
{
/* TODO: implement */
assert(false && "not implemented");
}
}

static void eval_wsum_hess(expr *node, const double *w)
{
double *x = node->left->value;
double y = node->right->value[0];
double *H = node->wsum_hess->x;
int var_id_x = node->left->var_id;
int var_id_y = node->right->var_id;
int x_d1 = node->left->d1;
double a = (2.0 * w[0]) / y;
double b = -(2.0 * w[0]) / (y * y);

/* if left node is a variable */
if (var_id_x != NOT_A_VARIABLE)
{
/* if x has lower idx than y*/
if (var_id_x < var_id_y)
{
/* x rows*/
for (int i = 0; i < x_d1; i++)
{
H[2 * i] = a;
H[2 * i + 1] = b * x[i];
}

/* y row */
int offset = 2 * x_d1;
for (int i = 0; i < x_d1; i++)
{
H[offset + i] = b * x[i];
}
H[offset + x_d1] = -b * node->value[0];
}
else /* y has lower idx than x */
{
/* y row */
H[0] = -b * node->value[0];
for (int i = 0; i < x_d1; i++)
{
H[i + 1] = b * x[i];
}

/* x rows*/
int offset = x_d1 + 1;
for (int i = 0; i < x_d1; i++)
{
H[offset + 2 * i] = b * x[i];
H[offset + 2 * i + 1] = a;
}
}
}
else
{
/* TODO: implement */
assert(false && "not implemented");
}
}

expr *new_quad_over_lin(expr *left, expr *right)
{
expr *node = new_expr(left->d1, 1, left->n_vars);
Expand All @@ -160,5 +301,7 @@ expr *new_quad_over_lin(expr *left, expr *right)
node->forward = forward;
node->jacobian_init = jacobian_init;
node->eval_jacobian = eval_jacobian;
node->wsum_hess_init = wsum_hess_init;
node->eval_wsum_hess = eval_wsum_hess;
return node;
}
2 changes: 1 addition & 1 deletion src/expr.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ void init_expr(expr *node, int d1, int d2, int n_vars, forward_fn forward,
node->d2 = d2;
node->size = d1 * d2;
node->n_vars = n_vars;
node->refcount = 1;
node->refcount = 0;
node->value = (double *) calloc(d1 * d2, sizeof(double));
node->var_id = NOT_A_VARIABLE;
node->forward = forward;
Expand Down
92 changes: 82 additions & 10 deletions src/other/quad_form.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <assert.h>
#include <math.h>
#include <stdlib.h>
#include <string.h>

static void forward(expr *node, const double *u)
{
Expand All @@ -25,10 +26,78 @@ static void forward(expr *node, const double *u)

static void jacobian_init(expr *node)
{
assert(node->left->var_id != NOT_A_VARIABLE);
assert(node->left->d2 == 1);
expr *x = node->left;
node->dwork = (double *) malloc(x->d1 * sizeof(double));
node->jacobian = new_csr_matrix(1, node->n_vars, x->d1);
node->jacobian->p[0] = 0;
node->jacobian->p[1] = x->d1;

/* if x is a variable */
for (int j = 0; j < x->d1; j++)
{
node->jacobian->i[j] = x->var_id + j;
}
}

static void eval_jacobian(expr *node)
{
expr *x = node->left;
CSR_Matrix *Q = ((quad_form_expr *) node)->Q;

// jacobian = 2 * Q * x
csr_matvec(Q, x->value, node->jacobian->x, 0);

for (int j = 0; j < x->d1; j++)
{
node->jacobian->x[j] *= 2.0;
}
}

static void wsum_hess_init(expr *node)
{
expr *x = node->left;
CSR_Matrix *Q = ((quad_form_expr *) node)->Q;
CSR_Matrix *H = new_csr_matrix(node->n_vars, node->n_vars, Q->nnz);

/* set global row pointers */
memcpy(H->p + x->var_id, Q->p, (x->d1 + 1) * sizeof(int));
for (int i = x->var_id + x->d1 + 1; i <= node->n_vars; i++)
{
H->p[i] = Q->nnz;
}

/* set global column indices */
for (int i = 0; i < Q->nnz; i++)
{
H->i[i] = Q->i[i] + x->var_id;
}

node->wsum_hess = H;
}

static void eval_wsum_hess(expr *node, const double *w)
{
CSR_Matrix *Q = ((quad_form_expr *) node)->Q;
double *H = node->wsum_hess->x;
double two_w = 2.0 * w[0];
for (int i = 0; i < Q->nnz; i++)
{
H[i] = two_w * Q->x[i];
}
}

/*
The following two functions are commented out. It supports the jacobian for
quad_form(Ax, Q), but after reconsideration, I think we should treat this as
quad_form(x, A.TQA) in the canonicalization so we don't need to support the chain
rule here.
static void jacobian_init(expr *node)
{
expr *x = node->left;
node->dwork = (double *) malloc(x->d1 * sizeof(double));

// if x is a variable
if (x->var_id != NOT_A_VARIABLE)
{
node->jacobian = new_csr_matrix(1, node->n_vars, x->d1);
Expand All @@ -40,14 +109,14 @@ static void jacobian_init(expr *node)
node->jacobian->i[j] = x->var_id + j;
}
}
else /* x is not a variable */
else // x is not a variable
{
/* compute required allocation and allocate jacobian */
// compute required allocation and allocate jacobian
bool *col_nz = (bool *) calloc(node->n_vars, sizeof(bool));
int nonzero_cols = count_nonzero_cols(x->jacobian, col_nz);
node->jacobian = new_csr_matrix(1, node->n_vars, nonzero_cols + 1);

/* precompute column indices */
// precompute column indices
node->jacobian->nnz = 0;
for (int j = 0; j < node->n_vars; j++)
{
Expand All @@ -65,12 +134,13 @@ static void jacobian_init(expr *node)
}
}

static void eval_jacobian(expr *node)

static void eval_jacobian_old(expr *node)
{
expr *x = node->left;
CSR_Matrix *Q = ((quad_form_expr *) node)->Q;

/* if x is a variable */
// if x is a variable
if (x->var_id != NOT_A_VARIABLE)
{
csr_matvec(Q, x->value, node->jacobian->x, 0);
Expand All @@ -80,22 +150,23 @@ static void eval_jacobian(expr *node)
node->jacobian->x[j] *= 2.0;
}
}
else /* x is not a variable */
else // x is not a variable
{
linear_op_expr *lin_x = (linear_op_expr *) x;

/* local jacobian */
// local jacobian
csr_matvec(Q, x->value, node->dwork, 0);

for (int j = 0; j < x->d1; j++)
{
node->dwork[j] *= 2.0;
}

/* chain rule using CSC format */
// chain rule using CSC format
csc_matvec_fill_values(lin_x->A_csc, node->dwork, node->jacobian);
}
}
*/

expr *new_quad_form(expr *left, CSR_Matrix *Q)
{
Expand All @@ -113,6 +184,7 @@ expr *new_quad_form(expr *left, CSR_Matrix *Q)

/* Set type-specific field */
qnode->Q = Q;

node->wsum_hess_init = wsum_hess_init;
node->eval_wsum_hess = eval_wsum_hess;
return node;
}
8 changes: 7 additions & 1 deletion tests/all_tests.c
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
#include "wsum_hess/elementwise/test_xexp.h"
#include "wsum_hess/test_hstack.h"
#include "wsum_hess/test_multiply.h"
#include "wsum_hess/test_quad_form.h"
#include "wsum_hess/test_quad_over_lin.h"
#include "wsum_hess/test_rel_entr.h"
#include "wsum_hess/test_sum.h"

Expand Down Expand Up @@ -71,7 +73,8 @@ int main(void)
mu_run_test(test_quad_over_lin4, tests_run);
mu_run_test(test_quad_over_lin5, tests_run);
mu_run_test(test_quad_form, tests_run);
mu_run_test(test_quad_form2, tests_run);
/* commented out - see test_quad_form.h */
// mu_run_test(test_quad_form2, tests_run);
mu_run_test(test_jacobian_sum_log, tests_run);
mu_run_test(test_jacobian_sum_mult, tests_run);
mu_run_test(test_jacobian_sum_log_axis_0, tests_run);
Expand Down Expand Up @@ -104,6 +107,9 @@ int main(void)
mu_run_test(test_wsum_hess_rel_entr_2, tests_run);
mu_run_test(test_wsum_hess_hstack, tests_run);
mu_run_test(test_wsum_hess_hstack_matrix, tests_run);
mu_run_test(test_wsum_hess_quad_over_lin_xy, tests_run);
mu_run_test(test_wsum_hess_quad_over_lin_yx, tests_run);
mu_run_test(test_wsum_hess_quad_form, tests_run);
mu_run_test(test_wsum_hess_multiply_linear_ops, tests_run);
mu_run_test(test_wsum_hess_multiply_sparse_random, tests_run);

Expand Down
2 changes: 0 additions & 2 deletions tests/forward_pass/affine/test_add.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,5 @@ const char *test_addition()
double expected[2] = {4.0, 6.0};
mu_assert("Addition test failed", cmp_double_array(sum->value, expected, 2));
free_expr(sum);
free_expr(var);
free_expr(const_node);
return 0;
}
8 changes: 0 additions & 8 deletions tests/forward_pass/affine/test_hstack.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ const char *test_hstack_forward_vectors()
mu_assert("hstack forward failed", cmp_double_array(stack->value, expected, 9));

free_expr(stack);
free_expr(sin_x);
free_expr(exp_x);
free_expr(log_x);
free_expr(x);
return 0;
}

Expand Down Expand Up @@ -63,9 +59,5 @@ const char *test_hstack_forward_matrix()
cmp_double_array(stack->value, expected, 18));

free_expr(stack);
free_expr(sin_x);
free_expr(exp_x);
free_expr(log_x);
free_expr(x);
return 0;
}
1 change: 0 additions & 1 deletion tests/forward_pass/affine/test_linear_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ const char *test_linear_op()
double expected[3] = {8, 7, 26};
mu_assert("fail", cmp_double_array(linear_node->value, expected, 3));
free_expr(linear_node);
free_expr(var);
free_csr_matrix(A);
return 0;
}
Loading