Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
1. power should be double
2. can we reuse calculations, like in hessian of logistic
3. more tests for chain rule elementwise univariate hessian
4. in the refactor, add consts
Copy link
Collaborator

@Transurgeon Transurgeon Jan 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this done? Did you mean to add the type hints (I am not sure if that's what we call them) for the arguments to functions?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed from "int p" to "double p" in power_expr, so now it's the correct data type

5. multiply with one constant vector/scalar argument
6. AX where X is a matrix. Can that happen? How is that canonicalized? Maybe it can't happen.
6. AX where X is a matrix. Can that happen? How is that canonicalized? Maybe it can't happen.
7. Must be able to compute jacobian and hessian of A @ phi(x), so linear operator needs other code! This requires new infrastructure, I think.
2 changes: 2 additions & 0 deletions include/affine.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
#define AFFINE_H

#include "expr.h"
#include "subexpr.h"
#include "utils/CSR_Matrix.h"

expr *new_linear(expr *u, const CSR_Matrix *A);

expr *new_add(expr *left, expr *right);

expr *new_sum(expr *child, int axis);
expr *new_hstack(expr **args, int n_args, int n_vars);

Expand Down
10 changes: 7 additions & 3 deletions include/elementwise_univariate.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

#include "expr.h"

/* Helper function to initialize an elementwise expr (can be used with derived types)
*/
void init_elementwise(expr *node, expr *child);

expr *new_exp(expr *child);
expr *new_log(expr *child);
expr *new_entr(expr *child);
Expand All @@ -14,19 +18,19 @@ expr *new_tanh(expr *child);
expr *new_asinh(expr *child);
expr *new_atanh(expr *child);
expr *new_logistic(expr *child);
expr *new_power(expr *child, int p);
expr *new_power(expr *child, double p);
expr *new_xexp(expr *child);

/* the jacobian and wsum_hess for elementwise univariate atoms are always
initialized in the same way and implement the chain rule in the same way */
void jacobian_init_elementwise(expr *node);
void eval_jacobian_elementwise(expr *node);
void wsum_hess_init_elementwise(expr *node);
void eval_wsum_hess_elementwise(expr *node, double *w);
void eval_wsum_hess_elementwise(expr *node, const double *w);
expr *new_elementwise(expr *child);

/* no elementwise atoms are affine according to our convention,
so we can have a common implementation */
bool is_affine_elementwise(expr *node);
bool is_affine_elementwise(const expr *node);

#endif /* ELEMENTWISE_UNIVARIATE_H */
54 changes: 21 additions & 33 deletions include/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,70 +7,58 @@
#include <stddef.h>

#define JAC_IDXS_NOT_SET -1

/* Forward declarations */
struct expr;
struct int_double_pair;
#define NOT_A_VARIABLE -1

/* Function pointer types */
struct expr;
typedef void (*forward_fn)(struct expr *node, const double *u);
typedef void (*jacobian_init_fn)(struct expr *node);
typedef void (*wsum_hess_init_fn)(struct expr *node);
typedef void (*eval_jacobian_fn)(struct expr *node);
typedef void (*wsum_hess_fn)(struct expr *node, double *w);
typedef void (*wsum_hess_fn)(struct expr *node, const double *w);
typedef void (*local_jacobian_fn)(struct expr *node, double *out);
typedef void (*local_wsum_hess_fn)(struct expr *node, double *out, double *w);
typedef bool (*is_affine_fn)(struct expr *node);

/* TODO: implement proper polymorphism */
typedef void (*local_wsum_hess_fn)(struct expr *node, double *out, const double *w);
typedef bool (*is_affine_fn)(const struct expr *node);
typedef void (*free_type_data_fn)(struct expr *node);

/* Expression node structure */
/* Base expression node structure - contains only common fields */
typedef struct expr
{
// ------------------------------------------------------------------------
// general quantities
// ------------------------------------------------------------------------
int d1, d2, size;
int n_vars;
int var_id;
int refcount;
int d1, d2, size, n_vars, refcount, var_id;
struct expr *left;
struct expr *right;
struct expr **args; /* hstack can have multiple arguments */
int n_args;
double *dwork;
int *iwork;
struct int_double_pair *int_double_pairs; /* for sorting jacobian entries */
int p; /* power of power expression */
int axis; /* axis for sum or similar operations */
CSR_Matrix *Q; /* Q for quad_form */

// ------------------------------------------------------------------------
// forward pass related quantities
// oracle related quantities
// ------------------------------------------------------------------------
double *value;
forward_fn forward;

// ------------------------------------------------------------------------
// jacobian related quantities
// ------------------------------------------------------------------------
CSR_Matrix *jacobian;
CSR_Matrix *wsum_hess;
CSR_Matrix *CSR_work;
forward_fn forward;
jacobian_init_fn jacobian_init;
wsum_hess_init_fn wsum_hess_init;
eval_jacobian_fn eval_jacobian;
wsum_hess_fn eval_wsum_hess;
local_jacobian_fn local_jacobian;
local_wsum_hess_fn local_wsum_hess;
is_affine_fn is_affine;

// for every linear operator we store A in CSR and CSC
CSC_Matrix *A_csc;
CSR_Matrix *A_csr;
// ------------------------------------------------------------------------
// other things
// ------------------------------------------------------------------------
is_affine_fn is_affine;
local_jacobian_fn local_jacobian; /* used by elementwise univariate atoms*/
local_wsum_hess_fn local_wsum_hess; /* used by elementwise univariate atoms*/
free_type_data_fn free_type_data; /* Cleanup for type-specific fields */

} expr;

void init_expr(expr *node, int d1, int d2, int n_vars, forward_fn forward,
jacobian_init_fn jacobian_init, eval_jacobian_fn eval_jacobian,
is_affine_fn is_affine, free_type_data_fn free_type_data);

expr *new_expr(int d1, int d2, int n_vars);
void free_expr(expr *node);

Expand Down
1 change: 1 addition & 0 deletions include/other.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define OTHER_H

#include "expr.h"
#include "subexpr.h"
#include "utils/CSR_Matrix.h"

expr *new_quad_form(expr *child, CSR_Matrix *Q);
Expand Down
51 changes: 51 additions & 0 deletions include/subexpr.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#ifndef SUBEXPR_H
#define SUBEXPR_H

#include "expr.h"
#include "utils/CSC_Matrix.h"
#include "utils/CSR_Matrix.h"

/* Forward declaration */
struct int_double_pair;

/* Type-specific expression structures that "inherit" from expr */

/* Linear operator: y = A * x */
typedef struct linear_op_expr
{
expr base;
CSC_Matrix *A_csc;
CSR_Matrix *A_csr;
} linear_op_expr;

/* Power: y = x^p */
typedef struct power_expr
{
expr base;
double p;
} power_expr;

/* Quadratic form: y = x'*Q*x */
typedef struct quad_form_expr
{
expr base;
CSR_Matrix *Q;
} quad_form_expr;

/* Sum reduction along an axis */
typedef struct sum_expr
{
expr base;
int axis;
struct int_double_pair *int_double_pairs; /* for sorting jacobian entries */
} sum_expr;

/* Horizontal stack (concatenate) */
typedef struct hstack_expr
{
expr base;
expr **args;
int n_args;
} hstack_expr;

#endif /* SUBEXPR_H */
16 changes: 0 additions & 16 deletions include/utils/COO_Matrix.h

This file was deleted.

5 changes: 5 additions & 0 deletions include/utils/CSC_Matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,9 @@ CSR_Matrix *ATA_alloc(const CSC_Matrix *A);
*/
void ATDA_values(const CSC_Matrix *A, const double *d, CSR_Matrix *C);

/* C = z^T A where A is in CSC format and C is assumed to have one row.
* C must have column indices pre-computed. Fills in values of C only.
*/
void csc_matvec_fill_values(const CSC_Matrix *A, const double *z, CSR_Matrix *C);

#endif /* CSC_MATRIX_H */
10 changes: 2 additions & 8 deletions src/affine/add.c
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ static void wsum_hess_init(expr *node)
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, nnz_max);
}

static void eval_wsum_hess(expr *node, double *w)
static void eval_wsum_hess(expr *node, const double *w)
{
/* evaluate children's wsum_hess */
node->left->eval_wsum_hess(node->left, w);
Expand All @@ -55,20 +55,14 @@ static void eval_wsum_hess(expr *node, double *w)
sum_csr_matrices(node->left->wsum_hess, node->right->wsum_hess, node->wsum_hess);
}

static bool is_affine(expr *node)
static bool is_affine(const expr *node)
{
return node->left->is_affine(node->left) && node->right->is_affine(node->right);
}

expr *new_add(expr *left, expr *right)
{
if (!left || !right) return NULL;
if (left->d1 != right->d1) return NULL;
if (left->d2 != right->d2) return NULL;

expr *node = new_expr(left->d1, left->d2, left->n_vars);
if (!node) return NULL;

node->left = left;
node->right = right;
expr_retain(left);
Expand Down
8 changes: 2 additions & 6 deletions src/affine/constant.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,16 @@ static void forward(expr *node, const double *u)
(void) u;
}

static bool is_affine(expr *node)
static bool is_affine(const expr *node)
{
(void) node;
return true; /* constant is affine */
return true;
}

expr *new_constant(int d1, int d2, int n_vars, const double *values)
{
expr *node = new_expr(d1, d2, n_vars);
if (!node) return NULL;

/* Copy constant values */
memcpy(node->value, values, d1 * d2 * sizeof(double));

node->forward = forward;
node->is_affine = is_affine;

Expand Down
Loading