Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/bivariate.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,10 @@ expr *new_quad_over_lin(expr *left, expr *right);
expr *new_rel_entr_first_arg_scalar(expr *left, expr *right);
expr *new_rel_entr_second_arg_scalar(expr *left, expr *right);

/* Left matrix multiplication: A @ f(x) where A is a constant matrix */
expr *new_left_matmul(expr *u, const CSR_Matrix *A);

/* Right matrix multiplication: f(x) @ A where A is a constant matrix */
expr *new_right_matmul(expr *u, const CSR_Matrix *A);

#endif /* BIVARIATE_H */
22 changes: 22 additions & 0 deletions include/subexpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,26 @@ typedef struct elementwise_mult_expr
CSR_Matrix *CSR_work2;
} elementwise_mult_expr;

/* Left matrix multiplication: y = A * f(x) where f(x) is an expression. Note that
here A does not have global column indices but it is a local matrix. This is an
important distinction compared to linear_op_expr. */
typedef struct left_matmul_expr
{
expr base;
CSR_Matrix *A;
CSR_Matrix *AT;
CSC_Matrix *CSC_work;
} left_matmul_expr;

/* Right matrix multiplication: y = f(x) * A where f(x) is an expression.
* f(x) has shape p x n, A has shape n x q, output y has shape p x q.
* Uses vec(y) = B * vec(f(x)) where B = A^T kron I_p. */
typedef struct right_matmul_expr
{
expr base;
CSR_Matrix *B; /* B = A^T kron I_p */
CSR_Matrix *BT; /* B^T for backpropagating Hessian weights */
CSC_Matrix *CSC_work;
} right_matmul_expr;

#endif /* SUBEXPR_H */
14 changes: 14 additions & 0 deletions include/utils/CSC_Matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,18 @@ void BTDA_fill_values(const CSC_Matrix *A, const CSC_Matrix *B, const double *d,
*/
void csc_matvec_fill_values(const CSC_Matrix *A, const double *z, CSR_Matrix *C);

CSC_Matrix *csr_to_csc_fill_sparsity(const CSR_Matrix *A, int *iwork);
void csr_to_csc_fill_values(const CSR_Matrix *A, CSC_Matrix *C, int *iwork);

/* Allocate CSR matrix for C = A @ B where A is CSR, B is CSC
* Precomputes sparsity pattern. No workspace required.
*/
CSR_Matrix *csr_csc_matmul_alloc(const CSR_Matrix *A, const CSC_Matrix *B);

/* Fill values of C = A @ B where A is CSR, B is CSC
* C must have sparsity pattern already computed
*/
void csr_csc_matmul_fill_values(const CSR_Matrix *A, const CSC_Matrix *B,
CSR_Matrix *C);

#endif /* CSC_MATRIX_H */
10 changes: 10 additions & 0 deletions include/utils/CSR_Matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,18 @@ void free_csr_matrix(CSR_Matrix *matrix);
/* Copy CSR matrix A to C */
void copy_csr_matrix(const CSR_Matrix *A, CSR_Matrix *C);

/* Build block-diagonal repeat A_blk = I_p kron A. Returns newly allocated CSR
* matrix of size (p*A->m) x (p*A->n) with nnz = p*A->nnz. */
CSR_Matrix *block_diag_repeat_csr(const CSR_Matrix *A, int p);

/* Build left-repeated Kronecker A_kron = A kron I_p. Returns newly allocated CSR
* matrix of size (A->m * p) x (A->n * p) with nnz = A->nnz * p. */
CSR_Matrix *kron_identity_csr(const CSR_Matrix *A, int p);

/* matvec y = Ax, where A indices minus col_offset gives x indices. Returns y as
* dense. */
void csr_matvec(const CSR_Matrix *A, const double *x, double *y, int col_offset);
void csr_matvec_wo_offset(const CSR_Matrix *A, const double *x, double *y);

/* C = z^T A is assumed to have one row. C must have column indices pre-computed
and transposed matrix AT must be provided. Fills in values of C only.
Expand Down Expand Up @@ -141,6 +150,7 @@ void insert_idx(int idx, int *arr, int len);

double csr_get_value(const CSR_Matrix *A, int row, int col);

/* iwork must be of size A->n*/
CSR_Matrix *transpose(const CSR_Matrix *A, int *iwork);
CSR_Matrix *AT_alloc(const CSR_Matrix *A, int *iwork);

Expand Down
26 changes: 17 additions & 9 deletions src/affine/linear_op.c
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "affine.h"
#include <assert.h>
#include <stdlib.h>

static void forward(expr *node, const double *u)
Expand All @@ -9,7 +10,7 @@ static void forward(expr *node, const double *u)
node->left->forward(node->left, u);

/* y = A * x */
csr_matvec(node->jacobian, x->value, node->value, x->var_id);
csr_matvec(((linear_op_expr *) node)->A_csr, x->value, node->value, x->var_id);
}

static bool is_affine(const expr *node)
Expand All @@ -20,21 +21,31 @@ static bool is_affine(const expr *node)
static void free_type_data(expr *node)
{
linear_op_expr *lin_node = (linear_op_expr *) node;
/* memory pointing to by A_csr will already be freed when the
jacobian is freed, so free_csr_matrix(lin_node->A_csr) should
be commented out */
// free_csr_matrix(lin_node->A_csr);
/* memory pointing to by A_csr will be freed when the jacobian is freed,
so if the jacobian is not null we must not free A_csr. */

if (!node->jacobian)
{
free_csr_matrix(lin_node->A_csr);
}

free_csc_matrix(lin_node->A_csc);
lin_node->A_csr = NULL;
lin_node->A_csc = NULL;
}

static void jacobian_init(expr *node)
{
node->jacobian = ((linear_op_expr *) node)->A_csr;
}

expr *new_linear(expr *u, const CSR_Matrix *A)
{
assert(u->d2 == 1);
/* Allocate the type-specific struct */
linear_op_expr *lin_node = (linear_op_expr *) calloc(1, sizeof(linear_op_expr));
expr *node = &lin_node->base;
init_expr(node, A->m, 1, u->n_vars, forward, NULL, NULL, is_affine,
init_expr(node, A->m, 1, u->n_vars, forward, jacobian_init, NULL, is_affine,
free_type_data);
node->left = u;
expr_retain(u);
Expand All @@ -44,8 +55,5 @@ expr *new_linear(expr *u, const CSR_Matrix *A)
copy_csr_matrix(A, lin_node->A_csr);
lin_node->A_csc = csr_to_csc(A);

/* what if we have A @ phi(x). Then I don't think this is correct. */
node->jacobian = lin_node->A_csr;

return node;
}
136 changes: 136 additions & 0 deletions src/bivariate/left_matmul.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#include "bivariate.h"
#include "subexpr.h"
#include <stdlib.h>

/* This file implement the atom 'left_matmul' corresponding to the operation y =
A @ f(x), where A is a given matrix and f(x) is an arbitrary expression.
Here, f(x) can be a vector-valued expression and a matrix-valued
expression. The dimensions are A - m x n, f(x) - n x p, y - m x p.
Note that here A does not have global column indices but it is a local matrix.
This is an important distinction compared to linear_op_expr.

* To compute the forward pass: vec(y) = A_kron @ vec(f(x)),
where A_kron = I_p kron A is a Kronecker product of size (m*p) x (n*p),
or more specificely, a block-diagonal matrix with p blocks of A along the
diagonal.

* To compute the Jacobian: J_y = A_kron @ J_f(x), where J_f(x) is the
Jacobian of f(x) of size (n*p) x n_vars.

* To compute the contribution to the Lagrange Hessian: we form
w = A_kron^T @ lambda and then evaluate the hessian of f(x).

Working in terms of A_kron unifies the implementation of f(x) being
vector-valued or matrix-valued.


*/

// todo: put this in common somewhere
#define MAX(a, b) ((a) > (b) ? (a) : (b))

static void forward(expr *node, const double *u)
{
expr *x = node->left;

/* child's forward pass */
node->left->forward(node->left, u);

/* y = A_kron @ vec(f(x)) */
csr_matvec_wo_offset(((left_matmul_expr *) node)->A, x->value, node->value);
}

static bool is_affine(const expr *node)
{
return node->left->is_affine(node->left);
}

static void free_type_data(expr *node)
{
left_matmul_expr *lin_node = (left_matmul_expr *) node;
free_csr_matrix(lin_node->A);
free_csr_matrix(lin_node->AT);
if (lin_node->CSC_work)
{
free_csc_matrix(lin_node->CSC_work);
}
lin_node->A = NULL;
lin_node->AT = NULL;
lin_node->CSC_work = NULL;
}

static void jacobian_init(expr *node)
{
expr *x = node->left;
left_matmul_expr *lin_node = (left_matmul_expr *) node;

/* initialize child's jacobian and precompute sparsity of its transpose */
x->jacobian_init(x);
lin_node->CSC_work = csr_to_csc_fill_sparsity(x->jacobian, node->iwork);

/* precompute sparsity of this node's jacobian */
node->jacobian = csr_csc_matmul_alloc(lin_node->A, lin_node->CSC_work);
}

static void eval_jacobian(expr *node)
{
expr *x = node->left;
left_matmul_expr *lin_node = (left_matmul_expr *) node;

/* evaluate child's jacobian and convert to CSC */
x->eval_jacobian(x);
csr_to_csc_fill_values(x->jacobian, lin_node->CSC_work, node->iwork);

/* compute this node's jacobian */
csr_csc_matmul_fill_values(lin_node->A, lin_node->CSC_work, node->jacobian);
}

static void wsum_hess_init(expr *node)
{
/* initialize child's hessian */
expr *x = node->left;
x->wsum_hess_init(x);

/* allocate this node's hessian with the same sparsity as child's */
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, x->wsum_hess->nnz);
memcpy(node->wsum_hess->p, x->wsum_hess->p, (node->n_vars + 1) * sizeof(int));
memcpy(node->wsum_hess->i, x->wsum_hess->i, x->wsum_hess->nnz * sizeof(int));

/* work for computing A^T w*/
int A_n = ((left_matmul_expr *) node)->A->n;
node->dwork = (double *) malloc(A_n * sizeof(double));
}

static void eval_wsum_hess(expr *node, const double *w)
{
/* compute A^T w*/
left_matmul_expr *lin_node = (left_matmul_expr *) node;
csr_matvec_wo_offset(lin_node->AT, w, node->dwork);

node->left->eval_wsum_hess(node->left, node->dwork);
memcpy(node->wsum_hess->x, node->left->wsum_hess->x,
node->wsum_hess->nnz * sizeof(double));
}

expr *new_left_matmul(expr *u, const CSR_Matrix *A)
{
/* Allocate the type-specific struct */
left_matmul_expr *lin_node =
(left_matmul_expr *) calloc(1, sizeof(left_matmul_expr));
expr *node = &lin_node->base;
init_expr(node, A->m, u->d2, u->n_vars, forward, jacobian_init, eval_jacobian,
is_affine, free_type_data);
node->left = u;
expr_retain(u);
node->wsum_hess_init = wsum_hess_init;
node->eval_wsum_hess = eval_wsum_hess;

/* Initialize type-specific fields */
lin_node->A = block_diag_repeat_csr(A, node->d2);
int alloc = MAX(lin_node->A->n, node->n_vars);
node->iwork = (int *) malloc(alloc * sizeof(int));
lin_node->AT = transpose(lin_node->A, node->iwork);
lin_node->CSC_work = NULL;

return node;
}
17 changes: 5 additions & 12 deletions src/bivariate/multiply.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,8 @@ static void forward(expr *node, const double *u)

static void jacobian_init(expr *node)
{
/* if a child is a variable we initialize its jacobian for a
short chain rule implementation */
if (node->left->var_id != NOT_A_VARIABLE)
{
node->left->jacobian_init(node->left);
}

if (node->right->var_id != NOT_A_VARIABLE)
{
node->right->jacobian_init(node->right);
}

node->left->jacobian_init(node->left);
node->right->jacobian_init(node->right);
node->dwork = (double *) malloc(2 * node->size * sizeof(double));
int nnz_max = node->left->jacobian->nnz + node->right->jacobian->nnz;
node->jacobian = new_csr_matrix(node->size, node->n_vars, nnz_max);
Expand Down Expand Up @@ -180,6 +170,9 @@ static void eval_wsum_hess(expr *node, const double *w)
/* Compute CT = C^T = A^T diag(w) B */
AT_fill_values(C, CT, node->iwork);

/* TODO: should fill sparsity before. Maybe we can map values from C to
* hessian directly instead?*/

/* Hessian = C + CT = B^T diag(w) A + A^T diag(w) B */
sum_csr_matrices(C, CT, node->wsum_hess);
}
Expand Down
2 changes: 1 addition & 1 deletion src/bivariate/quad_over_lin.c
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ static void jacobian_init(expr *node)
/* compute required allocation and allocate jacobian */
bool *col_nz = (bool *) calloc(
node->n_vars, sizeof(bool)); /* TODO: could use iwork here instead*/
int nonzero_cols = count_nonzero_cols(lin_x->base.jacobian, col_nz);
int nonzero_cols = count_nonzero_cols(lin_x->A_csr, col_nz);
node->jacobian = new_csr_matrix(1, node->n_vars, nonzero_cols + 1);

/* precompute column indices */
Expand Down
Loading
Loading