Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/subexpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,12 @@ typedef struct hstack_expr
CSR_Matrix *CSR_work; /* for summing Hessians of children */
} hstack_expr;

/* Elementwise multiplication */
typedef struct elementwise_mult_expr
{
expr base;
CSR_Matrix *CSR_work1;
CSR_Matrix *CSR_work2;
} elementwise_mult_expr;

#endif /* SUBEXPR_H */
13 changes: 9 additions & 4 deletions include/utils/CSC_Matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,15 @@ CSC_Matrix *csr_to_csc(const CSR_Matrix *A);
/* Allocate sparsity pattern for C = A^T D A for diagonal D */
CSR_Matrix *ATA_alloc(const CSC_Matrix *A);

/* Compute values for C = A^T D A
* C must have precomputed sparsity pattern
*/
void ATDA_values(const CSC_Matrix *A, const double *d, CSR_Matrix *C);
/* Allocate sparsity pattern for C = B^T D A for diagonal D */
CSR_Matrix *BTA_alloc(const CSC_Matrix *A, const CSC_Matrix *B);

/* Compute values for C = A^T D A. C must have precomputed sparsity pattern */
void ATDA_fill_values(const CSC_Matrix *A, const double *d, CSR_Matrix *C);

/* Compute values for C = B^T D A. C must have precomputed sparsity pattern */
void BTDA_fill_values(const CSC_Matrix *A, const CSC_Matrix *B, const double *d,
CSR_Matrix *C);

/* C = z^T A where A is in CSC format and C is assumed to have one row.
* C must have column indices pre-computed. Fills in values of C only.
Expand Down
4 changes: 4 additions & 0 deletions include/utils/CSR_Matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ void insert_idx(int idx, int *arr, int len);
double csr_get_value(const CSR_Matrix *A, int row, int col);

CSR_Matrix *transpose(const CSR_Matrix *A, int *iwork);
CSR_Matrix *AT_alloc(const CSR_Matrix *A, int *iwork);

/* Fill values of A^T given sparsity pattern is already computed */
void AT_fill_values(const CSR_Matrix *A, CSR_Matrix *AT, int *iwork);

/* Expand symmetric CSR matrix A to full matrix C. A is assumed to store
only upper triangle. C must be pre-allocated with sufficient nnz */
Expand Down
3 changes: 2 additions & 1 deletion python/bindings.c
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ static PyMethodDef DNLPMethods[] = {
{"make_add", py_make_add, METH_VARARGS, "Create add node"},
{"make_sum", py_make_sum, METH_VARARGS, "Create sum node"},
{"forward", py_forward, METH_VARARGS, "Run forward pass and return values"},
{"jacobian", py_jacobian, METH_VARARGS, "Compute jacobian and return CSR components"},
{"jacobian", py_jacobian, METH_VARARGS,
"Compute jacobian and return CSR components"},
{NULL, NULL, 0, NULL}};

static struct PyModuleDef dnlp_module = {PyModuleDef_HEAD_INIT, "DNLP_diff_engine",
Expand Down
148 changes: 144 additions & 4 deletions src/bivariate/multiply.c
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#include "bivariate.h"
#include "subexpr.h"
#include <assert.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

// ------------------------------------------------------------------------------
// Implementation of elementwise multiplication when both arguments are vectors.
Expand Down Expand Up @@ -52,15 +56,151 @@ static void eval_jacobian(expr *node)
x->value);
}

static void wsum_hess_init(expr *node)
{
expr *x = node->left;
expr *y = node->right;

/* for correctness x and y must be (1) different variables,
or (2) both must be linear operators */
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we not have (Ax) * y? as long as they have the same shape? So not necessarily both linear operators?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now we don't support this, because the second argument must be a "global" linear operator for the product rule to be correct. But we can form the B matrix representing y internally if only the left node is a linear operator. I'll add this as an issue so we do this later.

#ifndef DEBUG
if (x->var_id != NOT_A_VARIABLE && y->var_id != NOT_A_VARIABLE &&
x->var_id == y->var_id)
{
fprintf(stderr, "Error: elementwise multiplication of a variable by itself "
"not supported.\n");
exit(1);
}
else if ((x->var_id != NOT_A_VARIABLE && y->var_id == NOT_A_VARIABLE) ||
(x->var_id == NOT_A_VARIABLE && y->var_id != NOT_A_VARIABLE))
{
fprintf(stderr, "Error: elementwise multiplication of a variable by a "
"non-variable is not supported. (Both must be inserted "
"as linear operators)\n");
exit(1);
}
#endif

/* both x and y are variables*/
if (x->var_id != NOT_A_VARIABLE)
{
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, 2 * node->size);

int i, var1_id, var2_id;

if (x->var_id < y->var_id)
{
var1_id = x->var_id;
var2_id = y->var_id;
}
else
{
var1_id = y->var_id;
var2_id = x->var_id;
}

/* var1 rows of Hessian */
for (i = 0; i < node->size; i++)
{
node->wsum_hess->p[var1_id + i] = i;
node->wsum_hess->i[i] = var2_id + i;
}

int nnz = node->size;

/* rows between var1 and var2 */
for (i = var1_id + node->size; i < var2_id; i++)
{
node->wsum_hess->p[i] = nnz;
}

/* var2 rows of Hessian */
for (i = 0; i < node->size; i++)
{
node->wsum_hess->p[var2_id + i] = nnz + i;
node->wsum_hess->i[nnz + i] = var1_id + i;
}

/* remaining rows */
nnz += node->size;
for (i = var2_id + node->size; i <= node->n_vars; i++)
{
node->wsum_hess->p[i] = nnz;
}
}
else
{
/* both are linear operators */
CSC_Matrix *A = ((linear_op_expr *) x)->A_csc;
CSC_Matrix *B = ((linear_op_expr *) y)->A_csc;

/* Allocate workspace for Hessian computation */
elementwise_mult_expr *mul_node = (elementwise_mult_expr *) node;
CSR_Matrix *C; /* C = B^T diag(w) A */
C = BTA_alloc(A, B);
node->iwork = (int *) malloc(C->m * sizeof(int));

CSR_Matrix *CT = AT_alloc(C, node->iwork);
mul_node->CSR_work1 = C;
mul_node->CSR_work2 = CT;

/* Hessian is H = C + C^T where both are B->n x A->n, and can't be more than
* 2 * nnz(C) */
assert(C->m == node->n_vars && C->n == node->n_vars);
node->wsum_hess = new_csr_matrix(C->m, C->n, 2 * C->nnz);
}
}

static void eval_wsum_hess(expr *node, const double *w)
{
expr *x = node->left;
expr *y = node->right;

/* both x and y are variables*/
if (x->var_id != NOT_A_VARIABLE)
{
memcpy(node->wsum_hess->x, w, node->size * sizeof(double));
memcpy(node->wsum_hess->x + node->size, w, node->size * sizeof(double));
}
else
{
/* both are linear operators */
CSC_Matrix *A = ((linear_op_expr *) x)->A_csc;
CSC_Matrix *B = ((linear_op_expr *) y)->A_csc;
CSR_Matrix *C = ((elementwise_mult_expr *) node)->CSR_work1;
CSR_Matrix *CT = ((elementwise_mult_expr *) node)->CSR_work2;

/* Compute C = B^T diag(w) A */
BTDA_fill_values(A, B, w, C);

/* Compute CT = C^T = A^T diag(w) B */
AT_fill_values(C, CT, node->iwork);

/* Hessian = C + CT = B^T diag(w) A + A^T diag(w) B */
sum_csr_matrices(C, CT, node->wsum_hess);
}
}

static void free_type_data(expr *node)
{
free_csr_matrix(((elementwise_mult_expr *) node)->CSR_work1);
free_csr_matrix(((elementwise_mult_expr *) node)->CSR_work2);
}

expr *new_elementwise_mult(expr *left, expr *right)
{
expr *node = new_expr(left->d1, 1, left->n_vars);
elementwise_mult_expr *mul_node =
(elementwise_mult_expr *) calloc(1, sizeof(elementwise_mult_expr));
expr *node = &mul_node->base;

init_expr(node, left->d1, left->d2, left->n_vars, forward, jacobian_init,
eval_jacobian, NULL, free_type_data);
node->wsum_hess_init = wsum_hess_init;
node->eval_wsum_hess = eval_wsum_hess;
node->left = left;
node->right = right;
expr_retain(left);
expr_retain(right);
node->forward = forward;
node->jacobian_init = jacobian_init;
node->eval_jacobian = eval_jacobian;

return node;
}
4 changes: 4 additions & 0 deletions src/bivariate/rel_entr.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ static void forward_vector_args(expr *node, const double *u)
}
}

/* TODO: this probably doesn't work for matrices because we use node->d1
instead of node->size */
static void jacobian_init_vectors_args(expr *node)
{
node->jacobian = new_csr_matrix(node->d1, node->n_vars, 2 * node->d1);
Expand Down Expand Up @@ -80,6 +82,8 @@ static void eval_jacobian_vector_args(expr *node)
}
}

/* TODO: this probably doesn't work for matrices because we use node->d1
instead of node->size */
static void wsum_hess_init_vector_args(expr *node)
{
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, 4 * node->d1);
Expand Down
2 changes: 1 addition & 1 deletion src/elementwise_univariate/common.c
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ void eval_wsum_hess_elementwise(expr *node, const double *w)
/* Child will be a linear operator */
linear_op_expr *lin_child = (linear_op_expr *) child;
node->local_wsum_hess(node, node->dwork, w);
ATDA_values(lin_child->A_csc, node->dwork, node->wsum_hess);
ATDA_fill_values(lin_child->A_csc, node->dwork, node->wsum_hess);
}
}

Expand Down
82 changes: 81 additions & 1 deletion src/utils/CSC_Matrix.c
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ static inline double sparse_wdot(const double *a_x, const int *a_i, int a_nnz,
return sum;
}

void ATDA_values(const CSC_Matrix *A, const double *d, CSR_Matrix *C)
void ATDA_fill_values(const CSC_Matrix *A, const double *d, CSR_Matrix *C)
{
int i, j, ii, jj;
for (i = 0; i < C->m; i++)
Expand Down Expand Up @@ -196,6 +196,64 @@ CSC_Matrix *csr_to_csc(const CSR_Matrix *A)
free(count);
return C;
}
CSR_Matrix *BTA_alloc(const CSC_Matrix *A, const CSC_Matrix *B)
{
/* A is m x n, B is m x p, C = B^T A is p x n */
int n = A->n;
int p = B->n;
int m = A->m;
int nnz = 0;
int i, j, ii, jj;

/* row ptr and column idxs for C = B^T A */
int *Cp = (int *) malloc((p + 1) * sizeof(int));
iVec *Ci = iVec_new(n);
Cp[0] = 0;

/* compute sparsity pattern */
for (i = 0; i < p; i++)
{
/* check if Cij != 0 for each column j of A */
for (j = 0; j < n; j++)
{
ii = B->p[i];
jj = A->p[j];

/* check if row i of B^T (column i of B) has common row with column j of
* A */
while (ii < B->p[i + 1] && jj < A->p[j + 1])
{
if (B->i[ii] == A->i[jj])
{
nnz++;
iVec_append(Ci, j);
break;
}
else if (B->i[ii] < A->i[jj])
{
ii++;
}
else
{
jj++;
}
}
}
Cp[i + 1] = Ci->len;
}

/* Allocate C */
CSR_Matrix *C = new_csr_matrix(p, n, nnz);
memcpy(C->p, Cp, (p + 1) * sizeof(int));
memcpy(C->i, Ci->data, nnz * sizeof(int));

/* free workspace */
free(Cp);
iVec_free(Ci);

return C;
}

void csc_matvec_fill_values(const CSC_Matrix *A, const double *z, CSR_Matrix *C)
{
/* Compute C = z^T * A where A is in CSC format
Expand Down Expand Up @@ -224,3 +282,25 @@ void csc_matvec_fill_values(const CSC_Matrix *A, const double *z, CSR_Matrix *C)
}
}
}

void BTDA_fill_values(const CSC_Matrix *A, const CSC_Matrix *B, const double *d,
CSR_Matrix *C)
{
int i, j, jj;
for (i = 0; i < C->m; i++)
{
for (jj = C->p[i]; jj < C->p[i + 1]; jj++)
{
j = C->i[jj];

int nnz_bi = B->p[i + 1] - B->p[i];
int nnz_aj = A->p[j + 1] - A->p[j];

/* compute Cij = weighted inner product of col i of B and col j of A */
double sum = sparse_wdot(B->x + B->p[i], B->i + B->p[i], nnz_bi,
A->x + A->p[j], A->i + A->p[j], nnz_aj, d);

C->x[jj] = sum;
}
}
}
Loading