Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/subexpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ typedef struct hstack_expr
expr base;
expr **args;
int n_args;
CSR_Matrix *CSR_work; /* for summing Hessians of children */
} hstack_expr;

#endif /* SUBEXPR_H */
3 changes: 2 additions & 1 deletion include/utils/CSR_Matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ void diag_csr_mult(const double *d, const CSR_Matrix *A, CSR_Matrix *C);

/* Compute C = A + B where A, B, C are CSR matrices
* A and B must have same dimensions
* C must be pre-allocated with sufficient nnz capacity */
* C must be pre-allocated with sufficient nnz capacity.
* C must be different from A and B */
void sum_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C);

/* Compute C = diag(d1) * A + diag(d2) * B where A, B, C are CSR matrices */
Expand Down
42 changes: 41 additions & 1 deletion src/affine/hstack.c
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ static void jacobian_init(expr *node)

node->jacobian = new_csr_matrix(node->size, node->n_vars, nnz);

/* precompute sparsity pattern of this jacobian's node */
/* precompute sparsity pattern of this node's jacobian */
int row_offset = 0;
CSR_Matrix *A = node->jacobian;
A->nnz = 0;
Expand Down Expand Up @@ -80,6 +80,40 @@ static void eval_jacobian(expr *node)
}
}

static void wsum_hess_init(expr *node)
{
/* initialize children's hessians */
hstack_expr *hnode = (hstack_expr *) node;
int nnz = 0;
for (int i = 0; i < hnode->n_args; i++)
{
hnode->args[i]->wsum_hess_init(hnode->args[i]);
nnz += hnode->args[i]->wsum_hess->nnz;
}

/* worst-case scenario the nnz of node->wsum_hess is the sum of children's
nnz */
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, nnz);
hnode->CSR_work = new_csr_matrix(node->n_vars, node->n_vars, nnz);
}

static void wsum_hess_eval(expr *node, const double *w)
{
hstack_expr *hnode = (hstack_expr *) node;
CSR_Matrix *H = node->wsum_hess;
int row_offset = 0;
H->nnz = 0;

for (int i = 0; i < hnode->n_args; i++)
{
expr *child = hnode->args[i];
child->eval_wsum_hess(child, w + row_offset);
copy_csr_matrix(H, hnode->CSR_work);
sum_csr_matrices(hnode->CSR_work, child->wsum_hess, H);
row_offset += child->size;
}
}

static bool is_affine(const expr *node)
{
const hstack_expr *hnode = (const hstack_expr *) node;
Expand All @@ -100,7 +134,11 @@ static void free_type_data(expr *node)
for (int i = 0; i < hnode->n_args; i++)
{
free_expr(hnode->args[i]);
hnode->args[i] = NULL;
}

free_csr_matrix(hnode->CSR_work);
hnode->CSR_work = NULL;
}

expr *new_hstack(expr **args, int n_args, int n_vars)
Expand All @@ -121,6 +159,8 @@ expr *new_hstack(expr **args, int n_args, int n_vars)
/* Set type-specific fields */
hnode->args = args;
hnode->n_args = n_args;
node->wsum_hess_init = wsum_hess_init;
node->eval_wsum_hess = wsum_hess_eval;

for (int i = 0; i < n_args; i++)
{
Expand Down
2 changes: 1 addition & 1 deletion src/bivariate/multiply.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ static void jacobian_init(expr *node)
{
/* if a child is a variable we initialize its jacobian for a
short chain rule implementation */
if (node->left->var_id != -1)
if (node->left->var_id != NOT_A_VARIABLE)
{
node->left->jacobian_init(node->left);
}
Expand Down
95 changes: 95 additions & 0 deletions src/bivariate/rel_entr.c
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "bivariate.h"
#include <assert.h>
#include <math.h>
#include <stdlib.h>
#include <string.h>

// --------------------------------------------------------------------
// Implementation of relative entropy when both arguments are vectors.
Expand Down Expand Up @@ -28,6 +30,8 @@ static void jacobian_init_vectors_args(expr *node)

expr *x = node->left;
expr *y = node->right;
assert(x->var_id != NOT_A_VARIABLE && y->var_id != NOT_A_VARIABLE);
assert(x->var_id != y->var_id);

/* if x has lower variable idx than y it should appear first */
if (x->var_id < y->var_id)
Expand Down Expand Up @@ -76,6 +80,95 @@ static void eval_jacobian_vector_args(expr *node)
}
}

static void wsum_hess_init_vector_args(expr *node)
{
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, 4 * node->d1);
expr *x = node->left;
expr *y = node->right;

int i, var1_id, var2_id;

if (x->var_id < y->var_id)
{
var1_id = x->var_id;
var2_id = y->var_id;
}
else
{
var1_id = y->var_id;
var2_id = x->var_id;
}

/* var1 rows of Hessian */
for (i = 0; i < node->d1; i++)
{
node->wsum_hess->p[var1_id + i] = 2 * i;
node->wsum_hess->i[2 * i] = var1_id + i;
node->wsum_hess->i[2 * i + 1] = var2_id + i;
}

int nnz = 2 * node->d1;

/* rows between var1 and var2 */
for (i = var1_id + node->d1; i < var2_id; i++)
{
node->wsum_hess->p[i] = nnz;
}

/* var2 rows of Hessian */
for (i = 0; i < node->d1; i++)
{
node->wsum_hess->p[var2_id + i] = nnz + 2 * i;
}
memcpy(node->wsum_hess->i + nnz, node->wsum_hess->i, nnz * sizeof(int));

/* remaining rows */
for (i = var2_id + node->d1; i <= node->n_vars; i++)
{
node->wsum_hess->p[i] = 4 * node->d1;
}
}

static void eval_wsum_hess_vector_args(expr *node, const double *w)
{
double *x = node->left->value;
double *y = node->right->value;
double *hess = node->wsum_hess->x;

if (node->left->var_id < node->right->var_id)
{
for (int i = 0; i < node->d1; i++)
{
hess[2 * i] = w[i] / x[i];
hess[2 * i + 1] = -w[i] / y[i];
}

hess += 2 * node->d1;

for (int i = 0; i < node->d1; i++)
{
hess[2 * i] = -w[i] / y[i];
hess[2 * i + 1] = w[i] * x[i] / (y[i] * y[i]);
}
}
else
{
for (int i = 0; i < node->d1; i++)
{
hess[2 * i] = w[i] * x[i] / (y[i] * y[i]);
hess[2 * i + 1] = -w[i] / y[i];
}

hess += 2 * node->d1;

for (int i = 0; i < node->d1; i++)
{
hess[2 * i] = -w[i] / y[i];
hess[2 * i + 1] = w[i] / x[i];
}
}
}

expr *new_rel_entr_vector_args(expr *left, expr *right)
{
expr *node = new_expr(left->d1, 1, left->n_vars);
Expand All @@ -86,6 +179,8 @@ expr *new_rel_entr_vector_args(expr *left, expr *right)
node->forward = forward_vector_args;
node->jacobian_init = jacobian_init_vectors_args;
node->eval_jacobian = eval_jacobian_vector_args;
node->wsum_hess_init = wsum_hess_init_vector_args;
node->eval_wsum_hess = eval_wsum_hess_vector_args;
// node->is_affine = is_affine_elementwise;
// node->local_jacobian = local_jacobian;
return node;
Expand Down
11 changes: 7 additions & 4 deletions src/utils/CSR_Matrix.c
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ void diag_csr_mult(const double *d, const CSR_Matrix *A, CSR_Matrix *C)

void sum_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C)
{
/* A and B must be different from C */
assert(A != C && B != C);

C->nnz = 0;

for (int row = 0; row < A->m; row++)
Expand Down Expand Up @@ -138,17 +141,17 @@ void sum_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C)
if (a_ptr < a_end)
{
int a_remaining = a_end - a_ptr;
memmove(C->i + C->nnz, A->i + a_ptr, a_remaining * sizeof(int));
memmove(C->x + C->nnz, A->x + a_ptr, a_remaining * sizeof(double));
memcpy(C->i + C->nnz, A->i + a_ptr, a_remaining * sizeof(int));
memcpy(C->x + C->nnz, A->x + a_ptr, a_remaining * sizeof(double));
C->nnz += a_remaining;
}

/* Copy remaining elements from B */
if (b_ptr < b_end)
{
int b_remaining = b_end - b_ptr;
memmove(C->i + C->nnz, B->i + b_ptr, b_remaining * sizeof(int));
memmove(C->x + C->nnz, B->x + b_ptr, b_remaining * sizeof(double));
memcpy(C->i + C->nnz, B->i + b_ptr, b_remaining * sizeof(int));
memcpy(C->x + C->nnz, B->x + b_ptr, b_remaining * sizeof(double));
C->nnz += b_remaining;
}
}
Expand Down
22 changes: 14 additions & 8 deletions tests/all_tests.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,17 @@
#include "jacobian_tests/test_sum.h"
#include "utils/test_csc_matrix.h"
#include "utils/test_csr_matrix.h"
#include "wsum_hess/test_entr.h"
#include "wsum_hess/test_exp.h"
#include "wsum_hess/test_hyperbolic.h"
#include "wsum_hess/test_log.h"
#include "wsum_hess/test_logistic.h"
#include "wsum_hess/test_power.h"
#include "wsum_hess/elementwise/test_entr.h"
#include "wsum_hess/elementwise/test_exp.h"
#include "wsum_hess/elementwise/test_hyperbolic.h"
#include "wsum_hess/elementwise/test_log.h"
#include "wsum_hess/elementwise/test_logistic.h"
#include "wsum_hess/elementwise/test_power.h"
#include "wsum_hess/elementwise/test_trig.h"
#include "wsum_hess/elementwise/test_xexp.h"
#include "wsum_hess/test_hstack.h"
#include "wsum_hess/test_rel_entr.h"
#include "wsum_hess/test_sum.h"
#include "wsum_hess/test_trig.h"
#include "wsum_hess/test_xexp.h"

int main(void)
{
Expand Down Expand Up @@ -95,6 +97,10 @@ int main(void)
mu_run_test(test_wsum_hess_sum_log_linear, tests_run);
mu_run_test(test_wsum_hess_sum_log_axis0, tests_run);
mu_run_test(test_wsum_hess_sum_log_axis1, tests_run);
mu_run_test(test_wsum_hess_rel_entr_1, tests_run);
mu_run_test(test_wsum_hess_rel_entr_2, tests_run);
mu_run_test(test_wsum_hess_hstack, tests_run);
mu_run_test(test_wsum_hess_hstack_matrix, tests_run);

printf("\n--- Utility Tests ---\n");
mu_run_test(test_diag_csr_mult, tests_run);
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading