diff --git a/include/subexpr.h b/include/subexpr.h index 033fd7f..8c5661b 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -49,4 +49,12 @@ typedef struct hstack_expr CSR_Matrix *CSR_work; /* for summing Hessians of children */ } hstack_expr; +/* Elementwise multiplication */ +typedef struct elementwise_mult_expr +{ + expr base; + CSR_Matrix *CSR_work1; + CSR_Matrix *CSR_work2; +} elementwise_mult_expr; + #endif /* SUBEXPR_H */ diff --git a/include/utils/CSC_Matrix.h b/include/utils/CSC_Matrix.h index b3eb358..51e6f6a 100644 --- a/include/utils/CSC_Matrix.h +++ b/include/utils/CSC_Matrix.h @@ -34,10 +34,15 @@ CSC_Matrix *csr_to_csc(const CSR_Matrix *A); /* Allocate sparsity pattern for C = A^T D A for diagonal D */ CSR_Matrix *ATA_alloc(const CSC_Matrix *A); -/* Compute values for C = A^T D A - * C must have precomputed sparsity pattern - */ -void ATDA_values(const CSC_Matrix *A, const double *d, CSR_Matrix *C); +/* Allocate sparsity pattern for C = B^T D A for diagonal D */ +CSR_Matrix *BTA_alloc(const CSC_Matrix *A, const CSC_Matrix *B); + +/* Compute values for C = A^T D A. C must have precomputed sparsity pattern */ +void ATDA_fill_values(const CSC_Matrix *A, const double *d, CSR_Matrix *C); + +/* Compute values for C = B^T D A. C must have precomputed sparsity pattern */ +void BTDA_fill_values(const CSC_Matrix *A, const CSC_Matrix *B, const double *d, + CSR_Matrix *C); /* C = z^T A where A is in CSC format and C is assumed to have one row. * C must have column indices pre-computed. Fills in values of C only. diff --git a/include/utils/CSR_Matrix.h b/include/utils/CSR_Matrix.h index 89f83f2..ca4b0a2 100644 --- a/include/utils/CSR_Matrix.h +++ b/include/utils/CSR_Matrix.h @@ -84,6 +84,10 @@ void insert_idx(int idx, int *arr, int len); double csr_get_value(const CSR_Matrix *A, int row, int col); CSR_Matrix *transpose(const CSR_Matrix *A, int *iwork); +CSR_Matrix *AT_alloc(const CSR_Matrix *A, int *iwork); + +/* Fill values of A^T given sparsity pattern is already computed */ +void AT_fill_values(const CSR_Matrix *A, CSR_Matrix *AT, int *iwork); /* Expand symmetric CSR matrix A to full matrix C. A is assumed to store only upper triangle. C must be pre-allocated with sufficient nnz */ diff --git a/python/bindings.c b/python/bindings.c index 68e4b76..e6f4f38 100644 --- a/python/bindings.c +++ b/python/bindings.c @@ -244,7 +244,8 @@ static PyMethodDef DNLPMethods[] = { {"make_add", py_make_add, METH_VARARGS, "Create add node"}, {"make_sum", py_make_sum, METH_VARARGS, "Create sum node"}, {"forward", py_forward, METH_VARARGS, "Run forward pass and return values"}, - {"jacobian", py_jacobian, METH_VARARGS, "Compute jacobian and return CSR components"}, + {"jacobian", py_jacobian, METH_VARARGS, + "Compute jacobian and return CSR components"}, {NULL, NULL, 0, NULL}}; static struct PyModuleDef dnlp_module = {PyModuleDef_HEAD_INIT, "DNLP_diff_engine", diff --git a/src/bivariate/multiply.c b/src/bivariate/multiply.c index a7f81c1..e50882b 100644 --- a/src/bivariate/multiply.c +++ b/src/bivariate/multiply.c @@ -1,6 +1,10 @@ #include "bivariate.h" +#include "subexpr.h" +#include #include +#include #include +#include // ------------------------------------------------------------------------------ // Implementation of elementwise multiplication when both arguments are vectors. @@ -52,15 +56,151 @@ static void eval_jacobian(expr *node) x->value); } +static void wsum_hess_init(expr *node) +{ + expr *x = node->left; + expr *y = node->right; + + /* for correctness x and y must be (1) different variables, + or (2) both must be linear operators */ +#ifndef DEBUG + if (x->var_id != NOT_A_VARIABLE && y->var_id != NOT_A_VARIABLE && + x->var_id == y->var_id) + { + fprintf(stderr, "Error: elementwise multiplication of a variable by itself " + "not supported.\n"); + exit(1); + } + else if ((x->var_id != NOT_A_VARIABLE && y->var_id == NOT_A_VARIABLE) || + (x->var_id == NOT_A_VARIABLE && y->var_id != NOT_A_VARIABLE)) + { + fprintf(stderr, "Error: elementwise multiplication of a variable by a " + "non-variable is not supported. (Both must be inserted " + "as linear operators)\n"); + exit(1); + } +#endif + + /* both x and y are variables*/ + if (x->var_id != NOT_A_VARIABLE) + { + node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, 2 * node->size); + + int i, var1_id, var2_id; + + if (x->var_id < y->var_id) + { + var1_id = x->var_id; + var2_id = y->var_id; + } + else + { + var1_id = y->var_id; + var2_id = x->var_id; + } + + /* var1 rows of Hessian */ + for (i = 0; i < node->size; i++) + { + node->wsum_hess->p[var1_id + i] = i; + node->wsum_hess->i[i] = var2_id + i; + } + + int nnz = node->size; + + /* rows between var1 and var2 */ + for (i = var1_id + node->size; i < var2_id; i++) + { + node->wsum_hess->p[i] = nnz; + } + + /* var2 rows of Hessian */ + for (i = 0; i < node->size; i++) + { + node->wsum_hess->p[var2_id + i] = nnz + i; + node->wsum_hess->i[nnz + i] = var1_id + i; + } + + /* remaining rows */ + nnz += node->size; + for (i = var2_id + node->size; i <= node->n_vars; i++) + { + node->wsum_hess->p[i] = nnz; + } + } + else + { + /* both are linear operators */ + CSC_Matrix *A = ((linear_op_expr *) x)->A_csc; + CSC_Matrix *B = ((linear_op_expr *) y)->A_csc; + + /* Allocate workspace for Hessian computation */ + elementwise_mult_expr *mul_node = (elementwise_mult_expr *) node; + CSR_Matrix *C; /* C = B^T diag(w) A */ + C = BTA_alloc(A, B); + node->iwork = (int *) malloc(C->m * sizeof(int)); + + CSR_Matrix *CT = AT_alloc(C, node->iwork); + mul_node->CSR_work1 = C; + mul_node->CSR_work2 = CT; + + /* Hessian is H = C + C^T where both are B->n x A->n, and can't be more than + * 2 * nnz(C) */ + assert(C->m == node->n_vars && C->n == node->n_vars); + node->wsum_hess = new_csr_matrix(C->m, C->n, 2 * C->nnz); + } +} + +static void eval_wsum_hess(expr *node, const double *w) +{ + expr *x = node->left; + expr *y = node->right; + + /* both x and y are variables*/ + if (x->var_id != NOT_A_VARIABLE) + { + memcpy(node->wsum_hess->x, w, node->size * sizeof(double)); + memcpy(node->wsum_hess->x + node->size, w, node->size * sizeof(double)); + } + else + { + /* both are linear operators */ + CSC_Matrix *A = ((linear_op_expr *) x)->A_csc; + CSC_Matrix *B = ((linear_op_expr *) y)->A_csc; + CSR_Matrix *C = ((elementwise_mult_expr *) node)->CSR_work1; + CSR_Matrix *CT = ((elementwise_mult_expr *) node)->CSR_work2; + + /* Compute C = B^T diag(w) A */ + BTDA_fill_values(A, B, w, C); + + /* Compute CT = C^T = A^T diag(w) B */ + AT_fill_values(C, CT, node->iwork); + + /* Hessian = C + CT = B^T diag(w) A + A^T diag(w) B */ + sum_csr_matrices(C, CT, node->wsum_hess); + } +} + +static void free_type_data(expr *node) +{ + free_csr_matrix(((elementwise_mult_expr *) node)->CSR_work1); + free_csr_matrix(((elementwise_mult_expr *) node)->CSR_work2); +} + expr *new_elementwise_mult(expr *left, expr *right) { - expr *node = new_expr(left->d1, 1, left->n_vars); + elementwise_mult_expr *mul_node = + (elementwise_mult_expr *) calloc(1, sizeof(elementwise_mult_expr)); + expr *node = &mul_node->base; + + init_expr(node, left->d1, left->d2, left->n_vars, forward, jacobian_init, + eval_jacobian, NULL, free_type_data); + node->wsum_hess_init = wsum_hess_init; + node->eval_wsum_hess = eval_wsum_hess; node->left = left; node->right = right; expr_retain(left); expr_retain(right); - node->forward = forward; - node->jacobian_init = jacobian_init; - node->eval_jacobian = eval_jacobian; + return node; } diff --git a/src/bivariate/rel_entr.c b/src/bivariate/rel_entr.c index 59f385d..1360a1a 100644 --- a/src/bivariate/rel_entr.c +++ b/src/bivariate/rel_entr.c @@ -24,6 +24,8 @@ static void forward_vector_args(expr *node, const double *u) } } +/* TODO: this probably doesn't work for matrices because we use node->d1 + instead of node->size */ static void jacobian_init_vectors_args(expr *node) { node->jacobian = new_csr_matrix(node->d1, node->n_vars, 2 * node->d1); @@ -80,6 +82,8 @@ static void eval_jacobian_vector_args(expr *node) } } +/* TODO: this probably doesn't work for matrices because we use node->d1 + instead of node->size */ static void wsum_hess_init_vector_args(expr *node) { node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, 4 * node->d1); diff --git a/src/elementwise_univariate/common.c b/src/elementwise_univariate/common.c index 986d26a..752f107 100644 --- a/src/elementwise_univariate/common.c +++ b/src/elementwise_univariate/common.c @@ -87,7 +87,7 @@ void eval_wsum_hess_elementwise(expr *node, const double *w) /* Child will be a linear operator */ linear_op_expr *lin_child = (linear_op_expr *) child; node->local_wsum_hess(node, node->dwork, w); - ATDA_values(lin_child->A_csc, node->dwork, node->wsum_hess); + ATDA_fill_values(lin_child->A_csc, node->dwork, node->wsum_hess); } } diff --git a/src/utils/CSC_Matrix.c b/src/utils/CSC_Matrix.c index fc94b27..841c2b3 100644 --- a/src/utils/CSC_Matrix.c +++ b/src/utils/CSC_Matrix.c @@ -122,7 +122,7 @@ static inline double sparse_wdot(const double *a_x, const int *a_i, int a_nnz, return sum; } -void ATDA_values(const CSC_Matrix *A, const double *d, CSR_Matrix *C) +void ATDA_fill_values(const CSC_Matrix *A, const double *d, CSR_Matrix *C) { int i, j, ii, jj; for (i = 0; i < C->m; i++) @@ -196,6 +196,64 @@ CSC_Matrix *csr_to_csc(const CSR_Matrix *A) free(count); return C; } +CSR_Matrix *BTA_alloc(const CSC_Matrix *A, const CSC_Matrix *B) +{ + /* A is m x n, B is m x p, C = B^T A is p x n */ + int n = A->n; + int p = B->n; + int m = A->m; + int nnz = 0; + int i, j, ii, jj; + + /* row ptr and column idxs for C = B^T A */ + int *Cp = (int *) malloc((p + 1) * sizeof(int)); + iVec *Ci = iVec_new(n); + Cp[0] = 0; + + /* compute sparsity pattern */ + for (i = 0; i < p; i++) + { + /* check if Cij != 0 for each column j of A */ + for (j = 0; j < n; j++) + { + ii = B->p[i]; + jj = A->p[j]; + + /* check if row i of B^T (column i of B) has common row with column j of + * A */ + while (ii < B->p[i + 1] && jj < A->p[j + 1]) + { + if (B->i[ii] == A->i[jj]) + { + nnz++; + iVec_append(Ci, j); + break; + } + else if (B->i[ii] < A->i[jj]) + { + ii++; + } + else + { + jj++; + } + } + } + Cp[i + 1] = Ci->len; + } + + /* Allocate C */ + CSR_Matrix *C = new_csr_matrix(p, n, nnz); + memcpy(C->p, Cp, (p + 1) * sizeof(int)); + memcpy(C->i, Ci->data, nnz * sizeof(int)); + + /* free workspace */ + free(Cp); + iVec_free(Ci); + + return C; +} + void csc_matvec_fill_values(const CSC_Matrix *A, const double *z, CSR_Matrix *C) { /* Compute C = z^T * A where A is in CSC format @@ -224,3 +282,25 @@ void csc_matvec_fill_values(const CSC_Matrix *A, const double *z, CSR_Matrix *C) } } } + +void BTDA_fill_values(const CSC_Matrix *A, const CSC_Matrix *B, const double *d, + CSR_Matrix *C) +{ + int i, j, jj; + for (i = 0; i < C->m; i++) + { + for (jj = C->p[i]; jj < C->p[i + 1]; jj++) + { + j = C->i[jj]; + + int nnz_bi = B->p[i + 1] - B->p[i]; + int nnz_aj = A->p[j + 1] - A->p[j]; + + /* compute Cij = weighted inner product of col i of B and col j of A */ + double sum = sparse_wdot(B->x + B->p[i], B->i + B->p[i], nnz_bi, + A->x + A->p[j], A->i + A->p[j], nnz_aj, d); + + C->x[jj] = sum; + } + } +} diff --git a/src/utils/CSR_Matrix.c b/src/utils/CSR_Matrix.c index 620180f..8ee83e5 100644 --- a/src/utils/CSR_Matrix.c +++ b/src/utils/CSR_Matrix.c @@ -442,6 +442,69 @@ CSR_Matrix *transpose(const CSR_Matrix *A, int *iwork) return AT; } +CSR_Matrix *AT_alloc(const CSR_Matrix *A, int *iwork) +{ + /* Allocate A^T and compute sparsity pattern without filling values */ + CSR_Matrix *AT = new_csr_matrix(A->n, A->m, A->nnz); + + int i, j; + int *count = iwork; + memset(count, 0, A->n * sizeof(int)); + + // ------------------------------------------------------------------- + // compute nnz in each column of A + // ------------------------------------------------------------------- + for (i = 0; i < A->m; ++i) + { + for (j = A->p[i]; j < A->p[i + 1]; ++j) + { + count[A->i[j]]++; + } + } + + // ------------------------------------------------------------------ + // compute row pointers + // ------------------------------------------------------------------ + AT->p[0] = 0; + for (i = 0; i < A->n; i++) + { + AT->p[i + 1] = AT->p[i] + count[i]; + count[i] = AT->p[i]; + } + + // ------------------------------------------------------------------ + // fill column indices + // ------------------------------------------------------------------ + for (i = 0; i < A->m; ++i) + { + for (j = A->p[i]; j < A->p[i + 1]; j++) + { + AT->i[count[A->i[j]]] = i; + count[A->i[j]]++; + } + } + + return AT; +} + +void AT_fill_values(const CSR_Matrix *A, CSR_Matrix *AT, int *iwork) +{ + /* Fill values of A^T given sparsity pattern is already computed */ + int i, j; + int *count = iwork; + memcpy(count, AT->p, A->n * sizeof(int)); + + /* Fill values by placing each element of A into its transposed position */ + for (i = 0; i < A->m; ++i) + { + for (j = A->p[i]; j < A->p[i + 1]; j++) + { + AT->x[count[A->i[j]]] = A->x[j]; + count[A->i[j]]++; + } + } +} + /**/ void csr_matvec_fill_values(const CSR_Matrix *AT, const double *z, CSR_Matrix *C) { diff --git a/tests/all_tests.c b/tests/all_tests.c index d57c183..071d5bf 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -30,6 +30,7 @@ #include "wsum_hess/elementwise/test_trig.h" #include "wsum_hess/elementwise/test_xexp.h" #include "wsum_hess/test_hstack.h" +#include "wsum_hess/test_multiply.h" #include "wsum_hess/test_rel_entr.h" #include "wsum_hess/test_sum.h" @@ -78,6 +79,8 @@ int main(void) mu_run_test(test_jacobian_sum_log_axis_1, tests_run); mu_run_test(test_jacobian_hstack_vectors, tests_run); mu_run_test(test_jacobian_hstack_matrix, tests_run); + mu_run_test(test_wsum_hess_multiply_1, tests_run); + mu_run_test(test_wsum_hess_multiply_2, tests_run); printf("\n--- Weighted Sum of Hessian Tests ---\n"); mu_run_test(test_wsum_hess_log, tests_run); @@ -101,12 +104,15 @@ int main(void) mu_run_test(test_wsum_hess_rel_entr_2, tests_run); mu_run_test(test_wsum_hess_hstack, tests_run); mu_run_test(test_wsum_hess_hstack_matrix, tests_run); + mu_run_test(test_wsum_hess_multiply_linear_ops, tests_run); + mu_run_test(test_wsum_hess_multiply_sparse_random, tests_run); printf("\n--- Utility Tests ---\n"); mu_run_test(test_diag_csr_mult, tests_run); mu_run_test(test_csr_sum, tests_run); mu_run_test(test_csr_sum2, tests_run); mu_run_test(test_transpose, tests_run); + mu_run_test(test_AT_alloc_and_fill, tests_run); mu_run_test(test_csr_to_csc1, tests_run); mu_run_test(test_csr_to_csc2, tests_run); mu_run_test(test_csr_vecmat_values_sparse, tests_run); @@ -117,6 +123,7 @@ int main(void) mu_run_test(test_ATA_alloc_diagonal_like, tests_run); mu_run_test(test_ATA_alloc_random, tests_run); mu_run_test(test_ATA_alloc_random2, tests_run); + mu_run_test(test_BTA_alloc_and_BTDA_fill, tests_run); printf("\n=== All %d tests passed ===\n", tests_run); diff --git a/tests/utils/multiply_test.py b/tests/utils/multiply_test.py new file mode 100644 index 0000000..e3fa1b1 --- /dev/null +++ b/tests/utils/multiply_test.py @@ -0,0 +1,40 @@ +import numpy as np +import scipy.sparse as sp + +# test 1 +A = np.array([[1.0, 0.0, 2.0], [0.0, 3.0, 0.0], [4.0, 0.0, 5.0], [0.0, 6.0, 0.0]]) +B = np.array([[1.0, 0.0, 4], [0.0, 2.0, 7], [3.0, 0.0, 2], [0.0, 4.0, -1]]) +w = np.array([1.0, 2.0, 3.0, 4.0]) +H = A.T @ np.diag(w) @ B + B.T @ np.diag(w) @ A +print(H) +H_csr = sp.csr_matrix(H) +print("H in CSR format:") +print("data:", H_csr.data) +print("indices:", H_csr.indices) +print("indptr:", H_csr.indptr) + +# test 2 +m = 5 +n = 10 +density = 0.2 +A = sp.random(m, n, density=density, format="csr", data_rvs=np.random.randn) +B = sp.random(m, n, density=density, format="csr", data_rvs=np.random.randn) +w = np.random.rand(m) +H = A.T @ sp.diags(w) @ B + B.T @ sp.diags(w) @ A +H = H.tocsr() +print("Random sparse H in CSR format:") +print("data:", H.data) +print("indices:", H.indices) +print("indptr:", H.indptr) + +print("A in csr:") +print("data:", A.data) +print("indices:", A.indices) +print("indptr:", A.indptr) + +print("B in csr:") +print("data:", B.data) +print("indices:", B.indices) +print("indptr:", B.indptr) + +print("w:", w) diff --git a/tests/utils/test_csc_matrix.h b/tests/utils/test_csc_matrix.h index be2d489..803f5e6 100644 --- a/tests/utils/test_csc_matrix.h +++ b/tests/utils/test_csc_matrix.h @@ -180,7 +180,7 @@ const char *test_ATA_alloc_random() double d[10] = {2, 8, 6, 2, 5, 1, 6, 9, 1, 3}; - ATDA_values(A, d, C); + ATDA_fill_values(A, d, C); double Cx_correct[38] = { 49., 21., 491., 56., 240., 416., 144., 288., 56., 98., 56., 21., 9., @@ -218,7 +218,7 @@ const char *test_ATA_alloc_random2() double d[15] = {-0.6, -0.23, -0.29, -1.36, 0.4, 0.36, 0.11, -0.13, -1.32, -0.32, -0.24, -0.7, -0.06, 0.5, 1.99}; - ATDA_values(A, d, C); + ATDA_fill_values(A, d, C); double Cx_correct[17] = {-0.362232, -0.189896, 0.06656, -0.228888, -0.025732, -0.016146, 0.032857, 0.06656, -1.004802, 0.1505, @@ -231,3 +231,60 @@ const char *test_ATA_alloc_random2() return 0; } +const char *test_BTA_alloc_and_BTDA_fill() +{ + /* Create A: 4x3 CSC matrix + * [1.0 0.0 2.0] + * [0.0 3.0 0.0] + * [4.0 0.0 5.0] + * [0.0 6.0 0.0] + */ + int m = 4; + int n = 3; + CSC_Matrix *A = new_csc_matrix(m, n, 6); + int Ap_A[4] = {0, 2, 4, 6}; + int Ai_A[6] = {0, 2, 1, 3, 0, 2}; + double Ax_A[6] = {1.0, 4.0, 3.0, 6.0, 2.0, 5.0}; + memcpy(A->p, Ap_A, 4 * sizeof(int)); + memcpy(A->i, Ai_A, 6 * sizeof(int)); + memcpy(A->x, Ax_A, 6 * sizeof(double)); + + /* Create B: 4x2 CSC matrix + * [1.0 0.0] + * [0.0 2.0] + * [3.0 0.0] + * [0.0 4.0] + */ + int p = 2; + CSC_Matrix *B = new_csc_matrix(m, p, 4); + int Bp[3] = {0, 2, 4}; + int Bi[4] = {0, 2, 1, 3}; + double Bx[4] = {1.0, 3.0, 2.0, 4.0}; + memcpy(B->p, Bp, 3 * sizeof(int)); + memcpy(B->i, Bi, 4 * sizeof(int)); + memcpy(B->x, Bx, 4 * sizeof(double)); + + /* Allocate C = B^T A (should be 2x3) */ + CSR_Matrix *C = BTA_alloc(A, B); + + /* Sparsity pattern check before filling values */ + int expected_p[3] = {0, 2, 3}; + int expected_i[3] = {0, 2, 1}; + mu_assert("C dimensions incorrect", C->m == 2 && C->n == 3); + mu_assert("C nnz incorrect", C->nnz == 3); + mu_assert("C->p incorrect", cmp_int_array(C->p, expected_p, 3)); + mu_assert("C->i incorrect", cmp_int_array(C->i, expected_i, 3)); + + /* Fill values with diagonal weights d */ + double d[4] = {1.0, 2.0, 3.0, 4.0}; + BTDA_fill_values(A, B, d, C); + + double expected_x[3] = {37.0, 47.0, 108.0}; + mu_assert("C values incorrect", cmp_double_array(C->x, expected_x, 3)); + + free_csr_matrix(C); + free_csc_matrix(A); + free_csc_matrix(B); + + return 0; +} diff --git a/tests/utils/test_csr_matrix.h b/tests/utils/test_csr_matrix.h index f55c722..b2cfced 100644 --- a/tests/utils/test_csr_matrix.h +++ b/tests/utils/test_csr_matrix.h @@ -391,3 +391,47 @@ const char *test_sum_evenly_spaced_rows_csr() return 0; } +const char *test_AT_alloc_and_fill() +{ + /* Create a 3x4 CSR matrix A: + * [1.0 0.0 2.0 0.0] + * [0.0 3.0 0.0 4.0] + * [5.0 0.0 6.0 0.0] + */ + CSR_Matrix *A = new_csr_matrix(3, 4, 6); + double Ax[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + int Ai[6] = {0, 2, 1, 3, 0, 2}; + int Ap[4] = {0, 2, 4, 6}; + memcpy(A->x, Ax, 6 * sizeof(double)); + memcpy(A->i, Ai, 6 * sizeof(int)); + memcpy(A->p, Ap, 4 * sizeof(int)); + + /* Allocate A^T (should be 4x3) */ + int *iwork = (int *) malloc(A->n * sizeof(int)); + CSR_Matrix *AT = AT_alloc(A, iwork); + + /* Fill values of A^T */ + AT_fill_values(A, AT, iwork); + + /* Expected A^T: + * [1.0 0.0 5.0] + * [0.0 3.0 0.0] + * [2.0 0.0 6.0] + * [0.0 4.0 0.0] + */ + double ATx_correct[6] = {1.0, 5.0, 3.0, 2.0, 6.0, 4.0}; + int ATi_correct[6] = {0, 2, 1, 0, 2, 1}; + int ATp_correct[5] = {0, 2, 3, 5, 6}; + + mu_assert("AT dimensions incorrect", AT->m == 4 && AT->n == 3); + mu_assert("AT nnz incorrect", AT->nnz == 6); + mu_assert("AT vals incorrect", cmp_double_array(AT->x, ATx_correct, 6)); + mu_assert("AT cols incorrect", cmp_int_array(AT->i, ATi_correct, 6)); + mu_assert("AT rows incorrect", cmp_int_array(AT->p, ATp_correct, 5)); + + free_csr_matrix(A); + free_csr_matrix(AT); + free(iwork); + + return 0; +} diff --git a/tests/wsum_hess/test_multiply.h b/tests/wsum_hess/test_multiply.h new file mode 100644 index 0000000..05e075f --- /dev/null +++ b/tests/wsum_hess/test_multiply.h @@ -0,0 +1,234 @@ +#include "affine.h" +#include "bivariate.h" +#include "expr.h" +#include "minunit.h" +#include "test_helpers.h" +#include +#include + +const char *test_wsum_hess_multiply_1() +{ + // Total 12 variables: [?, ?, ?, x0, x1, x2, ?, ?, y0, y1, y2, ?] + // x has var_id = 3, y has var_id = 8 + // x = [1, 2, 3], y = [4, 5, 6] + // w = [1, 2, 3] + // Hessian structure: [0, diag(w); diag(w), 0] + + double u_vals[12] = {0, 0, 0, 1.0, 2.0, 3.0, 0, 0, 4.0, 5.0, 6.0, 0}; + double w[3] = {1.0, 2.0, 3.0}; + + expr *x = new_variable(3, 1, 3, 12); + expr *y = new_variable(3, 1, 8, 12); + expr *node = new_elementwise_mult(x, y); + + node->forward(node, u_vals); + node->wsum_hess_init(node); + node->eval_wsum_hess(node, w); + + int expected_p[13] = {0, 0, 0, 0, 1, 2, 3, 3, 3, 4, 5, 6, 6}; + int expected_i[6] = {8, 9, 10, 3, 4, 5}; + double expected_x[6] = {1.0, 2.0, 3.0, 1.0, 2.0, 3.0}; + + mu_assert("p array fails", cmp_int_array(node->wsum_hess->p, expected_p, 13)); + mu_assert("i array fails", cmp_int_array(node->wsum_hess->i, expected_i, 6)); + mu_assert("x array fails", cmp_double_array(node->wsum_hess->x, expected_x, 6)); + + free_expr(node); + free_expr(x); + free_expr(y); + return 0; +} + +const char *test_wsum_hess_multiply_sparse_random() +{ + /* Test with larger random sparse matrices + * A: 5x10 CSR matrix + * B: 5x10 CSR matrix + * x: 10-dimensional variable with var_id = 0, n_vars = 10 + * Expected Hessian: 10x10 sparse matrix + */ + + /* Create A matrix (5x10) */ + CSR_Matrix *A = new_csr_matrix(5, 10, 10); + double Ax[10] = {-1.44165273, -1.13687223, 0.55892257, 0.24912193, 0.84959744, + -0.23998915, 0.5913356, -1.21627912, -0.50379166, 0.41531801}; + int Ai[10] = {1, 2, 4, 8, 2, 3, 8, 9, 1, 2}; + int Ap[6] = {0, 3, 3, 4, 8, 10}; + memcpy(A->x, Ax, 10 * sizeof(double)); + memcpy(A->i, Ai, 10 * sizeof(int)); + memcpy(A->p, Ap, 6 * sizeof(int)); + + /* Create B matrix (5x10) */ + CSR_Matrix *B = new_csr_matrix(5, 10, 10); + double Bx[10] = {1.27549062, 0.04194731, -0.4356034, 0.405574, 1.34670487, + -0.57738638, 0.9411464, -0.31563179, 1.90831766, -0.89802958}; + int Bi[10] = {0, 3, 5, 7, 0, 5, 0, 3, 7, 9}; + int Bp[6] = {0, 1, 4, 6, 8, 10}; + memcpy(B->x, Bx, 10 * sizeof(double)); + memcpy(B->i, Bi, 10 * sizeof(int)); + memcpy(B->p, Bp, 6 * sizeof(int)); + + /* Create variable and linear operators */ + expr *x = new_variable(10, 1, 0, 10); + expr *Ax_node = new_linear(x, A); + expr *Bx_node = new_linear(x, B); + + /* Create multiply node */ + expr *mult_node = new_elementwise_mult(Ax_node, Bx_node); + + /* Forward pass with unit input */ + double u_vals[10] = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; + mult_node->forward(mult_node, u_vals); + + /* Initialize and evaluate Hessian */ + mult_node->wsum_hess_init(mult_node); + double w[5] = {0.50646339, 0.44756224, 0.67295241, 0.16424956, 0.03031469}; + mult_node->eval_wsum_hess(mult_node, w); + + /* Expected Hessian in CSR format (10x10) */ + int expected_p[11] = {0, 6, 9, 13, 18, 19, 20, 20, 22, 25, 29}; + int expected_i[29] = {1, 2, 3, 4, 8, 9, 0, 7, 9, 0, 3, 7, 9, 0, 2, + 3, 8, 9, 0, 8, 1, 2, 0, 3, 5, 0, 1, 2, 3}; + double expected_x[29] = { + -0.93129225, -0.60307409, -0.03709821, 0.361058, 0.31718166, -0.18801593, + -0.93129225, -0.02914438, 0.01371497, -0.60307409, -0.04404516, 0.02402617, + -0.01130641, -0.03709821, -0.04404516, 0.02488322, -0.03065625, 0.06305481, + 0.361058, -0.09679721, -0.02914438, 0.02402617, 0.31718166, -0.03065625, + -0.09679721, -0.18801593, 0.01371497, -0.01130641, 0.06305481}; + + mu_assert("p array fails", + cmp_int_array(mult_node->wsum_hess->p, expected_p, 11)); + mu_assert("i array fails", + cmp_int_array(mult_node->wsum_hess->i, expected_i, 29)); + mu_assert("x array fails", + cmp_double_array(mult_node->wsum_hess->x, expected_x, 29)); + + /* Cleanup */ + free_expr(mult_node); + free_expr(Ax_node); + free_expr(Bx_node); + free_expr(x); + free_csr_matrix(A); + free_csr_matrix(B); + + return 0; +} + +const char *test_wsum_hess_multiply_linear_ops() +{ + /* Test Hessian for mult(Ax, Bx) where A, B are 4x3 linear operators + * A = [[1.0, 0.0, 2.0], + * [0.0, 3.0, 0.0], + * [4.0, 0.0, 5.0], + * [0.0, 6.0, 0.0]] + * B = [[1.0, 0.0, 4.0], + * [0.0, 2.0, 7.0], + * [3.0, 0.0, 2.0], + * [0.0, 4.0, -1.0]] + * x has var_id = 0, n_vars = 3 + * w = [1.0, 2.0, 3.0, 4.0] + * Hessian = A^T diag(w) B + B^T diag(w) A + * Expected (from numpy): + * [[ 74. 0. 75.] + * [ 0. 216. 18.] + * [ 75. 18. 76.]] + */ + + /* Create CSR matrix A */ + CSR_Matrix *A = new_csr_matrix(4, 3, 6); + double Ax[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + int Ai[6] = {0, 2, 1, 0, 2, 1}; + int Ap[5] = {0, 2, 3, 5, 6}; + memcpy(A->x, Ax, 6 * sizeof(double)); + memcpy(A->i, Ai, 6 * sizeof(int)); + memcpy(A->p, Ap, 5 * sizeof(int)); + + /* Create CSR matrix B */ + CSR_Matrix *B = new_csr_matrix(4, 3, 8); + double Bx[8] = {1.0, 4.0, 2.0, 7.0, 3.0, 2.0, 4.0, -1.0}; + int Bi[8] = {0, 2, 1, 2, 0, 2, 1, 2}; + int Bp[5] = {0, 2, 4, 6, 8}; + memcpy(B->x, Bx, 8 * sizeof(double)); + memcpy(B->i, Bi, 8 * sizeof(int)); + memcpy(B->p, Bp, 5 * sizeof(int)); + + /* Create linear operator expressions */ + expr *x = new_variable(3, 1, 0, 3); + expr *Ax_node = new_linear(x, A); + expr *Bx_node = new_linear(x, B); + + /* Create elementwise multiply node */ + expr *mult_node = new_elementwise_mult(Ax_node, Bx_node); + + /* Forward pass */ + double u_vals[3] = {1.0, 1.0, 1.0}; + mult_node->forward(mult_node, u_vals); + + /* Initialize Hessian structure */ + mult_node->wsum_hess_init(mult_node); + + /* Evaluate Hessian with weights */ + double w[4] = {1.0, 2.0, 3.0, 4.0}; + mult_node->eval_wsum_hess(mult_node, w); + + /* Check sparsity pattern and values */ + /* Expected CSR format: + * indptr: [0, 2, 4, 7] + * indices: [0, 2, 1, 2, 0, 1, 2] + * data: [74.0, 75.0, 216.0, 18.0, 75.0, 18.0, 76.0] + */ + int expected_p[4] = {0, 2, 4, 7}; + int expected_i[7] = {0, 2, 1, 2, 0, 1, 2}; + double expected_x[7] = {74.0, 75.0, 216.0, 18.0, 75.0, 18.0, 76.0}; + + mu_assert("p array fails", + cmp_int_array(mult_node->wsum_hess->p, expected_p, 4)); + mu_assert("i array fails", + cmp_int_array(mult_node->wsum_hess->i, expected_i, 7)); + mu_assert("x array fails", + cmp_double_array(mult_node->wsum_hess->x, expected_x, 7)); + + /* Cleanup */ + free_expr(mult_node); + free_expr(Ax_node); + free_expr(Bx_node); + free_expr(x); + free_csr_matrix(A); + free_csr_matrix(B); + + return 0; +} + +const char *test_wsum_hess_multiply_2() +{ + // Total 12 variables: [?, ?, ?, y0, y1, y2, ?, ?, x0, x1, x2, ?] + // y has var_id = 3, x has var_id = 8 + // x = [1, 2, 3], y = [4, 5, 6] + // w = [1, 2, 3] + // Hessian structure: [0, diag(w); diag(w), 0] + // Should get the same result as test 1 + + double u_vals[12] = {0, 0, 0, 4.0, 5.0, 6.0, 0, 0, 1.0, 2.0, 3.0, 0}; + double w[3] = {1.0, 2.0, 3.0}; + + expr *x = new_variable(3, 1, 8, 12); + expr *y = new_variable(3, 1, 3, 12); + expr *node = new_elementwise_mult(x, y); + + node->forward(node, u_vals); + node->wsum_hess_init(node); + node->eval_wsum_hess(node, w); + + int expected_p[13] = {0, 0, 0, 0, 1, 2, 3, 3, 3, 4, 5, 6, 6}; + int expected_i[6] = {8, 9, 10, 3, 4, 5}; + double expected_x[6] = {1.0, 2.0, 3.0, 1.0, 2.0, 3.0}; + + mu_assert("p array fails", cmp_int_array(node->wsum_hess->p, expected_p, 13)); + mu_assert("i array fails", cmp_int_array(node->wsum_hess->i, expected_i, 6)); + mu_assert("x array fails", cmp_double_array(node->wsum_hess->x, expected_x, 6)); + + free_expr(node); + free_expr(x); + free_expr(y); + return 0; +}