From 5d869c6c5205488e4562ed505fcd9216d2292e74 Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 14 Jan 2026 12:49:11 -0800 Subject: [PATCH 1/2] some progress on broadcasting --- include/subexpr.h | 16 +++ src/affine/broadcast.c | 221 +++++++++++++++++++++++++++++++++++++++++ src/utils/mini_numpy.c | 1 + 3 files changed, 238 insertions(+) create mode 100644 src/affine/broadcast.c diff --git a/include/subexpr.h b/include/subexpr.h index 97a31bf..66df2ce 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -118,4 +118,20 @@ typedef struct index_expr bool has_duplicates; /* True if indices have duplicates (affects Hessian path) */ } index_expr; +/* Broadcast types */ +typedef enum +{ + BROADCAST_ROW, /* (1, n) -> (m, n) */ + BROADCAST_COL, /* (m, 1) -> (m, n) */ + BROADCAST_SCALAR /* (1, 1) -> (m, n) */ +} broadcast_type; + +typedef struct broadcast_expr +{ + expr base; + broadcast_type type; + int m; /* target rows */ + int n; /* target cols */ +} broadcast_expr; + #endif /* SUBEXPR_H */ diff --git a/src/affine/broadcast.c b/src/affine/broadcast.c new file mode 100644 index 0000000..c20df6d --- /dev/null +++ b/src/affine/broadcast.c @@ -0,0 +1,221 @@ +#include "affine.h" +#include "mini_numpy.h" +#include "subexpr.h" +#include +#include +#include +#include + +/* Broadcast expands an array to a larger shape by replicating along dimensions. + * Supports three types: + * 1. "row": (1, n) -> (m, n) - replicate rows + * 2. "col": (m, 1) -> (m, n) - replicate columns + * 3. "scalar": (1, 1) -> (m, n) - replicate in both dimensions + */ + +static void forward(expr *node, const double *u) +{ + expr *x = node->left; + broadcast_expr *bcast = (broadcast_expr *) node; + + x->forward(x, u); + + if (bcast->type == BROADCAST_ROW) + { + /* (1, n) -> (m, n): replicate row m times */ + for (int j = 0; j < bcast->n; j++) + { + for (int i = 0; i < bcast->m; i++) + { + node->value[i + j * bcast->m] = x->value[j]; + } + } + } + else if (bcast->type == BROADCAST_COL) + { + /* (m, 1) -> (m, n): replicate column n times */ + for (int j = 0; j < bcast->n; j++) + { + memcpy(node->value + j * bcast->m, x->value, bcast->m * sizeof(double)); + } + } + else + { + /* (1, 1) -> (m, n): fill with scalar value */ + for (int k = 0; k < node->size; k++) + { + node->value[k] = x->value[0]; + } + } +} + +static void jacobian_init(expr *node) +{ + expr *x = node->left; + x->jacobian_init(x); + broadcast_expr *bcast = (broadcast_expr *) node; + int total_nnz; + + // -------------------------------------------------------------------- + // count number of nonzeros + // -------------------------------------------------------------------- + if (bcast->type == BROADCAST_ROW) + { + /* Row broadcast: (1, n) -> (m, n) */ + total_nnz = x->jacobian->nnz * bcast->m; + } + else if (bcast->type == BROADCAST_COL) + { + /* Column broadcast: (m, 1) -> (m, n) */ + total_nnz = x->jacobian->nnz * bcast->n; + } + else + { + /* Scalar broadcast: (1, 1) -> (m, n) */ + total_nnz = x->jacobian->nnz * bcast->m * bcast->n; + } + + node->jacobian = new_csr_matrix(node->size, node->n_vars, total_nnz); + + // --------------------------------------------------------------------- + // fill sparsity pattern + // --------------------------------------------------------------------- + CSR_Matrix *Jx = x->jacobian; + CSR_Matrix *J = node->jacobian; + J->nnz = 0; + + if (bcast->type == BROADCAST_ROW) + { + for (int i = 0; i < bcast->n; i++) + { + int nnz_in_row = Jx->p[i + 1] - Jx->p[i]; + + /* copy columns indices */ + tile(J->i + J->nnz, Jx->i + Jx->p[i], nnz_in_row, bcast->m); + + /* set row pointers */ + for (int rep = 0; rep < bcast->m; rep++) + { + J->p[i * bcast->m + rep] = J->nnz; + J->nnz += nnz_in_row; + } + } + } + else if (bcast->type == BROADCAST_COL) + { + + /* copy column indices */ + tile(J->i, Jx->i, Jx->nnz, bcast->n); + + /* set row pointers */ + int offset = 0; + for (int i = 0; i < bcast->n; i++) + { + for (int j = 0; j < bcast->m; j++) + { + J->p[i * bcast->m + j] = offset; + offset += Jx->p[1] - Jx->p[0]; + } + } + assert(offset == total_nnz) J->p[node->size] = total_nnz; + } + else + { + assert(false); + } +} + +static void eval_jacobian(expr *node) +{ + node->left->eval_jacobian(node->left); + + broadcast_expr *bcast = (broadcast_expr *) node; + CSR_Matrix *Jx = node->left->jacobian; + CSR_Matrix *J = node->jacobian; + J->nnz = 0; + + if (bcast->type == BROADCAST_ROW) + { + for (int i = 0; i < bcast->n; i++) + { + int nnz_in_row = Jx->p[i + 1] - Jx->p[i]; + tile(J->x + J->nnz, Jx->x + Jx->p[i], nnz_in_row, bcast->m); + J->nnz += nnz_in_row * bcast->m; + } + } + else if (bcast->type == BROADCAST_COL) + { + tile(J->x, Jx->x, Jx->nnz, bcast->n); + } + else + { + assert(false); + } +} + +static void wsum_hess_init(expr *node) +{ + expr *x = node->left; + x->wsum_hess_init(x); + + /* Same sparsity as child - weights get summed */ + node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, x->wsum_hess->nnz); + memcpy(node->wsum_hess->p, x->wsum_hess->p, (x->wsum_hess->m + 1) * sizeof(int)); + memcpy(node->wsum_hess->i, x->wsum_hess->i, x->wsum_hess->nnz * sizeof(int)); +} + +static void eval_wsum_hess(expr *node, const double *w) +{ + broadcast_expr *bcast = (broadcast_expr *) node; + expr *child = node->left; +} + +static bool is_affine(const expr *node) +{ + return node->left->is_affine(node->left); +} + +expr *new_broadcast(expr *child, int target_d1, int target_d2) +{ + // --------------------------------------------------------------------------- + // determine broadcast type + // --------------------------------------------------------------------------- + broadcast_type type; + int m = target_d1; + int n = target_d2; + + if (child->d1 == 1 && child->d2 == n) + { + type = BROADCAST_ROW; + } + else if (child->d1 == m && child->d2 == 1) + { + type = BROADCAST_COL; + } + else if (child->d1 == 1 && child->d2 == 1) + { + type = BROADCAST_SCALAR; + } + else + { + assert(false); + } + + broadcast_expr *bcast = (broadcast_expr *) malloc(sizeof(broadcast_expr)); + expr *node = (expr *) bcast; + + // -------------------------------------------------------------------------- + // initialize the rest of the expression + // -------------------------------------------------------------------------- + init_expr(node, target_d1, target_d2, child->n_vars, forward, jacobian_init, + eval_jacobian, is_affine, NULL); + node->left = child; + expr_retain(child); + node->wsum_hess_init = wsum_hess_init; + node->eval_wsum_hess = eval_wsum_hess; + bcast->type = type; + bcast->m = m; + bcast->n = n; + + return node; +} diff --git a/src/utils/mini_numpy.c b/src/utils/mini_numpy.c index 4d198f2..84db029 100644 --- a/src/utils/mini_numpy.c +++ b/src/utils/mini_numpy.c @@ -12,6 +12,7 @@ void repeat(double *result, const double *a, int len, int repeats) } } +/* TODO: we can use memcpy here */ void tile(double *result, const double *a, int len, int tiles) { int idx = 0; From f32715e09d6400bbeaabe07951e5312acf3de0f1 Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 14 Jan 2026 15:27:38 -0800 Subject: [PATCH 2/2] progress on broadcast --- include/affine.h | 1 + include/utils/mini_numpy.h | 3 +- src/affine/broadcast.c | 74 +++++++-- src/affine/sum.c | 2 +- src/utils/mini_numpy.c | 14 +- tests/all_tests.c | 12 ++ tests/forward_pass/affine/test_broadcast.h | 89 +++++++++++ tests/jacobian_tests/test_broadcast.h | 135 ++++++++++++++++ tests/wsum_hess/test_broadcast.h | 170 +++++++++++++++++++++ 9 files changed, 487 insertions(+), 13 deletions(-) create mode 100644 tests/forward_pass/affine/test_broadcast.h create mode 100644 tests/jacobian_tests/test_broadcast.h create mode 100644 tests/wsum_hess/test_broadcast.h diff --git a/include/affine.h b/include/affine.h index f76376a..ec99717 100644 --- a/include/affine.h +++ b/include/affine.h @@ -20,5 +20,6 @@ expr *new_variable(int d1, int d2, int var_id, int n_vars); expr *new_index(expr *child, const int *indices, int n_idxs); expr *new_reshape(expr *child, int d1, int d2); +expr *new_broadcast(expr *child, int target_d1, int target_d2); #endif /* AFFINE_H */ diff --git a/include/utils/mini_numpy.h b/include/utils/mini_numpy.h index c24ee29..1710978 100644 --- a/include/utils/mini_numpy.h +++ b/include/utils/mini_numpy.h @@ -11,7 +11,8 @@ void repeat(double *result, const double *a, int len, int repeats); * Example: a = [1, 2], len = 2, tiles = 3 * result = [1, 2, 1, 2, 1, 2] */ -void tile(double *result, const double *a, int len, int tiles); +void tile_double(double *result, const double *a, int len, int tiles); +void tile_int(int *result, const int *a, int len, int tiles); /* Fill array with 'size' copies of 'value' * Example: size = 5, value = 3.0 diff --git a/src/affine/broadcast.c b/src/affine/broadcast.c index c20df6d..369936e 100644 --- a/src/affine/broadcast.c +++ b/src/affine/broadcast.c @@ -1,6 +1,6 @@ #include "affine.h" -#include "mini_numpy.h" #include "subexpr.h" +#include "utils/mini_numpy.h" #include #include #include @@ -91,7 +91,7 @@ static void jacobian_init(expr *node) int nnz_in_row = Jx->p[i + 1] - Jx->p[i]; /* copy columns indices */ - tile(J->i + J->nnz, Jx->i + Jx->p[i], nnz_in_row, bcast->m); + tile_int(J->i + J->nnz, Jx->i + Jx->p[i], nnz_in_row, bcast->m); /* set row pointers */ for (int rep = 0; rep < bcast->m; rep++) @@ -105,7 +105,7 @@ static void jacobian_init(expr *node) { /* copy column indices */ - tile(J->i, Jx->i, Jx->nnz, bcast->n); + tile_int(J->i, Jx->i, Jx->nnz, bcast->n); /* set row pointers */ int offset = 0; @@ -117,11 +117,24 @@ static void jacobian_init(expr *node) offset += Jx->p[1] - Jx->p[0]; } } - assert(offset == total_nnz) J->p[node->size] = total_nnz; + assert(offset == total_nnz); + J->p[node->size] = total_nnz; } else { - assert(false); + /* copy column indices */ + tile_int(J->i, Jx->i, Jx->nnz, bcast->m * bcast->n); + + /* set row pointers */ + int offset = 0; + int nnz = Jx->p[1] - Jx->p[0]; + for (int i = 0; i < bcast->m * bcast->n; i++) + { + J->p[i] = offset; + offset += nnz; + } + assert(offset == total_nnz); + J->p[node->size] = total_nnz; } } @@ -139,17 +152,17 @@ static void eval_jacobian(expr *node) for (int i = 0; i < bcast->n; i++) { int nnz_in_row = Jx->p[i + 1] - Jx->p[i]; - tile(J->x + J->nnz, Jx->x + Jx->p[i], nnz_in_row, bcast->m); + tile_double(J->x + J->nnz, Jx->x + Jx->p[i], nnz_in_row, bcast->m); J->nnz += nnz_in_row * bcast->m; } } else if (bcast->type == BROADCAST_COL) { - tile(J->x, Jx->x, Jx->nnz, bcast->n); + tile_double(J->x, Jx->x, Jx->nnz, bcast->n); } else { - assert(false); + tile_double(J->x, Jx->x, Jx->nnz, bcast->m * bcast->n); } } @@ -162,12 +175,53 @@ static void wsum_hess_init(expr *node) node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, x->wsum_hess->nnz); memcpy(node->wsum_hess->p, x->wsum_hess->p, (x->wsum_hess->m + 1) * sizeof(int)); memcpy(node->wsum_hess->i, x->wsum_hess->i, x->wsum_hess->nnz * sizeof(int)); + + /* allocate space for weight vector */ + node->dwork = malloc(node->size * sizeof(double)); } static void eval_wsum_hess(expr *node, const double *w) { broadcast_expr *bcast = (broadcast_expr *) node; - expr *child = node->left; + expr *x = node->left; + + /* Zero out the work array first */ + memset(node->dwork, 0, x->size * sizeof(double)); + + if (bcast->type == BROADCAST_ROW) + { + /* (1, n) -> (m, n): each input element has m weights to sum */ + for (int j = 0; j < bcast->n; j++) + { + for (int i = 0; i < bcast->m; i++) + { + node->dwork[j] += w[i + j * bcast->m]; + } + } + } + else if (bcast->type == BROADCAST_COL) + { + /* (m, 1) -> (m, n): each input element has n weights to sum */ + for (int j = 0; j < bcast->n; j++) + { + for (int i = 0; i < bcast->m; i++) + { + node->dwork[i] += w[i + j * bcast->m]; + } + } + } + else + { + /* (1, 1) -> (m, n): scalar has m*n weights to sum */ + node->dwork[0] = 0.0; + for (int k = 0; k < bcast->m * bcast->n; k++) + { + node->dwork[0] += w[k]; + } + } + + x->eval_wsum_hess(x, node->dwork); + memcpy(node->wsum_hess->x, x->wsum_hess->x, x->wsum_hess->nnz * sizeof(double)); } static bool is_affine(const expr *node) @@ -201,7 +255,7 @@ expr *new_broadcast(expr *child, int target_d1, int target_d2) assert(false); } - broadcast_expr *bcast = (broadcast_expr *) malloc(sizeof(broadcast_expr)); + broadcast_expr *bcast = (broadcast_expr *) calloc(1, sizeof(broadcast_expr)); expr *node = (expr *) bcast; // -------------------------------------------------------------------------- diff --git a/src/affine/sum.c b/src/affine/sum.c index c124796..9ab1f60 100644 --- a/src/affine/sum.c +++ b/src/affine/sum.c @@ -135,7 +135,7 @@ static void eval_wsum_hess(expr *node, const double *w) } else if (axis == 1) { - tile(node->dwork, w, x->d1, x->d2); + tile_double(node->dwork, w, x->d1, x->d2); } x->eval_wsum_hess(x, node->dwork); diff --git a/src/utils/mini_numpy.c b/src/utils/mini_numpy.c index 84db029..f020cdd 100644 --- a/src/utils/mini_numpy.c +++ b/src/utils/mini_numpy.c @@ -13,7 +13,19 @@ void repeat(double *result, const double *a, int len, int repeats) } /* TODO: we can use memcpy here */ -void tile(double *result, const double *a, int len, int tiles) +void tile_double(double *result, const double *a, int len, int tiles) +{ + int idx = 0; + for (int i = 0; i < tiles; i++) + { + for (int j = 0; j < len; j++) + { + result[idx++] = a[j]; + } + } +} + +void tile_int(int *result, const int *a, int len, int tiles) { int idx = 0; for (int i = 0; i < tiles; i++) diff --git a/tests/all_tests.c b/tests/all_tests.c index 9c7692d..0d0722d 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -4,6 +4,7 @@ /* Include all test headers */ #include "forward_pass/affine/test_add.h" +#include "forward_pass/affine/test_broadcast.h" #include "forward_pass/affine/test_hstack.h" #include "forward_pass/affine/test_linear_op.h" #include "forward_pass/affine/test_neg.h" @@ -13,6 +14,7 @@ #include "forward_pass/composite/test_composite.h" #include "forward_pass/elementwise/test_exp.h" #include "forward_pass/elementwise/test_log.h" +#include "jacobian_tests/test_broadcast.h" #include "jacobian_tests/test_composite.h" #include "jacobian_tests/test_const_scalar_mult.h" #include "jacobian_tests/test_const_vector_mult.h" @@ -41,6 +43,7 @@ #include "wsum_hess/elementwise/test_power.h" #include "wsum_hess/elementwise/test_trig.h" #include "wsum_hess/elementwise/test_xexp.h" +#include "wsum_hess/test_broadcast.h" #include "wsum_hess/test_const_scalar_mult.h" #include "wsum_hess/test_const_vector_mult.h" #include "wsum_hess/test_hstack.h" @@ -76,6 +79,9 @@ int main(void) mu_run_test(test_sum_axis_1, tests_run); mu_run_test(test_hstack_forward_vectors, tests_run); mu_run_test(test_hstack_forward_matrix, tests_run); + mu_run_test(test_broadcast_row, tests_run); + mu_run_test(test_broadcast_col, tests_run); + mu_run_test(test_broadcast_matrix, tests_run); printf("\n--- Jacobian Tests ---\n"); mu_run_test(test_neg_jacobian, tests_run); @@ -121,6 +127,9 @@ int main(void) mu_run_test(test_sum_of_index, tests_run); mu_run_test(test_promote_scalar_jacobian, tests_run); mu_run_test(test_promote_scalar_to_matrix_jacobian, tests_run); + mu_run_test(test_broadcast_row_jacobian, tests_run); + mu_run_test(test_broadcast_col_jacobian, tests_run); + mu_run_test(test_broadcast_scalar_to_matrix_jacobian, tests_run); mu_run_test(test_wsum_hess_multiply_1, tests_run); mu_run_test(test_wsum_hess_multiply_2, tests_run); mu_run_test(test_jacobian_trace_variable, tests_run); @@ -177,6 +186,9 @@ int main(void) mu_run_test(test_wsum_hess_left_matmul_composite, tests_run); mu_run_test(test_wsum_hess_right_matmul, tests_run); mu_run_test(test_wsum_hess_right_matmul_vector, tests_run); + mu_run_test(test_wsum_hess_broadcast_row, tests_run); + mu_run_test(test_wsum_hess_broadcast_col, tests_run); + mu_run_test(test_wsum_hess_broadcast_scalar_to_matrix, tests_run); // This test leads to seg fault // mu_run_test(test_wsum_hess_trace_variable, tests_run); diff --git a/tests/forward_pass/affine/test_broadcast.h b/tests/forward_pass/affine/test_broadcast.h new file mode 100644 index 0000000..ca030b5 --- /dev/null +++ b/tests/forward_pass/affine/test_broadcast.h @@ -0,0 +1,89 @@ +#include +#include +#include + +#include "affine.h" +#include "expr.h" +#include "minunit.h" +#include "test_helpers.h" + +const char *test_broadcast_row() +{ + /* Test broadcast row: (1, 3) -> (2, 3) + * Input: [1.0, 2.0, 3.0] (row vector) + * Output: [[1.0, 2.0, 3.0], + * [1.0, 2.0, 3.0]] + * Vectorized columnwise: [1.0, 1.0, 2.0, 2.0, 3.0, 3.0] + */ + double row_data[3] = {1.0, 2.0, 3.0}; + expr *row_var = new_variable(1, 3, 0, 3); + expr *bcast = new_broadcast(row_var, 2, 3); + + bcast->forward(bcast, row_data); + + /* Expected: columnwise vectorization [col1, col2, col3] */ + double expected[6] = {1.0, 1.0, 2.0, 2.0, 3.0, 3.0}; + mu_assert("Broadcast row test failed", + cmp_double_array(bcast->value, expected, 6)); + + free_expr(bcast); + return 0; +} + +const char *test_broadcast_col() +{ + /* Test broadcast column: (3, 1) -> (3, 2) + * Input: [[1.0], + * [2.0], + * [3.0]] (column vector) + * Output: [[1.0, 1.0], + * [2.0, 2.0], + * [3.0, 3.0]] + * Vectorized columnwise: [1.0, 2.0, 3.0, 1.0, 2.0, 3.0] + */ + double col_data[3] = {1.0, 2.0, 3.0}; + expr *col_var = new_variable(3, 1, 0, 3); + expr *bcast = new_broadcast(col_var, 3, 2); + + bcast->forward(bcast, col_data); + + /* Expected: columnwise vectorization [col1, col2] */ + double expected[6] = {1.0, 2.0, 3.0, 1.0, 2.0, 3.0}; + mu_assert("Broadcast column test failed", + cmp_double_array(bcast->value, expected, 6)); + + free_expr(bcast); + return 0; +} + +const char *test_broadcast_matrix() +{ + /* Test no broadcast needed: (2, 3) -> (2, 3) + * This should work when child shape already matches target + * Actually, based on the implementation, broadcast is only for: + * - row: (1, n) -> (m, n) + * - col: (m, 1) -> (m, n) + * - scalar: (1, 1) -> (m, n) + * So let's test scalar broadcast instead. + */ + + /* Test scalar broadcast: (1, 1) -> (2, 3) + * Input: [5.0] (scalar) + * Output: [[5.0, 5.0, 5.0], + * [5.0, 5.0, 5.0]] + * Vectorized columnwise: [5.0, 5.0, 5.0, 5.0, 5.0, 5.0] + */ + double scalar_data[1] = {5.0}; + expr *scalar_var = new_variable(1, 1, 0, 1); + expr *bcast = new_broadcast(scalar_var, 2, 3); + + bcast->forward(bcast, scalar_data); + + /* Expected: all elements are 5.0 */ + double expected[6] = {5.0, 5.0, 5.0, 5.0, 5.0, 5.0}; + mu_assert("Broadcast scalar test failed", + cmp_double_array(bcast->value, expected, 6)); + + free_expr(bcast); + return 0; +} diff --git a/tests/jacobian_tests/test_broadcast.h b/tests/jacobian_tests/test_broadcast.h new file mode 100644 index 0000000..43872fe --- /dev/null +++ b/tests/jacobian_tests/test_broadcast.h @@ -0,0 +1,135 @@ +#include +#include +#include + +#include "affine.h" +#include "expr.h" +#include "minunit.h" +#include "test_helpers.h" + +const char *test_broadcast_row_jacobian() +{ + /* Test jacobian of broadcast row: (1, 3) -> (2, 3) + * Input variable: 3 elements (d1=1, d2=3) + * Output: 6 elements (d1=2, d2=3) stored columnwise + * + * For a row broadcast, each column of the child is repeated m times. + * Child jacobian (3x1 sparsity): + * var[0] -> out[0] + * var[1] -> out[1] + * var[2] -> out[2] + * + * Broadcast row (1,3) -> (2,3) repeats the row 2 times: + * Broadcast jacobian (6x1 sparsity, columnwise): + * var[0] -> out[0], out[1] (column 0, rows 0,1) + * var[1] -> out[2], out[3] (column 1, rows 0,1) + * var[2] -> out[4], out[5] (column 2, rows 0,1) + */ + double u[3] = {1.0, 2.0, 3.0}; + expr *var = new_variable(1, 3, 0, 3); + expr *bcast = new_broadcast(var, 2, 3); + bcast->forward(bcast, u); + bcast->jacobian_init(bcast); + bcast->eval_jacobian(bcast); + + /* Each variable affects 2 elements (m times) */ + double expected_x[6] = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; + int expected_p[7] = {0, 1, 2, 3, 4, 5, 6}; + int expected_i[6] = {0, 0, 1, 1, 2, 2}; + + mu_assert("broadcast row jacobian vals fail", + cmp_double_array(bcast->jacobian->x, expected_x, 6)); + mu_assert("broadcast row jacobian rows fail", + cmp_int_array(bcast->jacobian->p, expected_p, 4)); + mu_assert("broadcast row jacobian cols fail", + cmp_int_array(bcast->jacobian->i, expected_i, 6)); + + free_expr(bcast); + return 0; +} + +const char *test_broadcast_col_jacobian(void) +{ + /* Test jacobian of broadcast column: (3, 1) -> (3, 2) + * Input variable: 3 elements (d1=3, d2=1) + * Output: 6 elements (d1=3, d2=2) stored columnwise + * + * For a column broadcast, the column is repeated n times. + * Child jacobian (3x1 sparsity): + * var[0] -> child[0] + * var[1] -> child[1] + * var[2] -> child[2] + * + * Broadcast column (3,1) -> (3,2) repeats the column 2 times: + * Output in columnwise order: + * col 0: out[0]=child[0], out[1]=child[1], out[2]=child[2] + * col 1: out[3]=child[0], out[4]=child[1], out[5]=child[2] + * + * Broadcast jacobian (6x1 sparsity): + * var[0] -> out[0], out[3] + * var[1] -> out[1], out[4] + * var[2] -> out[2], out[5] + */ + double u[3] = {1.0, 2.0, 3.0}; + expr *var = new_variable(3, 1, 0, 3); + expr *bcast = new_broadcast(var, 3, 2); + bcast->forward(bcast, u); + bcast->jacobian_init(bcast); + bcast->eval_jacobian(bcast); + + /* Each variable affects 2 elements (n times) */ + double expected_x[6] = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; + int expected_p[7] = {0, 1, 2, 3, 4, 5, 6}; + int expected_i[6] = {0, 1, 2, 0, 1, 2}; + + mu_assert("broadcast col jacobian vals fail", + cmp_double_array(bcast->jacobian->x, expected_x, 6)); + mu_assert("broadcast col jacobian rows fail", + cmp_int_array(bcast->jacobian->p, expected_p, 7)); + mu_assert("broadcast col jacobian cols fail", + cmp_int_array(bcast->jacobian->i, expected_i, 6)); + + free_expr(bcast); + return 0; +} + +const char *test_broadcast_scalar_to_matrix_jacobian(void) +{ + /* Test jacobian of broadcast scalar: (1, 1) -> (2, 3) + * Input variable: 1 element (scalar) + * Output: 6 elements (2x3 matrix) stored columnwise + * + * Scalar broadcast replicates the single input value to all m*n outputs. + * Child jacobian (6x1 sparsity): var[0] -> child[0] + * + * Broadcast scalar (1,1) -> (2,3) replicates in both dimensions: + * Output in columnwise order: + * col 0: out[0]=var[0], out[1]=var[0] + * col 1: out[2]=var[0], out[3]=var[0] + * col 2: out[4]=var[0], out[5]=var[0] + * + * Broadcast jacobian (6x1 sparsity): + * var[0] -> out[0], out[1], out[2], out[3], out[4], out[5] + */ + double u[1] = {5.0}; + expr *var = new_variable(1, 1, 0, 1); + expr *bcast = new_broadcast(var, 2, 3); + bcast->forward(bcast, u); + bcast->jacobian_init(bcast); + bcast->eval_jacobian(bcast); + + /* All 6 elements depend on the single input variable */ + double expected_x[6] = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; + int expected_p[7] = {0, 1, 2, 3, 4, 5, 6}; + int expected_i[6] = {0, 0, 0, 0, 0, 0}; + + mu_assert("broadcast scalar jacobian vals fail", + cmp_double_array(bcast->jacobian->x, expected_x, 6)); + mu_assert("broadcast scalar jacobian rows fail", + cmp_int_array(bcast->jacobian->p, expected_p, 7)); + mu_assert("broadcast scalar jacobian cols fail", + cmp_int_array(bcast->jacobian->i, expected_i, 6)); + + free_expr(bcast); + return 0; +} diff --git a/tests/wsum_hess/test_broadcast.h b/tests/wsum_hess/test_broadcast.h new file mode 100644 index 0000000..54ba36c --- /dev/null +++ b/tests/wsum_hess/test_broadcast.h @@ -0,0 +1,170 @@ +#include +#include +#include +#include + +#include "affine.h" +#include "elementwise_univariate.h" +#include "expr.h" +#include "minunit.h" +#include "test_helpers.h" + +const char *test_wsum_hess_broadcast_row() +{ + /* Test: wsum_hess of broadcast_row(log(x)) where x is (1, 3) + * x = [1.0, 2.0, 4.0] (row vector) + * broadcast to (2, 3): + * [[1.0, 2.0, 4.0], + * [1.0, 2.0, 4.0]] + * + * log(x) is affine in the Hessian sense, log(broadcast(x)) has Hessian + * from log applied element-wise. + * + * The weights are for the output (2x3 matrix), stored columnwise. + * Weights w = [w00, w10, w01, w11, w02, w12] where w_ij is for row i, col j + * The hessian should sum weights from replicated rows. + */ + double x[3] = {1.0, 2.0, 4.0}; + expr *x_node = new_variable(1, 3, 0, 3); + expr *log_node = new_log(x_node); + expr *bcast = new_broadcast(log_node, 2, 3); + + bcast->forward(bcast, x); + bcast->jacobian_init(bcast); + bcast->wsum_hess_init(bcast); + + /* Weights for the 2x3 output (columnwise): + * w = [1.0, 0.5, 2.0, 1.0, 0.25, 0.125] + * col0 col1 col2 + */ + double w[6] = {1.0, 0.5, 2.0, 1.0, 0.25, 0.125}; + bcast->eval_wsum_hess(bcast, w); + + /* For broadcast_row, weights are summed across the m replicas: + * Accumulated weights for log(x): + * w_acc[0] = w[0] + w[1] = 1.0 + 0.5 = 1.5 + * w_acc[1] = w[2] + w[3] = 2.0 + 1.0 = 3.0 + * w_acc[2] = w[4] + w[5] = 0.25 + 0.125 = 0.375 + * + * Hessian of log(x) is diagonal with -w_acc[i] / x[i]^2: + * H[0,0] = -1.5 / 1.0^2 = -1.5 + * H[1,1] = -3.0 / 2.0^2 = -0.75 + * H[2,2] = -0.375 / 4.0^2 = -0.0234375 + */ + double expected_x[3] = {-1.5, -0.75, -0.0234375}; + int expected_p[4] = {0, 1, 2, 3}; + int expected_i[3] = {0, 1, 2}; + + mu_assert("broadcast row wsum_hess: x values fail", + cmp_double_array(bcast->wsum_hess->x, expected_x, 3)); + mu_assert("broadcast row wsum_hess: row pointers fail", + cmp_int_array(bcast->wsum_hess->p, expected_p, 4)); + mu_assert("broadcast row wsum_hess: column indices fail", + cmp_int_array(bcast->wsum_hess->i, expected_i, 3)); + + free_expr(bcast); + return 0; +} + +const char *test_wsum_hess_broadcast_col() +{ + /* Test: wsum_hess of broadcast_col(log(x)) where x is (3, 1) + * x = [1.0, 2.0, 4.0]^T (column vector) + * broadcast to (3, 2): + * [[1.0, 1.0], + * [2.0, 2.0], + * [4.0, 4.0]] + * + * The weights are for the output (3x2 matrix), stored columnwise. + * Weights w = [w00, w10, w20, w01, w11, w21] where w_ij is for row i, col j + * The hessian should sum weights from replicated columns. + */ + double x[3] = {1.0, 2.0, 4.0}; + expr *x_node = new_variable(3, 1, 0, 3); + expr *log_node = new_log(x_node); + expr *bcast = new_broadcast(log_node, 3, 2); + + bcast->forward(bcast, x); + bcast->jacobian_init(bcast); + bcast->wsum_hess_init(bcast); + + /* Weights for the 3x2 output (columnwise): + * w = [1.0, 0.5, 0.25, 2.0, 1.0, 0.5] + * col0 col1 + */ + double w[6] = {1.0, 0.5, 0.25, 2.0, 1.0, 0.5}; + bcast->eval_wsum_hess(bcast, w); + + /* For broadcast_col, weights are summed across the n replicas: + * Accumulated weights for log(x): + * w_acc[0] = w[0] + w[3] = 1.0 + 2.0 = 3.0 + * w_acc[1] = w[1] + w[4] = 0.5 + 1.0 = 1.5 + * w_acc[2] = w[2] + w[5] = 0.25 + 0.5 = 0.75 + * + * Hessian of log(x) is diagonal with -w_acc[i] / x[i]^2: + * H[0,0] = -3.0 / 1.0^2 = -3.0 + * H[1,1] = -1.5 / 2.0^2 = -0.375 + * H[2,2] = -0.75 / 4.0^2 = -0.046875 + */ + double expected_x[3] = {-3.0, -0.375, -0.046875}; + int expected_p[4] = {0, 1, 2, 3}; + int expected_i[3] = {0, 1, 2}; + + mu_assert("broadcast col wsum_hess: x values fail", + cmp_double_array(bcast->wsum_hess->x, expected_x, 3)); + mu_assert("broadcast col wsum_hess: row pointers fail", + cmp_int_array(bcast->wsum_hess->p, expected_p, 4)); + mu_assert("broadcast col wsum_hess: column indices fail", + cmp_int_array(bcast->wsum_hess->i, expected_i, 3)); + + free_expr(bcast); + return 0; +} + +const char *test_wsum_hess_broadcast_scalar_to_matrix() +{ + /* Test: wsum_hess of broadcast_scalar(log(x)) where x is scalar (1, 1) + * x = 5.0 + * broadcast to (2, 3): + * [[5.0, 5.0, 5.0], + * [5.0, 5.0, 5.0]] + * + * The weights are for the output (2x3 matrix), stored columnwise. + * Weights w has 6 elements corresponding to all positions. + * The hessian should sum all weights into the scalar weight. + */ + double x[1] = {5.0}; + expr *x_node = new_variable(1, 1, 0, 1); + expr *log_node = new_log(x_node); + expr *bcast = new_broadcast(log_node, 2, 3); + + bcast->forward(bcast, x); + bcast->jacobian_init(bcast); + bcast->wsum_hess_init(bcast); + + /* Weights for the 2x3 output (columnwise): + * w = [1.0, 0.5, 2.0, 1.0, 0.25, 0.125] + */ + double w[6] = {1.0, 0.5, 2.0, 1.0, 0.25, 0.125}; + bcast->eval_wsum_hess(bcast, w); + + /* For broadcast_scalar, all weights are summed: + * w_acc[0] = sum(w) = 1.0 + 0.5 + 2.0 + 1.0 + 0.25 + 0.125 = 4.875 + * + * Hessian of log(scalar) is -w_acc[0] / x[0]^2: + * H[0,0] = -4.875 / 5.0^2 = -4.875 / 25.0 = -0.195 + */ + double expected_x[1] = {-0.195}; + int expected_p[2] = {0, 1}; + int expected_i[1] = {0}; + + mu_assert("broadcast scalar wsum_hess: x values fail", + cmp_double_array(bcast->wsum_hess->x, expected_x, 1)); + mu_assert("broadcast scalar wsum_hess: row pointers fail", + cmp_int_array(bcast->wsum_hess->p, expected_p, 2)); + mu_assert("broadcast scalar wsum_hess: column indices fail", + cmp_int_array(bcast->wsum_hess->i, expected_i, 1)); + + free_expr(bcast); + return 0; +}