diff --git a/CMakeLists.txt b/CMakeLists.txt index d604fe0..f00ac25 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,5 @@ cmake_minimum_required(VERSION 3.15) project(DNLP_Diff_Engine C) - set(CMAKE_C_STANDARD 99) #Set default build type to Release if not specified @@ -32,6 +31,7 @@ include_directories(${PROJECT_SOURCE_DIR}/include) # Source files - automatically gather all .c files from src/ file(GLOB_RECURSE SOURCES "src/*.c") + # Create core library add_library(dnlp_diff ${SOURCES}) target_link_libraries(dnlp_diff m) diff --git a/include/bivariate.h b/include/bivariate.h index 582a4ac..8a28e8e 100644 --- a/include/bivariate.h +++ b/include/bivariate.h @@ -10,6 +10,9 @@ expr *new_quad_over_lin(expr *left, expr *right); expr *new_rel_entr_first_arg_scalar(expr *left, expr *right); expr *new_rel_entr_second_arg_scalar(expr *left, expr *right); +/* Matrix multiplication: Z = X @ Y */ +expr *new_matmul(expr *x, expr *y); + /* Left matrix multiplication: A @ f(x) where A is a constant matrix */ expr *new_left_matmul(expr *u, const CSR_Matrix *A); diff --git a/include/utils/mini_numpy.h b/include/utils/mini_numpy.h index 1710978..b29c673 100644 --- a/include/utils/mini_numpy.h +++ b/include/utils/mini_numpy.h @@ -20,4 +20,7 @@ void tile_int(int *result, const int *a, int len, int tiles); */ void scaled_ones(double *result, int size, double value); +/* Naive implementation of Z = X @ Y, X is m x k, Y is k x n, Z is m x n */ +void mat_mat_mult(const double *X, const double *Y, double *Z, int m, int k, int n); + #endif /* MINI_NUMPY_H */ diff --git a/python/atoms/matmul.h b/python/atoms/matmul.h new file mode 100644 index 0000000..9bfe859 --- /dev/null +++ b/python/atoms/matmul.h @@ -0,0 +1,33 @@ +#ifndef ATOM_MATMUL_H +#define ATOM_MATMUL_H + +#include "bivariate.h" +#include "common.h" + +/* Matrix multiplication: Z = X @ Y */ +static PyObject *py_make_matmul(PyObject *self, PyObject *args) +{ + PyObject *left_capsule, *right_capsule; + if (!PyArg_ParseTuple(args, "OO", &left_capsule, &right_capsule)) + { + return NULL; + } + expr *left = (expr *) PyCapsule_GetPointer(left_capsule, EXPR_CAPSULE_NAME); + expr *right = (expr *) PyCapsule_GetPointer(right_capsule, EXPR_CAPSULE_NAME); + if (!left || !right) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + expr *node = new_matmul(left, right); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create matmul node"); + return NULL; + } + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_MATMUL_H */ diff --git a/python/bindings.c b/python/bindings.c index 0cd703e..3f46eae 100644 --- a/python/bindings.c +++ b/python/bindings.c @@ -19,6 +19,7 @@ #include "atoms/linear.h" #include "atoms/log.h" #include "atoms/logistic.h" +#include "atoms/matmul.h" #include "atoms/multiply.h" #include "atoms/neg.h" #include "atoms/power.h" @@ -73,6 +74,8 @@ static PyMethodDef DNLPMethods[] = { {"make_promote", py_make_promote, METH_VARARGS, "Create promote node"}, {"make_multiply", py_make_multiply, METH_VARARGS, "Create elementwise multiply node"}, + {"make_matmul", py_make_matmul, METH_VARARGS, + "Create matrix multiplication node (Z = X @ Y)"}, {"make_const_scalar_mult", py_make_const_scalar_mult, METH_VARARGS, "Create constant scalar multiplication node (a * f(x))"}, {"make_const_vector_mult", py_make_const_vector_mult, METH_VARARGS, diff --git a/src/bivariate/matmul.c b/src/bivariate/matmul.c new file mode 100644 index 0000000..617e539 --- /dev/null +++ b/src/bivariate/matmul.c @@ -0,0 +1,326 @@ +#include "bivariate.h" +#include "subexpr.h" +#include "utils/mini_numpy.h" +#include +#include +#include +#include + +// ------------------------------------------------------------------------------ +// Implementation of matrix multiplication: Z = X @ Y +// where X is m x k and Y is k x n, producing Z which is m x n +// All matrices are stored in column-major order +// ------------------------------------------------------------------------------ + +static void forward(expr *node, const double *u) +{ + expr *x = node->left; + expr *y = node->right; + + /* children's forward passes */ + x->forward(x, u); + y->forward(y, u); + + /* local forward pass */ + mat_mat_mult(x->value, y->value, node->value, x->d1, x->d2, y->d2); +} + +static bool is_affine(const expr *node) +{ + (void) node; + return false; +} + +static void jacobian_init(expr *node) +{ + expr *x = node->left; + expr *y = node->right; + + /* dimensions: X is m x k, Y is k x n, Z is m x n */ + int m = x->d1; + int k = x->d2; + int n = y->d2; + int nnz = m * n * 2 * k; + node->jacobian = new_csr_matrix(node->size, node->n_vars, nnz); + + /* fill sparsity pattern */ + int nnz_idx = 0; + for (int i = 0; i < node->size; i++) + { + /* Convert flat index to (row, col) in Z */ + int row = i % m; + int col = i / m; + + node->jacobian->p[i] = nnz_idx; + + /* X has lower var_id */ + if (x->var_id < y->var_id) + { + /* sparsity pattern of kron(YT, I) for this row */ + for (int j = 0; j < k; j++) + { + node->jacobian->i[nnz_idx++] = x->var_id + row + j * m; + } + + /* sparsity pattern of kron(I, X) for this row */ + for (int j = 0; j < k; j++) + { + node->jacobian->i[nnz_idx++] = y->var_id + col * k + j; + } + } + else /* Y has lower var_id */ + { + /* sparsity pattern of kron(I, X) for this row */ + for (int j = 0; j < k; j++) + { + node->jacobian->i[nnz_idx++] = y->var_id + col * k + j; + } + + /* sparsity pattern of kron(YT, I) for this row */ + for (int j = 0; j < k; j++) + { + node->jacobian->i[nnz_idx++] = x->var_id + row + j * m; + } + } + } + node->jacobian->p[node->size] = nnz_idx; + assert(nnz_idx == nnz); +} + +static void eval_jacobian(expr *node) +{ + expr *x = node->left; + expr *y = node->right; + + /* dimensions: X is m x k, Y is k x n, Z is m x n */ + int m = x->d1; + int k = x->d2; + double *Jx = node->jacobian->x; + + /* fill values row-by-row */ + for (int i = 0; i < node->size; i++) + { + int row = i % m; /* row in Z */ + int col = i / m; /* col in Z */ + int pos = node->jacobian->p[i]; + + if (x->var_id < y->var_id) + { + /* contribution to this row from YT */ + memcpy(Jx + pos, y->value + col * k, k * sizeof(double)); + + /* contribution to this row from X */ + for (int j = 0; j < k; j++) + { + Jx[pos + k + j] = x->value[row + j * m]; + } + } + else + { + /* contribution to this row from X */ + for (int j = 0; j < k; j++) + { + Jx[pos + j] = x->value[row + j * m]; + } + + /* contribution to this row from YT */ + memcpy(Jx + pos + k, y->value + col * k, k * sizeof(double)); + } + } +} + +static void wsum_hess_init(expr *node) +{ + expr *x = node->left; + expr *y = node->right; + + /* dimensions: X is m x k, Y is k x n, Z is m x n */ + int m = x->d1; + int k = x->d2; + int n = y->d2; + int total_nnz = 2 * m * k * n; + node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, total_nnz); + int nnz = 0; + int *Hi = node->wsum_hess->i; + int *Hp = node->wsum_hess->p; + int start, i; + + if (x->var_id < y->var_id) + { + /* fill rows corresponding to x */ + for (i = 0; i < x->size; i++) + { + Hp[x->var_id + i] = nnz; + start = y->var_id + (i / m); + for (int col = 0; col < n; col++) + { + Hi[nnz++] = start + col * k; + } + } + + /* fill rows between x and y */ + for (i = x->var_id + x->size; i < y->var_id; i++) + { + Hp[i] = nnz; + } + + /* fill rows corresponding to y */ + for (i = 0; i < y->size; i++) + { + Hp[y->var_id + i] = nnz; + start = x->var_id + (i % k) * m; + for (int row = 0; row < m; row++) + { + Hi[nnz++] = start + row; + } + } + + /* fill rows after y */ + for (i = y->var_id + y->size; i <= node->n_vars; i++) + { + Hp[i] = nnz; + } + } + else + { + /* Y has lower var_id than X */ + /* fill rows corresponding to y */ + for (i = 0; i < y->size; i++) + { + Hp[y->var_id + i] = nnz; + start = x->var_id + (i % k) * m; + for (int row = 0; row < m; row++) + { + Hi[nnz++] = start + row; + } + } + + /* fill rows between y and x */ + for (i = y->var_id + y->size; i < x->var_id; i++) + { + Hp[i] = nnz; + } + + /* fill rows corresponding to x */ + for (i = 0; i < x->size; i++) + { + Hp[x->var_id + i] = nnz; + start = y->var_id + (i / m); + for (int col = 0; col < n; col++) + { + Hi[nnz++] = start + col * k; + } + } + + /* fill rows after x */ + for (i = x->var_id + x->size; i <= node->n_vars; i++) + { + Hp[i] = nnz; + } + } + + Hp[node->n_vars] = nnz; + assert(nnz == total_nnz); +} + +static void eval_wsum_hess(expr *node, const double *w) +{ + expr *x = node->left; + expr *y = node->right; + + int m = x->d1; + int k = x->d2; + int n = y->d2; + int offset = 0; + + double *Hx = node->wsum_hess->x; + const double *w_temp; + + if (x->var_id < y->var_id) + { + /* rows corresponding to x */ + for (int k_idx = 0; k_idx < k; k_idx++) + { + for (int row = 0; row < m; row++) + { + for (int col = 0; col < n; col++) + { + Hx[offset++] = w[row + col * m]; + } + } + } + + /* rows corresponding to y */ + for (int col = 0; col < n; col++) + { + w_temp = w + col * m; + for (int k_idx = 0; k_idx < k; k_idx++) + { + memcpy(Hx + offset, w_temp, m * sizeof(double)); + offset += m; + } + } + } + else + { + /* rows corresponding to y */ + for (int col = 0; col < n; col++) + { + w_temp = w + col * m; + for (int k_idx = 0; k_idx < k; k_idx++) + { + memcpy(Hx + offset, w_temp, m * sizeof(double)); + offset += m; + } + } + + /* rows corresponding to x */ + for (int k_idx = 0; k_idx < k; k_idx++) + { + for (int row = 0; row < m; row++) + { + for (int col = 0; col < n; col++) + { + Hx[offset++] = w[row + col * m]; + } + } + } + } +} + +expr *new_matmul(expr *x, expr *y) +{ + /* Verify dimensions: x->d2 must equal y->d1 */ + if (x->d2 != y->d1) + { + fprintf(stderr, + "Error in new_matmul: dimension mismatch. " + "X is %d x %d, Y is %d x %d. X.d2 (%d) must equal Y.d1 (%d)\n", + x->d1, x->d2, y->d1, y->d2, x->d2, y->d1); + exit(1); + } + + /* verify both are variables and not the same variable */ + if (x->var_id == NOT_A_VARIABLE || y->var_id == NOT_A_VARIABLE || + x->var_id == y->var_id) + { + fprintf(stderr, "Error in new_matmul: operands must be variables and not " + "the same variable\n"); + exit(1); + } + + /* Allocate the expression node */ + expr *node = (expr *) calloc(1, sizeof(expr)); + + /* Initialize with d1 = x->d1, d2 = y->d2 (result is m x n) */ + init_expr(node, x->d1, y->d2, x->n_vars, forward, jacobian_init, eval_jacobian, + is_affine, wsum_hess_init, eval_wsum_hess, NULL); + + /* Set children */ + node->left = x; + node->right = y; + expr_retain(x); + expr_retain(y); + + return node; +} diff --git a/src/utils/mini_numpy.c b/src/utils/mini_numpy.c index 7dbb02f..fa78b5d 100644 --- a/src/utils/mini_numpy.c +++ b/src/utils/mini_numpy.c @@ -36,3 +36,18 @@ void scaled_ones(double *result, int size, double value) result[i] = value; } } + +void mat_mat_mult(const double *X, const double *Y, double *Z, int m, int k, int n) +{ + for (int j = 0; j < n; ++j) + { + for (int i = 0; i < m; ++i) + { + Z[i + j * m] = 0.0; + for (int l = 0; l < k; ++l) + { + Z[i + j * m] += X[i + l * m] * Y[l + j * k]; + } + } + } +} diff --git a/tests/all_tests.c b/tests/all_tests.c index 5de1b00..8756b13 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -14,6 +14,7 @@ #include "forward_pass/composite/test_composite.h" #include "forward_pass/elementwise/test_exp.h" #include "forward_pass/elementwise/test_log.h" +#include "forward_pass/test_matmul.h" #include "forward_pass/test_prod_axis_one.h" #include "forward_pass/test_prod_axis_zero.h" #include "jacobian_tests/test_broadcast.h" @@ -25,6 +26,7 @@ #include "jacobian_tests/test_index.h" #include "jacobian_tests/test_left_matmul.h" #include "jacobian_tests/test_log.h" +#include "jacobian_tests/test_matmul.h" #include "jacobian_tests/test_neg.h" #include "jacobian_tests/test_prod.h" #include "jacobian_tests/test_prod_axis_one.h" @@ -55,6 +57,7 @@ #include "wsum_hess/test_hstack.h" #include "wsum_hess/test_index.h" #include "wsum_hess/test_left_matmul.h" +#include "wsum_hess/test_matmul.h" #include "wsum_hess/test_multiply.h" #include "wsum_hess/test_prod.h" #include "wsum_hess/test_prod_axis_one.h" @@ -94,6 +97,7 @@ int main(void) mu_run_test(test_broadcast_matrix, tests_run); mu_run_test(test_forward_prod_axis_zero, tests_run); mu_run_test(test_forward_prod_axis_one, tests_run); + mu_run_test(test_matmul, tests_run); printf("\n--- Jacobian Tests ---\n"); mu_run_test(test_neg_jacobian, tests_run); @@ -156,6 +160,7 @@ int main(void) mu_run_test(test_jacobian_left_matmul_log_composite, tests_run); mu_run_test(test_jacobian_right_matmul_log, tests_run); mu_run_test(test_jacobian_right_matmul_log_vector, tests_run); + mu_run_test(test_jacobian_matmul, tests_run); printf("\n--- Weighted Sum of Hessian Tests ---\n"); mu_run_test(test_wsum_hess_log, tests_run); @@ -210,6 +215,8 @@ int main(void) mu_run_test(test_wsum_hess_left_matmul, tests_run); mu_run_test(test_wsum_hess_left_matmul_matrix, tests_run); mu_run_test(test_wsum_hess_left_matmul_composite, tests_run); + mu_run_test(test_wsum_hess_matmul, tests_run); + mu_run_test(test_wsum_hess_matmul_yx, tests_run); mu_run_test(test_wsum_hess_right_matmul, tests_run); mu_run_test(test_wsum_hess_right_matmul_vector, tests_run); mu_run_test(test_wsum_hess_broadcast_row, tests_run); @@ -250,7 +257,7 @@ int main(void) mu_run_test(test_problem_constraint_forward, tests_run); mu_run_test(test_problem_hessian, tests_run); - printf("\n=== All %d tests passed ===\n", tests_run); + printf("\n=== All %d tests passed ===\n", tests_run); return 0; } diff --git a/tests/forward_pass/test_matmul.h b/tests/forward_pass/test_matmul.h new file mode 100644 index 0000000..b2731d1 --- /dev/null +++ b/tests/forward_pass/test_matmul.h @@ -0,0 +1,62 @@ +#include +#include +#include + +#include "bivariate.h" +#include "expr.h" +#include "minunit.h" +#include "test_helpers.h" + +const char *test_matmul() +{ + /* Test: Z = X @ Y where + * X is 3x2: [1 4] Y is 2x4: [7 9 11 13] + * [2 5] [8 10 12 14] + * [3 6] + * + * Expected Z (3x4): [39 49 59 69] + * [54 68 82 96] + * [69 87 105 123] + */ + + /* Create X variable (3 x 2) - stored in column-major order */ + double u_x[6] = {1.0, 2.0, 3.0, /* first column */ + 4.0, 5.0, 6.0}; /* second column */ + expr *X = new_variable(3, 2, 0, 20); + + /* Create Y variable (2 x 4) - stored in column-major order */ + double u_y[8] = {7.0, 8.0, /* first column */ + 9.0, 10.0, /* second column */ + 11.0, 12.0, /* third column */ + 13.0, 14.0}; /* fourth column */ + expr *Y = new_variable(2, 4, 6, 20); + + /* Create matmul expression */ + expr *Z = new_matmul(X, Y); + + /* Concatenate parameter vectors */ + double u[14]; + for (int i = 0; i < 6; i++) u[i] = u_x[i]; + for (int i = 0; i < 8; i++) u[6 + i] = u_y[i]; + + /* Evaluate forward pass */ + Z->forward(Z, u); + + /* Expected result (3 x 4) in column-major order */ + double expected[12] = {39.0, 54.0, 69.0, /* first column */ + 49.0, 68.0, 87.0, /* second column */ + 59.0, 82.0, 105.0, /* third column */ + 69.0, 96.0, 123.0}; /* fourth column */ + + /* Verify dimensions */ + mu_assert("matmul result should have d1=3", Z->d1 == 3); + mu_assert("matmul result should have d2=4", Z->d2 == 4); + mu_assert("matmul result should have size=12", Z->size == 12); + + /* Verify values */ + mu_assert("Matmul forward pass test failed", + cmp_double_array(Z->value, expected, 12)); + + free_expr(Z); + return 0; +} diff --git a/tests/jacobian_tests/test_matmul.h b/tests/jacobian_tests/test_matmul.h new file mode 100644 index 0000000..16a23fc --- /dev/null +++ b/tests/jacobian_tests/test_matmul.h @@ -0,0 +1,89 @@ +#include "bivariate.h" +#include "expr.h" +#include "minunit.h" +#include "test_helpers.h" +#include +#include + +const char *test_jacobian_matmul() +{ + /* Test: Z = X @ Y where X is 2x3, Y is 3x4 + * var = (X, Y) where X starts at index 0 (size 6), Y starts at index 6 (size 12) + * Total n_vars = 18 + * Z is 2x4 (size 8) + * + * Each row of Jacobian has 6 nonzeros (3 from X, 3 from Y) + * Total nnz = 8 * 6 = 48 + */ + + int m = 2, k = 3, n = 4; + int x_size = m * k; // 6 + int y_size = k * n; // 12 + int z_size = m * n; // 8 + int n_vars = x_size + y_size; // 18 + + /* X values (2x3) in column-major order */ + double x_vals[6] = {1.0, 2.0, /* col 0 */ + 3.0, 4.0, /* col 1 */ + 5.0, 6.0}; /* col 2 */ + + /* Y values (3x4) in column-major order */ + double y_vals[12] = {7.0, 8.0, 9.0, /* col 0 */ + 10.0, 11.0, 12.0, /* col 1 */ + 13.0, 14.0, 15.0, /* col 2 */ + 16.0, 17.0, 18.0}; /* col 3 */ + + /* Combined parameter vector */ + double u_vals[18]; + for (int i = 0; i < 6; i++) u_vals[i] = x_vals[i]; + for (int i = 0; i < 12; i++) u_vals[6 + i] = y_vals[i]; + + /* Create variables */ + expr *X = new_variable(m, k, 0, n_vars); + expr *Y = new_variable(k, n, x_size, n_vars); + expr *Z = new_matmul(X, Y); + + /* Forward pass and jacobian initialization */ + Z->forward(Z, u_vals); + Z->jacobian_init(Z); + Z->eval_jacobian(Z); + + /* Verify sparsity pattern */ + mu_assert("Jacobian should have 8 rows", Z->jacobian->m == z_size); + mu_assert("Jacobian should have 18 columns", Z->jacobian->n == n_vars); + mu_assert("Jacobian should have 48 nonzeros", Z->jacobian->nnz == 48); + + /* Check row pointers: each row should have 6 entries */ + int expected_p[9] = {0, 6, 12, 18, 24, 30, 36, 42, 48}; + mu_assert("Row pointers incorrect", + cmp_int_array(Z->jacobian->p, expected_p, 9)); + + int expected_i[48] = {0, 2, 4, 6, 7, 8, /* row 0 */ + 1, 3, 5, 6, 7, 8, /* row 1 */ + 0, 2, 4, 9, 10, 11, /* row 2 */ + 1, 3, 5, 9, 10, 11, /* row 3 */ + 0, 2, 4, 12, 13, 14, /* row 4 */ + 1, 3, 5, 12, 13, 14, /* row 5 */ + 0, 2, 4, 15, 16, 17, /* row 6 */ + 1, 3, 5, 15, 16, 17}; /* row 7 */ + mu_assert("Column indices incorrect", + cmp_int_array(Z->jacobian->i, expected_i, 48)); + + /* Verify Jacobian values row-wise: for each row, values are + [Y^T row for the column, X row values] since X has lower var_id */ + double expected_x[48] = { + /* row 0 (col 0) */ 7.0, 8.0, 9.0, 1.0, 3.0, 5.0, + /* row 1 (col 0) */ 7.0, 8.0, 9.0, 2.0, 4.0, 6.0, + /* row 2 (col 1) */ 10.0, 11.0, 12.0, 1.0, 3.0, 5.0, + /* row 3 (col 1) */ 10.0, 11.0, 12.0, 2.0, 4.0, 6.0, + /* row 4 (col 2) */ 13.0, 14.0, 15.0, 1.0, 3.0, 5.0, + /* row 5 (col 2) */ 13.0, 14.0, 15.0, 2.0, 4.0, 6.0, + /* row 6 (col 3) */ 16.0, 17.0, 18.0, 1.0, 3.0, 5.0, + /* row 7 (col 3) */ 16.0, 17.0, 18.0, 2.0, 4.0, 6.0}; + + mu_assert("Jacobian values incorrect", + cmp_double_array(Z->jacobian->x, expected_x, 48)); + + free_expr(Z); + return 0; +} diff --git a/tests/wsum_hess/test_matmul.h b/tests/wsum_hess/test_matmul.h new file mode 100644 index 0000000..a6b74df --- /dev/null +++ b/tests/wsum_hess/test_matmul.h @@ -0,0 +1,217 @@ +#include "bivariate.h" +#include "expr.h" +#include "minunit.h" +#include "test_helpers.h" +#include +#include + +const char *test_wsum_hess_matmul() +{ + /* Test: Z = X @ Y where X is 2x3, Y is 3x4 + * var = (X, Y) where X starts at index 0 (size 6), Y starts at index 6 (size 12) + * Total n_vars = 18 + * Z is 2x4 (size 8) + */ + + int m = 2, k = 3, n = 4; + int x_size = m * k; // 6 + int y_size = k * n; // 12 + int n_vars = x_size + y_size; // 18 + + /* X values (2x3) in column-major order */ + double x_vals[6] = {1.0, 2.0, /* col 0 */ + 3.0, 4.0, /* col 1 */ + 5.0, 6.0}; /* col 2 */ + + /* Y values (3x4) in column-major order */ + double y_vals[12] = {7.0, 8.0, 9.0, /* col 0 */ + 10.0, 11.0, 12.0, /* col 1 */ + 13.0, 14.0, 15.0, /* col 2 */ + 16.0, 17.0, 18.0}; /* col 3 */ + + /* Combined parameter vector */ + double u_vals[18]; + for (int i = 0; i < 6; i++) u_vals[i] = x_vals[i]; + for (int i = 0; i < 12; i++) u_vals[6 + i] = y_vals[i]; + + /* Weights for the 8 outputs */ + double w[8] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}; + + /* Create variables */ + expr *X = new_variable(m, k, 0, n_vars); + expr *Y = new_variable(k, n, x_size, n_vars); + expr *Z = new_matmul(X, Y); + + /* Forward pass and Hessian initialization */ + Z->forward(Z, u_vals); + Z->wsum_hess_init(Z); + Z->eval_wsum_hess(Z, w); + + /* Verify Hessian dimensions and sparsity */ + mu_assert("Hessian should be 18x18", Z->wsum_hess->m == n_vars); + mu_assert("Hessian should be 18x18", Z->wsum_hess->n == n_vars); + mu_assert("Hessian should have 48 nonzeros", Z->wsum_hess->nnz == 48); + + int expected_p[19] = {0, 4, 8, 12, 16, 20, 24, 26, 28, 30, + 32, 34, 36, 38, 40, 42, 44, 46, 48}; + + mu_assert("Row pointers incorrect", + cmp_int_array(Z->wsum_hess->p, expected_p, 19)); + + int expected_i[48] = {6, 9, 12, 15, /* row 0 */ + 6, 9, 12, 15, /* row 1 */ + 7, 10, 13, 16, /* row 2 */ + 7, 10, 13, 16, /* row 3 */ + 8, 11, 14, 17, /* row 4 */ + 8, 11, 14, 17, /* row 5 */ + 0, 1, /* row 6*/ + 2, 3, /* row 7*/ + 4, 5, /* row 8*/ + 0, 1, /* row 9*/ + 2, 3, /* row 10*/ + 4, 5, /* row 11*/ + 0, 1, /* row 12*/ + 2, 3, /* row 13*/ + 4, 5, /* row 14*/ + 0, 1, /* row 15*/ + 2, 3, /* row 16*/ + 4, 5}; + + mu_assert("Column indices incorrect", + cmp_int_array(Z->wsum_hess->i, expected_i, 48)); + + double expected_x[48] = {1.0, 3.0, 5.0, 7.0, /* row 0 */ + 2.0, 4.0, 6.0, 8.0, /* row 1 */ + 1.0, 3.0, 5.0, 7.0, /* row 2 */ + 2.0, 4.0, 6.0, 8.0, /* row 3 */ + 1.0, 3.0, 5.0, 7.0, /* row 4 */ + 2.0, 4.0, 6.0, 8.0, /* row 5 */ + 1.0, 2.0, /* row 6 */ + 1.0, 2.0, /* row 7 */ + 1.0, 2.0, /* row 8 */ + 3.0, 4.0, /* row 9 */ + 3.0, 4.0, /* row 10 */ + 3.0, 4.0, /* row 11 */ + 5.0, 6.0, /* row 12 */ + 5.0, 6.0, /* row 13 */ + 5.0, 6.0, /* row 14 */ + 7.0, 8.0, /* row 15 */ + 7.0, 8.0, /* row 16 */ + 7.0, 8.0}; /* row 17 */ + + mu_assert("Hessian values incorrect", + cmp_double_array(Z->wsum_hess->x, expected_x, 48)); + + free_expr(Z); + return 0; +} + +const char *test_wsum_hess_matmul_yx() +{ + /* Test: Z = X @ Y where X is 2x3, Y is 3x4 + * var = (Y, X) where Y starts at index 0 (size 12), X starts at index 12 (size + * 6) Total n_vars = 18 Z is 2x4 (size 8) + */ + + int m = 2, k = 3, n = 4; + int y_size = k * n; // 12 + int x_size = m * k; // 6 + int n_vars = x_size + y_size; // 18 + + /* X values (2x3) in column-major order */ + double x_vals[6] = {1.0, 2.0, /* col 0 */ + 3.0, 4.0, /* col 1 */ + 5.0, 6.0}; /* col 2 */ + + /* Y values (3x4) in column-major order */ + double y_vals[12] = {7.0, 8.0, 9.0, /* col 0 */ + 10.0, 11.0, 12.0, /* col 1 */ + 13.0, 14.0, 15.0, /* col 2 */ + 16.0, 17.0, 18.0}; /* col 3 */ + + /* Combined parameter vector: Y first, then X */ + double u_vals[18]; + for (int i = 0; i < 12; i++) u_vals[i] = y_vals[i]; + for (int i = 0; i < 6; i++) u_vals[12 + i] = x_vals[i]; + + /* Weights for the 8 outputs */ + double w[8] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}; + + /* Create variables with Y first */ + expr *Y = new_variable(k, n, 0, n_vars); + expr *X = new_variable(m, k, y_size, n_vars); + expr *Z = new_matmul(X, Y); + + /* Forward pass and Hessian initialization */ + Z->forward(Z, u_vals); + Z->wsum_hess_init(Z); + Z->eval_wsum_hess(Z, w); + + /* Verify Hessian dimensions and sparsity */ + mu_assert("Hessian should be 18x18", Z->wsum_hess->m == n_vars); + mu_assert("Hessian should be 18x18", Z->wsum_hess->n == n_vars); + mu_assert("Hessian should have 48 nonzeros", Z->wsum_hess->nnz == 48); + + /* Row pointers when Y < X: + * Rows 0-11 (Y variables): each couples with m=2 X variables + * Rows 12-17 (X variables): each couples with n=4 Y variables + */ + int expected_p[19] = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, + 20, 22, 24, 28, 32, 36, 40, 44, 48}; + + mu_assert("Row pointers incorrect", + cmp_int_array(Z->wsum_hess->p, expected_p, 19)); + + /* Column indices when Y < X: + * Y[k_idx, col] couples with X[row, k_idx] for all row + * X variable index = 12 + row + k_idx*m + * X[row, k_idx] couples with Y[k_idx, col] for all col + * Y variable index = 0 + k_idx + col*k + */ + int expected_i[48] = {12, 13, /* row 0: Y[0,0] */ + 14, 15, /* row 1: Y[1,0] */ + 16, 17, /* row 2: Y[2,0] */ + 12, 13, /* row 3: Y[0,1] */ + 14, 15, /* row 4: Y[1,1] */ + 16, 17, /* row 5: Y[2,1] */ + 12, 13, /* row 6: Y[0,2] */ + 14, 15, /* row 7: Y[1,2] */ + 16, 17, /* row 8: Y[2,2] */ + 12, 13, /* row 9: Y[0,3] */ + 14, 15, /* row 10: Y[1,3] */ + 16, 17, /* row 11: Y[2,3] */ + 0, 3, 6, 9, /* row 12: X[0,0] */ + 0, 3, 6, 9, /* row 13: X[1,0] */ + 1, 4, 7, 10, /* row 14: X[0,1] */ + 1, 4, 7, 10, /* row 15: X[1,1] */ + 2, 5, 8, 11, /* row 16: X[0,2] */ + 2, 5, 8, 11}; /* row 17: X[1,2] */ + + mu_assert("Column indices incorrect", + cmp_int_array(Z->wsum_hess->i, expected_i, 48)); + + double expected_x[48] = {1.0, 2.0, /* row 0 */ + 1.0, 2.0, /* row 1 */ + 1.0, 2.0, /* row 2 */ + 3.0, 4.0, /* row 3 */ + 3.0, 4.0, /* row 4 */ + 3.0, 4.0, /* row 5 */ + 5.0, 6.0, /* row 6 */ + 5.0, 6.0, /* row 7 */ + 5.0, 6.0, /* row 8 */ + 7.0, 8.0, /* row 9 */ + 7.0, 8.0, /* row 10 */ + 7.0, 8.0, /* row 11 */ + 1.0, 3.0, 5.0, 7.0, /* row 12 */ + 2.0, 4.0, 6.0, 8.0, /* row 13 */ + 1.0, 3.0, 5.0, 7.0, /* row 14 */ + 2.0, 4.0, 6.0, 8.0, /* row 15 */ + 1.0, 3.0, 5.0, 7.0, /* row 16 */ + 2.0, 4.0, 6.0, 8.0}; /* row 17 */ + + mu_assert("Hessian values incorrect", + cmp_double_array(Z->wsum_hess->x, expected_x, 48)); + + free_expr(Z); + return 0; +}