diff --git a/include/affine.h b/include/affine.h index 88b1f7b..c7a75cf 100644 --- a/include/affine.h +++ b/include/affine.h @@ -21,6 +21,7 @@ expr *new_variable(int d1, int d2, int var_id, int n_vars); expr *new_index(expr *child, int d1, int d2, const int *indices, int n_idxs); expr *new_reshape(expr *child, int d1, int d2); expr *new_broadcast(expr *child, int target_d1, int target_d2); +expr *new_diag_vec(expr *child); expr *new_transpose(expr *child); #endif /* AFFINE_H */ diff --git a/python/atoms/diag_vec.h b/python/atoms/diag_vec.h new file mode 100644 index 0000000..b1056d6 --- /dev/null +++ b/python/atoms/diag_vec.h @@ -0,0 +1,32 @@ +#ifndef ATOM_DIAG_VEC_H +#define ATOM_DIAG_VEC_H + +#include "common.h" + +static PyObject *py_make_diag_vec(PyObject *self, PyObject *args) +{ + PyObject *child_capsule; + + if (!PyArg_ParseTuple(args, "O", &child_capsule)) + { + return NULL; + } + + expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); + if (!child) + { + return NULL; + } + + expr *node = new_diag_vec(child); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create diag_vec node"); + return NULL; + } + + expr_retain(node); + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_DIAG_VEC_H */ diff --git a/python/bindings.c b/python/bindings.c index 71d702a..27e916a 100644 --- a/python/bindings.c +++ b/python/bindings.c @@ -11,6 +11,7 @@ #include "atoms/const_vector_mult.h" #include "atoms/constant.h" #include "atoms/cos.h" +#include "atoms/diag_vec.h" #include "atoms/entr.h" #include "atoms/exp.h" #include "atoms/getters.h" @@ -96,6 +97,7 @@ static PyMethodDef DNLPMethods[] = { "Create prod_axis_one node"}, {"make_sin", py_make_sin, METH_VARARGS, "Create sin node"}, {"make_cos", py_make_cos, METH_VARARGS, "Create cos node"}, + {"make_diag_vec", py_make_diag_vec, METH_VARARGS, "Create diag_vec node"}, {"make_tan", py_make_tan, METH_VARARGS, "Create tan node"}, {"make_sinh", py_make_sinh, METH_VARARGS, "Create sinh node"}, {"make_tanh", py_make_tanh, METH_VARARGS, "Create tanh node"}, diff --git a/src/affine/diag_vec.c b/src/affine/diag_vec.c new file mode 100644 index 0000000..5d2a303 --- /dev/null +++ b/src/affine/diag_vec.c @@ -0,0 +1,130 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "affine.h" +#include +#include +#include + +/* diag_vec: converts a vector of size n into an n×n diagonal matrix. + * In Fortran (column-major) order, element i of the input maps to + * position i*(n+1) in the flattened output (the diagonal positions). */ + +static void forward(expr *node, const double *u) +{ + expr *x = node->left; + int n = x->size; + + /* child's forward pass */ + x->forward(x, u); + + /* zero-initialize output */ + memset(node->value, 0, node->size * sizeof(double)); + + /* place input elements on the diagonal */ + for (int i = 0; i < n; i++) + { + node->value[i * (n + 1)] = x->value[i]; + } +} + +static void jacobian_init(expr *node) +{ + expr *x = node->left; + int n = x->size; + x->jacobian_init(x); + + CSR_Matrix *Jx = x->jacobian; + CSR_Matrix *J = new_csr_matrix(node->size, node->n_vars, Jx->nnz); + + /* Output has n² rows but only n diagonal positions are non-empty. + * Diagonal position i is at row i*(n+1) in Fortran order. */ + int nnz = 0; + int next_diag = 0; + for (int row = 0; row < node->size; row++) + { + J->p[row] = nnz; + if (row == next_diag) + { + int child_row = row / (n + 1); + int len = Jx->p[child_row + 1] - Jx->p[child_row]; + memcpy(J->i + nnz, Jx->i + Jx->p[child_row], len * sizeof(int)); + nnz += len; + next_diag += n + 1; + } + } + J->p[node->size] = nnz; + + node->jacobian = J; +} + +static void eval_jacobian(expr *node) +{ + expr *x = node->left; + int n = x->size; + x->eval_jacobian(x); + + CSR_Matrix *J = node->jacobian; + CSR_Matrix *Jx = x->jacobian; + + /* Copy values from child row i to output diagonal row i*(n+1) */ + for (int i = 0; i < n; i++) + { + int out_row = i * (n + 1); + int len = J->p[out_row + 1] - J->p[out_row]; + memcpy(J->x + J->p[out_row], Jx->x + Jx->p[i], len * sizeof(double)); + } +} + +static void wsum_hess_init(expr *node) +{ + expr *x = node->left; + + /* initialize child's wsum_hess */ + x->wsum_hess_init(x); + + /* workspace for extracting diagonal weights */ + node->dwork = (double *) calloc(x->size, sizeof(double)); + + /* Copy child's Hessian structure (diag_vec is linear, so its own Hessian is zero) */ + CSR_Matrix *Hx = x->wsum_hess; + node->wsum_hess = new_csr_matrix(Hx->m, Hx->n, Hx->nnz); + memcpy(node->wsum_hess->p, Hx->p, (Hx->m + 1) * sizeof(int)); + memcpy(node->wsum_hess->i, Hx->i, Hx->nnz * sizeof(int)); +} + +static void eval_wsum_hess(expr *node, const double *w) +{ + expr *x = node->left; + int n = x->size; + + /* Extract weights from diagonal positions of w (which has n² elements) */ + for (int i = 0; i < n; i++) + { + node->dwork[i] = w[i * (n + 1)]; + } + + /* Evaluate child's Hessian with extracted weights */ + x->eval_wsum_hess(x, node->dwork); + memcpy(node->wsum_hess->x, x->wsum_hess->x, x->wsum_hess->nnz * sizeof(double)); +} + +static bool is_affine(const expr *node) +{ + return node->left->is_affine(node->left); +} + +expr *new_diag_vec(expr *child) +{ + /* child must be a vector: either column (n, 1) or row (1, n) */ + assert(child->d1 == 1 || child->d2 == 1); + + /* n is the number of elements (works for both row and column vectors) */ + int n = child->size; + expr *node = (expr *) calloc(1, sizeof(expr)); + init_expr(node, n, n, child->n_vars, forward, jacobian_init, eval_jacobian, + is_affine, wsum_hess_init, eval_wsum_hess, NULL); + node->left = child; + expr_retain(child); + + return node; +} diff --git a/src/dnlp_diff_engine/__init__.py b/src/dnlp_diff_engine/__init__.py index b1ab670..fae5fff 100644 --- a/src/dnlp_diff_engine/__init__.py +++ b/src/dnlp_diff_engine/__init__.py @@ -20,6 +20,7 @@ "make_promote", "make_index", "make_reshape", + "make_diag_vec", "make_log", "make_exp", "make_power",