Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/affine.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ expr *new_variable(int d1, int d2, int var_id, int n_vars);
expr *new_index(expr *child, int d1, int d2, const int *indices, int n_idxs);
expr *new_reshape(expr *child, int d1, int d2);
expr *new_broadcast(expr *child, int target_d1, int target_d2);
expr *new_transpose(expr *child);

#endif /* AFFINE_H */
1 change: 0 additions & 1 deletion python/atoms/trace.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#ifndef ATOM_TRACE_H

#define ATOM_TRACE_H

#include "common.h"
Expand Down
32 changes: 32 additions & 0 deletions python/atoms/transpose.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#ifndef PYTHON_ATOMS_TRANSPOSE_H
#define PYTHON_ATOMS_TRANSPOSE_H

#include "common.h"

// Python binding for the transpose atom
static PyObject *py_make_transpose(PyObject *self, PyObject *args)
{
PyObject *child_capsule;
if (!PyArg_ParseTuple(args, "O", &child_capsule))
{
return NULL;
}
expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME);

if (!child)
{
PyErr_SetString(PyExc_ValueError, "invalid child capsule");
return NULL;
}

expr *node = new_transpose(child);
if (!node)
{
PyErr_SetString(PyExc_RuntimeError, "failed to create trace node");
return NULL;
}
expr_retain(node); /* Capsule owns a reference */
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
}

#endif // PYTHON_ATOMS_TRANSPOSE_H
5 changes: 3 additions & 2 deletions python/bindings.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "atoms/tan.h"
#include "atoms/tanh.h"
#include "atoms/trace.h"
#include "atoms/transpose.h"
#include "atoms/variable.h"
#include "atoms/xexp.h"

Expand Down Expand Up @@ -71,8 +72,8 @@ static PyMethodDef DNLPMethods[] = {
{"make_exp", py_make_exp, METH_VARARGS, "Create exp node"},
{"make_index", py_make_index, METH_VARARGS, "Create index node"},
{"make_add", py_make_add, METH_VARARGS, "Create add node"},
{"make_trace", py_make_trace, METH_VARARGS,
"Create trace node from an expr capsule (make_trace(child))"},
{"make_trace", py_make_trace, METH_VARARGS, "Create trace node"},
{"make_transpose", py_make_transpose, METH_VARARGS, "Create transpose node"},
{"make_hstack", py_make_hstack, METH_VARARGS,
"Create hstack node from list of expr capsules and n_vars (make_hstack([e1, "
"e2, ...], n_vars))"},
Expand Down
123 changes: 123 additions & 0 deletions src/affine/transpose.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#include "affine.h"
#include <assert.h>
#include <stdlib.h>
#include <string.h>

// Forward pass for transpose atom
static void forward(expr *node, const double *u)
{
/* forward pass for child */
node->left->forward(node->left, u);

/* local forward pass */
int d1 = node->d1;
int d2 = node->d2;
double *X = node->left->value;
double *XT = node->value;
for (int i = 0; i < d1; ++i)
{
for (int j = 0; j < d2; ++j)
{
XT[j * d1 + i] = X[i * d2 + j];
}
}
}

static void jacobian_init(expr *node)
{
expr *child = node->left;
child->jacobian_init(child);
CSR_Matrix *Jc = child->jacobian;
node->jacobian = new_csr_matrix(node->size, node->n_vars, Jc->nnz);

/* fill sparsity */
CSR_Matrix *J = node->jacobian;
int d1 = node->d1;
int d2 = node->d2;
int nnz = 0;
J->p[0] = 0;

/* 'k' is the old row that gets swapped to 'row'*/
int k, len;
for (int row = 0; row < J->m; ++row)
{
k = (row / d1) + (row % d1) * d2;
len = Jc->p[k + 1] - Jc->p[k];
memcpy(J->i + nnz, Jc->i + Jc->p[k], len * sizeof(int));
nnz += len;
J->p[row + 1] = nnz;
}
}

static void eval_jacobian(expr *node)
{
expr *child = node->left;
child->eval_jacobian(child);
CSR_Matrix *Jc = child->jacobian;
CSR_Matrix *J = node->jacobian;

int d1 = node->d1;
int d2 = node->d2;
int nnz = 0;
for (int row = 0; row < J->m; ++row)
{
int k = (row / d1) + (row % d1) * d2;
int len = Jc->p[k + 1] - Jc->p[k];
memcpy(J->x + nnz, Jc->x + Jc->p[k], len * sizeof(double));
nnz += len;
}
}

static void wsum_hess_init(expr *node)
{
/* initialize child */
expr *x = node->left;
x->wsum_hess_init(x);

/* same sparsity pattern as child */
CSR_Matrix *H = node->wsum_hess;
H = new_csr_matrix(x->wsum_hess->m, node->n_vars, x->wsum_hess->nnz);
memcpy(H->p, x->wsum_hess->p, (H->m + 1) * sizeof(int));
memcpy(H->i, x->wsum_hess->i, H->nnz * sizeof(int));
node->wsum_hess = H;

/* for computing Kw where K is the commutation matrix */
node->dwork = (double *) malloc(node->size * sizeof(double));
}
static void eval_wsum_hess(expr *node, const double *w)
{
int d2 = node->d2;
int d1 = node->d1;
// TODO: meaybe more efficient to do this with memcpy first

/* evaluate hessian of child at Kw */
for (int i = 0; i < d2; ++i)
{
for (int j = 0; j < d1; ++j)
{
node->dwork[j * d2 + i] = w[i * d1 + j];
}
}

node->left->eval_wsum_hess(node->left, node->dwork);

/* copy to this node's hessian */
memcpy(node->wsum_hess->x, node->left->wsum_hess->x,
node->wsum_hess->nnz * sizeof(double));
}

static bool is_affine(const expr *node)
{
return node->left->is_affine(node->left);
}

expr *new_transpose(expr *child)
{
expr *node = (expr *) calloc(1, sizeof(expr));
init_expr(node, child->d2, child->d1, child->n_vars, forward, jacobian_init,
eval_jacobian, is_affine, wsum_hess_init, eval_wsum_hess, NULL);
node->left = child;
expr_retain(child);

return node;
}
4 changes: 4 additions & 0 deletions tests/all_tests.c
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "jacobian_tests/test_right_matmul.h"
#include "jacobian_tests/test_sum.h"
#include "jacobian_tests/test_trace.h"
#include "jacobian_tests/test_transpose.h"
#include "problem/test_problem.h"
#include "utils/test_csc_matrix.h"
#include "utils/test_csr_matrix.h"
Expand Down Expand Up @@ -70,6 +71,7 @@
#include "wsum_hess/test_right_matmul.h"
#include "wsum_hess/test_sum.h"
#include "wsum_hess/test_trace.h"
#include "wsum_hess/test_transpose.h"

int main(void)
{
Expand Down Expand Up @@ -161,6 +163,7 @@ int main(void)
mu_run_test(test_jacobian_right_matmul_log, tests_run);
mu_run_test(test_jacobian_right_matmul_log_vector, tests_run);
mu_run_test(test_jacobian_matmul, tests_run);
mu_run_test(test_jacobian_transpose, tests_run);

printf("\n--- Weighted Sum of Hessian Tests ---\n");
mu_run_test(test_wsum_hess_log, tests_run);
Expand Down Expand Up @@ -225,6 +228,7 @@ int main(void)
mu_run_test(test_wsum_hess_trace_variable, tests_run);
mu_run_test(test_wsum_hess_trace_log_variable, tests_run);
mu_run_test(test_wsum_hess_trace_composite, tests_run);
mu_run_test(test_wsum_hess_transpose, tests_run);

printf("\n--- Utility Tests ---\n");
mu_run_test(test_diag_csr_mult, tests_run);
Expand Down
47 changes: 47 additions & 0 deletions tests/jacobian_tests/test_transpose.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@

#ifndef TEST_TRANSPOSE_H
#define TEST_TRANSPOSE_H

#include "affine.h"
#include "minunit.h"
#include "test_helpers.h"
#include <math.h>
#include <stdio.h>

const char *test_jacobian_transpose()
{
// A = [1 2; 3 4]
CSR_Matrix *A = new_csr_matrix(2, 2, 4);
int A_p[3] = {0, 2, 4};
int A_i[4] = {0, 1, 0, 1};
double A_x[4] = {1, 2, 3, 4};
memcpy(A->p, A_p, 3 * sizeof(int));
memcpy(A->i, A_i, 4 * sizeof(int));
memcpy(A->x, A_x, 4 * sizeof(double));

// X = [1 2; 3 4] (columnwise: x = [1 3 2 4])
expr *X = new_variable(2, 2, 0, 4);
expr *AX = new_left_matmul(X, A);
expr *transpose_AX = new_transpose(AX);
double u[4] = {1, 3, 2, 4};
transpose_AX->forward(transpose_AX, u);
transpose_AX->jacobian_init(transpose_AX);
transpose_AX->eval_jacobian(transpose_AX);

// Jacobian of transpose_AX
double expected_x[8] = {1, 2, 1, 2, 3, 4, 3, 4};
int expected_p[5] = {0, 2, 4, 6, 8};
int expected_i[8] = {0, 1, 2, 3, 0, 1, 2, 3};

mu_assert("jacobian values fail",
cmp_double_array(transpose_AX->jacobian->x, expected_x, 8));
mu_assert("jacobian row ptr fail",
cmp_int_array(transpose_AX->jacobian->p, expected_p, 5));
mu_assert("jacobian col idx fail",
cmp_int_array(transpose_AX->jacobian->i, expected_i, 8));
free_expr(AX);
free_csr_matrix(A);
return 0;
}

#endif // TEST_TRANSPOSE_H
40 changes: 40 additions & 0 deletions tests/wsum_hess/test_transpose.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#ifndef TEST_WSUM_HESS_TRANSPOSE_H
#define TEST_WSUM_HESS_TRANSPOSE_H

#include "affine.h"
#include "minunit.h"
#include "test_helpers.h"
#include <math.h>
#include <stdio.h>

const char *test_wsum_hess_transpose()
{

expr *X = new_variable(2, 2, 0, 8);
expr *Y = new_variable(2, 2, 4, 8);

expr *XY = new_matmul(X, Y);
expr *XYT = new_transpose(XY);

double u[8] = {1, 3, 2, 4, 5, 7, 6, 8};
XYT->forward(XYT, u);
XYT->wsum_hess_init(XYT);
double w[4] = {1, 2, 3, 4};
XYT->eval_wsum_hess(XYT, w);

double expected_x[16] = {1, 2, 3, 4, 1, 2, 3, 4, 1, 3, 1, 3, 2, 4, 2, 4};
int expected_p[9] = {0, 2, 4, 6, 8, 10, 12, 14, 16};
int expected_i[16] = {4, 6, 4, 6, 5, 7, 5, 7, 0, 1, 2, 3, 0, 1, 2, 3};

mu_assert("hess values fail",
cmp_double_array(XYT->wsum_hess->x, expected_x, 8));
mu_assert("jacobian row ptr fail",
cmp_int_array(XYT->wsum_hess->p, expected_p, 5));
mu_assert("jacobian col idx fail",
cmp_int_array(XYT->wsum_hess->i, expected_i, 8));
free_expr(XYT);

return 0;
}

#endif // TEST_WSUM_HESS_TRANSPOSE_H
Loading