diff --git a/README.md b/README.md index d6e10bc..d1d7b1f 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,5 @@ -1. power should be double -2. can we reuse calculations, like in hessian of logistic 3. more tests for chain rule elementwise univariate hessian 4. in the refactor, add consts 5. multiply with one constant vector/scalar argument -6. AX where X is a matrix. Can that happen? How is that canonicalized? Maybe it can't happen. \ No newline at end of file +6. AX where X is a matrix. Can that happen? How is that canonicalized? Maybe it can't happen. +7. Must be able to compute jacobian and hessian of A @ phi(x), so linear operator needs other code! This requires new infrastructure, I think. \ No newline at end of file diff --git a/include/affine.h b/include/affine.h index c2bcd74..e1b66ac 100644 --- a/include/affine.h +++ b/include/affine.h @@ -2,11 +2,13 @@ #define AFFINE_H #include "expr.h" +#include "subexpr.h" #include "utils/CSR_Matrix.h" expr *new_linear(expr *u, const CSR_Matrix *A); expr *new_add(expr *left, expr *right); + expr *new_sum(expr *child, int axis); expr *new_hstack(expr **args, int n_args, int n_vars); diff --git a/include/elementwise_univariate.h b/include/elementwise_univariate.h index 52cd4a1..f0ed791 100644 --- a/include/elementwise_univariate.h +++ b/include/elementwise_univariate.h @@ -3,6 +3,10 @@ #include "expr.h" +/* Helper function to initialize an elementwise expr (can be used with derived types) + */ +void init_elementwise(expr *node, expr *child); + expr *new_exp(expr *child); expr *new_log(expr *child); expr *new_entr(expr *child); @@ -14,7 +18,7 @@ expr *new_tanh(expr *child); expr *new_asinh(expr *child); expr *new_atanh(expr *child); expr *new_logistic(expr *child); -expr *new_power(expr *child, int p); +expr *new_power(expr *child, double p); expr *new_xexp(expr *child); /* the jacobian and wsum_hess for elementwise univariate atoms are always @@ -22,11 +26,11 @@ expr *new_xexp(expr *child); void jacobian_init_elementwise(expr *node); void eval_jacobian_elementwise(expr *node); void wsum_hess_init_elementwise(expr *node); -void eval_wsum_hess_elementwise(expr *node, double *w); +void eval_wsum_hess_elementwise(expr *node, const double *w); expr *new_elementwise(expr *child); /* no elementwise atoms are affine according to our convention, so we can have a common implementation */ -bool is_affine_elementwise(expr *node); +bool is_affine_elementwise(const expr *node); #endif /* ELEMENTWISE_UNIVARIATE_H */ diff --git a/include/expr.h b/include/expr.h index 83ee83e..f6381a3 100644 --- a/include/expr.h +++ b/include/expr.h @@ -7,70 +7,58 @@ #include #define JAC_IDXS_NOT_SET -1 - -/* Forward declarations */ -struct expr; -struct int_double_pair; +#define NOT_A_VARIABLE -1 /* Function pointer types */ +struct expr; typedef void (*forward_fn)(struct expr *node, const double *u); typedef void (*jacobian_init_fn)(struct expr *node); typedef void (*wsum_hess_init_fn)(struct expr *node); typedef void (*eval_jacobian_fn)(struct expr *node); -typedef void (*wsum_hess_fn)(struct expr *node, double *w); +typedef void (*wsum_hess_fn)(struct expr *node, const double *w); typedef void (*local_jacobian_fn)(struct expr *node, double *out); -typedef void (*local_wsum_hess_fn)(struct expr *node, double *out, double *w); -typedef bool (*is_affine_fn)(struct expr *node); - -/* TODO: implement proper polymorphism */ +typedef void (*local_wsum_hess_fn)(struct expr *node, double *out, const double *w); +typedef bool (*is_affine_fn)(const struct expr *node); +typedef void (*free_type_data_fn)(struct expr *node); -/* Expression node structure */ +/* Base expression node structure - contains only common fields */ typedef struct expr { // ------------------------------------------------------------------------ // general quantities // ------------------------------------------------------------------------ - int d1, d2, size; - int n_vars; - int var_id; - int refcount; + int d1, d2, size, n_vars, refcount, var_id; struct expr *left; struct expr *right; - struct expr **args; /* hstack can have multiple arguments */ - int n_args; double *dwork; int *iwork; - struct int_double_pair *int_double_pairs; /* for sorting jacobian entries */ - int p; /* power of power expression */ - int axis; /* axis for sum or similar operations */ - CSR_Matrix *Q; /* Q for quad_form */ // ------------------------------------------------------------------------ - // forward pass related quantities + // oracle related quantities // ------------------------------------------------------------------------ double *value; - forward_fn forward; - - // ------------------------------------------------------------------------ - // jacobian related quantities - // ------------------------------------------------------------------------ CSR_Matrix *jacobian; CSR_Matrix *wsum_hess; - CSR_Matrix *CSR_work; + forward_fn forward; jacobian_init_fn jacobian_init; wsum_hess_init_fn wsum_hess_init; eval_jacobian_fn eval_jacobian; wsum_hess_fn eval_wsum_hess; - local_jacobian_fn local_jacobian; - local_wsum_hess_fn local_wsum_hess; - is_affine_fn is_affine; - // for every linear operator we store A in CSR and CSC - CSC_Matrix *A_csc; - CSR_Matrix *A_csr; + // ------------------------------------------------------------------------ + // other things + // ------------------------------------------------------------------------ + is_affine_fn is_affine; + local_jacobian_fn local_jacobian; /* used by elementwise univariate atoms*/ + local_wsum_hess_fn local_wsum_hess; /* used by elementwise univariate atoms*/ + free_type_data_fn free_type_data; /* Cleanup for type-specific fields */ } expr; +void init_expr(expr *node, int d1, int d2, int n_vars, forward_fn forward, + jacobian_init_fn jacobian_init, eval_jacobian_fn eval_jacobian, + is_affine_fn is_affine, free_type_data_fn free_type_data); + expr *new_expr(int d1, int d2, int n_vars); void free_expr(expr *node); diff --git a/include/other.h b/include/other.h index 3dbdf72..95fb035 100644 --- a/include/other.h +++ b/include/other.h @@ -2,6 +2,7 @@ #define OTHER_H #include "expr.h" +#include "subexpr.h" #include "utils/CSR_Matrix.h" expr *new_quad_form(expr *child, CSR_Matrix *Q); diff --git a/include/subexpr.h b/include/subexpr.h new file mode 100644 index 0000000..4f56ea4 --- /dev/null +++ b/include/subexpr.h @@ -0,0 +1,51 @@ +#ifndef SUBEXPR_H +#define SUBEXPR_H + +#include "expr.h" +#include "utils/CSC_Matrix.h" +#include "utils/CSR_Matrix.h" + +/* Forward declaration */ +struct int_double_pair; + +/* Type-specific expression structures that "inherit" from expr */ + +/* Linear operator: y = A * x */ +typedef struct linear_op_expr +{ + expr base; + CSC_Matrix *A_csc; + CSR_Matrix *A_csr; +} linear_op_expr; + +/* Power: y = x^p */ +typedef struct power_expr +{ + expr base; + double p; +} power_expr; + +/* Quadratic form: y = x'*Q*x */ +typedef struct quad_form_expr +{ + expr base; + CSR_Matrix *Q; +} quad_form_expr; + +/* Sum reduction along an axis */ +typedef struct sum_expr +{ + expr base; + int axis; + struct int_double_pair *int_double_pairs; /* for sorting jacobian entries */ +} sum_expr; + +/* Horizontal stack (concatenate) */ +typedef struct hstack_expr +{ + expr base; + expr **args; + int n_args; +} hstack_expr; + +#endif /* SUBEXPR_H */ diff --git a/include/utils/COO_Matrix.h b/include/utils/COO_Matrix.h deleted file mode 100644 index 6e922e3..0000000 --- a/include/utils/COO_Matrix.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef MATRIX_H -#define MATRIX_H - -/* Coordinate (COO) format sparse matrix */ -typedef struct COO_Matrix -{ - int nnz; - int *rows; - int *cols; - double *vals; -} COO_Matrix; - -COO_Matrix *new_coo_matrix(int nnz); -void free_coo_matrix(COO_Matrix *matrix); - -#endif /* MATRIX_H */ diff --git a/include/utils/CSC_Matrix.h b/include/utils/CSC_Matrix.h index 6636cea..b3eb358 100644 --- a/include/utils/CSC_Matrix.h +++ b/include/utils/CSC_Matrix.h @@ -39,4 +39,9 @@ CSR_Matrix *ATA_alloc(const CSC_Matrix *A); */ void ATDA_values(const CSC_Matrix *A, const double *d, CSR_Matrix *C); +/* C = z^T A where A is in CSC format and C is assumed to have one row. + * C must have column indices pre-computed. Fills in values of C only. + */ +void csc_matvec_fill_values(const CSC_Matrix *A, const double *z, CSR_Matrix *C); + #endif /* CSC_MATRIX_H */ diff --git a/src/affine/add.c b/src/affine/add.c index 8133162..9f6b2d1 100644 --- a/src/affine/add.c +++ b/src/affine/add.c @@ -45,7 +45,7 @@ static void wsum_hess_init(expr *node) node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, nnz_max); } -static void eval_wsum_hess(expr *node, double *w) +static void eval_wsum_hess(expr *node, const double *w) { /* evaluate children's wsum_hess */ node->left->eval_wsum_hess(node->left, w); @@ -55,20 +55,14 @@ static void eval_wsum_hess(expr *node, double *w) sum_csr_matrices(node->left->wsum_hess, node->right->wsum_hess, node->wsum_hess); } -static bool is_affine(expr *node) +static bool is_affine(const expr *node) { return node->left->is_affine(node->left) && node->right->is_affine(node->right); } expr *new_add(expr *left, expr *right) { - if (!left || !right) return NULL; - if (left->d1 != right->d1) return NULL; - if (left->d2 != right->d2) return NULL; - expr *node = new_expr(left->d1, left->d2, left->n_vars); - if (!node) return NULL; - node->left = left; node->right = right; expr_retain(left); diff --git a/src/affine/constant.c b/src/affine/constant.c index f328137..d427ec9 100644 --- a/src/affine/constant.c +++ b/src/affine/constant.c @@ -8,20 +8,16 @@ static void forward(expr *node, const double *u) (void) u; } -static bool is_affine(expr *node) +static bool is_affine(const expr *node) { (void) node; - return true; /* constant is affine */ + return true; } expr *new_constant(int d1, int d2, int n_vars, const double *values) { expr *node = new_expr(d1, d2, n_vars); - if (!node) return NULL; - - /* Copy constant values */ memcpy(node->value, values, d1 * d2 * sizeof(double)); - node->forward = forward; node->is_affine = is_affine; diff --git a/src/affine/hstack.c b/src/affine/hstack.c index e2b8dd9..a00863c 100644 --- a/src/affine/hstack.c +++ b/src/affine/hstack.c @@ -1,21 +1,23 @@ #include "affine.h" #include +#include #include static void forward(expr *node, const double *u) { + hstack_expr *hnode = (hstack_expr *) node; /* children's forward passes */ - for (int i = 0; i < node->n_args; i++) + for (int i = 0; i < hnode->n_args; i++) { - node->args[i]->forward(node->args[i], u); + hnode->args[i]->forward(hnode->args[i], u); } /* concatenate values horizontally */ int offset = 0; - for (int i = 0; i < node->n_args; i++) + for (int i = 0; i < hnode->n_args; i++) { - expr *child = node->args[i]; + expr *child = hnode->args[i]; memcpy(node->value + offset, child->value, child->size * sizeof(double)); offset += child->size; } @@ -23,32 +25,29 @@ static void forward(expr *node, const double *u) static void jacobian_init(expr *node) { + hstack_expr *hnode = (hstack_expr *) node; + /* initialize children's jacobians */ int nnz = 0; - for (int i = 0; i < node->n_args; i++) + for (int i = 0; i < hnode->n_args; i++) { - node->args[i]->jacobian_init(node->args[i]); - nnz += node->args[i]->jacobian->nnz; + hnode->args[i]->jacobian_init(hnode->args[i]); + nnz += hnode->args[i]->jacobian->nnz; } node->jacobian = new_csr_matrix(node->size, node->n_vars, nnz); -} -static void eval_jacobian(expr *node) -{ - /* evaluate children's jacobians */ + /* precompute sparsity pattern of this jacobian's node */ int row_offset = 0; CSR_Matrix *A = node->jacobian; A->nnz = 0; - for (int i = 0; i < node->n_args; i++) + for (int i = 0; i < hnode->n_args; i++) { - expr *child = node->args[i]; - child->eval_jacobian(child); + expr *child = hnode->args[i]; CSR_Matrix *B = child->jacobian; - /* copy columns and values */ - memcpy(A->x + A->nnz, B->x, B->nnz * sizeof(double)); + /* copy columns */ memcpy(A->i + A->nnz, B->i, B->nnz * sizeof(int)); /* set row pointers */ @@ -63,11 +62,31 @@ static void eval_jacobian(expr *node) A->p[node->size] = A->nnz; } -static bool is_affine(expr *node) +static void eval_jacobian(expr *node) +{ + hstack_expr *hnode = (hstack_expr *) node; + CSR_Matrix *A = node->jacobian; + A->nnz = 0; + + for (int i = 0; i < hnode->n_args; i++) + { + expr *child = hnode->args[i]; + child->eval_jacobian(child); + + /* copy values */ + memcpy(A->x + A->nnz, child->jacobian->x, + child->jacobian->nnz * sizeof(double)); + A->nnz += child->jacobian->nnz; + } +} + +static bool is_affine(const expr *node) { - for (int i = 0; i < node->n_args; i++) + const hstack_expr *hnode = (const hstack_expr *) node; + + for (int i = 0; i < hnode->n_args; i++) { - if (!node->args[i]->is_affine(node->args[i])) + if (!hnode->args[i]->is_affine(hnode->args[i])) { return false; } @@ -75,6 +94,15 @@ static bool is_affine(expr *node) return true; } +static void free_type_data(expr *node) +{ + hstack_expr *hnode = (hstack_expr *) node; + for (int i = 0; i < hnode->n_args; i++) + { + free_expr(hnode->args[i]); + } +} + expr *new_hstack(expr **args, int n_args, int n_vars) { /* compute second dimension */ @@ -84,20 +112,20 @@ expr *new_hstack(expr **args, int n_args, int n_vars) d2 += args[i]->d2; } - expr *node = new_expr(args[0]->d1, d2, n_vars); - if (!node) return NULL; - node->args = args; - node->n_args = n_args; + /* Allocate the type-specific struct */ + hstack_expr *hnode = (hstack_expr *) calloc(1, sizeof(hstack_expr)); + expr *node = &hnode->base; + init_expr(node, args[0]->d1, d2, n_vars, forward, jacobian_init, eval_jacobian, + is_affine, free_type_data); + + /* Set type-specific fields */ + hnode->args = args; + hnode->n_args = n_args; for (int i = 0; i < n_args; i++) { expr_retain(args[i]); } - node->forward = forward; - node->is_affine = is_affine; - node->jacobian_init = jacobian_init; - node->eval_jacobian = eval_jacobian; - return node; } diff --git a/src/affine/linear_op.c b/src/affine/linear_op.c index 2bc988a..4e9caa8 100644 --- a/src/affine/linear_op.c +++ b/src/affine/linear_op.c @@ -4,6 +4,7 @@ static void forward(expr *node, const double *u) { expr *x = node->left; + /* child's forward pass */ node->left->forward(node->left, u); @@ -11,29 +12,40 @@ static void forward(expr *node, const double *u) csr_matvec(node->jacobian, x->value, node->value, x->var_id); } -static bool is_affine(expr *node) +static bool is_affine(const expr *node) { return node->left->is_affine(node->left); } -expr *new_linear(expr *u, const CSR_Matrix *A) +static void free_type_data(expr *node) { - expr *node = new_expr(A->m, 1, u->n_vars); - if (!node) return NULL; + linear_op_expr *lin_node = (linear_op_expr *) node; + /* memory pointing to by A_csr will already be freed when the + jacobian is freed, so free_csr_matrix(lin_node->A_csr) should + be commented out */ + // free_csr_matrix(lin_node->A_csr); + free_csc_matrix(lin_node->A_csc); + lin_node->A_csr = NULL; + lin_node->A_csc = NULL; +} +expr *new_linear(expr *u, const CSR_Matrix *A) +{ + /* Allocate the type-specific struct */ + linear_op_expr *lin_node = (linear_op_expr *) calloc(1, sizeof(linear_op_expr)); + expr *node = &lin_node->base; + init_expr(node, A->m, 1, u->n_vars, forward, NULL, NULL, is_affine, + free_type_data); node->left = u; expr_retain(u); - node->forward = forward; - node->is_affine = is_affine; - /* allocate jacobian and copy A into it */ - // TODO: this should eventually be removed - node->jacobian = new_csr_matrix(A->m, A->n, A->nnz); - copy_csr_matrix(A, node->jacobian); + /* Initialize type-specific fields */ + lin_node->A_csr = new_csr_matrix(A->m, A->n, A->nnz); + copy_csr_matrix(A, lin_node->A_csr); + lin_node->A_csc = csr_to_csc(A); - node->A_csr = new_csr_matrix(A->m, A->n, A->nnz); - copy_csr_matrix(A, node->A_csr); - node->A_csc = csr_to_csc(A); + /* what if we have A @ phi(x). Then I don't think this is correct. */ + node->jacobian = lin_node->A_csr; return node; } diff --git a/src/affine/sum.c b/src/affine/sum.c index a4c11a7..ee9988b 100644 --- a/src/affine/sum.c +++ b/src/affine/sum.c @@ -10,11 +10,13 @@ static void forward(expr *node, const double *u) int i, j, end; double sum; expr *x = node->left; + sum_expr *snode = (sum_expr *) node; + int axis = snode->axis; /* child's forward pass */ x->forward(x, u); - if (node->axis == -1) + if (axis == -1) { sum = 0.0; end = x->d1 * x->d2; @@ -26,7 +28,7 @@ static void forward(expr *node, const double *u) } node->value[0] = sum; } - else if (node->axis == 0) + else if (axis == 0) { /* sum rows together */ for (j = 0; j < x->d2; j++) @@ -40,7 +42,7 @@ static void forward(expr *node, const double *u) node->value[j] = sum; } } - else if (node->axis == 1) + else if (axis == 1) { memset(node->value, 0, node->d1 * sizeof(double)); @@ -59,36 +61,40 @@ static void forward(expr *node, const double *u) static void jacobian_init(expr *node) { expr *x = node->left; + sum_expr *snode = (sum_expr *) node; + /* initialize child's jacobian */ x->jacobian_init(x); /* we never have to store more than the child's nnz */ node->jacobian = new_csr_matrix(node->d1, node->n_vars, x->jacobian->nnz); - node->int_double_pairs = new_int_double_pair_array(x->jacobian->nnz); + snode->int_double_pairs = new_int_double_pair_array(x->jacobian->nnz); } static void eval_jacobian(expr *node) { expr *x = node->left; + sum_expr *snode = (sum_expr *) node; + int axis = snode->axis; /* evaluate child's jacobian */ x->eval_jacobian(x); /* sum rows or columns of child's jacobian */ - if (node->axis == -1) + if (axis == -1) { - sum_all_rows_csr(x->jacobian, node->jacobian, node->int_double_pairs); + sum_all_rows_csr(x->jacobian, node->jacobian, snode->int_double_pairs); } - else if (node->axis == 0) + else if (axis == 0) { - sum_block_of_rows_csr(x->jacobian, node->jacobian, node->int_double_pairs, + sum_block_of_rows_csr(x->jacobian, node->jacobian, snode->int_double_pairs, x->d1); } - else if (node->axis == 1) + else if (axis == 1) { sum_evenly_spaced_rows_csr(x->jacobian, node->jacobian, - node->int_double_pairs, node->d1); + snode->int_double_pairs, node->d1); } } @@ -103,19 +109,21 @@ static void wsum_hess_init(expr *node) node->dwork = malloc(x->size * sizeof(double)); } -static void eval_wsum_hess(expr *node, double *w) +static void eval_wsum_hess(expr *node, const double *w) { expr *x = node->left; + sum_expr *snode = (sum_expr *) node; + int axis = snode->axis; - if (node->axis == -1) + if (axis == -1) { scaled_ones(node->dwork, x->size, *w); } - else if (node->axis == 0) + else if (axis == 0) { repeat(node->dwork, w, x->d2, x->d1); } - else if (node->axis == 1) + else if (axis == 1) { tile(node->dwork, w, x->d1, x->d2); } @@ -126,11 +134,17 @@ static void eval_wsum_hess(expr *node, double *w) copy_csr_matrix(x->wsum_hess, node->wsum_hess); } -static bool is_affine(expr *node) +static bool is_affine(const expr *node) { return node->left->is_affine(node->left); } +static void free_type_data(expr *node) +{ + sum_expr *snode = (sum_expr *) node; + free_int_double_pair_array(snode->int_double_pairs); +} + expr *new_sum(expr *child, int axis) { int d1 = 0; @@ -151,18 +165,21 @@ expr *new_sum(expr *child, int axis) break; } - expr *node = new_expr(d1, 1, child->n_vars); - if (!node) return NULL; - + /* Allocate the type-specific struct */ + sum_expr *snode = (sum_expr *) calloc(1, sizeof(sum_expr)); + expr *node = &snode->base; + init_expr(node, d1, 1, child->n_vars, forward, jacobian_init, eval_jacobian, + is_affine, free_type_data); node->left = child; expr_retain(child); - node->forward = forward; - node->is_affine = is_affine; - node->jacobian_init = jacobian_init; - node->eval_jacobian = eval_jacobian; + + /* hessian function pointers */ node->wsum_hess_init = wsum_hess_init; node->eval_wsum_hess = eval_wsum_hess; - node->axis = axis; + + /* Set type-specific fields */ + snode->axis = axis; + snode->int_double_pairs = NULL; return node; } diff --git a/src/affine/variable.c b/src/affine/variable.c index 122db20..46e2446 100644 --- a/src/affine/variable.c +++ b/src/affine/variable.c @@ -20,7 +20,7 @@ static void jacobian_init(expr *node) node->jacobian->p[size] = size; } -static bool is_affine(expr *node) +static bool is_affine(const expr *node) { (void) node; return true; @@ -29,13 +29,10 @@ static bool is_affine(expr *node) expr *new_variable(int d1, int d2, int var_id, int n_vars) { expr *node = new_expr(d1, d2, n_vars); - if (!node) return NULL; - node->forward = forward; node->var_id = var_id; node->is_affine = is_affine; node->jacobian_init = jacobian_init; - // node->jacobian = NULL; return node; } diff --git a/src/bivariate/multiply.c b/src/bivariate/multiply.c index 0bed4e9..eae64ed 100644 --- a/src/bivariate/multiply.c +++ b/src/bivariate/multiply.c @@ -32,7 +32,7 @@ static void jacobian_init(expr *node) node->left->jacobian_init(node->left); } - if (node->right->var_id != -1) + if (node->right->var_id != NOT_A_VARIABLE) { node->right->jacobian_init(node->right); } diff --git a/src/bivariate/quad_over_lin.c b/src/bivariate/quad_over_lin.c index ac31dfb..20be94f 100644 --- a/src/bivariate/quad_over_lin.c +++ b/src/bivariate/quad_over_lin.c @@ -1,4 +1,6 @@ #include "bivariate.h" +#include "subexpr.h" +#include "utils/CSC_Matrix.h" #include #include #include @@ -34,7 +36,7 @@ static void jacobian_init(expr *node) expr *y = node->right; /* if left node is a variable */ - if (x->var_id != -1) + if (x->var_id != NOT_A_VARIABLE) { node->jacobian = new_csr_matrix(1, node->n_vars, x->d1 + 1); node->jacobian->p[0] = 0; @@ -58,14 +60,15 @@ static void jacobian_init(expr *node) } } } - else /* left node is not a variable */ + else /* left node is not a variable (guaranteed to be a linear operator) */ { + linear_op_expr *lin_x = (linear_op_expr *) x; node->dwork = (double *) malloc(x->d1 * sizeof(double)); /* compute required allocation and allocate jacobian */ bool *col_nz = (bool *) calloc( node->n_vars, sizeof(bool)); /* TODO: could use iwork here instead*/ - int nonzero_cols = count_nonzero_cols(x->jacobian, col_nz); + int nonzero_cols = count_nonzero_cols(lin_x->base.jacobian, col_nz); node->jacobian = new_csr_matrix(1, node->n_vars, nonzero_cols + 1); /* precompute column indices */ @@ -88,11 +91,8 @@ static void jacobian_init(expr *node) node->jacobian->p[0] = 0; node->jacobian->p[1] = node->jacobian->nnz; - /* store A^T of child's A to simplify chain rule computation */ - node->iwork = (int *) malloc(x->jacobian->n * sizeof(int)); - node->CSR_work = transpose(x->jacobian, node->iwork); - /* find position where y should be inserted */ + node->iwork = (int *) malloc(sizeof(int)); for (int j = 0; j < node->jacobian->nnz; j++) { if (node->jacobian->i[j] == y->var_id) @@ -110,7 +110,7 @@ static void eval_jacobian(expr *node) expr *y = node->right; /* if x is a variable */ - if (x->var_id != -1) + if (x->var_id != NOT_A_VARIABLE) { /* if x has lower idx than y*/ if (x->var_id < y->var_id) @@ -132,14 +132,16 @@ static void eval_jacobian(expr *node) } else /* x is not a variable */ { + CSC_Matrix *A_csc = ((linear_op_expr *) x)->A_csc; + /* local jacobian */ for (int j = 0; j < x->d1; j++) { node->dwork[j] = (2.0 * x->value[j]) / y->value[0]; } - /* chain rule (no derivative wrt y) */ - csr_matvec_fill_values(node->CSR_work, node->dwork, node->jacobian); + /* chain rule (no derivative wrt y) using CSC format */ + csc_matvec_fill_values(A_csc, node->dwork, node->jacobian); /* insert derivative wrt y at right place (for correctness this assumes that y does not appear in the denominator, but this will always be diff --git a/src/elementwise_univariate/common.c b/src/elementwise_univariate/common.c index 8733226..986d26a 100644 --- a/src/elementwise_univariate/common.c +++ b/src/elementwise_univariate/common.c @@ -1,4 +1,6 @@ #include "elementwise_univariate.h" +#include "expr.h" +#include "subexpr.h" #include void jacobian_init_elementwise(expr *node) @@ -6,7 +8,7 @@ void jacobian_init_elementwise(expr *node) expr *child = node->left; /* if the variable is a child */ - if (child->var_id != -1) + if (child->var_id != NOT_A_VARIABLE) { node->jacobian = new_csr_matrix(node->size, node->n_vars, node->size); for (int j = 0; j < node->size; j++) @@ -16,7 +18,7 @@ void jacobian_init_elementwise(expr *node) } node->jacobian->p[node->size] = node->size; } - /* otherwise it should be a linear operator */ + /* otherwise it will be a linear operator */ else { node->jacobian = new_csr_matrix(child->jacobian->m, child->jacobian->n, @@ -29,14 +31,16 @@ void eval_jacobian_elementwise(expr *node) { expr *child = node->left; - if (child->var_id != -1) + if (child->var_id != NOT_A_VARIABLE) { node->local_jacobian(node, node->jacobian->x); } else { + /* Child will be a linear operator */ + linear_op_expr *lin_child = (linear_op_expr *) child; node->local_jacobian(node, node->dwork); - diag_csr_mult(node->dwork, child->A_csr, node->jacobian); + diag_csr_mult(node->dwork, lin_child->A_csr, node->jacobian); } } @@ -47,7 +51,7 @@ void wsum_hess_init_elementwise(expr *node) int i; /* if the variable is a child*/ - if (id != -1) + if (id != NOT_A_VARIABLE) { node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, node->size); @@ -65,40 +69,59 @@ void wsum_hess_init_elementwise(expr *node) /* otherwise it will be a linear operator */ else { - node->wsum_hess = ATA_alloc(child->A_csc); + linear_op_expr *lin_child = (linear_op_expr *) child; + node->wsum_hess = ATA_alloc(lin_child->A_csc); } } -void eval_wsum_hess_elementwise(expr *node, double *w) +void eval_wsum_hess_elementwise(expr *node, const double *w) { expr *child = node->left; - if (child->var_id != -1) + if (child->var_id != NOT_A_VARIABLE) { node->local_wsum_hess(node, node->wsum_hess->x, w); } else { + /* Child will be a linear operator */ + linear_op_expr *lin_child = (linear_op_expr *) child; node->local_wsum_hess(node, node->dwork, w); - ATDA_values(child->A_csc, node->dwork, node->wsum_hess); + ATDA_values(lin_child->A_csc, node->dwork, node->wsum_hess); } } -bool is_affine_elementwise(expr *node) +bool is_affine_elementwise(const expr *node) { (void) node; return false; } -expr *new_elementwise(expr *child) +/* Helper function to initialize an already-allocated expr for elementwise operations + * This is called when a power_expr or other type-specific struct is allocated + * and we need to initialize the base expr fields + */ +void init_elementwise(expr *node, expr *child) { - expr *node = new_expr(child->d1, child->d2, child->n_vars); - node->left = child; - expr_retain(child); - node->is_affine = is_affine_elementwise; - node->jacobian_init = jacobian_init_elementwise; - node->eval_jacobian = eval_jacobian_elementwise; + /* Initialize base fields */ + init_expr(node, child->d1, child->d2, child->n_vars, NULL, + jacobian_init_elementwise, eval_jacobian_elementwise, + is_affine_elementwise, NULL); + + /* Set wsum_hess functions */ node->wsum_hess_init = wsum_hess_init_elementwise; node->eval_wsum_hess = eval_wsum_hess_elementwise; + + /* Set left child */ + node->left = child; + expr_retain(child); +} + +expr *new_elementwise(expr *child) +{ + expr *node = (expr *) calloc(1, sizeof(expr)); + if (!node) return NULL; + + init_elementwise(node, child); return node; } diff --git a/src/elementwise_univariate/entr.c b/src/elementwise_univariate/entr.c index c0d0240..282541a 100644 --- a/src/elementwise_univariate/entr.c +++ b/src/elementwise_univariate/entr.c @@ -24,7 +24,7 @@ static void local_jacobian(expr *node, double *vals) } } -static void local_wsum_hess(expr *node, double *out, double *w) +static void local_wsum_hess(expr *node, double *out, const double *w) { double *x = node->left->value; diff --git a/src/elementwise_univariate/exp.c b/src/elementwise_univariate/exp.c index b91f1c7..3b27c15 100644 --- a/src/elementwise_univariate/exp.c +++ b/src/elementwise_univariate/exp.c @@ -19,12 +19,11 @@ static void local_jacobian(expr *node, double *vals) memcpy(vals, node->value, node->size * sizeof(double)); } -static void local_wsum_hess(expr *node, double *out, double *w) +static void local_wsum_hess(expr *node, double *out, const double *w) { - double *x = node->left->value; for (int j = 0; j < node->size; j++) { - out[j] = w[j] * exp(x[j]); + out[j] = w[j] * node->value[j]; } } diff --git a/src/elementwise_univariate/hyperbolic.c b/src/elementwise_univariate/hyperbolic.c index 65ae85f..c0bd9fa 100644 --- a/src/elementwise_univariate/hyperbolic.c +++ b/src/elementwise_univariate/hyperbolic.c @@ -21,12 +21,11 @@ static void sinh_local_jacobian(expr *node, double *vals) } } -static void sinh_local_wsum_hess(expr *node, double *out, double *w) +static void sinh_local_wsum_hess(expr *node, double *out, const double *w) { - double *x = node->left->value; for (int j = 0; j < node->size; j++) { - out[j] = w[j] * sinh(x[j]); + out[j] = w[j] * node->value[j]; } } @@ -51,21 +50,21 @@ static void tanh_forward(expr *node, const double *u) static void tanh_local_jacobian(expr *node, double *vals) { - expr *child = node->left; + double *x = node->left->value; for (int j = 0; j < node->size; j++) { - double c = cosh(child->value[j]); + double c = cosh(x[j]); vals[j] = 1.0 / (c * c); } } -static void tanh_local_wsum_hess(expr *node, double *out, double *w) +static void tanh_local_wsum_hess(expr *node, double *out, const double *w) { double *x = node->left->value; for (int j = 0; j < node->size; j++) { double c = cosh(x[j]); - out[j] = w[j] * (-2.0 * tanh(x[j]) / (c * c)); + out[j] = w[j] * (-2.0 * node->value[j] / (c * c)); } } @@ -97,7 +96,7 @@ static void asinh_local_jacobian(expr *node, double *vals) } } -static void asinh_local_wsum_hess(expr *node, double *out, double *w) +static void asinh_local_wsum_hess(expr *node, double *out, const double *w) { double *x = node->left->value; for (int j = 0; j < node->size; j++) @@ -135,7 +134,7 @@ static void atanh_local_jacobian(expr *node, double *vals) } } -static void atanh_local_wsum_hess(expr *node, double *out, double *w) +static void atanh_local_wsum_hess(expr *node, double *out, const double *w) { double *x = node->left->value; for (int j = 0; j < node->size; j++) diff --git a/src/elementwise_univariate/log.c b/src/elementwise_univariate/log.c index 89337d2..637dde3 100644 --- a/src/elementwise_univariate/log.c +++ b/src/elementwise_univariate/log.c @@ -22,7 +22,7 @@ static void local_jacobian(expr *node, double *vals) } } -static void local_wsum_hess(expr *node, double *out, double *w) +static void local_wsum_hess(expr *node, double *out, const double *w) { double *x = node->left->value; diff --git a/src/elementwise_univariate/logistic.c b/src/elementwise_univariate/logistic.c index da57fac..ba7a64d 100644 --- a/src/elementwise_univariate/logistic.c +++ b/src/elementwise_univariate/logistic.c @@ -40,12 +40,11 @@ static void local_jacobian(expr *node, double *vals) } } -static void local_wsum_hess(expr *node, double *out, double *w) +static void local_wsum_hess(expr *node, double *out, const double *w) { - double *x = node->left->value; double *sigmas; - if (node->left->var_id != -1) + if (node->left->var_id != NOT_A_VARIABLE) { sigmas = node->jacobian->x; } @@ -54,23 +53,8 @@ static void local_wsum_hess(expr *node, double *out, double *w) sigmas = node->dwork; } - // double sigma; - for (int j = 0; j < node->size; j++) { - /* - if (x[j] >= 0) - { - sigma = 1.0 / (1.0 + exp(-x[j])); - } - else - { - double exp_x = exp(x[j]); - sigma = exp_x / (1.0 + exp_x); - } - - out[j] = w[j] * sigma * (1.0 - sigma); - */ out[j] = sigmas[j] * (1.0 - sigmas[j]) * w[j]; } } diff --git a/src/elementwise_univariate/power.c b/src/elementwise_univariate/power.c index b00d358..c977749 100644 --- a/src/elementwise_univariate/power.c +++ b/src/elementwise_univariate/power.c @@ -1,5 +1,7 @@ #include "elementwise_univariate.h" +#include "subexpr.h" #include +#include static void forward(expr *node, const double *u) { @@ -7,28 +9,30 @@ static void forward(expr *node, const double *u) /* child's forward pass */ node->left->forward(node->left, u); - double *x = node->left->value; - /* local forward pass */ + double *x = node->left->value; + double p = ((power_expr *) node)->p; for (int i = 0; i < node->size; i++) { - node->value[i] = pow(x[i], node->p); + node->value[i] = pow(x[i], p); } } static void local_jacobian(expr *node, double *vals) { double *x = node->left->value; + double p = ((power_expr *) node)->p; + for (int j = 0; j < node->size; j++) { - vals[j] = node->p * pow(x[j], node->p - 1); + vals[j] = p * pow(x[j], p - 1); } } -static void local_wsum_hess(expr *node, double *out, double *w) +static void local_wsum_hess(expr *node, double *out, const double *w) { double *x = node->left->value; - double p = (double) node->p; + double p = ((power_expr *) node)->p; for (int j = 0; j < node->size; j++) { @@ -36,12 +40,18 @@ static void local_wsum_hess(expr *node, double *out, double *w) } } -expr *new_power(expr *child, int p) +expr *new_power(expr *child, double p) { - expr *node = new_elementwise(child); - node->p = p; + /* Allocate the type-specific struct */ + power_expr *pnode = (power_expr *) calloc(1, sizeof(power_expr)); + expr *node = &pnode->base; + init_elementwise(node, child); node->forward = forward; node->local_jacobian = local_jacobian; node->local_wsum_hess = local_wsum_hess; + + /* Set type-specific field */ + pnode->p = p; + return node; } diff --git a/src/elementwise_univariate/trig.c b/src/elementwise_univariate/trig.c index 20902e3..900dd2b 100644 --- a/src/elementwise_univariate/trig.c +++ b/src/elementwise_univariate/trig.c @@ -13,20 +13,18 @@ static void sin_forward(expr *node, const double *u) static void sin_local_jacobian(expr *node, double *vals) { - expr *child = node->left; + double *x = node->left->value; for (int j = 0; j < node->size; j++) { - vals[j] = cos(child->value[j]); + vals[j] = cos(x[j]); } } -static void sin_local_wsum_hess(expr *node, double *out, double *w) +static void sin_local_wsum_hess(expr *node, double *out, const double *w) { - double *x = node->left->value; - for (int j = 0; j < node->size; j++) { - out[j] = -w[j] * sin(x[j]); + out[j] = -w[j] * node->value[j]; } } @@ -51,20 +49,18 @@ static void cos_forward(expr *node, const double *u) static void cos_local_jacobian(expr *node, double *vals) { - expr *child = node->left; + double *x = node->left->value; for (int j = 0; j < node->size; j++) { - vals[j] = -sin(child->value[j]); + vals[j] = -sin(x[j]); } } -static void cos_local_wsum_hess(expr *node, double *out, double *w) +static void cos_local_wsum_hess(expr *node, double *out, const double *w) { - double *x = node->left->value; - for (int j = 0; j < node->size; j++) { - out[j] = -w[j] * cos(x[j]); + out[j] = -w[j] * node->value[j]; } } @@ -97,14 +93,14 @@ static void tan_local_jacobian(expr *node, double *vals) } } -static void tan_local_wsum_hess(expr *node, double *out, double *w) +static void tan_local_wsum_hess(expr *node, double *out, const double *w) { double *x = node->left->value; for (int j = 0; j < node->size; j++) { double c = cos(x[j]); - out[j] = 2.0 * w[j] * tan(x[j]) / (c * c); + out[j] = 2.0 * w[j] * node->value[j] / (c * c); } } diff --git a/src/elementwise_univariate/xexp.c b/src/elementwise_univariate/xexp.c index 9bf7a7f..cb3fdb6 100644 --- a/src/elementwise_univariate/xexp.c +++ b/src/elementwise_univariate/xexp.c @@ -16,14 +16,14 @@ static void forward(expr *node, const double *u) static void local_jacobian(expr *node, double *vals) { - expr *child = node->left; + double *x = node->left->value; for (int j = 0; j < node->size; j++) { - vals[j] = (1.0 + child->value[j]) * exp(child->value[j]); + vals[j] = (1.0 + x[j]) * exp(x[j]); } } -static void local_wsum_hess(expr *node, double *out, double *w) +static void local_wsum_hess(expr *node, double *out, const double *w) { double *x = node->left->value; diff --git a/src/expr.c b/src/expr.c index 520ffb2..0815825 100644 --- a/src/expr.c +++ b/src/expr.c @@ -3,35 +3,28 @@ #include #include -expr *new_expr(int d1, int d2, int n_vars) +void init_expr(expr *node, int d1, int d2, int n_vars, forward_fn forward, + jacobian_init_fn jacobian_init, eval_jacobian_fn eval_jacobian, + is_affine_fn is_affine, free_type_data_fn free_type_data) { - expr *node = (expr *) calloc(1, sizeof(expr)); - if (!node) return NULL; - node->d1 = d1; node->d2 = d2; node->size = d1 * d2; node->n_vars = n_vars; node->refcount = 1; node->value = (double *) calloc(d1 * d2, sizeof(double)); - if (!node->value) - { - free(node); - return NULL; - } - - // node->left = NULL; - // node->right = NULL; - // node->forward = NULL; - // node->jacobian = NULL; - // node->jacobian_init = NULL; - // node->eval_jacobian = NULL; - // node->is_affine = NULL; - // node->dwork = NULL; - // node->iwork = NULL; - // node->CSR_work = NULL; - node->var_id = -1; + node->var_id = NOT_A_VARIABLE; + node->forward = forward; + node->jacobian_init = jacobian_init; + node->eval_jacobian = eval_jacobian; + node->is_affine = is_affine; + node->free_type_data = free_type_data; +} +expr *new_expr(int d1, int d2, int n_vars) +{ + expr *node = (expr *) calloc(1, sizeof(expr)); + init_expr(node, d1, d2, n_vars, NULL, NULL, NULL, NULL, NULL); return node; } @@ -45,26 +38,26 @@ void free_expr(expr *node) /* recursively free children */ free_expr(node->left); free_expr(node->right); - - /* free arguments */ - if (node->args) - { - for (int i = 0; i < node->n_args; i++) - { - free_expr(node->args[i]); - } - } + node->left = NULL; + node->right = NULL; /* free value array and jacobian */ free(node->value); free_csr_matrix(node->jacobian); free_csr_matrix(node->wsum_hess); - free_csr_matrix(node->CSR_work); - free_csr_matrix(node->A_csr); - free_csc_matrix(node->A_csc); free(node->dwork); free(node->iwork); - free_int_double_pair_array(node->int_double_pairs); + node->value = NULL; + node->jacobian = NULL; + node->wsum_hess = NULL; + node->dwork = NULL; + node->iwork = NULL; + + /* free type-specific data */ + if (node->free_type_data) + { + node->free_type_data(node); + } /* free the node itself */ free(node); @@ -76,7 +69,7 @@ void expr_retain(expr *node) node->refcount++; } -bool is_affine(expr *node) +bool is_affine(const expr *node) { bool left_affine = true; bool right_affine = true; diff --git a/src/other/quad_form.c b/src/other/quad_form.c index d069ae7..9eb74a3 100644 --- a/src/other/quad_form.c +++ b/src/other/quad_form.c @@ -1,4 +1,6 @@ #include "other.h" +#include "subexpr.h" +#include "utils/CSC_Matrix.h" #include #include #include @@ -7,12 +9,12 @@ static void forward(expr *node, const double *u) { expr *x = node->left; - /* children's forward passes */ + /* child's forward pass */ x->forward(x, u); /* local forward pass */ - csr_matvec(node->Q, x->value, node->dwork, 0); - + CSR_Matrix *Q = ((quad_form_expr *) node)->Q; + csr_matvec(Q, x->value, node->dwork, 0); node->value[0] = 0.0; for (int i = 0; i < x->d1; i++) @@ -27,7 +29,7 @@ static void jacobian_init(expr *node) node->dwork = (double *) malloc(x->d1 * sizeof(double)); /* if x is a variable */ - if (x->var_id != -1) + if (x->var_id != NOT_A_VARIABLE) { node->jacobian = new_csr_matrix(1, node->n_vars, x->d1); node->jacobian->p[0] = 0; @@ -41,8 +43,7 @@ static void jacobian_init(expr *node) else /* x is not a variable */ { /* compute required allocation and allocate jacobian */ - bool *col_nz = (bool *) calloc( - node->n_vars, sizeof(bool)); /* TODO: could use iwork here instead*/ + bool *col_nz = (bool *) calloc(node->n_vars, sizeof(bool)); int nonzero_cols = count_nonzero_cols(x->jacobian, col_nz); node->jacobian = new_csr_matrix(1, node->n_vars, nonzero_cols + 1); @@ -61,20 +62,16 @@ static void jacobian_init(expr *node) node->jacobian->p[0] = 0; node->jacobian->p[1] = node->jacobian->nnz; - - /* store A^T of child's A to simplify chain rule computation */ - node->iwork = (int *) malloc(x->jacobian->n * sizeof(int)); - node->CSR_work = transpose(x->jacobian, node->iwork); } } static void eval_jacobian(expr *node) { expr *x = node->left; - CSR_Matrix *Q = node->Q; + CSR_Matrix *Q = ((quad_form_expr *) node)->Q; /* if x is a variable */ - if (x->var_id != -1) + if (x->var_id != NOT_A_VARIABLE) { csr_matvec(Q, x->value, node->jacobian->x, 0); @@ -85,6 +82,8 @@ static void eval_jacobian(expr *node) } else /* x is not a variable */ { + linear_op_expr *lin_x = (linear_op_expr *) x; + /* local jacobian */ csr_matvec(Q, x->value, node->dwork, 0); @@ -93,19 +92,27 @@ static void eval_jacobian(expr *node) node->dwork[j] *= 2.0; } - /* chain rule */ - csr_matvec_fill_values(node->CSR_work, node->dwork, node->jacobian); + /* chain rule using CSC format */ + csc_matvec_fill_values(lin_x->A_csc, node->dwork, node->jacobian); } } expr *new_quad_form(expr *left, CSR_Matrix *Q) { - expr *node = new_expr(left->d1, 1, left->n_vars); + quad_form_expr *qnode = (quad_form_expr *) calloc(1, sizeof(quad_form_expr)); + expr *node = &qnode->base; + + /* Initialize base fields */ + assert(left->d2 == 1); + init_expr(node, left->d1, 1, left->n_vars, forward, jacobian_init, eval_jacobian, + NULL, NULL); + + /* Set left child */ node->left = left; expr_retain(left); - node->Q = Q; - node->forward = forward; - node->jacobian_init = jacobian_init; - node->eval_jacobian = eval_jacobian; + + /* Set type-specific field */ + qnode->Q = Q; + return node; } diff --git a/src/utils/COO_Matrix.c b/src/utils/COO_Matrix.c deleted file mode 100644 index 9162bb0..0000000 --- a/src/utils/COO_Matrix.c +++ /dev/null @@ -1,33 +0,0 @@ -#include "utils/COO_Matrix.h" -#include - -COO_Matrix *new_coo_matrix(int nnz) -{ - COO_Matrix *matrix = (COO_Matrix *) malloc(sizeof(COO_Matrix)); - if (!matrix) return NULL; - - matrix->nnz = nnz; - matrix->rows = (int *) calloc(nnz, sizeof(int)); - matrix->cols = (int *) calloc(nnz, sizeof(int)); - matrix->vals = (double *) calloc(nnz, sizeof(double)); - - if (!matrix->rows || !matrix->cols || !matrix->vals) - { - free(matrix->rows); - free(matrix->cols); - free(matrix->vals); - free(matrix); - return NULL; - } - - return matrix; -} - -void free_coo_matrix(COO_Matrix *matrix) -{ - if (!matrix) return; - free(matrix->rows); - free(matrix->cols); - free(matrix->vals); - free(matrix); -} diff --git a/src/utils/CSC_Matrix.c b/src/utils/CSC_Matrix.c index 9ef3b8d..fc94b27 100644 --- a/src/utils/CSC_Matrix.c +++ b/src/utils/CSC_Matrix.c @@ -196,3 +196,31 @@ CSC_Matrix *csr_to_csc(const CSR_Matrix *A) free(count); return C; } +void csc_matvec_fill_values(const CSC_Matrix *A, const double *z, CSR_Matrix *C) +{ + /* Compute C = z^T * A where A is in CSC format + * C is a single-row CSR matrix with column indices pre-computed + * This fills in the values of C only + */ + + for (int col = 0; col < A->n; col++) + { + double val = 0; + for (int j = A->p[col]; j < A->p[col + 1]; j++) + { + val += z[A->i[j]] * A->x[j]; + } + + if (A->p[col + 1] - A->p[col] == 0) continue; + + /* find position in C and fill value */ + for (int k = 0; k < C->nnz; k++) + { + if (C->i[k] == col) + { + C->x[k] = val; + break; + } + } + } +} diff --git a/tests/wsum_hess/test_power.h b/tests/wsum_hess/test_power.h index a8f82c5..f201e8c 100644 --- a/tests/wsum_hess/test_power.h +++ b/tests/wsum_hess/test_power.h @@ -15,7 +15,7 @@ const char *test_wsum_hess_power() double w[3] = {1.0, 2.0, 3.0}; expr *x = new_variable(3, 1, 0, 3); - expr *power_node = new_power(x, 3); + expr *power_node = new_power(x, 3.0); power_node->forward(power_node, u_vals); power_node->wsum_hess_init(power_node); power_node->eval_wsum_hess(power_node, w);