Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/other.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,7 @@

expr *new_quad_form(expr *child, CSR_Matrix *Q);

/* product of all entries, without axis argument */
expr *new_prod(expr *child);

#endif /* OTHER_H */
9 changes: 9 additions & 0 deletions include/subexpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ typedef struct sum_expr
struct int_double_pair *int_double_pairs; /* for sorting jacobian entries */
} sum_expr;

/* Product of all entries */
typedef struct prod_expr
{
expr base;
int num_of_zeros;
int zero_index; /* index of zero element when num_of_zeros == 1 */
double prod_nonzero; /* product of non-zero elements */
} prod_expr;

/* Horizontal stack (concatenate) */
typedef struct hstack_expr
{
Expand Down
276 changes: 276 additions & 0 deletions src/other/prod.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
#include "other.h"
#include <assert.h>
#include <math.h>
#include <stdlib.h>
#include <string.h>

#define IS_ZERO(x) (fabs((x)) < 1e-8)

static inline void wsum_hess_no_zeros(expr *node, const double *w);
static inline void wsum_hess_one_zero(expr *node, const double *w);
static inline void wsum_hess_two_zeros(expr *node, const double *w);
static inline void wsum_hess_many_zeros(expr *node, const double *w);

static void forward(expr *node, const double *u)
{
expr *x = node->left;

/* forward pass of child */
x->forward(x, u);

/* local forward pass and count zeros */
double prod_nonzero = 1.0;
int zeros = 0;
int zero_idx = -1;
for (int i = 0; i < x->size; i++)
{
if (IS_ZERO(x->value[i]))
{
zeros++;
zero_idx = i;
}
else
{
prod_nonzero *= x->value[i];
}
}

node->value[0] = (zeros > 0) ? 0.0 : prod_nonzero;
prod_expr *pnode = (prod_expr *) node;
pnode->num_of_zeros = zeros;
pnode->zero_index = zero_idx;
pnode->prod_nonzero = prod_nonzero;
}

static void jacobian_init(expr *node)
{
expr *x = node->left;

/* initialize child's jacobian */
x->jacobian_init(x);

/* if x is a variable */
if (x->var_id != NOT_A_VARIABLE)
{
node->jacobian = new_csr_matrix(1, node->n_vars, x->size);
node->jacobian->p[0] = 0;
node->jacobian->p[1] = x->size;
for (int j = 0; j < x->size; j++)
{
node->jacobian->i[j] = x->var_id + j;
}
}
else
{
assert(false && "not implemented");
}
}

static void eval_jacobian(expr *node)
{
expr *x = node->left;
prod_expr *pnode = (prod_expr *) node;
int num_of_zeros = pnode->num_of_zeros;

/* if x is a variable */
if (x->var_id != NOT_A_VARIABLE)
{
if (num_of_zeros == 0)
{
for (int j = 0; j < x->size; j++)
{
node->jacobian->x[j] = node->value[0] / x->value[j];
}
}
else if (num_of_zeros == 1)
{
memset(node->jacobian->x, 0, sizeof(double) * x->size);
node->jacobian->x[pnode->zero_index] = pnode->prod_nonzero;
}
else
{
memset(node->jacobian->x, 0, sizeof(double) * x->size);
}
}
else
{
assert(false && "not implemented");
}
}

static void wsum_hess_init(expr *node)
{
expr *x = node->left;

/* if x is a variable */
if (x->var_id != NOT_A_VARIABLE)
{
/* allocate n_vars x n_vars CSR matrix with dense block */
int block_size = x->size;
int nnz = block_size * block_size;
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, nnz);

/* fill row pointers for the dense block */
for (int i = 0; i < block_size; i++)
{
node->wsum_hess->p[x->var_id + i] = i * block_size;
}

/* fill row pointers for rows after the block */
for (int i = x->var_id + block_size; i <= node->n_vars; i++)
{
node->wsum_hess->p[i] = nnz;
}

/* fill column indices for the dense block */
for (int i = 0; i < block_size; i++)
{
for (int j = 0; j < block_size; j++)
{
node->wsum_hess->i[i * block_size + j] = x->var_id + j;
}
}
}
else
{
assert(false && "not implemented");
}
}

static void eval_wsum_hess(expr *node, const double *w)
{
expr *x = node->left;
int num_of_zeros = ((prod_expr *) node)->num_of_zeros;

/* if x is a variable */
if (x->var_id != NOT_A_VARIABLE)
{
if (num_of_zeros == 0)
{
wsum_hess_no_zeros(node, w);
}
else if (num_of_zeros == 1)
{
wsum_hess_one_zero(node, w);
}
else if (num_of_zeros == 2)
{
wsum_hess_two_zeros(node, w);
}
else
{
wsum_hess_many_zeros(node, w);
}
}
else
{
assert(false && "not implemented");
}
}

static bool is_affine(const expr *node)
{
(void) node;
return false;
}

static void free_type_data(expr *node)
{
(void) node;
}

expr *new_prod(expr *child)
{
/* Output is scalar: 1 x 1 */
prod_expr *pnode = (prod_expr *) calloc(1, sizeof(prod_expr));
expr *node = &pnode->base;
init_expr(node, 1, 1, child->n_vars, forward, jacobian_init, eval_jacobian,
is_affine, free_type_data);
node->wsum_hess_init = wsum_hess_init;
node->eval_wsum_hess = eval_wsum_hess;
node->left = child;
expr_retain(child);
return node;
}

// ---------------------------------------------------------------------------------------
// Helper functions for Hessian evaluation
// ---------------------------------------------------------------------------------------
static inline void wsum_hess_no_zeros(expr *node, const double *w)
{
double *x = node->left->value;
int n = node->left->size;
double wf = w[0] * node->value[0];

for (int i = 0; i < n; i++)
{
for (int j = 0; j < n; j++)
{
if (i == j)
{
node->wsum_hess->x[i * n + j] = 0.0;
}
else
{
node->wsum_hess->x[i * n + j] = wf / (x[i] * x[j]);
}
}
}
}

static inline void wsum_hess_one_zero(expr *node, const double *w)
{
expr *x = node->left;
double *H = node->wsum_hess->x;
memset(H, 0, sizeof(double) * (x->size * x->size));
int p = ((prod_expr *) node)->zero_index;
double prod_nonzero = ((prod_expr *) node)->prod_nonzero;
double w_prod = w[0] * prod_nonzero;

/* fill row p and column p */
for (int j = 0; j < x->size; j++)
{
if (j == p) continue;

double hess_val = w_prod / x->value[j];
H[p * x->size + j] = hess_val;
H[j * x->size + p] = hess_val;
}
}

static inline void wsum_hess_two_zeros(expr *node, const double *w)
{
expr *x = node->left;
int n = x->size;
memset(node->wsum_hess->x, 0, sizeof(double) * (n * n));

/* find indices p and q where x[p] = x[q] = 0 */
int p = -1, q = -1;
for (int i = 0; i < n; i++)
{
if (IS_ZERO(x->value[i]))
{
if (p == -1)
{
p = i;
}
else
{
q = i;
break;
}
}
}
assert(p != -1 && q != -1);

double hess_val = w[0] * ((prod_expr *) node)->prod_nonzero;
node->wsum_hess->x[p * n + q] = hess_val;
node->wsum_hess->x[q * n + p] = hess_val;
}

static inline void wsum_hess_many_zeros(expr *node, const double *w)
{
expr *x = node->left;
memset(node->wsum_hess->x, 0, sizeof(double) * (x->size * x->size));
(void) w;
}
9 changes: 9 additions & 0 deletions tests/all_tests.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "jacobian_tests/test_elementwise_mult.h"
#include "jacobian_tests/test_hstack.h"
#include "jacobian_tests/test_log.h"
#include "jacobian_tests/test_prod.h"
#include "jacobian_tests/test_quad_form.h"
#include "jacobian_tests/test_quad_over_lin.h"
#include "jacobian_tests/test_rel_entr.h"
Expand All @@ -31,6 +32,7 @@
#include "wsum_hess/elementwise/test_xexp.h"
#include "wsum_hess/test_hstack.h"
#include "wsum_hess/test_multiply.h"
#include "wsum_hess/test_prod.h"
#include "wsum_hess/test_quad_form.h"
#include "wsum_hess/test_quad_over_lin.h"
#include "wsum_hess/test_rel_entr.h"
Expand Down Expand Up @@ -75,6 +77,9 @@ int main(void)
mu_run_test(test_quad_form, tests_run);
/* commented out - see test_quad_form.h */
// mu_run_test(test_quad_form2, tests_run);
mu_run_test(test_jacobian_prod_no_zero, tests_run);
mu_run_test(test_jacobian_prod_one_zero, tests_run);
mu_run_test(test_jacobian_prod_two_zeros, tests_run);
mu_run_test(test_jacobian_sum_log, tests_run);
mu_run_test(test_jacobian_sum_mult, tests_run);
mu_run_test(test_jacobian_sum_log_axis_0, tests_run);
Expand Down Expand Up @@ -103,6 +108,10 @@ int main(void)
mu_run_test(test_wsum_hess_sum_log_linear, tests_run);
mu_run_test(test_wsum_hess_sum_log_axis0, tests_run);
mu_run_test(test_wsum_hess_sum_log_axis1, tests_run);
mu_run_test(test_wsum_hess_prod_no_zero, tests_run);
mu_run_test(test_wsum_hess_prod_one_zero, tests_run);
mu_run_test(test_wsum_hess_prod_two_zeros, tests_run);
mu_run_test(test_wsum_hess_prod_many_zeros, tests_run);
mu_run_test(test_wsum_hess_rel_entr_1, tests_run);
mu_run_test(test_wsum_hess_rel_entr_2, tests_run);
mu_run_test(test_wsum_hess_hstack, tests_run);
Expand Down
Loading