From 105d2eecf22ab5d63dec8a3007167d851a8cccc9 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 2 Jan 2026 09:56:48 -0800 Subject: [PATCH 01/10] started refactoring polymorphism --- README.md | 3 +- include/affine.h | 11 ++++ include/elementwise_univariate.h | 4 ++ include/expr.h | 75 +++++++++++++++------- include/other.h | 3 + src/affine/hstack.c | 95 +++++++++++++++++++++------ src/affine/linear_op.c | 62 +++++++++++++++--- src/affine/sum.c | 99 ++++++++++++++++++++++------- src/elementwise_univariate/common.c | 48 +++++++++++--- src/elementwise_univariate/power.c | 26 ++++++-- src/expr.c | 28 ++------ src/other/quad_form.c | 62 +++++++++++++++--- 12 files changed, 394 insertions(+), 122 deletions(-) diff --git a/README.md b/README.md index d6e10bc..8e8b174 100644 --- a/README.md +++ b/README.md @@ -3,4 +3,5 @@ 3. more tests for chain rule elementwise univariate hessian 4. in the refactor, add consts 5. multiply with one constant vector/scalar argument -6. AX where X is a matrix. Can that happen? How is that canonicalized? Maybe it can't happen. \ No newline at end of file +6. AX where X is a matrix. Can that happen? How is that canonicalized? Maybe it can't happen. +7. Must be able to compute jacobian and hessian of A @ phi(x), so linear operator needs other code? \ No newline at end of file diff --git a/include/affine.h b/include/affine.h index c2bcd74..20d4b46 100644 --- a/include/affine.h +++ b/include/affine.h @@ -4,9 +4,20 @@ #include "expr.h" #include "utils/CSR_Matrix.h" +/* Helper function to initialize a linear operator expr (can be used with derived + * types) */ +void init_linear_op(expr *node, expr *child, int d1, int d2); + expr *new_linear(expr *u, const CSR_Matrix *A); expr *new_add(expr *left, expr *right); + +/* Helper function to initialize a sum expr (can be used with derived types) */ +void init_sum(expr *node, expr *child, int d1); + +/* Helper function to initialize an hstack expr (can be used with derived types) */ +void init_hstack(expr *node, int d1, int d2, int n_vars); + expr *new_sum(expr *child, int axis); expr *new_hstack(expr **args, int n_args, int n_vars); diff --git a/include/elementwise_univariate.h b/include/elementwise_univariate.h index 52cd4a1..d190720 100644 --- a/include/elementwise_univariate.h +++ b/include/elementwise_univariate.h @@ -3,6 +3,10 @@ #include "expr.h" +/* Helper function to initialize an elementwise expr (can be used with derived types) + */ +void init_elementwise(expr *node, expr *child); + expr *new_exp(expr *child); expr *new_log(expr *child); expr *new_entr(expr *child); diff --git a/include/expr.h b/include/expr.h index 83ee83e..a20f4d4 100644 --- a/include/expr.h +++ b/include/expr.h @@ -21,55 +21,82 @@ typedef void (*wsum_hess_fn)(struct expr *node, double *w); typedef void (*local_jacobian_fn)(struct expr *node, double *out); typedef void (*local_wsum_hess_fn)(struct expr *node, double *out, double *w); typedef bool (*is_affine_fn)(struct expr *node); +typedef void (*free_type_data_fn)(struct expr *node); -/* TODO: implement proper polymorphism */ - -/* Expression node structure */ +/* Base expression node structure - contains only common fields */ typedef struct expr { // ------------------------------------------------------------------------ // general quantities // ------------------------------------------------------------------------ - int d1, d2, size; - int n_vars; - int var_id; - int refcount; + int d1, d2, size, n_vars, refcount, var_id; struct expr *left; struct expr *right; - struct expr **args; /* hstack can have multiple arguments */ - int n_args; double *dwork; int *iwork; - struct int_double_pair *int_double_pairs; /* for sorting jacobian entries */ - int p; /* power of power expression */ - int axis; /* axis for sum or similar operations */ - CSR_Matrix *Q; /* Q for quad_form */ // ------------------------------------------------------------------------ - // forward pass related quantities + // oracle related quantities // ------------------------------------------------------------------------ double *value; - forward_fn forward; - - // ------------------------------------------------------------------------ - // jacobian related quantities - // ------------------------------------------------------------------------ CSR_Matrix *jacobian; CSR_Matrix *wsum_hess; - CSR_Matrix *CSR_work; + forward_fn forward; jacobian_init_fn jacobian_init; wsum_hess_init_fn wsum_hess_init; eval_jacobian_fn eval_jacobian; wsum_hess_fn eval_wsum_hess; - local_jacobian_fn local_jacobian; - local_wsum_hess_fn local_wsum_hess; + + // ------------------------------------------------------------------------ + // other things + // ------------------------------------------------------------------------ + CSR_Matrix *CSR_work; is_affine_fn is_affine; + local_jacobian_fn local_jacobian; /* used by elementwise univariate atoms*/ + local_wsum_hess_fn local_wsum_hess; /* used by elementwise univariate atoms*/ + free_type_data_fn free_type_data; /* Cleanup for type-specific fields */ - // for every linear operator we store A in CSR and CSC +} expr; + +/* Type-specific expression structures that "inherit" from expr */ + +/* Linear operator: y = A * x */ +typedef struct linear_op_expr +{ + expr base; /* MUST be first member for casting to work */ CSC_Matrix *A_csc; CSR_Matrix *A_csr; +} linear_op_expr; -} expr; +/* Power: y = x^p */ +typedef struct power_expr +{ + expr base; /* MUST be first member for casting to work */ + int p; +} power_expr; + +/* Quadratic form: y = x'*Q*x */ +typedef struct quad_form_expr +{ + expr base; /* MUST be first member for casting to work */ + CSR_Matrix *Q; +} quad_form_expr; + +/* Sum reduction along an axis */ +typedef struct sum_expr +{ + expr base; /* MUST be first member for casting to work */ + int axis; + struct int_double_pair *int_double_pairs; /* for sorting jacobian entries */ +} sum_expr; + +/* Horizontal stack (concatenate) */ +typedef struct hstack_expr +{ + expr base; /* MUST be first member for casting to work */ + expr **args; + int n_args; +} hstack_expr; expr *new_expr(int d1, int d2, int n_vars); void free_expr(expr *node); diff --git a/include/other.h b/include/other.h index 3dbdf72..01ba411 100644 --- a/include/other.h +++ b/include/other.h @@ -4,6 +4,9 @@ #include "expr.h" #include "utils/CSR_Matrix.h" +/* Helper function to initialize a quad_form expr (can be used with derived types) */ +void init_quad_form(expr *node, expr *child); + expr *new_quad_form(expr *child, CSR_Matrix *Q); #endif /* OTHER_H */ diff --git a/src/affine/hstack.c b/src/affine/hstack.c index e2b8dd9..f1b269e 100644 --- a/src/affine/hstack.c +++ b/src/affine/hstack.c @@ -1,21 +1,23 @@ #include "affine.h" #include +#include #include static void forward(expr *node, const double *u) { + hstack_expr *hnode = (hstack_expr *) node; /* children's forward passes */ - for (int i = 0; i < node->n_args; i++) + for (int i = 0; i < hnode->n_args; i++) { - node->args[i]->forward(node->args[i], u); + hnode->args[i]->forward(hnode->args[i], u); } /* concatenate values horizontally */ int offset = 0; - for (int i = 0; i < node->n_args; i++) + for (int i = 0; i < hnode->n_args; i++) { - expr *child = node->args[i]; + expr *child = hnode->args[i]; memcpy(node->value + offset, child->value, child->size * sizeof(double)); offset += child->size; } @@ -23,12 +25,14 @@ static void forward(expr *node, const double *u) static void jacobian_init(expr *node) { + hstack_expr *hnode = (hstack_expr *) node; + /* initialize children's jacobians */ int nnz = 0; - for (int i = 0; i < node->n_args; i++) + for (int i = 0; i < hnode->n_args; i++) { - node->args[i]->jacobian_init(node->args[i]); - nnz += node->args[i]->jacobian->nnz; + hnode->args[i]->jacobian_init(hnode->args[i]); + nnz += hnode->args[i]->jacobian->nnz; } node->jacobian = new_csr_matrix(node->size, node->n_vars, nnz); @@ -36,14 +40,16 @@ static void jacobian_init(expr *node) static void eval_jacobian(expr *node) { + hstack_expr *hnode = (hstack_expr *) node; + /* evaluate children's jacobians */ int row_offset = 0; CSR_Matrix *A = node->jacobian; A->nnz = 0; - for (int i = 0; i < node->n_args; i++) + for (int i = 0; i < hnode->n_args; i++) { - expr *child = node->args[i]; + expr *child = hnode->args[i]; child->eval_jacobian(child); CSR_Matrix *B = child->jacobian; @@ -65,9 +71,11 @@ static void eval_jacobian(expr *node) static bool is_affine(expr *node) { - for (int i = 0; i < node->n_args; i++) + hstack_expr *hnode = (hstack_expr *) node; + + for (int i = 0; i < hnode->n_args; i++) { - if (!node->args[i]->is_affine(node->args[i])) + if (!hnode->args[i]->is_affine(hnode->args[i])) { return false; } @@ -75,6 +83,43 @@ static bool is_affine(expr *node) return true; } +static void free_type_data(expr *node) +{ + hstack_expr *hnode = (hstack_expr *) node; + for (int i = 0; i < hnode->n_args; i++) + { + free_expr(hnode->args[i]); + } +} + +/* Helper function to initialize an hstack expr */ +void init_hstack(expr *node, int d1, int d2, int n_vars) +{ + node->d1 = d1; + node->d2 = d2; + node->size = d1 * d2; + node->n_vars = n_vars; + node->var_id = -1; + node->refcount = 1; + node->left = NULL; + node->right = NULL; + node->dwork = NULL; + node->iwork = NULL; + node->value = (double *) calloc(node->size, sizeof(double)); + node->jacobian = NULL; + node->wsum_hess = NULL; + node->CSR_work = NULL; + node->jacobian_init = jacobian_init; + node->wsum_hess_init = NULL; + node->eval_jacobian = eval_jacobian; + node->eval_wsum_hess = NULL; + node->local_jacobian = NULL; + node->local_wsum_hess = NULL; + node->forward = forward; + node->is_affine = is_affine; + node->free_type_data = free_type_data; +} + expr *new_hstack(expr **args, int n_args, int n_vars) { /* compute second dimension */ @@ -84,20 +129,30 @@ expr *new_hstack(expr **args, int n_args, int n_vars) d2 += args[i]->d2; } - expr *node = new_expr(args[0]->d1, d2, n_vars); - if (!node) return NULL; - node->args = args; - node->n_args = n_args; + /* Allocate the type-specific struct */ + hstack_expr *hnode = (hstack_expr *) malloc(sizeof(hstack_expr)); + if (!hnode) return NULL; + + expr *node = &hnode->base; + + /* Initialize base hstack fields */ + init_hstack(node, args[0]->d1, d2, n_vars); + + /* Check if allocation succeeded */ + if (!node->value) + { + free(hnode); + return NULL; + } + + /* Set type-specific fields */ + hnode->args = args; + hnode->n_args = n_args; for (int i = 0; i < n_args; i++) { expr_retain(args[i]); } - node->forward = forward; - node->is_affine = is_affine; - node->jacobian_init = jacobian_init; - node->eval_jacobian = eval_jacobian; - return node; } diff --git a/src/affine/linear_op.c b/src/affine/linear_op.c index 2bc988a..a6f289c 100644 --- a/src/affine/linear_op.c +++ b/src/affine/linear_op.c @@ -16,24 +16,70 @@ static bool is_affine(expr *node) return node->left->is_affine(node->left); } -expr *new_linear(expr *u, const CSR_Matrix *A) +static void free_type_data(expr *node) { - expr *node = new_expr(A->m, 1, u->n_vars); - if (!node) return NULL; + linear_op_expr *lin_node = (linear_op_expr *) node; + free_csr_matrix(lin_node->A_csr); + free_csc_matrix(lin_node->A_csc); +} - node->left = u; - expr_retain(u); +/* Helper function to initialize a linear operator expr */ +void init_linear_op(expr *node, expr *child, int d1, int d2) +{ + node->d1 = d1; + node->d2 = d2; + node->size = d1 * d2; + node->n_vars = child->n_vars; + node->var_id = -1; + node->refcount = 1; + node->left = child; + node->right = NULL; + node->dwork = NULL; + node->iwork = NULL; + node->value = (double *) calloc(node->size, sizeof(double)); + node->jacobian = NULL; + node->wsum_hess = NULL; + node->CSR_work = NULL; + node->jacobian_init = NULL; + node->wsum_hess_init = NULL; + node->eval_jacobian = NULL; + node->eval_wsum_hess = NULL; + node->local_jacobian = NULL; + node->local_wsum_hess = NULL; node->forward = forward; node->is_affine = is_affine; + node->free_type_data = free_type_data; + + expr_retain(child); +} + +expr *new_linear(expr *u, const CSR_Matrix *A) +{ + /* Allocate the type-specific struct */ + linear_op_expr *lin_node = (linear_op_expr *) malloc(sizeof(linear_op_expr)); + if (!lin_node) return NULL; + + expr *node = &lin_node->base; + + /* Initialize base linear operator fields */ + init_linear_op(node, u, A->m, 1); + + /* Check if allocation succeeded */ + if (!node->value) + { + free(lin_node); + return NULL; + } /* allocate jacobian and copy A into it */ // TODO: this should eventually be removed node->jacobian = new_csr_matrix(A->m, A->n, A->nnz); copy_csr_matrix(A, node->jacobian); - node->A_csr = new_csr_matrix(A->m, A->n, A->nnz); - copy_csr_matrix(A, node->A_csr); - node->A_csc = csr_to_csc(A); + /* Initialize type-specific fields */ + lin_node->A_csr = new_csr_matrix(A->m, A->n, A->nnz); + copy_csr_matrix(A, lin_node->A_csr); + lin_node->A_csc = csr_to_csc(A); return node; } diff --git a/src/affine/sum.c b/src/affine/sum.c index a4c11a7..fe68336 100644 --- a/src/affine/sum.c +++ b/src/affine/sum.c @@ -10,11 +10,13 @@ static void forward(expr *node, const double *u) int i, j, end; double sum; expr *x = node->left; + sum_expr *snode = (sum_expr *) node; + int axis = snode->axis; /* child's forward pass */ x->forward(x, u); - if (node->axis == -1) + if (axis == -1) { sum = 0.0; end = x->d1 * x->d2; @@ -26,7 +28,7 @@ static void forward(expr *node, const double *u) } node->value[0] = sum; } - else if (node->axis == 0) + else if (axis == 0) { /* sum rows together */ for (j = 0; j < x->d2; j++) @@ -40,7 +42,7 @@ static void forward(expr *node, const double *u) node->value[j] = sum; } } - else if (node->axis == 1) + else if (axis == 1) { memset(node->value, 0, node->d1 * sizeof(double)); @@ -59,36 +61,40 @@ static void forward(expr *node, const double *u) static void jacobian_init(expr *node) { expr *x = node->left; + sum_expr *snode = (sum_expr *) node; + /* initialize child's jacobian */ x->jacobian_init(x); /* we never have to store more than the child's nnz */ node->jacobian = new_csr_matrix(node->d1, node->n_vars, x->jacobian->nnz); - node->int_double_pairs = new_int_double_pair_array(x->jacobian->nnz); + snode->int_double_pairs = new_int_double_pair_array(x->jacobian->nnz); } static void eval_jacobian(expr *node) { expr *x = node->left; + sum_expr *snode = (sum_expr *) node; + int axis = snode->axis; /* evaluate child's jacobian */ x->eval_jacobian(x); /* sum rows or columns of child's jacobian */ - if (node->axis == -1) + if (axis == -1) { - sum_all_rows_csr(x->jacobian, node->jacobian, node->int_double_pairs); + sum_all_rows_csr(x->jacobian, node->jacobian, snode->int_double_pairs); } - else if (node->axis == 0) + else if (axis == 0) { - sum_block_of_rows_csr(x->jacobian, node->jacobian, node->int_double_pairs, + sum_block_of_rows_csr(x->jacobian, node->jacobian, snode->int_double_pairs, x->d1); } - else if (node->axis == 1) + else if (axis == 1) { sum_evenly_spaced_rows_csr(x->jacobian, node->jacobian, - node->int_double_pairs, node->d1); + snode->int_double_pairs, node->d1); } } @@ -106,16 +112,18 @@ static void wsum_hess_init(expr *node) static void eval_wsum_hess(expr *node, double *w) { expr *x = node->left; + sum_expr *snode = (sum_expr *) node; + int axis = snode->axis; - if (node->axis == -1) + if (axis == -1) { scaled_ones(node->dwork, x->size, *w); } - else if (node->axis == 0) + else if (axis == 0) { repeat(node->dwork, w, x->d2, x->d1); } - else if (node->axis == 1) + else if (axis == 1) { tile(node->dwork, w, x->d1, x->d2); } @@ -131,6 +139,42 @@ static bool is_affine(expr *node) return node->left->is_affine(node->left); } +static void free_type_data(expr *node) +{ + sum_expr *snode = (sum_expr *) node; + free_int_double_pair_array(snode->int_double_pairs); +} + +/* Helper function to initialize a sum expr */ +void init_sum(expr *node, expr *child, int d1) +{ + node->d1 = d1; + node->d2 = 1; + node->size = d1 * 1; + node->n_vars = child->n_vars; + node->var_id = -1; + node->refcount = 1; + node->left = child; + node->right = NULL; + node->dwork = NULL; + node->iwork = NULL; + node->value = (double *) calloc(node->size, sizeof(double)); + node->jacobian = NULL; + node->wsum_hess = NULL; + node->CSR_work = NULL; + node->jacobian_init = jacobian_init; + node->wsum_hess_init = wsum_hess_init; + node->eval_jacobian = eval_jacobian; + node->eval_wsum_hess = eval_wsum_hess; + node->local_jacobian = NULL; + node->local_wsum_hess = NULL; + node->is_affine = is_affine; + node->forward = forward; + node->free_type_data = free_type_data; + + expr_retain(child); +} + expr *new_sum(expr *child, int axis) { int d1 = 0; @@ -151,18 +195,25 @@ expr *new_sum(expr *child, int axis) break; } - expr *node = new_expr(d1, 1, child->n_vars); - if (!node) return NULL; + /* Allocate the type-specific struct */ + sum_expr *snode = (sum_expr *) malloc(sizeof(sum_expr)); + if (!snode) return NULL; - node->left = child; - expr_retain(child); - node->forward = forward; - node->is_affine = is_affine; - node->jacobian_init = jacobian_init; - node->eval_jacobian = eval_jacobian; - node->wsum_hess_init = wsum_hess_init; - node->eval_wsum_hess = eval_wsum_hess; - node->axis = axis; + expr *node = &snode->base; + + /* Initialize base sum fields */ + init_sum(node, child, d1); + + /* Check if allocation succeeded */ + if (!node->value) + { + free(snode); + return NULL; + } + + /* Set type-specific fields */ + snode->axis = axis; + snode->int_double_pairs = NULL; return node; } diff --git a/src/elementwise_univariate/common.c b/src/elementwise_univariate/common.c index 8733226..31f1340 100644 --- a/src/elementwise_univariate/common.c +++ b/src/elementwise_univariate/common.c @@ -35,8 +35,10 @@ void eval_jacobian_elementwise(expr *node) } else { + /* Child must be a linear operator */ + linear_op_expr *lin_child = (linear_op_expr *) child; node->local_jacobian(node, node->dwork); - diag_csr_mult(node->dwork, child->A_csr, node->jacobian); + diag_csr_mult(node->dwork, lin_child->A_csr, node->jacobian); } } @@ -65,7 +67,8 @@ void wsum_hess_init_elementwise(expr *node) /* otherwise it will be a linear operator */ else { - node->wsum_hess = ATA_alloc(child->A_csc); + linear_op_expr *lin_child = (linear_op_expr *) child; + node->wsum_hess = ATA_alloc(lin_child->A_csc); } } @@ -79,8 +82,10 @@ void eval_wsum_hess_elementwise(expr *node, double *w) } else { + /* Child must be a linear operator */ + linear_op_expr *lin_child = (linear_op_expr *) child; node->local_wsum_hess(node, node->dwork, w); - ATDA_values(child->A_csc, node->dwork, node->wsum_hess); + ATDA_values(lin_child->A_csc, node->dwork, node->wsum_hess); } } @@ -90,15 +95,42 @@ bool is_affine_elementwise(expr *node) return false; } -expr *new_elementwise(expr *child) +/* Helper function to initialize an already-allocated expr for elementwise operations + */ +void init_elementwise(expr *node, expr *child) { - expr *node = new_expr(child->d1, child->d2, child->n_vars); + node->d1 = child->d1; + node->d2 = child->d2; + node->size = child->d1 * child->d2; + node->n_vars = child->n_vars; + node->var_id = -1; + node->refcount = 1; node->left = child; - expr_retain(child); - node->is_affine = is_affine_elementwise; + node->right = NULL; + node->dwork = NULL; + node->iwork = NULL; + node->value = (double *) calloc(node->size, sizeof(double)); + node->jacobian = NULL; + node->wsum_hess = NULL; + node->CSR_work = NULL; node->jacobian_init = jacobian_init_elementwise; - node->eval_jacobian = eval_jacobian_elementwise; node->wsum_hess_init = wsum_hess_init_elementwise; + node->eval_jacobian = eval_jacobian_elementwise; node->eval_wsum_hess = eval_wsum_hess_elementwise; + node->local_jacobian = NULL; + node->local_wsum_hess = NULL; + node->is_affine = is_affine_elementwise; + node->forward = NULL; + node->free_type_data = NULL; + + expr_retain(child); +} + +expr *new_elementwise(expr *child) +{ + expr *node = new_expr(child->d1, child->d2, child->n_vars); + if (!node) return NULL; + + init_elementwise(node, child); return node; } diff --git a/src/elementwise_univariate/power.c b/src/elementwise_univariate/power.c index b00d358..32be97d 100644 --- a/src/elementwise_univariate/power.c +++ b/src/elementwise_univariate/power.c @@ -1,5 +1,6 @@ #include "elementwise_univariate.h" #include +#include static void forward(expr *node, const double *u) { @@ -7,28 +8,30 @@ static void forward(expr *node, const double *u) /* child's forward pass */ node->left->forward(node->left, u); - double *x = node->left->value; - /* local forward pass */ + double *x = node->left->value; + double p = ((power_expr *) node)->p; for (int i = 0; i < node->size; i++) { - node->value[i] = pow(x[i], node->p); + node->value[i] = pow(x[i], p); } } static void local_jacobian(expr *node, double *vals) { double *x = node->left->value; + double p = ((power_expr *) node)->p; + for (int j = 0; j < node->size; j++) { - vals[j] = node->p * pow(x[j], node->p - 1); + vals[j] = p * pow(x[j], p - 1); } } static void local_wsum_hess(expr *node, double *out, double *w) { double *x = node->left->value; - double p = (double) node->p; + double p = ((power_expr *) node)->p; for (int j = 0; j < node->size; j++) { @@ -38,10 +41,19 @@ static void local_wsum_hess(expr *node, double *out, double *w) expr *new_power(expr *child, int p) { - expr *node = new_elementwise(child); - node->p = p; + /* Allocate the type-specific struct */ + power_expr *pnode = (power_expr *) malloc(sizeof(power_expr)); + expr *node = &pnode->base; + + /* Initialize base elementwise fields */ + init_elementwise(node, child); + node->forward = forward; node->local_jacobian = local_jacobian; node->local_wsum_hess = local_wsum_hess; + + /* Set type-specific field */ + pnode->p = p; + return node; } diff --git a/src/expr.c b/src/expr.c index 520ffb2..8906360 100644 --- a/src/expr.c +++ b/src/expr.c @@ -20,16 +20,6 @@ expr *new_expr(int d1, int d2, int n_vars) return NULL; } - // node->left = NULL; - // node->right = NULL; - // node->forward = NULL; - // node->jacobian = NULL; - // node->jacobian_init = NULL; - // node->eval_jacobian = NULL; - // node->is_affine = NULL; - // node->dwork = NULL; - // node->iwork = NULL; - // node->CSR_work = NULL; node->var_id = -1; return node; @@ -46,25 +36,19 @@ void free_expr(expr *node) free_expr(node->left); free_expr(node->right); - /* free arguments */ - if (node->args) - { - for (int i = 0; i < node->n_args; i++) - { - free_expr(node->args[i]); - } - } - /* free value array and jacobian */ free(node->value); free_csr_matrix(node->jacobian); free_csr_matrix(node->wsum_hess); free_csr_matrix(node->CSR_work); - free_csr_matrix(node->A_csr); - free_csc_matrix(node->A_csc); free(node->dwork); free(node->iwork); - free_int_double_pair_array(node->int_double_pairs); + + /* free type-specific data */ + if (node->free_type_data) + { + node->free_type_data(node); + } /* free the node itself */ free(node); diff --git a/src/other/quad_form.c b/src/other/quad_form.c index d069ae7..2647beb 100644 --- a/src/other/quad_form.c +++ b/src/other/quad_form.c @@ -3,6 +3,8 @@ #include #include +/* Note: Q is not freed here because it's owned by the caller */ + static void forward(expr *node, const double *u) { expr *x = node->left; @@ -11,7 +13,8 @@ static void forward(expr *node, const double *u) x->forward(x, u); /* local forward pass */ - csr_matvec(node->Q, x->value, node->dwork, 0); + quad_form_expr *qnode = (quad_form_expr *) node; + csr_matvec(qnode->Q, x->value, node->dwork, 0); node->value[0] = 0.0; @@ -71,7 +74,8 @@ static void jacobian_init(expr *node) static void eval_jacobian(expr *node) { expr *x = node->left; - CSR_Matrix *Q = node->Q; + quad_form_expr *qnode = (quad_form_expr *) node; + CSR_Matrix *Q = qnode->Q; /* if x is a variable */ if (x->var_id != -1) @@ -98,14 +102,56 @@ static void eval_jacobian(expr *node) } } -expr *new_quad_form(expr *left, CSR_Matrix *Q) +/* Helper function to initialize a quad_form expr */ +void init_quad_form(expr *node, expr *child) { - expr *node = new_expr(left->d1, 1, left->n_vars); - node->left = left; - expr_retain(left); - node->Q = Q; - node->forward = forward; + node->d1 = child->d1; + node->d2 = 1; + node->size = child->d1 * 1; + node->n_vars = child->n_vars; + node->var_id = -1; + node->refcount = 1; + node->left = child; + node->right = NULL; + node->dwork = NULL; + node->iwork = NULL; + node->value = (double *) calloc(node->size, sizeof(double)); + node->jacobian = NULL; + node->wsum_hess = NULL; + node->CSR_work = NULL; node->jacobian_init = jacobian_init; + node->wsum_hess_init = NULL; node->eval_jacobian = eval_jacobian; + node->eval_wsum_hess = NULL; + node->local_jacobian = NULL; + node->local_wsum_hess = NULL; + node->is_affine = NULL; + node->forward = forward; + node->free_type_data = NULL; /* Q is owned by caller */ + + expr_retain(child); +} + +expr *new_quad_form(expr *left, CSR_Matrix *Q) +{ + /* Allocate the type-specific struct */ + quad_form_expr *qnode = (quad_form_expr *) malloc(sizeof(quad_form_expr)); + if (!qnode) return NULL; + + expr *node = &qnode->base; + + /* Initialize base quad_form fields */ + init_quad_form(node, left); + + /* Check if allocation succeeded */ + if (!node->value) + { + free(qnode); + return NULL; + } + + /* Set type-specific field */ + qnode->Q = Q; + return node; } From 43f51b86bddc02420bda558989ace39aa3e950c5 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 2 Jan 2026 10:04:46 -0800 Subject: [PATCH 02/10] started refactoring polymorphism --- src/elementwise_univariate/common.c | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/elementwise_univariate/common.c b/src/elementwise_univariate/common.c index 31f1340..3026d69 100644 --- a/src/elementwise_univariate/common.c +++ b/src/elementwise_univariate/common.c @@ -96,20 +96,26 @@ bool is_affine_elementwise(expr *node) } /* Helper function to initialize an already-allocated expr for elementwise operations + * This is called when a power_expr or other type-specific struct is allocated + * and we need to initialize the base expr fields */ void init_elementwise(expr *node, expr *child) { + /* Initialize dimensions from child */ node->d1 = child->d1; node->d2 = child->d2; node->size = child->d1 * child->d2; node->n_vars = child->n_vars; node->var_id = -1; node->refcount = 1; + + /* Allocate the value array */ + node->value = (double *) calloc(node->size, sizeof(double)); + node->left = child; node->right = NULL; node->dwork = NULL; node->iwork = NULL; - node->value = (double *) calloc(node->size, sizeof(double)); node->jacobian = NULL; node->wsum_hess = NULL; node->CSR_work = NULL; @@ -128,7 +134,7 @@ void init_elementwise(expr *node, expr *child) expr *new_elementwise(expr *child) { - expr *node = new_expr(child->d1, child->d2, child->n_vars); + expr *node = (expr *) malloc(sizeof(expr)); if (!node) return NULL; init_elementwise(node, child); From af8da7dcb4826d751bf37fe111130863a498c455 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 2 Jan 2026 11:18:34 -0800 Subject: [PATCH 03/10] removed CSR_work from expr --- include/affine.h | 1 + include/elementwise_univariate.h | 1 + include/expr.h | 46 +------------------------- include/other.h | 1 + include/subexpr.h | 51 +++++++++++++++++++++++++++++ include/utils/CSC_Matrix.h | 5 +++ src/affine/hstack.c | 1 - src/affine/linear_op.c | 1 - src/affine/sum.c | 1 - src/bivariate/quad_over_lin.c | 18 +++++----- src/elementwise_univariate/common.c | 1 - src/expr.c | 1 - src/other/quad_form.c | 36 +++++++------------- src/utils/CSC_Matrix.c | 28 ++++++++++++++++ 14 files changed, 110 insertions(+), 82 deletions(-) create mode 100644 include/subexpr.h diff --git a/include/affine.h b/include/affine.h index 20d4b46..30c1da6 100644 --- a/include/affine.h +++ b/include/affine.h @@ -2,6 +2,7 @@ #define AFFINE_H #include "expr.h" +#include "subexpr.h" #include "utils/CSR_Matrix.h" /* Helper function to initialize a linear operator expr (can be used with derived diff --git a/include/elementwise_univariate.h b/include/elementwise_univariate.h index d190720..82eb371 100644 --- a/include/elementwise_univariate.h +++ b/include/elementwise_univariate.h @@ -2,6 +2,7 @@ #define ELEMENTWISE_H #include "expr.h" +#include "subexpr.h" /* Helper function to initialize an elementwise expr (can be used with derived types) */ diff --git a/include/expr.h b/include/expr.h index a20f4d4..44dc474 100644 --- a/include/expr.h +++ b/include/expr.h @@ -8,11 +8,8 @@ #define JAC_IDXS_NOT_SET -1 -/* Forward declarations */ -struct expr; -struct int_double_pair; - /* Function pointer types */ +struct expr; typedef void (*forward_fn)(struct expr *node, const double *u); typedef void (*jacobian_init_fn)(struct expr *node); typedef void (*wsum_hess_init_fn)(struct expr *node); @@ -50,7 +47,6 @@ typedef struct expr // ------------------------------------------------------------------------ // other things // ------------------------------------------------------------------------ - CSR_Matrix *CSR_work; is_affine_fn is_affine; local_jacobian_fn local_jacobian; /* used by elementwise univariate atoms*/ local_wsum_hess_fn local_wsum_hess; /* used by elementwise univariate atoms*/ @@ -58,46 +54,6 @@ typedef struct expr } expr; -/* Type-specific expression structures that "inherit" from expr */ - -/* Linear operator: y = A * x */ -typedef struct linear_op_expr -{ - expr base; /* MUST be first member for casting to work */ - CSC_Matrix *A_csc; - CSR_Matrix *A_csr; -} linear_op_expr; - -/* Power: y = x^p */ -typedef struct power_expr -{ - expr base; /* MUST be first member for casting to work */ - int p; -} power_expr; - -/* Quadratic form: y = x'*Q*x */ -typedef struct quad_form_expr -{ - expr base; /* MUST be first member for casting to work */ - CSR_Matrix *Q; -} quad_form_expr; - -/* Sum reduction along an axis */ -typedef struct sum_expr -{ - expr base; /* MUST be first member for casting to work */ - int axis; - struct int_double_pair *int_double_pairs; /* for sorting jacobian entries */ -} sum_expr; - -/* Horizontal stack (concatenate) */ -typedef struct hstack_expr -{ - expr base; /* MUST be first member for casting to work */ - expr **args; - int n_args; -} hstack_expr; - expr *new_expr(int d1, int d2, int n_vars); void free_expr(expr *node); diff --git a/include/other.h b/include/other.h index 01ba411..1666c92 100644 --- a/include/other.h +++ b/include/other.h @@ -2,6 +2,7 @@ #define OTHER_H #include "expr.h" +#include "subexpr.h" #include "utils/CSR_Matrix.h" /* Helper function to initialize a quad_form expr (can be used with derived types) */ diff --git a/include/subexpr.h b/include/subexpr.h new file mode 100644 index 0000000..b1f37ab --- /dev/null +++ b/include/subexpr.h @@ -0,0 +1,51 @@ +#ifndef SUBEXPR_H +#define SUBEXPR_H + +#include "expr.h" +#include "utils/CSC_Matrix.h" +#include "utils/CSR_Matrix.h" + +/* Forward declaration */ +struct int_double_pair; + +/* Type-specific expression structures that "inherit" from expr */ + +/* Linear operator: y = A * x */ +typedef struct linear_op_expr +{ + expr base; + CSC_Matrix *A_csc; + CSR_Matrix *A_csr; +} linear_op_expr; + +/* Power: y = x^p */ +typedef struct power_expr +{ + expr base; + int p; +} power_expr; + +/* Quadratic form: y = x'*Q*x */ +typedef struct quad_form_expr +{ + expr base; + CSR_Matrix *Q; +} quad_form_expr; + +/* Sum reduction along an axis */ +typedef struct sum_expr +{ + expr base; + int axis; + struct int_double_pair *int_double_pairs; /* for sorting jacobian entries */ +} sum_expr; + +/* Horizontal stack (concatenate) */ +typedef struct hstack_expr +{ + expr base; + expr **args; + int n_args; +} hstack_expr; + +#endif /* SUBEXPR_H */ diff --git a/include/utils/CSC_Matrix.h b/include/utils/CSC_Matrix.h index 6636cea..b3eb358 100644 --- a/include/utils/CSC_Matrix.h +++ b/include/utils/CSC_Matrix.h @@ -39,4 +39,9 @@ CSR_Matrix *ATA_alloc(const CSC_Matrix *A); */ void ATDA_values(const CSC_Matrix *A, const double *d, CSR_Matrix *C); +/* C = z^T A where A is in CSC format and C is assumed to have one row. + * C must have column indices pre-computed. Fills in values of C only. + */ +void csc_matvec_fill_values(const CSC_Matrix *A, const double *z, CSR_Matrix *C); + #endif /* CSC_MATRIX_H */ diff --git a/src/affine/hstack.c b/src/affine/hstack.c index f1b269e..066e294 100644 --- a/src/affine/hstack.c +++ b/src/affine/hstack.c @@ -108,7 +108,6 @@ void init_hstack(expr *node, int d1, int d2, int n_vars) node->value = (double *) calloc(node->size, sizeof(double)); node->jacobian = NULL; node->wsum_hess = NULL; - node->CSR_work = NULL; node->jacobian_init = jacobian_init; node->wsum_hess_init = NULL; node->eval_jacobian = eval_jacobian; diff --git a/src/affine/linear_op.c b/src/affine/linear_op.c index a6f289c..40a6d34 100644 --- a/src/affine/linear_op.c +++ b/src/affine/linear_op.c @@ -39,7 +39,6 @@ void init_linear_op(expr *node, expr *child, int d1, int d2) node->value = (double *) calloc(node->size, sizeof(double)); node->jacobian = NULL; node->wsum_hess = NULL; - node->CSR_work = NULL; node->jacobian_init = NULL; node->wsum_hess_init = NULL; node->eval_jacobian = NULL; diff --git a/src/affine/sum.c b/src/affine/sum.c index fe68336..7fddf68 100644 --- a/src/affine/sum.c +++ b/src/affine/sum.c @@ -161,7 +161,6 @@ void init_sum(expr *node, expr *child, int d1) node->value = (double *) calloc(node->size, sizeof(double)); node->jacobian = NULL; node->wsum_hess = NULL; - node->CSR_work = NULL; node->jacobian_init = jacobian_init; node->wsum_hess_init = wsum_hess_init; node->eval_jacobian = eval_jacobian; diff --git a/src/bivariate/quad_over_lin.c b/src/bivariate/quad_over_lin.c index ac31dfb..4149ade 100644 --- a/src/bivariate/quad_over_lin.c +++ b/src/bivariate/quad_over_lin.c @@ -1,4 +1,6 @@ #include "bivariate.h" +#include "subexpr.h" +#include "utils/CSC_Matrix.h" #include #include #include @@ -58,14 +60,15 @@ static void jacobian_init(expr *node) } } } - else /* left node is not a variable */ + else /* left node is not a variable (guaranteed to be a linear operator) */ { + linear_op_expr *lin_x = (linear_op_expr *) x; node->dwork = (double *) malloc(x->d1 * sizeof(double)); /* compute required allocation and allocate jacobian */ bool *col_nz = (bool *) calloc( node->n_vars, sizeof(bool)); /* TODO: could use iwork here instead*/ - int nonzero_cols = count_nonzero_cols(x->jacobian, col_nz); + int nonzero_cols = count_nonzero_cols(lin_x->base.jacobian, col_nz); node->jacobian = new_csr_matrix(1, node->n_vars, nonzero_cols + 1); /* precompute column indices */ @@ -88,11 +91,8 @@ static void jacobian_init(expr *node) node->jacobian->p[0] = 0; node->jacobian->p[1] = node->jacobian->nnz; - /* store A^T of child's A to simplify chain rule computation */ - node->iwork = (int *) malloc(x->jacobian->n * sizeof(int)); - node->CSR_work = transpose(x->jacobian, node->iwork); - /* find position where y should be inserted */ + node->iwork = (int *) malloc(sizeof(int)); for (int j = 0; j < node->jacobian->nnz; j++) { if (node->jacobian->i[j] == y->var_id) @@ -132,14 +132,16 @@ static void eval_jacobian(expr *node) } else /* x is not a variable */ { + CSC_Matrix *A_csc = ((linear_op_expr *) x)->A_csc; + /* local jacobian */ for (int j = 0; j < x->d1; j++) { node->dwork[j] = (2.0 * x->value[j]) / y->value[0]; } - /* chain rule (no derivative wrt y) */ - csr_matvec_fill_values(node->CSR_work, node->dwork, node->jacobian); + /* chain rule (no derivative wrt y) using CSC format */ + csc_matvec_fill_values(A_csc, node->dwork, node->jacobian); /* insert derivative wrt y at right place (for correctness this assumes that y does not appear in the denominator, but this will always be diff --git a/src/elementwise_univariate/common.c b/src/elementwise_univariate/common.c index 3026d69..43c65dc 100644 --- a/src/elementwise_univariate/common.c +++ b/src/elementwise_univariate/common.c @@ -118,7 +118,6 @@ void init_elementwise(expr *node, expr *child) node->iwork = NULL; node->jacobian = NULL; node->wsum_hess = NULL; - node->CSR_work = NULL; node->jacobian_init = jacobian_init_elementwise; node->wsum_hess_init = wsum_hess_init_elementwise; node->eval_jacobian = eval_jacobian_elementwise; diff --git a/src/expr.c b/src/expr.c index 8906360..90884e8 100644 --- a/src/expr.c +++ b/src/expr.c @@ -40,7 +40,6 @@ void free_expr(expr *node) free(node->value); free_csr_matrix(node->jacobian); free_csr_matrix(node->wsum_hess); - free_csr_matrix(node->CSR_work); free(node->dwork); free(node->iwork); diff --git a/src/other/quad_form.c b/src/other/quad_form.c index 2647beb..b32c5e4 100644 --- a/src/other/quad_form.c +++ b/src/other/quad_form.c @@ -1,4 +1,6 @@ #include "other.h" +#include "subexpr.h" +#include "utils/CSC_Matrix.h" #include #include #include @@ -13,9 +15,8 @@ static void forward(expr *node, const double *u) x->forward(x, u); /* local forward pass */ - quad_form_expr *qnode = (quad_form_expr *) node; - csr_matvec(qnode->Q, x->value, node->dwork, 0); - + CSR_Matrix *Q = ((quad_form_expr *) node)->Q; + csr_matvec(Q, x->value, node->dwork, 0); node->value[0] = 0.0; for (int i = 0; i < x->d1; i++) @@ -65,17 +66,15 @@ static void jacobian_init(expr *node) node->jacobian->p[0] = 0; node->jacobian->p[1] = node->jacobian->nnz; - /* store A^T of child's A to simplify chain rule computation */ - node->iwork = (int *) malloc(x->jacobian->n * sizeof(int)); - node->CSR_work = transpose(x->jacobian, node->iwork); + /* Cast x to linear operator to use its A_csc in eval_jacobian */ + node->iwork = (int *) malloc(sizeof(int)); } } static void eval_jacobian(expr *node) { expr *x = node->left; - quad_form_expr *qnode = (quad_form_expr *) node; - CSR_Matrix *Q = qnode->Q; + CSR_Matrix *Q = ((quad_form_expr *) node)->Q; /* if x is a variable */ if (x->var_id != -1) @@ -89,6 +88,8 @@ static void eval_jacobian(expr *node) } else /* x is not a variable */ { + linear_op_expr *lin_x = (linear_op_expr *) x; + /* local jacobian */ csr_matvec(Q, x->value, node->dwork, 0); @@ -97,8 +98,8 @@ static void eval_jacobian(expr *node) node->dwork[j] *= 2.0; } - /* chain rule */ - csr_matvec_fill_values(node->CSR_work, node->dwork, node->jacobian); + /* chain rule using CSC format */ + csc_matvec_fill_values(lin_x->A_csc, node->dwork, node->jacobian); } } @@ -118,7 +119,6 @@ void init_quad_form(expr *node, expr *child) node->value = (double *) calloc(node->size, sizeof(double)); node->jacobian = NULL; node->wsum_hess = NULL; - node->CSR_work = NULL; node->jacobian_init = jacobian_init; node->wsum_hess_init = NULL; node->eval_jacobian = eval_jacobian; @@ -127,29 +127,17 @@ void init_quad_form(expr *node, expr *child) node->local_wsum_hess = NULL; node->is_affine = NULL; node->forward = forward; - node->free_type_data = NULL; /* Q is owned by caller */ + node->free_type_data = NULL; expr_retain(child); } expr *new_quad_form(expr *left, CSR_Matrix *Q) { - /* Allocate the type-specific struct */ quad_form_expr *qnode = (quad_form_expr *) malloc(sizeof(quad_form_expr)); - if (!qnode) return NULL; - expr *node = &qnode->base; - - /* Initialize base quad_form fields */ init_quad_form(node, left); - /* Check if allocation succeeded */ - if (!node->value) - { - free(qnode); - return NULL; - } - /* Set type-specific field */ qnode->Q = Q; diff --git a/src/utils/CSC_Matrix.c b/src/utils/CSC_Matrix.c index 9ef3b8d..fc94b27 100644 --- a/src/utils/CSC_Matrix.c +++ b/src/utils/CSC_Matrix.c @@ -196,3 +196,31 @@ CSC_Matrix *csr_to_csc(const CSR_Matrix *A) free(count); return C; } +void csc_matvec_fill_values(const CSC_Matrix *A, const double *z, CSR_Matrix *C) +{ + /* Compute C = z^T * A where A is in CSC format + * C is a single-row CSR matrix with column indices pre-computed + * This fills in the values of C only + */ + + for (int col = 0; col < A->n; col++) + { + double val = 0; + for (int j = A->p[col]; j < A->p[col + 1]; j++) + { + val += z[A->i[j]] * A->x[j]; + } + + if (A->p[col + 1] - A->p[col] == 0) continue; + + /* find position in C and fill value */ + for (int k = 0; k < C->nnz; k++) + { + if (C->i[k] == col) + { + C->x[k] = val; + break; + } + } + } +} From 5527b177cc5bd07c30ee3fa0668a45050212eaca Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 2 Jan 2026 22:52:34 -0800 Subject: [PATCH 04/10] refactor hstack to precompute sparsity pattern + new init_expr to reuse functioanlity --- include/affine.h | 3 - include/elementwise_univariate.h | 5 +- include/expr.h | 8 ++- src/affine/add.c | 10 +--- src/affine/constant.c | 8 +-- src/affine/hstack.c | 78 ++++++++++--------------- src/affine/linear_op.c | 2 +- src/affine/sum.c | 4 +- src/affine/variable.c | 2 +- src/elementwise_univariate/common.c | 6 +- src/elementwise_univariate/entr.c | 2 +- src/elementwise_univariate/exp.c | 2 +- src/elementwise_univariate/hyperbolic.c | 8 +-- src/elementwise_univariate/log.c | 2 +- src/elementwise_univariate/logistic.c | 2 +- src/elementwise_univariate/power.c | 3 +- src/elementwise_univariate/trig.c | 6 +- src/elementwise_univariate/xexp.c | 2 +- src/expr.c | 18 +++--- 19 files changed, 70 insertions(+), 101 deletions(-) diff --git a/include/affine.h b/include/affine.h index 30c1da6..321d2fd 100644 --- a/include/affine.h +++ b/include/affine.h @@ -16,9 +16,6 @@ expr *new_add(expr *left, expr *right); /* Helper function to initialize a sum expr (can be used with derived types) */ void init_sum(expr *node, expr *child, int d1); -/* Helper function to initialize an hstack expr (can be used with derived types) */ -void init_hstack(expr *node, int d1, int d2, int n_vars); - expr *new_sum(expr *child, int axis); expr *new_hstack(expr **args, int n_args, int n_vars); diff --git a/include/elementwise_univariate.h b/include/elementwise_univariate.h index 82eb371..4e22892 100644 --- a/include/elementwise_univariate.h +++ b/include/elementwise_univariate.h @@ -2,7 +2,6 @@ #define ELEMENTWISE_H #include "expr.h" -#include "subexpr.h" /* Helper function to initialize an elementwise expr (can be used with derived types) */ @@ -27,11 +26,11 @@ expr *new_xexp(expr *child); void jacobian_init_elementwise(expr *node); void eval_jacobian_elementwise(expr *node); void wsum_hess_init_elementwise(expr *node); -void eval_wsum_hess_elementwise(expr *node, double *w); +void eval_wsum_hess_elementwise(expr *node, const double *w); expr *new_elementwise(expr *child); /* no elementwise atoms are affine according to our convention, so we can have a common implementation */ -bool is_affine_elementwise(expr *node); +bool is_affine_elementwise(const expr *node); #endif /* ELEMENTWISE_UNIVARIATE_H */ diff --git a/include/expr.h b/include/expr.h index 44dc474..6755ccb 100644 --- a/include/expr.h +++ b/include/expr.h @@ -14,10 +14,10 @@ typedef void (*forward_fn)(struct expr *node, const double *u); typedef void (*jacobian_init_fn)(struct expr *node); typedef void (*wsum_hess_init_fn)(struct expr *node); typedef void (*eval_jacobian_fn)(struct expr *node); -typedef void (*wsum_hess_fn)(struct expr *node, double *w); +typedef void (*wsum_hess_fn)(struct expr *node, const double *w); typedef void (*local_jacobian_fn)(struct expr *node, double *out); -typedef void (*local_wsum_hess_fn)(struct expr *node, double *out, double *w); -typedef bool (*is_affine_fn)(struct expr *node); +typedef void (*local_wsum_hess_fn)(struct expr *node, double *out, const double *w); +typedef bool (*is_affine_fn)(const struct expr *node); typedef void (*free_type_data_fn)(struct expr *node); /* Base expression node structure - contains only common fields */ @@ -54,6 +54,8 @@ typedef struct expr } expr; +void init_expr(expr *node, int d1, int d2, int n_vars); + expr *new_expr(int d1, int d2, int n_vars); void free_expr(expr *node); diff --git a/src/affine/add.c b/src/affine/add.c index 8133162..9f6b2d1 100644 --- a/src/affine/add.c +++ b/src/affine/add.c @@ -45,7 +45,7 @@ static void wsum_hess_init(expr *node) node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, nnz_max); } -static void eval_wsum_hess(expr *node, double *w) +static void eval_wsum_hess(expr *node, const double *w) { /* evaluate children's wsum_hess */ node->left->eval_wsum_hess(node->left, w); @@ -55,20 +55,14 @@ static void eval_wsum_hess(expr *node, double *w) sum_csr_matrices(node->left->wsum_hess, node->right->wsum_hess, node->wsum_hess); } -static bool is_affine(expr *node) +static bool is_affine(const expr *node) { return node->left->is_affine(node->left) && node->right->is_affine(node->right); } expr *new_add(expr *left, expr *right) { - if (!left || !right) return NULL; - if (left->d1 != right->d1) return NULL; - if (left->d2 != right->d2) return NULL; - expr *node = new_expr(left->d1, left->d2, left->n_vars); - if (!node) return NULL; - node->left = left; node->right = right; expr_retain(left); diff --git a/src/affine/constant.c b/src/affine/constant.c index f328137..d427ec9 100644 --- a/src/affine/constant.c +++ b/src/affine/constant.c @@ -8,20 +8,16 @@ static void forward(expr *node, const double *u) (void) u; } -static bool is_affine(expr *node) +static bool is_affine(const expr *node) { (void) node; - return true; /* constant is affine */ + return true; } expr *new_constant(int d1, int d2, int n_vars, const double *values) { expr *node = new_expr(d1, d2, n_vars); - if (!node) return NULL; - - /* Copy constant values */ memcpy(node->value, values, d1 * d2 * sizeof(double)); - node->forward = forward; node->is_affine = is_affine; diff --git a/src/affine/hstack.c b/src/affine/hstack.c index 066e294..457ce3f 100644 --- a/src/affine/hstack.c +++ b/src/affine/hstack.c @@ -36,13 +36,8 @@ static void jacobian_init(expr *node) } node->jacobian = new_csr_matrix(node->size, node->n_vars, nnz); -} - -static void eval_jacobian(expr *node) -{ - hstack_expr *hnode = (hstack_expr *) node; - /* evaluate children's jacobians */ + /* precompute sparsity pattern of this jacobian's node */ int row_offset = 0; CSR_Matrix *A = node->jacobian; A->nnz = 0; @@ -50,11 +45,9 @@ static void eval_jacobian(expr *node) for (int i = 0; i < hnode->n_args; i++) { expr *child = hnode->args[i]; - child->eval_jacobian(child); CSR_Matrix *B = child->jacobian; - /* copy columns and values */ - memcpy(A->x + A->nnz, B->x, B->nnz * sizeof(double)); + /* copy columns */ memcpy(A->i + A->nnz, B->i, B->nnz * sizeof(int)); /* set row pointers */ @@ -69,9 +62,27 @@ static void eval_jacobian(expr *node) A->p[node->size] = A->nnz; } -static bool is_affine(expr *node) +static void eval_jacobian(expr *node) { hstack_expr *hnode = (hstack_expr *) node; + CSR_Matrix *A = node->jacobian; + A->nnz = 0; + + for (int i = 0; i < hnode->n_args; i++) + { + expr *child = hnode->args[i]; + child->eval_jacobian(child); + + /* copy values */ + memcpy(A->x + A->nnz, child->jacobian->x, + child->jacobian->nnz * sizeof(double)); + A->nnz += child->jacobian->nnz; + } +} + +static bool is_affine(const expr *node) +{ + const hstack_expr *hnode = (const hstack_expr *) node; for (int i = 0; i < hnode->n_args; i++) { @@ -92,33 +103,6 @@ static void free_type_data(expr *node) } } -/* Helper function to initialize an hstack expr */ -void init_hstack(expr *node, int d1, int d2, int n_vars) -{ - node->d1 = d1; - node->d2 = d2; - node->size = d1 * d2; - node->n_vars = n_vars; - node->var_id = -1; - node->refcount = 1; - node->left = NULL; - node->right = NULL; - node->dwork = NULL; - node->iwork = NULL; - node->value = (double *) calloc(node->size, sizeof(double)); - node->jacobian = NULL; - node->wsum_hess = NULL; - node->jacobian_init = jacobian_init; - node->wsum_hess_init = NULL; - node->eval_jacobian = eval_jacobian; - node->eval_wsum_hess = NULL; - node->local_jacobian = NULL; - node->local_wsum_hess = NULL; - node->forward = forward; - node->is_affine = is_affine; - node->free_type_data = free_type_data; -} - expr *new_hstack(expr **args, int n_args, int n_vars) { /* compute second dimension */ @@ -129,20 +113,18 @@ expr *new_hstack(expr **args, int n_args, int n_vars) } /* Allocate the type-specific struct */ - hstack_expr *hnode = (hstack_expr *) malloc(sizeof(hstack_expr)); - if (!hnode) return NULL; - + hstack_expr *hnode = (hstack_expr *) calloc(1, sizeof(hstack_expr)); expr *node = &hnode->base; - /* Initialize base hstack fields */ - init_hstack(node, args[0]->d1, d2, n_vars); + /* Initialize basic fields */ + init_expr(node, args[0]->d1, d2, n_vars); - /* Check if allocation succeeded */ - if (!node->value) - { - free(hnode); - return NULL; - } + /* Set function pointers */ + node->forward = forward; + node->jacobian_init = jacobian_init; + node->eval_jacobian = eval_jacobian; + node->is_affine = is_affine; + node->free_type_data = free_type_data; /* Set type-specific fields */ hnode->args = args; diff --git a/src/affine/linear_op.c b/src/affine/linear_op.c index 40a6d34..b55fec1 100644 --- a/src/affine/linear_op.c +++ b/src/affine/linear_op.c @@ -11,7 +11,7 @@ static void forward(expr *node, const double *u) csr_matvec(node->jacobian, x->value, node->value, x->var_id); } -static bool is_affine(expr *node) +static bool is_affine(const expr *node) { return node->left->is_affine(node->left); } diff --git a/src/affine/sum.c b/src/affine/sum.c index 7fddf68..9480ad6 100644 --- a/src/affine/sum.c +++ b/src/affine/sum.c @@ -109,7 +109,7 @@ static void wsum_hess_init(expr *node) node->dwork = malloc(x->size * sizeof(double)); } -static void eval_wsum_hess(expr *node, double *w) +static void eval_wsum_hess(expr *node, const double *w) { expr *x = node->left; sum_expr *snode = (sum_expr *) node; @@ -134,7 +134,7 @@ static void eval_wsum_hess(expr *node, double *w) copy_csr_matrix(x->wsum_hess, node->wsum_hess); } -static bool is_affine(expr *node) +static bool is_affine(const expr *node) { return node->left->is_affine(node->left); } diff --git a/src/affine/variable.c b/src/affine/variable.c index 122db20..5fca186 100644 --- a/src/affine/variable.c +++ b/src/affine/variable.c @@ -20,7 +20,7 @@ static void jacobian_init(expr *node) node->jacobian->p[size] = size; } -static bool is_affine(expr *node) +static bool is_affine(const expr *node) { (void) node; return true; diff --git a/src/elementwise_univariate/common.c b/src/elementwise_univariate/common.c index 43c65dc..1befdda 100644 --- a/src/elementwise_univariate/common.c +++ b/src/elementwise_univariate/common.c @@ -1,4 +1,6 @@ #include "elementwise_univariate.h" +#include "expr.h" +#include "subexpr.h" #include void jacobian_init_elementwise(expr *node) @@ -72,7 +74,7 @@ void wsum_hess_init_elementwise(expr *node) } } -void eval_wsum_hess_elementwise(expr *node, double *w) +void eval_wsum_hess_elementwise(expr *node, const double *w) { expr *child = node->left; @@ -89,7 +91,7 @@ void eval_wsum_hess_elementwise(expr *node, double *w) } } -bool is_affine_elementwise(expr *node) +bool is_affine_elementwise(const expr *node) { (void) node; return false; diff --git a/src/elementwise_univariate/entr.c b/src/elementwise_univariate/entr.c index c0d0240..282541a 100644 --- a/src/elementwise_univariate/entr.c +++ b/src/elementwise_univariate/entr.c @@ -24,7 +24,7 @@ static void local_jacobian(expr *node, double *vals) } } -static void local_wsum_hess(expr *node, double *out, double *w) +static void local_wsum_hess(expr *node, double *out, const double *w) { double *x = node->left->value; diff --git a/src/elementwise_univariate/exp.c b/src/elementwise_univariate/exp.c index b91f1c7..b78fcff 100644 --- a/src/elementwise_univariate/exp.c +++ b/src/elementwise_univariate/exp.c @@ -19,7 +19,7 @@ static void local_jacobian(expr *node, double *vals) memcpy(vals, node->value, node->size * sizeof(double)); } -static void local_wsum_hess(expr *node, double *out, double *w) +static void local_wsum_hess(expr *node, double *out, const double *w) { double *x = node->left->value; for (int j = 0; j < node->size; j++) diff --git a/src/elementwise_univariate/hyperbolic.c b/src/elementwise_univariate/hyperbolic.c index 65ae85f..e7ce823 100644 --- a/src/elementwise_univariate/hyperbolic.c +++ b/src/elementwise_univariate/hyperbolic.c @@ -21,7 +21,7 @@ static void sinh_local_jacobian(expr *node, double *vals) } } -static void sinh_local_wsum_hess(expr *node, double *out, double *w) +static void sinh_local_wsum_hess(expr *node, double *out, const double *w) { double *x = node->left->value; for (int j = 0; j < node->size; j++) @@ -59,7 +59,7 @@ static void tanh_local_jacobian(expr *node, double *vals) } } -static void tanh_local_wsum_hess(expr *node, double *out, double *w) +static void tanh_local_wsum_hess(expr *node, double *out, const double *w) { double *x = node->left->value; for (int j = 0; j < node->size; j++) @@ -97,7 +97,7 @@ static void asinh_local_jacobian(expr *node, double *vals) } } -static void asinh_local_wsum_hess(expr *node, double *out, double *w) +static void asinh_local_wsum_hess(expr *node, double *out, const double *w) { double *x = node->left->value; for (int j = 0; j < node->size; j++) @@ -135,7 +135,7 @@ static void atanh_local_jacobian(expr *node, double *vals) } } -static void atanh_local_wsum_hess(expr *node, double *out, double *w) +static void atanh_local_wsum_hess(expr *node, double *out, const double *w) { double *x = node->left->value; for (int j = 0; j < node->size; j++) diff --git a/src/elementwise_univariate/log.c b/src/elementwise_univariate/log.c index 89337d2..637dde3 100644 --- a/src/elementwise_univariate/log.c +++ b/src/elementwise_univariate/log.c @@ -22,7 +22,7 @@ static void local_jacobian(expr *node, double *vals) } } -static void local_wsum_hess(expr *node, double *out, double *w) +static void local_wsum_hess(expr *node, double *out, const double *w) { double *x = node->left->value; diff --git a/src/elementwise_univariate/logistic.c b/src/elementwise_univariate/logistic.c index da57fac..f5479a3 100644 --- a/src/elementwise_univariate/logistic.c +++ b/src/elementwise_univariate/logistic.c @@ -40,7 +40,7 @@ static void local_jacobian(expr *node, double *vals) } } -static void local_wsum_hess(expr *node, double *out, double *w) +static void local_wsum_hess(expr *node, double *out, const double *w) { double *x = node->left->value; double *sigmas; diff --git a/src/elementwise_univariate/power.c b/src/elementwise_univariate/power.c index 32be97d..fc501d9 100644 --- a/src/elementwise_univariate/power.c +++ b/src/elementwise_univariate/power.c @@ -1,4 +1,5 @@ #include "elementwise_univariate.h" +#include "subexpr.h" #include #include @@ -28,7 +29,7 @@ static void local_jacobian(expr *node, double *vals) } } -static void local_wsum_hess(expr *node, double *out, double *w) +static void local_wsum_hess(expr *node, double *out, const double *w) { double *x = node->left->value; double p = ((power_expr *) node)->p; diff --git a/src/elementwise_univariate/trig.c b/src/elementwise_univariate/trig.c index 20902e3..c1e892c 100644 --- a/src/elementwise_univariate/trig.c +++ b/src/elementwise_univariate/trig.c @@ -20,7 +20,7 @@ static void sin_local_jacobian(expr *node, double *vals) } } -static void sin_local_wsum_hess(expr *node, double *out, double *w) +static void sin_local_wsum_hess(expr *node, double *out, const double *w) { double *x = node->left->value; @@ -58,7 +58,7 @@ static void cos_local_jacobian(expr *node, double *vals) } } -static void cos_local_wsum_hess(expr *node, double *out, double *w) +static void cos_local_wsum_hess(expr *node, double *out, const double *w) { double *x = node->left->value; @@ -97,7 +97,7 @@ static void tan_local_jacobian(expr *node, double *vals) } } -static void tan_local_wsum_hess(expr *node, double *out, double *w) +static void tan_local_wsum_hess(expr *node, double *out, const double *w) { double *x = node->left->value; diff --git a/src/elementwise_univariate/xexp.c b/src/elementwise_univariate/xexp.c index 9bf7a7f..88c314f 100644 --- a/src/elementwise_univariate/xexp.c +++ b/src/elementwise_univariate/xexp.c @@ -23,7 +23,7 @@ static void local_jacobian(expr *node, double *vals) } } -static void local_wsum_hess(expr *node, double *out, double *w) +static void local_wsum_hess(expr *node, double *out, const double *w) { double *x = node->left->value; diff --git a/src/expr.c b/src/expr.c index 90884e8..0a31aa6 100644 --- a/src/expr.c +++ b/src/expr.c @@ -3,25 +3,21 @@ #include #include -expr *new_expr(int d1, int d2, int n_vars) +void init_expr(expr *node, int d1, int d2, int n_vars) { - expr *node = (expr *) calloc(1, sizeof(expr)); - if (!node) return NULL; - node->d1 = d1; node->d2 = d2; node->size = d1 * d2; node->n_vars = n_vars; node->refcount = 1; node->value = (double *) calloc(d1 * d2, sizeof(double)); - if (!node->value) - { - free(node); - return NULL; - } - node->var_id = -1; +} +expr *new_expr(int d1, int d2, int n_vars) +{ + expr *node = (expr *) calloc(1, sizeof(expr)); + init_expr(node, d1, d2, n_vars); return node; } @@ -59,7 +55,7 @@ void expr_retain(expr *node) node->refcount++; } -bool is_affine(expr *node) +bool is_affine(const expr *node) { bool left_affine = true; bool right_affine = true; From 618efc24d4cc91198e919ad7d0f9828d707384e4 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 2 Jan 2026 23:02:16 -0800 Subject: [PATCH 05/10] refactored to remove some redundant init functions --- include/affine.h | 4 --- include/expr.h | 4 ++- include/other.h | 3 --- src/affine/hstack.c | 12 ++------- src/affine/linear_op.c | 40 ++++++----------------------- src/elementwise_univariate/common.c | 32 ++++++----------------- src/elementwise_univariate/power.c | 2 +- src/expr.c | 11 ++++++-- src/other/quad_form.c | 40 +++++++---------------------- 9 files changed, 40 insertions(+), 108 deletions(-) diff --git a/include/affine.h b/include/affine.h index 321d2fd..dde9af7 100644 --- a/include/affine.h +++ b/include/affine.h @@ -5,10 +5,6 @@ #include "subexpr.h" #include "utils/CSR_Matrix.h" -/* Helper function to initialize a linear operator expr (can be used with derived - * types) */ -void init_linear_op(expr *node, expr *child, int d1, int d2); - expr *new_linear(expr *u, const CSR_Matrix *A); expr *new_add(expr *left, expr *right); diff --git a/include/expr.h b/include/expr.h index 6755ccb..a3154d7 100644 --- a/include/expr.h +++ b/include/expr.h @@ -54,7 +54,9 @@ typedef struct expr } expr; -void init_expr(expr *node, int d1, int d2, int n_vars); +void init_expr(expr *node, int d1, int d2, int n_vars, forward_fn forward, + jacobian_init_fn jacobian_init, eval_jacobian_fn eval_jacobian, + is_affine_fn is_affine, free_type_data_fn free_type_data); expr *new_expr(int d1, int d2, int n_vars); void free_expr(expr *node); diff --git a/include/other.h b/include/other.h index 1666c92..95fb035 100644 --- a/include/other.h +++ b/include/other.h @@ -5,9 +5,6 @@ #include "subexpr.h" #include "utils/CSR_Matrix.h" -/* Helper function to initialize a quad_form expr (can be used with derived types) */ -void init_quad_form(expr *node, expr *child); - expr *new_quad_form(expr *child, CSR_Matrix *Q); #endif /* OTHER_H */ diff --git a/src/affine/hstack.c b/src/affine/hstack.c index 457ce3f..a00863c 100644 --- a/src/affine/hstack.c +++ b/src/affine/hstack.c @@ -115,16 +115,8 @@ expr *new_hstack(expr **args, int n_args, int n_vars) /* Allocate the type-specific struct */ hstack_expr *hnode = (hstack_expr *) calloc(1, sizeof(hstack_expr)); expr *node = &hnode->base; - - /* Initialize basic fields */ - init_expr(node, args[0]->d1, d2, n_vars); - - /* Set function pointers */ - node->forward = forward; - node->jacobian_init = jacobian_init; - node->eval_jacobian = eval_jacobian; - node->is_affine = is_affine; - node->free_type_data = free_type_data; + init_expr(node, args[0]->d1, d2, n_vars, forward, jacobian_init, eval_jacobian, + is_affine, free_type_data); /* Set type-specific fields */ hnode->args = args; diff --git a/src/affine/linear_op.c b/src/affine/linear_op.c index b55fec1..a671d86 100644 --- a/src/affine/linear_op.c +++ b/src/affine/linear_op.c @@ -23,45 +23,21 @@ static void free_type_data(expr *node) free_csc_matrix(lin_node->A_csc); } -/* Helper function to initialize a linear operator expr */ -void init_linear_op(expr *node, expr *child, int d1, int d2) -{ - node->d1 = d1; - node->d2 = d2; - node->size = d1 * d2; - node->n_vars = child->n_vars; - node->var_id = -1; - node->refcount = 1; - node->left = child; - node->right = NULL; - node->dwork = NULL; - node->iwork = NULL; - node->value = (double *) calloc(node->size, sizeof(double)); - node->jacobian = NULL; - node->wsum_hess = NULL; - node->jacobian_init = NULL; - node->wsum_hess_init = NULL; - node->eval_jacobian = NULL; - node->eval_wsum_hess = NULL; - node->local_jacobian = NULL; - node->local_wsum_hess = NULL; - node->forward = forward; - node->is_affine = is_affine; - node->free_type_data = free_type_data; - - expr_retain(child); -} - expr *new_linear(expr *u, const CSR_Matrix *A) { /* Allocate the type-specific struct */ - linear_op_expr *lin_node = (linear_op_expr *) malloc(sizeof(linear_op_expr)); + linear_op_expr *lin_node = (linear_op_expr *) calloc(1, sizeof(linear_op_expr)); if (!lin_node) return NULL; expr *node = &lin_node->base; - /* Initialize base linear operator fields */ - init_linear_op(node, u, A->m, 1); + /* Initialize base fields */ + init_expr(node, A->m, 1, u->n_vars, forward, NULL, NULL, is_affine, + free_type_data); + + /* Set left child */ + node->left = u; + expr_retain(u); /* Check if allocation succeeded */ if (!node->value) diff --git a/src/elementwise_univariate/common.c b/src/elementwise_univariate/common.c index 1befdda..c0fa30e 100644 --- a/src/elementwise_univariate/common.c +++ b/src/elementwise_univariate/common.c @@ -103,39 +103,23 @@ bool is_affine_elementwise(const expr *node) */ void init_elementwise(expr *node, expr *child) { - /* Initialize dimensions from child */ - node->d1 = child->d1; - node->d2 = child->d2; - node->size = child->d1 * child->d2; - node->n_vars = child->n_vars; - node->var_id = -1; - node->refcount = 1; + /* Initialize base fields */ + init_expr(node, child->d1, child->d2, child->n_vars, NULL, + jacobian_init_elementwise, eval_jacobian_elementwise, + is_affine_elementwise, NULL); - /* Allocate the value array */ - node->value = (double *) calloc(node->size, sizeof(double)); - - node->left = child; - node->right = NULL; - node->dwork = NULL; - node->iwork = NULL; - node->jacobian = NULL; - node->wsum_hess = NULL; - node->jacobian_init = jacobian_init_elementwise; + /* Set wsum_hess functions */ node->wsum_hess_init = wsum_hess_init_elementwise; - node->eval_jacobian = eval_jacobian_elementwise; node->eval_wsum_hess = eval_wsum_hess_elementwise; - node->local_jacobian = NULL; - node->local_wsum_hess = NULL; - node->is_affine = is_affine_elementwise; - node->forward = NULL; - node->free_type_data = NULL; + /* Set left child */ + node->left = child; expr_retain(child); } expr *new_elementwise(expr *child) { - expr *node = (expr *) malloc(sizeof(expr)); + expr *node = (expr *) calloc(1, sizeof(expr)); if (!node) return NULL; init_elementwise(node, child); diff --git a/src/elementwise_univariate/power.c b/src/elementwise_univariate/power.c index fc501d9..16c0e80 100644 --- a/src/elementwise_univariate/power.c +++ b/src/elementwise_univariate/power.c @@ -43,7 +43,7 @@ static void local_wsum_hess(expr *node, double *out, const double *w) expr *new_power(expr *child, int p) { /* Allocate the type-specific struct */ - power_expr *pnode = (power_expr *) malloc(sizeof(power_expr)); + power_expr *pnode = (power_expr *) calloc(1, sizeof(power_expr)); expr *node = &pnode->base; /* Initialize base elementwise fields */ diff --git a/src/expr.c b/src/expr.c index 0a31aa6..d235862 100644 --- a/src/expr.c +++ b/src/expr.c @@ -3,7 +3,9 @@ #include #include -void init_expr(expr *node, int d1, int d2, int n_vars) +void init_expr(expr *node, int d1, int d2, int n_vars, forward_fn forward, + jacobian_init_fn jacobian_init, eval_jacobian_fn eval_jacobian, + is_affine_fn is_affine, free_type_data_fn free_type_data) { node->d1 = d1; node->d2 = d2; @@ -12,12 +14,17 @@ void init_expr(expr *node, int d1, int d2, int n_vars) node->refcount = 1; node->value = (double *) calloc(d1 * d2, sizeof(double)); node->var_id = -1; + node->forward = forward; + node->jacobian_init = jacobian_init; + node->eval_jacobian = eval_jacobian; + node->is_affine = is_affine; + node->free_type_data = free_type_data; } expr *new_expr(int d1, int d2, int n_vars) { expr *node = (expr *) calloc(1, sizeof(expr)); - init_expr(node, d1, d2, n_vars); + init_expr(node, d1, d2, n_vars, NULL, NULL, NULL, NULL, NULL); return node; } diff --git a/src/other/quad_form.c b/src/other/quad_form.c index b32c5e4..eefa5df 100644 --- a/src/other/quad_form.c +++ b/src/other/quad_form.c @@ -103,40 +103,18 @@ static void eval_jacobian(expr *node) } } -/* Helper function to initialize a quad_form expr */ -void init_quad_form(expr *node, expr *child) -{ - node->d1 = child->d1; - node->d2 = 1; - node->size = child->d1 * 1; - node->n_vars = child->n_vars; - node->var_id = -1; - node->refcount = 1; - node->left = child; - node->right = NULL; - node->dwork = NULL; - node->iwork = NULL; - node->value = (double *) calloc(node->size, sizeof(double)); - node->jacobian = NULL; - node->wsum_hess = NULL; - node->jacobian_init = jacobian_init; - node->wsum_hess_init = NULL; - node->eval_jacobian = eval_jacobian; - node->eval_wsum_hess = NULL; - node->local_jacobian = NULL; - node->local_wsum_hess = NULL; - node->is_affine = NULL; - node->forward = forward; - node->free_type_data = NULL; - - expr_retain(child); -} - expr *new_quad_form(expr *left, CSR_Matrix *Q) { - quad_form_expr *qnode = (quad_form_expr *) malloc(sizeof(quad_form_expr)); + quad_form_expr *qnode = (quad_form_expr *) calloc(1, sizeof(quad_form_expr)); expr *node = &qnode->base; - init_quad_form(node, left); + + /* Initialize base fields */ + init_expr(node, left->d1, 1, left->n_vars, forward, jacobian_init, eval_jacobian, + NULL, NULL); + + /* Set left child */ + node->left = left; + expr_retain(left); /* Set type-specific field */ qnode->Q = Q; From 2ed525be1d8b31c6213ee51c58183471ba8ad0e4 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 2 Jan 2026 23:21:19 -0800 Subject: [PATCH 06/10] refactor of affine atoms done --- include/affine.h | 3 --- src/affine/linear_op.c | 27 +++++++---------------- src/affine/sum.c | 49 +++++++----------------------------------- src/affine/variable.c | 3 --- src/expr.c | 7 ++++++ 5 files changed, 23 insertions(+), 66 deletions(-) diff --git a/include/affine.h b/include/affine.h index dde9af7..e1b66ac 100644 --- a/include/affine.h +++ b/include/affine.h @@ -9,9 +9,6 @@ expr *new_linear(expr *u, const CSR_Matrix *A); expr *new_add(expr *left, expr *right); -/* Helper function to initialize a sum expr (can be used with derived types) */ -void init_sum(expr *node, expr *child, int d1); - expr *new_sum(expr *child, int axis); expr *new_hstack(expr **args, int n_args, int n_vars); diff --git a/src/affine/linear_op.c b/src/affine/linear_op.c index a671d86..5e4729f 100644 --- a/src/affine/linear_op.c +++ b/src/affine/linear_op.c @@ -4,6 +4,7 @@ static void forward(expr *node, const double *u) { expr *x = node->left; + /* child's forward pass */ node->left->forward(node->left, u); @@ -19,42 +20,30 @@ static bool is_affine(const expr *node) static void free_type_data(expr *node) { linear_op_expr *lin_node = (linear_op_expr *) node; - free_csr_matrix(lin_node->A_csr); + /* memory pointing to by A_csr will already be freed when the + jacobian is freed, so free_csr_matrix(lin_node->A_csr) should + be commented out */ + // free_csr_matrix(lin_node->A_csr); free_csc_matrix(lin_node->A_csc); + lin_node->A_csr = NULL; + lin_node->A_csc = NULL; } expr *new_linear(expr *u, const CSR_Matrix *A) { /* Allocate the type-specific struct */ linear_op_expr *lin_node = (linear_op_expr *) calloc(1, sizeof(linear_op_expr)); - if (!lin_node) return NULL; - expr *node = &lin_node->base; - - /* Initialize base fields */ init_expr(node, A->m, 1, u->n_vars, forward, NULL, NULL, is_affine, free_type_data); - - /* Set left child */ node->left = u; expr_retain(u); - /* Check if allocation succeeded */ - if (!node->value) - { - free(lin_node); - return NULL; - } - - /* allocate jacobian and copy A into it */ - // TODO: this should eventually be removed - node->jacobian = new_csr_matrix(A->m, A->n, A->nnz); - copy_csr_matrix(A, node->jacobian); - /* Initialize type-specific fields */ lin_node->A_csr = new_csr_matrix(A->m, A->n, A->nnz); copy_csr_matrix(A, lin_node->A_csr); lin_node->A_csc = csr_to_csc(A); + node->jacobian = lin_node->A_csr; return node; } diff --git a/src/affine/sum.c b/src/affine/sum.c index 9480ad6..ee9988b 100644 --- a/src/affine/sum.c +++ b/src/affine/sum.c @@ -145,35 +145,6 @@ static void free_type_data(expr *node) free_int_double_pair_array(snode->int_double_pairs); } -/* Helper function to initialize a sum expr */ -void init_sum(expr *node, expr *child, int d1) -{ - node->d1 = d1; - node->d2 = 1; - node->size = d1 * 1; - node->n_vars = child->n_vars; - node->var_id = -1; - node->refcount = 1; - node->left = child; - node->right = NULL; - node->dwork = NULL; - node->iwork = NULL; - node->value = (double *) calloc(node->size, sizeof(double)); - node->jacobian = NULL; - node->wsum_hess = NULL; - node->jacobian_init = jacobian_init; - node->wsum_hess_init = wsum_hess_init; - node->eval_jacobian = eval_jacobian; - node->eval_wsum_hess = eval_wsum_hess; - node->local_jacobian = NULL; - node->local_wsum_hess = NULL; - node->is_affine = is_affine; - node->forward = forward; - node->free_type_data = free_type_data; - - expr_retain(child); -} - expr *new_sum(expr *child, int axis) { int d1 = 0; @@ -195,20 +166,16 @@ expr *new_sum(expr *child, int axis) } /* Allocate the type-specific struct */ - sum_expr *snode = (sum_expr *) malloc(sizeof(sum_expr)); - if (!snode) return NULL; - + sum_expr *snode = (sum_expr *) calloc(1, sizeof(sum_expr)); expr *node = &snode->base; + init_expr(node, d1, 1, child->n_vars, forward, jacobian_init, eval_jacobian, + is_affine, free_type_data); + node->left = child; + expr_retain(child); - /* Initialize base sum fields */ - init_sum(node, child, d1); - - /* Check if allocation succeeded */ - if (!node->value) - { - free(snode); - return NULL; - } + /* hessian function pointers */ + node->wsum_hess_init = wsum_hess_init; + node->eval_wsum_hess = eval_wsum_hess; /* Set type-specific fields */ snode->axis = axis; diff --git a/src/affine/variable.c b/src/affine/variable.c index 5fca186..46e2446 100644 --- a/src/affine/variable.c +++ b/src/affine/variable.c @@ -29,13 +29,10 @@ static bool is_affine(const expr *node) expr *new_variable(int d1, int d2, int var_id, int n_vars) { expr *node = new_expr(d1, d2, n_vars); - if (!node) return NULL; - node->forward = forward; node->var_id = var_id; node->is_affine = is_affine; node->jacobian_init = jacobian_init; - // node->jacobian = NULL; return node; } diff --git a/src/expr.c b/src/expr.c index d235862..fa30213 100644 --- a/src/expr.c +++ b/src/expr.c @@ -38,6 +38,8 @@ void free_expr(expr *node) /* recursively free children */ free_expr(node->left); free_expr(node->right); + node->left = NULL; + node->right = NULL; /* free value array and jacobian */ free(node->value); @@ -45,6 +47,11 @@ void free_expr(expr *node) free_csr_matrix(node->wsum_hess); free(node->dwork); free(node->iwork); + node->value = NULL; + node->jacobian = NULL; + node->wsum_hess = NULL; + node->dwork = NULL; + node->iwork = NULL; /* free type-specific data */ if (node->free_type_data) From 891f1279615a4711f1b7fc2521211fc2b9030948 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 2 Jan 2026 23:44:24 -0800 Subject: [PATCH 07/10] refactored elementwise univariate to resuse already computed values --- include/expr.h | 1 + src/bivariate/multiply.c | 2 +- src/bivariate/quad_over_lin.c | 4 ++-- src/elementwise_univariate/common.c | 14 +++++++------- src/elementwise_univariate/exp.c | 3 +-- src/elementwise_univariate/hyperbolic.c | 9 ++++----- src/elementwise_univariate/logistic.c | 18 +----------------- src/elementwise_univariate/power.c | 3 --- src/elementwise_univariate/trig.c | 18 +++++++----------- src/elementwise_univariate/xexp.c | 4 ++-- src/expr.c | 2 +- src/other/quad_form.c | 4 ++-- 12 files changed, 29 insertions(+), 53 deletions(-) diff --git a/include/expr.h b/include/expr.h index a3154d7..f6381a3 100644 --- a/include/expr.h +++ b/include/expr.h @@ -7,6 +7,7 @@ #include #define JAC_IDXS_NOT_SET -1 +#define NOT_A_VARIABLE -1 /* Function pointer types */ struct expr; diff --git a/src/bivariate/multiply.c b/src/bivariate/multiply.c index 0bed4e9..eae64ed 100644 --- a/src/bivariate/multiply.c +++ b/src/bivariate/multiply.c @@ -32,7 +32,7 @@ static void jacobian_init(expr *node) node->left->jacobian_init(node->left); } - if (node->right->var_id != -1) + if (node->right->var_id != NOT_A_VARIABLE) { node->right->jacobian_init(node->right); } diff --git a/src/bivariate/quad_over_lin.c b/src/bivariate/quad_over_lin.c index 4149ade..20be94f 100644 --- a/src/bivariate/quad_over_lin.c +++ b/src/bivariate/quad_over_lin.c @@ -36,7 +36,7 @@ static void jacobian_init(expr *node) expr *y = node->right; /* if left node is a variable */ - if (x->var_id != -1) + if (x->var_id != NOT_A_VARIABLE) { node->jacobian = new_csr_matrix(1, node->n_vars, x->d1 + 1); node->jacobian->p[0] = 0; @@ -110,7 +110,7 @@ static void eval_jacobian(expr *node) expr *y = node->right; /* if x is a variable */ - if (x->var_id != -1) + if (x->var_id != NOT_A_VARIABLE) { /* if x has lower idx than y*/ if (x->var_id < y->var_id) diff --git a/src/elementwise_univariate/common.c b/src/elementwise_univariate/common.c index c0fa30e..986d26a 100644 --- a/src/elementwise_univariate/common.c +++ b/src/elementwise_univariate/common.c @@ -8,7 +8,7 @@ void jacobian_init_elementwise(expr *node) expr *child = node->left; /* if the variable is a child */ - if (child->var_id != -1) + if (child->var_id != NOT_A_VARIABLE) { node->jacobian = new_csr_matrix(node->size, node->n_vars, node->size); for (int j = 0; j < node->size; j++) @@ -18,7 +18,7 @@ void jacobian_init_elementwise(expr *node) } node->jacobian->p[node->size] = node->size; } - /* otherwise it should be a linear operator */ + /* otherwise it will be a linear operator */ else { node->jacobian = new_csr_matrix(child->jacobian->m, child->jacobian->n, @@ -31,13 +31,13 @@ void eval_jacobian_elementwise(expr *node) { expr *child = node->left; - if (child->var_id != -1) + if (child->var_id != NOT_A_VARIABLE) { node->local_jacobian(node, node->jacobian->x); } else { - /* Child must be a linear operator */ + /* Child will be a linear operator */ linear_op_expr *lin_child = (linear_op_expr *) child; node->local_jacobian(node, node->dwork); diag_csr_mult(node->dwork, lin_child->A_csr, node->jacobian); @@ -51,7 +51,7 @@ void wsum_hess_init_elementwise(expr *node) int i; /* if the variable is a child*/ - if (id != -1) + if (id != NOT_A_VARIABLE) { node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, node->size); @@ -78,13 +78,13 @@ void eval_wsum_hess_elementwise(expr *node, const double *w) { expr *child = node->left; - if (child->var_id != -1) + if (child->var_id != NOT_A_VARIABLE) { node->local_wsum_hess(node, node->wsum_hess->x, w); } else { - /* Child must be a linear operator */ + /* Child will be a linear operator */ linear_op_expr *lin_child = (linear_op_expr *) child; node->local_wsum_hess(node, node->dwork, w); ATDA_values(lin_child->A_csc, node->dwork, node->wsum_hess); diff --git a/src/elementwise_univariate/exp.c b/src/elementwise_univariate/exp.c index b78fcff..3b27c15 100644 --- a/src/elementwise_univariate/exp.c +++ b/src/elementwise_univariate/exp.c @@ -21,10 +21,9 @@ static void local_jacobian(expr *node, double *vals) static void local_wsum_hess(expr *node, double *out, const double *w) { - double *x = node->left->value; for (int j = 0; j < node->size; j++) { - out[j] = w[j] * exp(x[j]); + out[j] = w[j] * node->value[j]; } } diff --git a/src/elementwise_univariate/hyperbolic.c b/src/elementwise_univariate/hyperbolic.c index e7ce823..c0bd9fa 100644 --- a/src/elementwise_univariate/hyperbolic.c +++ b/src/elementwise_univariate/hyperbolic.c @@ -23,10 +23,9 @@ static void sinh_local_jacobian(expr *node, double *vals) static void sinh_local_wsum_hess(expr *node, double *out, const double *w) { - double *x = node->left->value; for (int j = 0; j < node->size; j++) { - out[j] = w[j] * sinh(x[j]); + out[j] = w[j] * node->value[j]; } } @@ -51,10 +50,10 @@ static void tanh_forward(expr *node, const double *u) static void tanh_local_jacobian(expr *node, double *vals) { - expr *child = node->left; + double *x = node->left->value; for (int j = 0; j < node->size; j++) { - double c = cosh(child->value[j]); + double c = cosh(x[j]); vals[j] = 1.0 / (c * c); } } @@ -65,7 +64,7 @@ static void tanh_local_wsum_hess(expr *node, double *out, const double *w) for (int j = 0; j < node->size; j++) { double c = cosh(x[j]); - out[j] = w[j] * (-2.0 * tanh(x[j]) / (c * c)); + out[j] = w[j] * (-2.0 * node->value[j] / (c * c)); } } diff --git a/src/elementwise_univariate/logistic.c b/src/elementwise_univariate/logistic.c index f5479a3..ba7a64d 100644 --- a/src/elementwise_univariate/logistic.c +++ b/src/elementwise_univariate/logistic.c @@ -42,10 +42,9 @@ static void local_jacobian(expr *node, double *vals) static void local_wsum_hess(expr *node, double *out, const double *w) { - double *x = node->left->value; double *sigmas; - if (node->left->var_id != -1) + if (node->left->var_id != NOT_A_VARIABLE) { sigmas = node->jacobian->x; } @@ -54,23 +53,8 @@ static void local_wsum_hess(expr *node, double *out, const double *w) sigmas = node->dwork; } - // double sigma; - for (int j = 0; j < node->size; j++) { - /* - if (x[j] >= 0) - { - sigma = 1.0 / (1.0 + exp(-x[j])); - } - else - { - double exp_x = exp(x[j]); - sigma = exp_x / (1.0 + exp_x); - } - - out[j] = w[j] * sigma * (1.0 - sigma); - */ out[j] = sigmas[j] * (1.0 - sigmas[j]) * w[j]; } } diff --git a/src/elementwise_univariate/power.c b/src/elementwise_univariate/power.c index 16c0e80..c8463e8 100644 --- a/src/elementwise_univariate/power.c +++ b/src/elementwise_univariate/power.c @@ -45,10 +45,7 @@ expr *new_power(expr *child, int p) /* Allocate the type-specific struct */ power_expr *pnode = (power_expr *) calloc(1, sizeof(power_expr)); expr *node = &pnode->base; - - /* Initialize base elementwise fields */ init_elementwise(node, child); - node->forward = forward; node->local_jacobian = local_jacobian; node->local_wsum_hess = local_wsum_hess; diff --git a/src/elementwise_univariate/trig.c b/src/elementwise_univariate/trig.c index c1e892c..900dd2b 100644 --- a/src/elementwise_univariate/trig.c +++ b/src/elementwise_univariate/trig.c @@ -13,20 +13,18 @@ static void sin_forward(expr *node, const double *u) static void sin_local_jacobian(expr *node, double *vals) { - expr *child = node->left; + double *x = node->left->value; for (int j = 0; j < node->size; j++) { - vals[j] = cos(child->value[j]); + vals[j] = cos(x[j]); } } static void sin_local_wsum_hess(expr *node, double *out, const double *w) { - double *x = node->left->value; - for (int j = 0; j < node->size; j++) { - out[j] = -w[j] * sin(x[j]); + out[j] = -w[j] * node->value[j]; } } @@ -51,20 +49,18 @@ static void cos_forward(expr *node, const double *u) static void cos_local_jacobian(expr *node, double *vals) { - expr *child = node->left; + double *x = node->left->value; for (int j = 0; j < node->size; j++) { - vals[j] = -sin(child->value[j]); + vals[j] = -sin(x[j]); } } static void cos_local_wsum_hess(expr *node, double *out, const double *w) { - double *x = node->left->value; - for (int j = 0; j < node->size; j++) { - out[j] = -w[j] * cos(x[j]); + out[j] = -w[j] * node->value[j]; } } @@ -104,7 +100,7 @@ static void tan_local_wsum_hess(expr *node, double *out, const double *w) for (int j = 0; j < node->size; j++) { double c = cos(x[j]); - out[j] = 2.0 * w[j] * tan(x[j]) / (c * c); + out[j] = 2.0 * w[j] * node->value[j] / (c * c); } } diff --git a/src/elementwise_univariate/xexp.c b/src/elementwise_univariate/xexp.c index 88c314f..cb3fdb6 100644 --- a/src/elementwise_univariate/xexp.c +++ b/src/elementwise_univariate/xexp.c @@ -16,10 +16,10 @@ static void forward(expr *node, const double *u) static void local_jacobian(expr *node, double *vals) { - expr *child = node->left; + double *x = node->left->value; for (int j = 0; j < node->size; j++) { - vals[j] = (1.0 + child->value[j]) * exp(child->value[j]); + vals[j] = (1.0 + x[j]) * exp(x[j]); } } diff --git a/src/expr.c b/src/expr.c index fa30213..0815825 100644 --- a/src/expr.c +++ b/src/expr.c @@ -13,7 +13,7 @@ void init_expr(expr *node, int d1, int d2, int n_vars, forward_fn forward, node->n_vars = n_vars; node->refcount = 1; node->value = (double *) calloc(d1 * d2, sizeof(double)); - node->var_id = -1; + node->var_id = NOT_A_VARIABLE; node->forward = forward; node->jacobian_init = jacobian_init; node->eval_jacobian = eval_jacobian; diff --git a/src/other/quad_form.c b/src/other/quad_form.c index eefa5df..13cf9c8 100644 --- a/src/other/quad_form.c +++ b/src/other/quad_form.c @@ -31,7 +31,7 @@ static void jacobian_init(expr *node) node->dwork = (double *) malloc(x->d1 * sizeof(double)); /* if x is a variable */ - if (x->var_id != -1) + if (x->var_id != NOT_A_VARIABLE) { node->jacobian = new_csr_matrix(1, node->n_vars, x->d1); node->jacobian->p[0] = 0; @@ -77,7 +77,7 @@ static void eval_jacobian(expr *node) CSR_Matrix *Q = ((quad_form_expr *) node)->Q; /* if x is a variable */ - if (x->var_id != -1) + if (x->var_id != NOT_A_VARIABLE) { csr_matvec(Q, x->value, node->jacobian->x, 0); From 7062b8fc0f6029181e93c9810e4e715753339778 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 2 Jan 2026 23:46:25 -0800 Subject: [PATCH 08/10] removed COO --- include/utils/COO_Matrix.h | 16 ---------------- src/utils/COO_Matrix.c | 33 --------------------------------- 2 files changed, 49 deletions(-) delete mode 100644 include/utils/COO_Matrix.h delete mode 100644 src/utils/COO_Matrix.c diff --git a/include/utils/COO_Matrix.h b/include/utils/COO_Matrix.h deleted file mode 100644 index 6e922e3..0000000 --- a/include/utils/COO_Matrix.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef MATRIX_H -#define MATRIX_H - -/* Coordinate (COO) format sparse matrix */ -typedef struct COO_Matrix -{ - int nnz; - int *rows; - int *cols; - double *vals; -} COO_Matrix; - -COO_Matrix *new_coo_matrix(int nnz); -void free_coo_matrix(COO_Matrix *matrix); - -#endif /* MATRIX_H */ diff --git a/src/utils/COO_Matrix.c b/src/utils/COO_Matrix.c deleted file mode 100644 index 9162bb0..0000000 --- a/src/utils/COO_Matrix.c +++ /dev/null @@ -1,33 +0,0 @@ -#include "utils/COO_Matrix.h" -#include - -COO_Matrix *new_coo_matrix(int nnz) -{ - COO_Matrix *matrix = (COO_Matrix *) malloc(sizeof(COO_Matrix)); - if (!matrix) return NULL; - - matrix->nnz = nnz; - matrix->rows = (int *) calloc(nnz, sizeof(int)); - matrix->cols = (int *) calloc(nnz, sizeof(int)); - matrix->vals = (double *) calloc(nnz, sizeof(double)); - - if (!matrix->rows || !matrix->cols || !matrix->vals) - { - free(matrix->rows); - free(matrix->cols); - free(matrix->vals); - free(matrix); - return NULL; - } - - return matrix; -} - -void free_coo_matrix(COO_Matrix *matrix) -{ - if (!matrix) return; - free(matrix->rows); - free(matrix->cols); - free(matrix->vals); - free(matrix); -} From f8f3fdce4a183fdc6831c71af22117cc909b3ac5 Mon Sep 17 00:00:00 2001 From: Daniel Date: Sat, 3 Jan 2026 00:07:30 -0800 Subject: [PATCH 09/10] finished cleaning up elementwise univariate and affine atoms --- README.md | 4 +--- include/elementwise_univariate.h | 2 +- include/subexpr.h | 2 +- src/elementwise_univariate/power.c | 2 +- src/other/quad_form.c | 11 +++-------- tests/wsum_hess/test_power.h | 2 +- 6 files changed, 8 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 8e8b174..d1d7b1f 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,5 @@ -1. power should be double -2. can we reuse calculations, like in hessian of logistic 3. more tests for chain rule elementwise univariate hessian 4. in the refactor, add consts 5. multiply with one constant vector/scalar argument 6. AX where X is a matrix. Can that happen? How is that canonicalized? Maybe it can't happen. -7. Must be able to compute jacobian and hessian of A @ phi(x), so linear operator needs other code? \ No newline at end of file +7. Must be able to compute jacobian and hessian of A @ phi(x), so linear operator needs other code! This requires new infrastructure, I think. \ No newline at end of file diff --git a/include/elementwise_univariate.h b/include/elementwise_univariate.h index 4e22892..f0ed791 100644 --- a/include/elementwise_univariate.h +++ b/include/elementwise_univariate.h @@ -18,7 +18,7 @@ expr *new_tanh(expr *child); expr *new_asinh(expr *child); expr *new_atanh(expr *child); expr *new_logistic(expr *child); -expr *new_power(expr *child, int p); +expr *new_power(expr *child, double p); expr *new_xexp(expr *child); /* the jacobian and wsum_hess for elementwise univariate atoms are always diff --git a/include/subexpr.h b/include/subexpr.h index b1f37ab..4f56ea4 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -22,7 +22,7 @@ typedef struct linear_op_expr typedef struct power_expr { expr base; - int p; + double p; } power_expr; /* Quadratic form: y = x'*Q*x */ diff --git a/src/elementwise_univariate/power.c b/src/elementwise_univariate/power.c index c8463e8..c977749 100644 --- a/src/elementwise_univariate/power.c +++ b/src/elementwise_univariate/power.c @@ -40,7 +40,7 @@ static void local_wsum_hess(expr *node, double *out, const double *w) } } -expr *new_power(expr *child, int p) +expr *new_power(expr *child, double p) { /* Allocate the type-specific struct */ power_expr *pnode = (power_expr *) calloc(1, sizeof(power_expr)); diff --git a/src/other/quad_form.c b/src/other/quad_form.c index 13cf9c8..9eb74a3 100644 --- a/src/other/quad_form.c +++ b/src/other/quad_form.c @@ -5,13 +5,11 @@ #include #include -/* Note: Q is not freed here because it's owned by the caller */ - static void forward(expr *node, const double *u) { expr *x = node->left; - /* children's forward passes */ + /* child's forward pass */ x->forward(x, u); /* local forward pass */ @@ -45,8 +43,7 @@ static void jacobian_init(expr *node) else /* x is not a variable */ { /* compute required allocation and allocate jacobian */ - bool *col_nz = (bool *) calloc( - node->n_vars, sizeof(bool)); /* TODO: could use iwork here instead*/ + bool *col_nz = (bool *) calloc(node->n_vars, sizeof(bool)); int nonzero_cols = count_nonzero_cols(x->jacobian, col_nz); node->jacobian = new_csr_matrix(1, node->n_vars, nonzero_cols + 1); @@ -65,9 +62,6 @@ static void jacobian_init(expr *node) node->jacobian->p[0] = 0; node->jacobian->p[1] = node->jacobian->nnz; - - /* Cast x to linear operator to use its A_csc in eval_jacobian */ - node->iwork = (int *) malloc(sizeof(int)); } } @@ -109,6 +103,7 @@ expr *new_quad_form(expr *left, CSR_Matrix *Q) expr *node = &qnode->base; /* Initialize base fields */ + assert(left->d2 == 1); init_expr(node, left->d1, 1, left->n_vars, forward, jacobian_init, eval_jacobian, NULL, NULL); diff --git a/tests/wsum_hess/test_power.h b/tests/wsum_hess/test_power.h index a8f82c5..f201e8c 100644 --- a/tests/wsum_hess/test_power.h +++ b/tests/wsum_hess/test_power.h @@ -15,7 +15,7 @@ const char *test_wsum_hess_power() double w[3] = {1.0, 2.0, 3.0}; expr *x = new_variable(3, 1, 0, 3); - expr *power_node = new_power(x, 3); + expr *power_node = new_power(x, 3.0); power_node->forward(power_node, u_vals); power_node->wsum_hess_init(power_node); power_node->eval_wsum_hess(power_node, w); From 3a23c26263256f0b2120621978c499b3eaa4897a Mon Sep 17 00:00:00 2001 From: Daniel Date: Sat, 3 Jan 2026 00:13:02 -0800 Subject: [PATCH 10/10] added comment --- src/affine/linear_op.c | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/affine/linear_op.c b/src/affine/linear_op.c index 5e4729f..4e9caa8 100644 --- a/src/affine/linear_op.c +++ b/src/affine/linear_op.c @@ -43,6 +43,8 @@ expr *new_linear(expr *u, const CSR_Matrix *A) lin_node->A_csr = new_csr_matrix(A->m, A->n, A->nnz); copy_csr_matrix(A, lin_node->A_csr); lin_node->A_csc = csr_to_csc(A); + + /* what if we have A @ phi(x). Then I don't think this is correct. */ node->jacobian = lin_node->A_csr; return node;