From 4ba651ec33670bfae4bdc9118b1d261b97926ca4 Mon Sep 17 00:00:00 2001 From: Daniel Date: Sun, 18 Jan 2026 13:04:35 -0800 Subject: [PATCH 1/6] started on matmul --- BLAS_README.md | 109 ++++++++++ CMakeLists.txt | 28 ++- include/bivariate.h | 3 + include/blas.h | 40 ++++ include/utils/blas_wrappers.h | 39 ++++ src/bivariate/matmul.c | 336 +++++++++++++++++++++++++++++ src/utils/blas_wrappers.c | 86 ++++++++ tests/all_tests.c | 11 + tests/forward_pass/test_matmul.h | 62 ++++++ tests/jacobian_tests/test_matmul.h | 89 ++++++++ tests/utils/test_blas.h | 64 ++++++ tests/wsum_hess/test_matmul.h | 217 +++++++++++++++++++ 12 files changed, 1083 insertions(+), 1 deletion(-) create mode 100644 BLAS_README.md create mode 100644 include/blas.h create mode 100644 include/utils/blas_wrappers.h create mode 100644 src/bivariate/matmul.c create mode 100644 src/utils/blas_wrappers.c create mode 100644 tests/forward_pass/test_matmul.h create mode 100644 tests/jacobian_tests/test_matmul.h create mode 100644 tests/utils/test_blas.h create mode 100644 tests/wsum_hess/test_matmul.h diff --git a/BLAS_README.md b/BLAS_README.md new file mode 100644 index 0000000..6307d00 --- /dev/null +++ b/BLAS_README.md @@ -0,0 +1,109 @@ +# BLAS Integration + +This project **requires** BLAS/LAPACK libraries for optimized linear algebra operations. + +## Building + +BLAS and LAPACK are **mandatory dependencies**. CMake will search for them and fail if not found. + +### Standard Build + +```bash +mkdir -p build +cd build +cmake .. +make +``` + +## Installing BLAS/LAPACK (Required) + +You must install BLAS and LAPACK before building: + +**Ubuntu/Debian:** +```bash +sudo apt-get install libblas-dev liblapack-dev +``` + +**macOS:** +```bash +brew install openblas lapack +``` + +**Alternative Libraries:** +- **OpenBLAS** (recommended): Fast, open-source BLAS +- **Intel MKL**: Optimized for Intel processors +- **Apple Accelerate**: Built-in on macOS +- **ATLAS**: Automatically tuned + +## Configuration Options + +```bash +# Use 64-bit integers for BLAS (if your BLAS library uses 64-bit ints) +cmake -DBLAS_64=ON .. + +# If BLAS functions have no underscore suffix +cmake -DNO_BLAS_SUFFIX=ON .. +``` + +## Using BLAS in Code + +Include the wrapper header: + +```c +#include "utils/blas_wrappers.h" + +// These use optimized BLAS routines +double norm = vec_norm2(x, n); +double dot_product = vec_dot(x, y, n); +vec_axpy(2.0, x, y, n); // y += 2.0 * x +vec_scale(0.5, x, n); // x *= 0.5 +mat_vec_mult(A, x, y, m, n); // y = A*x +``` + +## Direct BLAS Usage + +For advanced usage, include the BLAS header directly: + +```c +#include "blas.h" + +// Declare the BLAS function you need +extern double BLAS(nrm2)(blas_int *n, const double *x, blas_int *incx); + +void my_function(double *x, int n) { + blas_int bn = (blas_int)n; + blas_int inc = 1; + double norm = BLAS(nrm2)(&bn, x, &inc); +} +``` + +The `BLAS(name)` macro handles function name mangling (e.g., `BLAS(nrm2)` → `dnrm2_`). + +## Available BLAS Functions + +**Level 1 (vector-vector):** +- `BLAS(nrm2)` - L2 norm +- `BLAS(dot)` - dot product +- `BLAS(axpy)` - y = alpha*x + y +- `BLAS(scal)` - x = alpha*x +- `BLASI(amax)` - index of max absolute value + +**Level 2 (matrix-vector):** +- `BLAS(gemv)` - matrix-vector multiply + +**Level 3 (matrix-matrix):** +- `BLAS(gemm)` - matrix-matrix multiply + +**LAPACK:** +- `BLAS(gesv)` - solve linear system +- `BLAS(syevr)` - symmetric eigenvalue decomposition +- Many more... + +## Performance Notes + +BLAS provides significant speedups for: +- Large vector operations (n > 1000) +- Matrix-vector products +- Matrix-matrix products + +For small vectors (n < 100), the overhead may outweigh the benefits. diff --git a/CMakeLists.txt b/CMakeLists.txt index d604fe0..90d3289 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,6 +12,32 @@ else() message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") endif() +# ============================================================================= +# BLAS/LAPACK Configuration (REQUIRED) +# ============================================================================= +find_package(BLAS REQUIRED) +find_package(LAPACK REQUIRED) + +message(STATUS "Found BLAS: ${BLAS_LIBRARIES}") +message(STATUS "Found LAPACK: ${LAPACK_LIBRARIES}") + +set(BLAS_LIBRARIES ${BLAS_LIBRARIES} ${LAPACK_LIBRARIES}) +add_compile_definitions(USE_BLAS) + +# Optional BLAS configuration flags +option(BLAS_64 "Use 64-bit integers for BLAS" OFF) +option(NO_BLAS_SUFFIX "BLAS functions have no suffix" OFF) + +if(BLAS_64) + add_compile_definitions(BLAS_64) + message(STATUS "Using 64-bit BLAS integers") +endif() + +if(NO_BLAS_SUFFIX) + add_compile_definitions(NO_BLAS_SUFFIX) + message(STATUS "Using BLAS with no function suffix") +endif() + # Warning flags (always enabled) add_compile_options( -Wall # Enable most warnings @@ -34,7 +60,7 @@ file(GLOB_RECURSE SOURCES "src/*.c") # Create core library add_library(dnlp_diff ${SOURCES}) -target_link_libraries(dnlp_diff m) +target_link_libraries(dnlp_diff m ${BLAS_LIBRARIES}) # Config-specific compile options target_compile_options(dnlp_diff PRIVATE diff --git a/include/bivariate.h b/include/bivariate.h index 582a4ac..8a28e8e 100644 --- a/include/bivariate.h +++ b/include/bivariate.h @@ -10,6 +10,9 @@ expr *new_quad_over_lin(expr *left, expr *right); expr *new_rel_entr_first_arg_scalar(expr *left, expr *right); expr *new_rel_entr_second_arg_scalar(expr *left, expr *right); +/* Matrix multiplication: Z = X @ Y */ +expr *new_matmul(expr *x, expr *y); + /* Left matrix multiplication: A @ f(x) where A is a constant matrix */ expr *new_left_matmul(expr *u, const CSR_Matrix *A); diff --git a/include/blas.h b/include/blas.h new file mode 100644 index 0000000..11d78be --- /dev/null +++ b/include/blas.h @@ -0,0 +1,40 @@ +#ifndef DNLP_BLAS_H_GUARD +#define DNLP_BLAS_H_GUARD + +#ifdef __cplusplus +extern "C" +{ +#endif + +/* Default to underscore suffix for BLAS/LAPACK function names */ +#ifndef BLAS_SUFFIX +#define BLAS_SUFFIX _ +#endif + +/* Handle BLAS function name mangling */ +#if defined(NO_BLAS_SUFFIX) && NO_BLAS_SUFFIX > 0 +/* No suffix version */ +#define BLAS(x) d##x +#define BLASI(x) id##x +#else +/* Macro indirection needed for BLAS_SUFFIX to work correctly */ +#define _stitch(pre, x, post) pre##x##post +#define _stitch2(pre, x, post) _stitch(pre, x, post) +/* Add suffix (default: underscore) */ +#define BLAS(x) _stitch2(d, x, BLAS_SUFFIX) +#define BLASI(x) _stitch2(id, x, BLAS_SUFFIX) +#endif + +/* BLAS integer type */ +#ifdef BLAS_64 +#include + typedef int64_t blas_int; +#else +typedef int blas_int; +#endif + +#ifdef __cplusplus +} +#endif + +#endif /* DNLP_BLAS_H_GUARD */ diff --git a/include/utils/blas_wrappers.h b/include/utils/blas_wrappers.h new file mode 100644 index 0000000..6abc5c9 --- /dev/null +++ b/include/utils/blas_wrappers.h @@ -0,0 +1,39 @@ +#ifndef BLAS_WRAPPERS_H_GUARD +#define BLAS_WRAPPERS_H_GUARD + +#ifdef __cplusplus +extern "C" +{ +#endif + +#include "blas.h" +#include +#include + + /* + * BLAS wrapper functions for common linear algebra operations. + * + * When compiled with USE_BLAS, these use optimized BLAS routines. + * Otherwise, they fall back to simple C implementations. + */ + + /* Vector operations */ + double vec_norm2(const double *x, int n); + double vec_dot(const double *x, const double *y, int n); + void vec_axpy(double alpha, const double *x, double *y, + int n); /* y += alpha*x */ + void vec_scale(double alpha, double *x, int n); /* x *= alpha */ + + /* Matrix-vector operations */ + void mat_vec_mult(const double *A, const double *x, double *y, int m, + int n); /* y = A*x, A is m x n */ + + /* Matrix-matrix operations */ + void mat_mat_mult(const double *X, const double *Y, double *Z, int m, int k, + int n); /* Z = X*Y, X is m x k, Y is k x n, Z is m x n */ + +#ifdef __cplusplus +} +#endif + +#endif /* BLAS_WRAPPERS_H_GUARD */ diff --git a/src/bivariate/matmul.c b/src/bivariate/matmul.c new file mode 100644 index 0000000..2a38127 --- /dev/null +++ b/src/bivariate/matmul.c @@ -0,0 +1,336 @@ +#include "bivariate.h" +#include "subexpr.h" +#include "utils/blas_wrappers.h" +#include "utils/mini_numpy.h" +#include +#include +#include +#include + +// ------------------------------------------------------------------------------ +// Implementation of matrix multiplication: Z = X @ Y +// where X is m x k and Y is k x n, producing Z which is m x n +// All matrices are stored in column-major order +// ------------------------------------------------------------------------------ + +static void forward(expr *node, const double *u) +{ + expr *x = node->left; + expr *y = node->right; + + /* children's forward passes */ + x->forward(x, u); + y->forward(y, u); + + /* local forward pass: Z = X @ Y + * X is m x k (d1 x d2) + * Y is k x n (d1 x d2) + * Z is m x n (d1 x d2) + */ + int m = x->d1; + int k = x->d2; + int n = y->d2; + + mat_mat_mult(x->value, y->value, node->value, m, k, n); +} + +static bool is_affine(const expr *node) +{ + (void) node; + return false; +} + +static void jacobian_init(expr *node) +{ + expr *x = node->left; + expr *y = node->right; + + /* dimensions: X is m x k, Y is k x n, Z is m x n */ + int m = x->d1; + int k = x->d2; + int n = y->d2; + int nnz = m * n * 2 * k; + node->jacobian = new_csr_matrix(node->size, node->n_vars, nnz); + + /* fill sparsity pattern */ + int nnz_idx = 0; + for (int i = 0; i < node->size; i++) + { + /* Convert flat index to (row, col) in Z */ + int row = i % m; + int col = i / m; + + node->jacobian->p[i] = nnz_idx; + + /* X has lower var_id */ + if (x->var_id < y->var_id) + { + /* sparsity pattern of kron(Y^T, I) for this row */ + for (int j = 0; j < k; j++) + { + node->jacobian->i[nnz_idx++] = x->var_id + row + j * m; + } + + /* sparsity pattern of kron(I, X) for this row */ + for (int j = 0; j < k; j++) + { + node->jacobian->i[nnz_idx++] = y->var_id + col * k + j; + } + } + else /* Y has lower var_id */ + { + /* sparsity pattern of kron(I, X) for this row */ + for (int j = 0; j < k; j++) + { + node->jacobian->i[nnz_idx++] = y->var_id + col * k + j; + } + + /* sparsity pattern of kron(Y^T, I) for this row */ + for (int j = 0; j < k; j++) + { + node->jacobian->i[nnz_idx++] = x->var_id + row + j * m; + } + } + } + node->jacobian->p[node->size] = nnz_idx; + assert(nnz_idx == nnz); +} + +static void eval_jacobian(expr *node) +{ + expr *x = node->left; + expr *y = node->right; + + /* dimensions: X is m x k, Y is k x n, Z is m x n */ + int m = x->d1; + int k = x->d2; + double *Jx = node->jacobian->x; + + /* fill values row-by-row */ + for (int i = 0; i < node->size; i++) + { + int row = i % m; /* row in Z */ + int col = i / m; /* col in Z */ + int pos = node->jacobian->p[i]; + + if (x->var_id < y->var_id) + { + /* Y^T contribution: contiguous column of Y at 'col' */ + memcpy(Jx + pos, y->value + col * k, k * sizeof(double)); + + /* X row contribution: stride m across columns */ + for (int j = 0; j < k; j++) + { + Jx[pos + k + j] = x->value[row + j * m]; + } + } + else + { + /* X row contribution: stride m across columns */ + for (int j = 0; j < k; j++) + { + Jx[pos + j] = x->value[row + j * m]; + } + + /* Y^T contribution: contiguous column of Y at 'col' */ + memcpy(Jx + pos + k, y->value + col * k, k * sizeof(double)); + } + } +} + +static void wsum_hess_init(expr *node) +{ + expr *x = node->left; + expr *y = node->right; + + /* dimensions: X is m x k, Y is k x n, Z is m x n */ + int m = x->d1; + int k = x->d2; + int n = y->d2; + int total_nnz = 2 * m * k * n; + node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, total_nnz); + int nnz = 0; + int *Hi = node->wsum_hess->i; + int *Hp = node->wsum_hess->p; + int start, i; + + if (x->var_id < y->var_id) + { + /* fill rows corresponding to x */ + for (i = 0; i < x->size; i++) + { + Hp[x->var_id + i] = nnz; + start = y->var_id + (i / m); + for (int col = 0; col < n; col++) + { + Hi[nnz++] = start + col * k; + } + } + + /* fill rows between x and y */ + for (i = x->var_id + x->size; i < y->var_id; i++) + { + Hp[i] = nnz; + } + + /* fill rows corresponding to y */ + for (i = 0; i < y->size; i++) + { + Hp[y->var_id + i] = nnz; + start = x->var_id + (i % k) * m; + for (int row = 0; row < m; row++) + { + Hi[nnz++] = start + row; + } + } + + /* fill rows after y */ + for (i = y->var_id + y->size; i <= node->n_vars; i++) + { + Hp[i] = nnz; + } + } + else + { + /* Y has lower var_id than X */ + /* fill rows corresponding to y */ + for (i = 0; i < y->size; i++) + { + Hp[y->var_id + i] = nnz; + start = x->var_id + (i % k) * m; + for (int row = 0; row < m; row++) + { + Hi[nnz++] = start + row; + } + } + + /* fill rows between y and x */ + for (i = y->var_id + y->size; i < x->var_id; i++) + { + Hp[i] = nnz; + } + + /* fill rows corresponding to x */ + for (i = 0; i < x->size; i++) + { + Hp[x->var_id + i] = nnz; + start = y->var_id + (i / m); + for (int col = 0; col < n; col++) + { + Hi[nnz++] = start + col * k; + } + } + + /* fill rows after x */ + for (i = x->var_id + x->size; i <= node->n_vars; i++) + { + Hp[i] = nnz; + } + } + + Hp[node->n_vars] = nnz; + assert(nnz == total_nnz); +} + +static void eval_wsum_hess(expr *node, const double *w) +{ + expr *x = node->left; + expr *y = node->right; + + int m = x->d1; + int k = x->d2; + int n = y->d2; + int offset = 0; + + double *Hx = node->wsum_hess->x; + + if (x->var_id < y->var_id) + { + /* X variable rows: For X[row, k_idx], couples with Y[k_idx, col] for all col + */ + for (int k_idx = 0; k_idx < k; k_idx++) + { + for (int row = 0; row < m; row++) + { + for (int col = 0; col < n; col++) + { + Hx[offset++] = w[row + col * m]; + } + } + } + + /* Y variable rows: For Y[k_idx, col], couples with X[row, k_idx] for all row + */ + for (int col = 0; col < n; col++) + { + for (int k_idx = 0; k_idx < k; k_idx++) + { + memcpy(Hx + offset, w + col * m, m * sizeof(double)); + offset += m; + } + } + } + else + { + /* Y variable rows come first: For Y[k_idx, col], couples with X[row, k_idx] + */ + for (int col = 0; col < n; col++) + { + for (int k_idx = 0; k_idx < k; k_idx++) + { + memcpy(Hx + offset, w + col * m, m * sizeof(double)); + offset += m; + } + } + + /* X variable rows: For X[row, k_idx], couples with Y[k_idx, col] for all col + */ + for (int k_idx = 0; k_idx < k; k_idx++) + { + for (int row = 0; row < m; row++) + { + for (int col = 0; col < n; col++) + { + Hx[offset++] = w[row + col * m]; + } + } + } + } +} + +expr *new_matmul(expr *x, expr *y) +{ + /* Verify dimensions: x->d2 must equal y->d1 */ + if (x->d2 != y->d1) + { + fprintf(stderr, + "Error in new_matmul: dimension mismatch. " + "X is %d x %d, Y is %d x %d. X.d2 (%d) must equal Y.d1 (%d)\n", + x->d1, x->d2, y->d1, y->d2, x->d2, y->d1); + exit(1); + } + + /* verify both are variables and not the same variable */ + if (x->var_id == NOT_A_VARIABLE || y->var_id == NOT_A_VARIABLE || + x->var_id == y->var_id) + { + fprintf(stderr, "Error in new_matmul: operands must be variables and not " + "the same variable\n"); + exit(1); + } + + /* Allocate the expression node */ + expr *node = (expr *) calloc(1, sizeof(expr)); + + /* Initialize with d1 = x->d1, d2 = y->d2 (result is m x n) */ + init_expr(node, x->d1, y->d2, x->n_vars, forward, jacobian_init, eval_jacobian, + is_affine, wsum_hess_init, eval_wsum_hess, NULL); + + /* Set children */ + node->left = x; + node->right = y; + expr_retain(x); + expr_retain(y); + + return node; +} diff --git a/src/utils/blas_wrappers.c b/src/utils/blas_wrappers.c new file mode 100644 index 0000000..8608fc4 --- /dev/null +++ b/src/utils/blas_wrappers.c @@ -0,0 +1,86 @@ +#include "utils/blas_wrappers.h" + +/* BLAS function prototypes */ +#ifdef __cplusplus +extern "C" +{ +#endif + + double BLAS(nrm2)(blas_int *n, const double *x, blas_int *incx); + double BLAS(dot)(const blas_int *n, const double *x, const blas_int *incx, + const double *y, const blas_int *incy); + void BLAS(axpy)(blas_int *n, const double *alpha, const double *x, + blas_int *incx, double *y, blas_int *incy); + void BLAS(scal)(const blas_int *n, const double *alpha, double *x, + const blas_int *incx); + void BLAS(gemv)(const char *trans, const blas_int *m, const blas_int *n, + const double *alpha, const double *a, const blas_int *lda, + const double *x, const blas_int *incx, const double *beta, + double *y, const blas_int *incy); + void BLAS(gemm)(const char *transa, const char *transb, const blas_int *m, + const blas_int *n, const blas_int *k, const double *alpha, + const double *a, const blas_int *lda, const double *b, + const blas_int *ldb, const double *beta, double *c, + const blas_int *ldc); + +#ifdef __cplusplus +} +#endif + +/* BLAS implementations */ + +double vec_norm2(const double *x, int n) +{ + blas_int bn = (blas_int) n; + blas_int inc = 1; + return BLAS(nrm2)(&bn, x, &inc); +} + +double vec_dot(const double *x, const double *y, int n) +{ + blas_int bn = (blas_int) n; + blas_int inc = 1; + return BLAS(dot)(&bn, x, &inc, y, &inc); +} + +void vec_axpy(double alpha, const double *x, double *y, int n) +{ + blas_int bn = (blas_int) n; + blas_int inc = 1; + BLAS(axpy)(&bn, &alpha, x, &inc, y, &inc); +} + +void vec_scale(double alpha, double *x, int n) +{ + blas_int bn = (blas_int) n; + blas_int inc = 1; + BLAS(scal)(&bn, &alpha, x, &inc); +} + +void mat_vec_mult(const double *A, const double *x, double *y, int m, int n) +{ + /* y = A*x where A is m x n (column-major) */ + blas_int bm = (blas_int) m; + blas_int bn = (blas_int) n; + blas_int inc = 1; + double alpha = 1.0; + double beta = 0.0; + char trans = 'N'; + + BLAS(gemv)(&trans, &bm, &bn, &alpha, A, &bm, x, &inc, &beta, y, &inc); +} + +void mat_mat_mult(const double *X, const double *Y, double *Z, int m, int k, int n) +{ + /* Z = X*Y where X is m x k, Y is k x n, Z is m x n (all column-major) */ + blas_int bm = (blas_int) m; + blas_int bk = (blas_int) k; + blas_int bn = (blas_int) n; + double alpha = 1.0; + double beta = 0.0; + char transX = 'N'; + char transY = 'N'; + + BLAS(gemm)(&transX, &transY, &bm, &bn, &bk, &alpha, X, &bm, Y, &bk, &beta, Z, + &bm); +} diff --git a/tests/all_tests.c b/tests/all_tests.c index 5de1b00..5d72ceb 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -14,6 +14,7 @@ #include "forward_pass/composite/test_composite.h" #include "forward_pass/elementwise/test_exp.h" #include "forward_pass/elementwise/test_log.h" +#include "forward_pass/test_matmul.h" #include "forward_pass/test_prod_axis_one.h" #include "forward_pass/test_prod_axis_zero.h" #include "jacobian_tests/test_broadcast.h" @@ -25,6 +26,7 @@ #include "jacobian_tests/test_index.h" #include "jacobian_tests/test_left_matmul.h" #include "jacobian_tests/test_log.h" +#include "jacobian_tests/test_matmul.h" #include "jacobian_tests/test_neg.h" #include "jacobian_tests/test_prod.h" #include "jacobian_tests/test_prod_axis_one.h" @@ -39,6 +41,7 @@ #include "jacobian_tests/test_sum.h" #include "jacobian_tests/test_trace.h" #include "problem/test_problem.h" +#include "utils/test_blas.h" #include "utils/test_csc_matrix.h" #include "utils/test_csr_matrix.h" #include "wsum_hess/elementwise/test_entr.h" @@ -55,6 +58,7 @@ #include "wsum_hess/test_hstack.h" #include "wsum_hess/test_index.h" #include "wsum_hess/test_left_matmul.h" +#include "wsum_hess/test_matmul.h" #include "wsum_hess/test_multiply.h" #include "wsum_hess/test_prod.h" #include "wsum_hess/test_prod_axis_one.h" @@ -94,6 +98,7 @@ int main(void) mu_run_test(test_broadcast_matrix, tests_run); mu_run_test(test_forward_prod_axis_zero, tests_run); mu_run_test(test_forward_prod_axis_one, tests_run); + mu_run_test(test_matmul, tests_run); printf("\n--- Jacobian Tests ---\n"); mu_run_test(test_neg_jacobian, tests_run); @@ -156,6 +161,7 @@ int main(void) mu_run_test(test_jacobian_left_matmul_log_composite, tests_run); mu_run_test(test_jacobian_right_matmul_log, tests_run); mu_run_test(test_jacobian_right_matmul_log_vector, tests_run); + mu_run_test(test_jacobian_matmul, tests_run); printf("\n--- Weighted Sum of Hessian Tests ---\n"); mu_run_test(test_wsum_hess_log, tests_run); @@ -210,6 +216,8 @@ int main(void) mu_run_test(test_wsum_hess_left_matmul, tests_run); mu_run_test(test_wsum_hess_left_matmul_matrix, tests_run); mu_run_test(test_wsum_hess_left_matmul_composite, tests_run); + mu_run_test(test_wsum_hess_matmul, tests_run); + mu_run_test(test_wsum_hess_matmul_yx, tests_run); mu_run_test(test_wsum_hess_right_matmul, tests_run); mu_run_test(test_wsum_hess_right_matmul_vector, tests_run); mu_run_test(test_wsum_hess_broadcast_row, tests_run); @@ -250,6 +258,9 @@ int main(void) mu_run_test(test_problem_constraint_forward, tests_run); mu_run_test(test_problem_hessian, tests_run); + printf("\n--- BLAS Tests ---\n"); + mu_run_test(test_matrix_matrix_mult, tests_run); + printf("\n=== All %d tests passed ===\n", tests_run); return 0; diff --git a/tests/forward_pass/test_matmul.h b/tests/forward_pass/test_matmul.h new file mode 100644 index 0000000..b2731d1 --- /dev/null +++ b/tests/forward_pass/test_matmul.h @@ -0,0 +1,62 @@ +#include +#include +#include + +#include "bivariate.h" +#include "expr.h" +#include "minunit.h" +#include "test_helpers.h" + +const char *test_matmul() +{ + /* Test: Z = X @ Y where + * X is 3x2: [1 4] Y is 2x4: [7 9 11 13] + * [2 5] [8 10 12 14] + * [3 6] + * + * Expected Z (3x4): [39 49 59 69] + * [54 68 82 96] + * [69 87 105 123] + */ + + /* Create X variable (3 x 2) - stored in column-major order */ + double u_x[6] = {1.0, 2.0, 3.0, /* first column */ + 4.0, 5.0, 6.0}; /* second column */ + expr *X = new_variable(3, 2, 0, 20); + + /* Create Y variable (2 x 4) - stored in column-major order */ + double u_y[8] = {7.0, 8.0, /* first column */ + 9.0, 10.0, /* second column */ + 11.0, 12.0, /* third column */ + 13.0, 14.0}; /* fourth column */ + expr *Y = new_variable(2, 4, 6, 20); + + /* Create matmul expression */ + expr *Z = new_matmul(X, Y); + + /* Concatenate parameter vectors */ + double u[14]; + for (int i = 0; i < 6; i++) u[i] = u_x[i]; + for (int i = 0; i < 8; i++) u[6 + i] = u_y[i]; + + /* Evaluate forward pass */ + Z->forward(Z, u); + + /* Expected result (3 x 4) in column-major order */ + double expected[12] = {39.0, 54.0, 69.0, /* first column */ + 49.0, 68.0, 87.0, /* second column */ + 59.0, 82.0, 105.0, /* third column */ + 69.0, 96.0, 123.0}; /* fourth column */ + + /* Verify dimensions */ + mu_assert("matmul result should have d1=3", Z->d1 == 3); + mu_assert("matmul result should have d2=4", Z->d2 == 4); + mu_assert("matmul result should have size=12", Z->size == 12); + + /* Verify values */ + mu_assert("Matmul forward pass test failed", + cmp_double_array(Z->value, expected, 12)); + + free_expr(Z); + return 0; +} diff --git a/tests/jacobian_tests/test_matmul.h b/tests/jacobian_tests/test_matmul.h new file mode 100644 index 0000000..16a23fc --- /dev/null +++ b/tests/jacobian_tests/test_matmul.h @@ -0,0 +1,89 @@ +#include "bivariate.h" +#include "expr.h" +#include "minunit.h" +#include "test_helpers.h" +#include +#include + +const char *test_jacobian_matmul() +{ + /* Test: Z = X @ Y where X is 2x3, Y is 3x4 + * var = (X, Y) where X starts at index 0 (size 6), Y starts at index 6 (size 12) + * Total n_vars = 18 + * Z is 2x4 (size 8) + * + * Each row of Jacobian has 6 nonzeros (3 from X, 3 from Y) + * Total nnz = 8 * 6 = 48 + */ + + int m = 2, k = 3, n = 4; + int x_size = m * k; // 6 + int y_size = k * n; // 12 + int z_size = m * n; // 8 + int n_vars = x_size + y_size; // 18 + + /* X values (2x3) in column-major order */ + double x_vals[6] = {1.0, 2.0, /* col 0 */ + 3.0, 4.0, /* col 1 */ + 5.0, 6.0}; /* col 2 */ + + /* Y values (3x4) in column-major order */ + double y_vals[12] = {7.0, 8.0, 9.0, /* col 0 */ + 10.0, 11.0, 12.0, /* col 1 */ + 13.0, 14.0, 15.0, /* col 2 */ + 16.0, 17.0, 18.0}; /* col 3 */ + + /* Combined parameter vector */ + double u_vals[18]; + for (int i = 0; i < 6; i++) u_vals[i] = x_vals[i]; + for (int i = 0; i < 12; i++) u_vals[6 + i] = y_vals[i]; + + /* Create variables */ + expr *X = new_variable(m, k, 0, n_vars); + expr *Y = new_variable(k, n, x_size, n_vars); + expr *Z = new_matmul(X, Y); + + /* Forward pass and jacobian initialization */ + Z->forward(Z, u_vals); + Z->jacobian_init(Z); + Z->eval_jacobian(Z); + + /* Verify sparsity pattern */ + mu_assert("Jacobian should have 8 rows", Z->jacobian->m == z_size); + mu_assert("Jacobian should have 18 columns", Z->jacobian->n == n_vars); + mu_assert("Jacobian should have 48 nonzeros", Z->jacobian->nnz == 48); + + /* Check row pointers: each row should have 6 entries */ + int expected_p[9] = {0, 6, 12, 18, 24, 30, 36, 42, 48}; + mu_assert("Row pointers incorrect", + cmp_int_array(Z->jacobian->p, expected_p, 9)); + + int expected_i[48] = {0, 2, 4, 6, 7, 8, /* row 0 */ + 1, 3, 5, 6, 7, 8, /* row 1 */ + 0, 2, 4, 9, 10, 11, /* row 2 */ + 1, 3, 5, 9, 10, 11, /* row 3 */ + 0, 2, 4, 12, 13, 14, /* row 4 */ + 1, 3, 5, 12, 13, 14, /* row 5 */ + 0, 2, 4, 15, 16, 17, /* row 6 */ + 1, 3, 5, 15, 16, 17}; /* row 7 */ + mu_assert("Column indices incorrect", + cmp_int_array(Z->jacobian->i, expected_i, 48)); + + /* Verify Jacobian values row-wise: for each row, values are + [Y^T row for the column, X row values] since X has lower var_id */ + double expected_x[48] = { + /* row 0 (col 0) */ 7.0, 8.0, 9.0, 1.0, 3.0, 5.0, + /* row 1 (col 0) */ 7.0, 8.0, 9.0, 2.0, 4.0, 6.0, + /* row 2 (col 1) */ 10.0, 11.0, 12.0, 1.0, 3.0, 5.0, + /* row 3 (col 1) */ 10.0, 11.0, 12.0, 2.0, 4.0, 6.0, + /* row 4 (col 2) */ 13.0, 14.0, 15.0, 1.0, 3.0, 5.0, + /* row 5 (col 2) */ 13.0, 14.0, 15.0, 2.0, 4.0, 6.0, + /* row 6 (col 3) */ 16.0, 17.0, 18.0, 1.0, 3.0, 5.0, + /* row 7 (col 3) */ 16.0, 17.0, 18.0, 2.0, 4.0, 6.0}; + + mu_assert("Jacobian values incorrect", + cmp_double_array(Z->jacobian->x, expected_x, 48)); + + free_expr(Z); + return 0; +} diff --git a/tests/utils/test_blas.h b/tests/utils/test_blas.h new file mode 100644 index 0000000..52e07e7 --- /dev/null +++ b/tests/utils/test_blas.h @@ -0,0 +1,64 @@ +#ifndef TEST_BLAS_H +#define TEST_BLAS_H + +#include "../minunit.h" +#include "blas.h" +#include "utils/blas_wrappers.h" +#include +#include + +/* BLAS dgemm prototype: C = alpha*A*B + beta*C */ +extern void BLAS(gemm)(const char *transa, const char *transb, blas_int *m, + blas_int *n, blas_int *k, const double *alpha, + const double *a, blas_int *lda, const double *b, + blas_int *ldb, const double *beta, double *c, blas_int *ldc); + +static const char *test_matrix_matrix_mult(void) +{ + /* Test: C = A * B where + * A is 3x2: [1 4] B is 2x3: [7 10 13] + * [2 5] [8 11 14] + * [3 6] + * + * Expected C (3x3): [39 54 69] + * [54 75 96] + * [69 96 123] + */ + + /* Matrices stored in column-major order */ + double A[6] = {1.0, 2.0, 3.0, /* first column */ + 4.0, 5.0, 6.0}; /* second column */ + + double B[6] = {7.0, 8.0, /* first column */ + 10.0, 11.0, /* second column */ + 13.0, 14.0}; /* third column */ + + double C[9] = {0.0, 0.0, 0.0, /* initialize to zero */ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; + + /* Expected result (column-major) */ + double expected[9] = {39.0, 54.0, 69.0, /* first column */ + 54.0, 75.0, 96.0, /* second column */ + 69.0, 96.0, 123.0}; /* third column */ + + /* Compute C = A * B using wrapper */ + /* A is 3x2, B is 2x3, C is 3x3 */ + mat_mat_mult(A, B, C, 3, 2, 3); + + /* Verify result */ + double tol = 1e-10; + for (int i = 0; i < 9; i++) + { + double error = fabs(C[i] - expected[i]); + if (error > tol) + { + printf("Error at index %d: got %f, expected %f (diff: %e)\n", i, C[i], + expected[i], error); + return "Matrix multiplication result incorrect"; + } + } + + return NULL; /* Test passed */ +} + +#endif /* TEST_BLAS_H */ diff --git a/tests/wsum_hess/test_matmul.h b/tests/wsum_hess/test_matmul.h new file mode 100644 index 0000000..a6b74df --- /dev/null +++ b/tests/wsum_hess/test_matmul.h @@ -0,0 +1,217 @@ +#include "bivariate.h" +#include "expr.h" +#include "minunit.h" +#include "test_helpers.h" +#include +#include + +const char *test_wsum_hess_matmul() +{ + /* Test: Z = X @ Y where X is 2x3, Y is 3x4 + * var = (X, Y) where X starts at index 0 (size 6), Y starts at index 6 (size 12) + * Total n_vars = 18 + * Z is 2x4 (size 8) + */ + + int m = 2, k = 3, n = 4; + int x_size = m * k; // 6 + int y_size = k * n; // 12 + int n_vars = x_size + y_size; // 18 + + /* X values (2x3) in column-major order */ + double x_vals[6] = {1.0, 2.0, /* col 0 */ + 3.0, 4.0, /* col 1 */ + 5.0, 6.0}; /* col 2 */ + + /* Y values (3x4) in column-major order */ + double y_vals[12] = {7.0, 8.0, 9.0, /* col 0 */ + 10.0, 11.0, 12.0, /* col 1 */ + 13.0, 14.0, 15.0, /* col 2 */ + 16.0, 17.0, 18.0}; /* col 3 */ + + /* Combined parameter vector */ + double u_vals[18]; + for (int i = 0; i < 6; i++) u_vals[i] = x_vals[i]; + for (int i = 0; i < 12; i++) u_vals[6 + i] = y_vals[i]; + + /* Weights for the 8 outputs */ + double w[8] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}; + + /* Create variables */ + expr *X = new_variable(m, k, 0, n_vars); + expr *Y = new_variable(k, n, x_size, n_vars); + expr *Z = new_matmul(X, Y); + + /* Forward pass and Hessian initialization */ + Z->forward(Z, u_vals); + Z->wsum_hess_init(Z); + Z->eval_wsum_hess(Z, w); + + /* Verify Hessian dimensions and sparsity */ + mu_assert("Hessian should be 18x18", Z->wsum_hess->m == n_vars); + mu_assert("Hessian should be 18x18", Z->wsum_hess->n == n_vars); + mu_assert("Hessian should have 48 nonzeros", Z->wsum_hess->nnz == 48); + + int expected_p[19] = {0, 4, 8, 12, 16, 20, 24, 26, 28, 30, + 32, 34, 36, 38, 40, 42, 44, 46, 48}; + + mu_assert("Row pointers incorrect", + cmp_int_array(Z->wsum_hess->p, expected_p, 19)); + + int expected_i[48] = {6, 9, 12, 15, /* row 0 */ + 6, 9, 12, 15, /* row 1 */ + 7, 10, 13, 16, /* row 2 */ + 7, 10, 13, 16, /* row 3 */ + 8, 11, 14, 17, /* row 4 */ + 8, 11, 14, 17, /* row 5 */ + 0, 1, /* row 6*/ + 2, 3, /* row 7*/ + 4, 5, /* row 8*/ + 0, 1, /* row 9*/ + 2, 3, /* row 10*/ + 4, 5, /* row 11*/ + 0, 1, /* row 12*/ + 2, 3, /* row 13*/ + 4, 5, /* row 14*/ + 0, 1, /* row 15*/ + 2, 3, /* row 16*/ + 4, 5}; + + mu_assert("Column indices incorrect", + cmp_int_array(Z->wsum_hess->i, expected_i, 48)); + + double expected_x[48] = {1.0, 3.0, 5.0, 7.0, /* row 0 */ + 2.0, 4.0, 6.0, 8.0, /* row 1 */ + 1.0, 3.0, 5.0, 7.0, /* row 2 */ + 2.0, 4.0, 6.0, 8.0, /* row 3 */ + 1.0, 3.0, 5.0, 7.0, /* row 4 */ + 2.0, 4.0, 6.0, 8.0, /* row 5 */ + 1.0, 2.0, /* row 6 */ + 1.0, 2.0, /* row 7 */ + 1.0, 2.0, /* row 8 */ + 3.0, 4.0, /* row 9 */ + 3.0, 4.0, /* row 10 */ + 3.0, 4.0, /* row 11 */ + 5.0, 6.0, /* row 12 */ + 5.0, 6.0, /* row 13 */ + 5.0, 6.0, /* row 14 */ + 7.0, 8.0, /* row 15 */ + 7.0, 8.0, /* row 16 */ + 7.0, 8.0}; /* row 17 */ + + mu_assert("Hessian values incorrect", + cmp_double_array(Z->wsum_hess->x, expected_x, 48)); + + free_expr(Z); + return 0; +} + +const char *test_wsum_hess_matmul_yx() +{ + /* Test: Z = X @ Y where X is 2x3, Y is 3x4 + * var = (Y, X) where Y starts at index 0 (size 12), X starts at index 12 (size + * 6) Total n_vars = 18 Z is 2x4 (size 8) + */ + + int m = 2, k = 3, n = 4; + int y_size = k * n; // 12 + int x_size = m * k; // 6 + int n_vars = x_size + y_size; // 18 + + /* X values (2x3) in column-major order */ + double x_vals[6] = {1.0, 2.0, /* col 0 */ + 3.0, 4.0, /* col 1 */ + 5.0, 6.0}; /* col 2 */ + + /* Y values (3x4) in column-major order */ + double y_vals[12] = {7.0, 8.0, 9.0, /* col 0 */ + 10.0, 11.0, 12.0, /* col 1 */ + 13.0, 14.0, 15.0, /* col 2 */ + 16.0, 17.0, 18.0}; /* col 3 */ + + /* Combined parameter vector: Y first, then X */ + double u_vals[18]; + for (int i = 0; i < 12; i++) u_vals[i] = y_vals[i]; + for (int i = 0; i < 6; i++) u_vals[12 + i] = x_vals[i]; + + /* Weights for the 8 outputs */ + double w[8] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}; + + /* Create variables with Y first */ + expr *Y = new_variable(k, n, 0, n_vars); + expr *X = new_variable(m, k, y_size, n_vars); + expr *Z = new_matmul(X, Y); + + /* Forward pass and Hessian initialization */ + Z->forward(Z, u_vals); + Z->wsum_hess_init(Z); + Z->eval_wsum_hess(Z, w); + + /* Verify Hessian dimensions and sparsity */ + mu_assert("Hessian should be 18x18", Z->wsum_hess->m == n_vars); + mu_assert("Hessian should be 18x18", Z->wsum_hess->n == n_vars); + mu_assert("Hessian should have 48 nonzeros", Z->wsum_hess->nnz == 48); + + /* Row pointers when Y < X: + * Rows 0-11 (Y variables): each couples with m=2 X variables + * Rows 12-17 (X variables): each couples with n=4 Y variables + */ + int expected_p[19] = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, + 20, 22, 24, 28, 32, 36, 40, 44, 48}; + + mu_assert("Row pointers incorrect", + cmp_int_array(Z->wsum_hess->p, expected_p, 19)); + + /* Column indices when Y < X: + * Y[k_idx, col] couples with X[row, k_idx] for all row + * X variable index = 12 + row + k_idx*m + * X[row, k_idx] couples with Y[k_idx, col] for all col + * Y variable index = 0 + k_idx + col*k + */ + int expected_i[48] = {12, 13, /* row 0: Y[0,0] */ + 14, 15, /* row 1: Y[1,0] */ + 16, 17, /* row 2: Y[2,0] */ + 12, 13, /* row 3: Y[0,1] */ + 14, 15, /* row 4: Y[1,1] */ + 16, 17, /* row 5: Y[2,1] */ + 12, 13, /* row 6: Y[0,2] */ + 14, 15, /* row 7: Y[1,2] */ + 16, 17, /* row 8: Y[2,2] */ + 12, 13, /* row 9: Y[0,3] */ + 14, 15, /* row 10: Y[1,3] */ + 16, 17, /* row 11: Y[2,3] */ + 0, 3, 6, 9, /* row 12: X[0,0] */ + 0, 3, 6, 9, /* row 13: X[1,0] */ + 1, 4, 7, 10, /* row 14: X[0,1] */ + 1, 4, 7, 10, /* row 15: X[1,1] */ + 2, 5, 8, 11, /* row 16: X[0,2] */ + 2, 5, 8, 11}; /* row 17: X[1,2] */ + + mu_assert("Column indices incorrect", + cmp_int_array(Z->wsum_hess->i, expected_i, 48)); + + double expected_x[48] = {1.0, 2.0, /* row 0 */ + 1.0, 2.0, /* row 1 */ + 1.0, 2.0, /* row 2 */ + 3.0, 4.0, /* row 3 */ + 3.0, 4.0, /* row 4 */ + 3.0, 4.0, /* row 5 */ + 5.0, 6.0, /* row 6 */ + 5.0, 6.0, /* row 7 */ + 5.0, 6.0, /* row 8 */ + 7.0, 8.0, /* row 9 */ + 7.0, 8.0, /* row 10 */ + 7.0, 8.0, /* row 11 */ + 1.0, 3.0, 5.0, 7.0, /* row 12 */ + 2.0, 4.0, 6.0, 8.0, /* row 13 */ + 1.0, 3.0, 5.0, 7.0, /* row 14 */ + 2.0, 4.0, 6.0, 8.0, /* row 15 */ + 1.0, 3.0, 5.0, 7.0, /* row 16 */ + 2.0, 4.0, 6.0, 8.0}; /* row 17 */ + + mu_assert("Hessian values incorrect", + cmp_double_array(Z->wsum_hess->x, expected_x, 48)); + + free_expr(Z); + return 0; +} From 244e7464d1e5985b527393c763c34036d1236ee2 Mon Sep 17 00:00:00 2001 From: Daniel Date: Sun, 18 Jan 2026 14:40:26 -0800 Subject: [PATCH 2/6] removed blas for now --- BLAS_README.md | 109 ---------------------------------- CMakeLists.txt | 61 +++++++++++-------- include/utils/blas_wrappers.h | 39 ------------ include/utils/mini_numpy.h | 3 + pyproject.toml | 4 +- python/atoms/matmul.h | 33 ++++++++++ python/bindings.c | 3 + src/bivariate/matmul.c | 13 ++-- src/utils/blas_wrappers.c | 86 --------------------------- src/utils/mini_numpy.c | 15 +++++ tests/all_tests.c | 6 +- tests/utils/test_blas.h | 64 -------------------- 12 files changed, 98 insertions(+), 338 deletions(-) delete mode 100644 BLAS_README.md delete mode 100644 include/utils/blas_wrappers.h create mode 100644 python/atoms/matmul.h delete mode 100644 src/utils/blas_wrappers.c delete mode 100644 tests/utils/test_blas.h diff --git a/BLAS_README.md b/BLAS_README.md deleted file mode 100644 index 6307d00..0000000 --- a/BLAS_README.md +++ /dev/null @@ -1,109 +0,0 @@ -# BLAS Integration - -This project **requires** BLAS/LAPACK libraries for optimized linear algebra operations. - -## Building - -BLAS and LAPACK are **mandatory dependencies**. CMake will search for them and fail if not found. - -### Standard Build - -```bash -mkdir -p build -cd build -cmake .. -make -``` - -## Installing BLAS/LAPACK (Required) - -You must install BLAS and LAPACK before building: - -**Ubuntu/Debian:** -```bash -sudo apt-get install libblas-dev liblapack-dev -``` - -**macOS:** -```bash -brew install openblas lapack -``` - -**Alternative Libraries:** -- **OpenBLAS** (recommended): Fast, open-source BLAS -- **Intel MKL**: Optimized for Intel processors -- **Apple Accelerate**: Built-in on macOS -- **ATLAS**: Automatically tuned - -## Configuration Options - -```bash -# Use 64-bit integers for BLAS (if your BLAS library uses 64-bit ints) -cmake -DBLAS_64=ON .. - -# If BLAS functions have no underscore suffix -cmake -DNO_BLAS_SUFFIX=ON .. -``` - -## Using BLAS in Code - -Include the wrapper header: - -```c -#include "utils/blas_wrappers.h" - -// These use optimized BLAS routines -double norm = vec_norm2(x, n); -double dot_product = vec_dot(x, y, n); -vec_axpy(2.0, x, y, n); // y += 2.0 * x -vec_scale(0.5, x, n); // x *= 0.5 -mat_vec_mult(A, x, y, m, n); // y = A*x -``` - -## Direct BLAS Usage - -For advanced usage, include the BLAS header directly: - -```c -#include "blas.h" - -// Declare the BLAS function you need -extern double BLAS(nrm2)(blas_int *n, const double *x, blas_int *incx); - -void my_function(double *x, int n) { - blas_int bn = (blas_int)n; - blas_int inc = 1; - double norm = BLAS(nrm2)(&bn, x, &inc); -} -``` - -The `BLAS(name)` macro handles function name mangling (e.g., `BLAS(nrm2)` → `dnrm2_`). - -## Available BLAS Functions - -**Level 1 (vector-vector):** -- `BLAS(nrm2)` - L2 norm -- `BLAS(dot)` - dot product -- `BLAS(axpy)` - y = alpha*x + y -- `BLAS(scal)` - x = alpha*x -- `BLASI(amax)` - index of max absolute value - -**Level 2 (matrix-vector):** -- `BLAS(gemv)` - matrix-vector multiply - -**Level 3 (matrix-matrix):** -- `BLAS(gemm)` - matrix-matrix multiply - -**LAPACK:** -- `BLAS(gesv)` - solve linear system -- `BLAS(syevr)` - symmetric eigenvalue decomposition -- Many more... - -## Performance Notes - -BLAS provides significant speedups for: -- Large vector operations (n > 1000) -- Matrix-vector products -- Matrix-matrix products - -For small vectors (n < 100), the overhead may outweigh the benefits. diff --git a/CMakeLists.txt b/CMakeLists.txt index 90d3289..d8503ff 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,16 @@ + cmake_minimum_required(VERSION 3.15) project(DNLP_Diff_Engine C) + +# Prefer OpenBLAS unless user overrides BLA_VENDOR +# if(NOT DEFINED BLA_VENDOR) +# set(BLA_VENDOR "OpenBLAS" CACHE STRING "Preferred BLAS vendor" FORCE) +# message(STATUS "BLA_VENDOR not set, defaulting to OpenBLAS.") +# else() +# message(STATUS "BLA_VENDOR set to ${BLA_VENDOR}") +# endif() + set(CMAKE_C_STANDARD 99) #Set default build type to Release if not specified @@ -12,31 +22,32 @@ else() message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") endif() + # ============================================================================= -# BLAS/LAPACK Configuration (REQUIRED) +# BLAS/LAPACK Configuration (DISABLED) # ============================================================================= -find_package(BLAS REQUIRED) -find_package(LAPACK REQUIRED) - -message(STATUS "Found BLAS: ${BLAS_LIBRARIES}") -message(STATUS "Found LAPACK: ${LAPACK_LIBRARIES}") - -set(BLAS_LIBRARIES ${BLAS_LIBRARIES} ${LAPACK_LIBRARIES}) -add_compile_definitions(USE_BLAS) - -# Optional BLAS configuration flags -option(BLAS_64 "Use 64-bit integers for BLAS" OFF) -option(NO_BLAS_SUFFIX "BLAS functions have no suffix" OFF) - -if(BLAS_64) - add_compile_definitions(BLAS_64) - message(STATUS "Using 64-bit BLAS integers") -endif() - -if(NO_BLAS_SUFFIX) - add_compile_definitions(NO_BLAS_SUFFIX) - message(STATUS "Using BLAS with no function suffix") -endif() +# find_package(BLAS REQUIRED) +# find_package(LAPACK REQUIRED) +# +# message(STATUS "Found BLAS: ${BLAS_LIBRARIES}") +# message(STATUS "Found LAPACK: ${LAPACK_LIBRARIES}") +# +# set(BLAS_LIBRARIES ${BLAS_LIBRARIES} ${LAPACK_LIBRARIES}) +# add_compile_definitions(USE_BLAS) +# +# # Optional BLAS configuration flags +# option(BLAS_64 "Use 64-bit integers for BLAS" OFF) +# option(NO_BLAS_SUFFIX "BLAS functions have no suffix" OFF) +# +# if(BLAS_64) +# add_compile_definitions(BLAS_64) +# message(STATUS "Using 64-bit BLAS integers") +# endif() +# +# if(NO_BLAS_SUFFIX) +# add_compile_definitions(NO_BLAS_SUFFIX) +# message(STATUS "Using BLAS with no function suffix") +# endif() # Warning flags (always enabled) add_compile_options( @@ -58,9 +69,11 @@ include_directories(${PROJECT_SOURCE_DIR}/include) # Source files - automatically gather all .c files from src/ file(GLOB_RECURSE SOURCES "src/*.c") + # Create core library add_library(dnlp_diff ${SOURCES}) -target_link_libraries(dnlp_diff m ${BLAS_LIBRARIES}) +# target_link_libraries(dnlp_diff m ${BLAS_LIBRARIES}) +target_link_libraries(dnlp_diff m) # Config-specific compile options target_compile_options(dnlp_diff PRIVATE diff --git a/include/utils/blas_wrappers.h b/include/utils/blas_wrappers.h deleted file mode 100644 index 6abc5c9..0000000 --- a/include/utils/blas_wrappers.h +++ /dev/null @@ -1,39 +0,0 @@ -#ifndef BLAS_WRAPPERS_H_GUARD -#define BLAS_WRAPPERS_H_GUARD - -#ifdef __cplusplus -extern "C" -{ -#endif - -#include "blas.h" -#include -#include - - /* - * BLAS wrapper functions for common linear algebra operations. - * - * When compiled with USE_BLAS, these use optimized BLAS routines. - * Otherwise, they fall back to simple C implementations. - */ - - /* Vector operations */ - double vec_norm2(const double *x, int n); - double vec_dot(const double *x, const double *y, int n); - void vec_axpy(double alpha, const double *x, double *y, - int n); /* y += alpha*x */ - void vec_scale(double alpha, double *x, int n); /* x *= alpha */ - - /* Matrix-vector operations */ - void mat_vec_mult(const double *A, const double *x, double *y, int m, - int n); /* y = A*x, A is m x n */ - - /* Matrix-matrix operations */ - void mat_mat_mult(const double *X, const double *Y, double *Z, int m, int k, - int n); /* Z = X*Y, X is m x k, Y is k x n, Z is m x n */ - -#ifdef __cplusplus -} -#endif - -#endif /* BLAS_WRAPPERS_H_GUARD */ diff --git a/include/utils/mini_numpy.h b/include/utils/mini_numpy.h index 1710978..b29c673 100644 --- a/include/utils/mini_numpy.h +++ b/include/utils/mini_numpy.h @@ -20,4 +20,7 @@ void tile_int(int *result, const int *a, int len, int tiles); */ void scaled_ones(double *result, int size, double value); +/* Naive implementation of Z = X @ Y, X is m x k, Y is k x n, Z is m x n */ +void mat_mat_mult(const double *X, const double *Y, double *Z, int m, int k, int n); + #endif /* MINI_NUMPY_H */ diff --git a/pyproject.toml b/pyproject.toml index 37c735f..e798591 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ description = "Low-level C autodiff engine for nonlinear optimization" readme = "README.md" requires-python = ">=3.10" dependencies = [ - "numpy", + "numpy" ] [project.optional-dependencies] @@ -29,7 +29,7 @@ wheel.packages = ["src/dnlp_diff_engine"] build-dir = "build/{wheel_tag}" [tool.scikit-build.cmake.define] -CMAKE_BUILD_TYPE = "Debug" +CMAKE_BUILD_TYPE = "Release" [tool.pytest.ini_options] testpaths = ["python/tests"] diff --git a/python/atoms/matmul.h b/python/atoms/matmul.h new file mode 100644 index 0000000..9bfe859 --- /dev/null +++ b/python/atoms/matmul.h @@ -0,0 +1,33 @@ +#ifndef ATOM_MATMUL_H +#define ATOM_MATMUL_H + +#include "bivariate.h" +#include "common.h" + +/* Matrix multiplication: Z = X @ Y */ +static PyObject *py_make_matmul(PyObject *self, PyObject *args) +{ + PyObject *left_capsule, *right_capsule; + if (!PyArg_ParseTuple(args, "OO", &left_capsule, &right_capsule)) + { + return NULL; + } + expr *left = (expr *) PyCapsule_GetPointer(left_capsule, EXPR_CAPSULE_NAME); + expr *right = (expr *) PyCapsule_GetPointer(right_capsule, EXPR_CAPSULE_NAME); + if (!left || !right) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + expr *node = new_matmul(left, right); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create matmul node"); + return NULL; + } + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_MATMUL_H */ diff --git a/python/bindings.c b/python/bindings.c index 0cd703e..3f46eae 100644 --- a/python/bindings.c +++ b/python/bindings.c @@ -19,6 +19,7 @@ #include "atoms/linear.h" #include "atoms/log.h" #include "atoms/logistic.h" +#include "atoms/matmul.h" #include "atoms/multiply.h" #include "atoms/neg.h" #include "atoms/power.h" @@ -73,6 +74,8 @@ static PyMethodDef DNLPMethods[] = { {"make_promote", py_make_promote, METH_VARARGS, "Create promote node"}, {"make_multiply", py_make_multiply, METH_VARARGS, "Create elementwise multiply node"}, + {"make_matmul", py_make_matmul, METH_VARARGS, + "Create matrix multiplication node (Z = X @ Y)"}, {"make_const_scalar_mult", py_make_const_scalar_mult, METH_VARARGS, "Create constant scalar multiplication node (a * f(x))"}, {"make_const_vector_mult", py_make_const_vector_mult, METH_VARARGS, diff --git a/src/bivariate/matmul.c b/src/bivariate/matmul.c index 2a38127..0b946cb 100644 --- a/src/bivariate/matmul.c +++ b/src/bivariate/matmul.c @@ -1,6 +1,5 @@ #include "bivariate.h" #include "subexpr.h" -#include "utils/blas_wrappers.h" #include "utils/mini_numpy.h" #include #include @@ -246,8 +245,7 @@ static void eval_wsum_hess(expr *node, const double *w) if (x->var_id < y->var_id) { - /* X variable rows: For X[row, k_idx], couples with Y[k_idx, col] for all col - */ + /* rows corresponding to x */ for (int k_idx = 0; k_idx < k; k_idx++) { for (int row = 0; row < m; row++) @@ -259,8 +257,7 @@ static void eval_wsum_hess(expr *node, const double *w) } } - /* Y variable rows: For Y[k_idx, col], couples with X[row, k_idx] for all row - */ + /* rows corresponding to y */ for (int col = 0; col < n; col++) { for (int k_idx = 0; k_idx < k; k_idx++) @@ -272,8 +269,7 @@ static void eval_wsum_hess(expr *node, const double *w) } else { - /* Y variable rows come first: For Y[k_idx, col], couples with X[row, k_idx] - */ + /* rows corresponding to y */ for (int col = 0; col < n; col++) { for (int k_idx = 0; k_idx < k; k_idx++) @@ -283,8 +279,7 @@ static void eval_wsum_hess(expr *node, const double *w) } } - /* X variable rows: For X[row, k_idx], couples with Y[k_idx, col] for all col - */ + /* rows corresponding to x */ for (int k_idx = 0; k_idx < k; k_idx++) { for (int row = 0; row < m; row++) diff --git a/src/utils/blas_wrappers.c b/src/utils/blas_wrappers.c deleted file mode 100644 index 8608fc4..0000000 --- a/src/utils/blas_wrappers.c +++ /dev/null @@ -1,86 +0,0 @@ -#include "utils/blas_wrappers.h" - -/* BLAS function prototypes */ -#ifdef __cplusplus -extern "C" -{ -#endif - - double BLAS(nrm2)(blas_int *n, const double *x, blas_int *incx); - double BLAS(dot)(const blas_int *n, const double *x, const blas_int *incx, - const double *y, const blas_int *incy); - void BLAS(axpy)(blas_int *n, const double *alpha, const double *x, - blas_int *incx, double *y, blas_int *incy); - void BLAS(scal)(const blas_int *n, const double *alpha, double *x, - const blas_int *incx); - void BLAS(gemv)(const char *trans, const blas_int *m, const blas_int *n, - const double *alpha, const double *a, const blas_int *lda, - const double *x, const blas_int *incx, const double *beta, - double *y, const blas_int *incy); - void BLAS(gemm)(const char *transa, const char *transb, const blas_int *m, - const blas_int *n, const blas_int *k, const double *alpha, - const double *a, const blas_int *lda, const double *b, - const blas_int *ldb, const double *beta, double *c, - const blas_int *ldc); - -#ifdef __cplusplus -} -#endif - -/* BLAS implementations */ - -double vec_norm2(const double *x, int n) -{ - blas_int bn = (blas_int) n; - blas_int inc = 1; - return BLAS(nrm2)(&bn, x, &inc); -} - -double vec_dot(const double *x, const double *y, int n) -{ - blas_int bn = (blas_int) n; - blas_int inc = 1; - return BLAS(dot)(&bn, x, &inc, y, &inc); -} - -void vec_axpy(double alpha, const double *x, double *y, int n) -{ - blas_int bn = (blas_int) n; - blas_int inc = 1; - BLAS(axpy)(&bn, &alpha, x, &inc, y, &inc); -} - -void vec_scale(double alpha, double *x, int n) -{ - blas_int bn = (blas_int) n; - blas_int inc = 1; - BLAS(scal)(&bn, &alpha, x, &inc); -} - -void mat_vec_mult(const double *A, const double *x, double *y, int m, int n) -{ - /* y = A*x where A is m x n (column-major) */ - blas_int bm = (blas_int) m; - blas_int bn = (blas_int) n; - blas_int inc = 1; - double alpha = 1.0; - double beta = 0.0; - char trans = 'N'; - - BLAS(gemv)(&trans, &bm, &bn, &alpha, A, &bm, x, &inc, &beta, y, &inc); -} - -void mat_mat_mult(const double *X, const double *Y, double *Z, int m, int k, int n) -{ - /* Z = X*Y where X is m x k, Y is k x n, Z is m x n (all column-major) */ - blas_int bm = (blas_int) m; - blas_int bk = (blas_int) k; - blas_int bn = (blas_int) n; - double alpha = 1.0; - double beta = 0.0; - char transX = 'N'; - char transY = 'N'; - - BLAS(gemm)(&transX, &transY, &bm, &bn, &bk, &alpha, X, &bm, Y, &bk, &beta, Z, - &bm); -} diff --git a/src/utils/mini_numpy.c b/src/utils/mini_numpy.c index 7dbb02f..fa78b5d 100644 --- a/src/utils/mini_numpy.c +++ b/src/utils/mini_numpy.c @@ -36,3 +36,18 @@ void scaled_ones(double *result, int size, double value) result[i] = value; } } + +void mat_mat_mult(const double *X, const double *Y, double *Z, int m, int k, int n) +{ + for (int j = 0; j < n; ++j) + { + for (int i = 0; i < m; ++i) + { + Z[i + j * m] = 0.0; + for (int l = 0; l < k; ++l) + { + Z[i + j * m] += X[i + l * m] * Y[l + j * k]; + } + } + } +} diff --git a/tests/all_tests.c b/tests/all_tests.c index 5d72ceb..8756b13 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -41,7 +41,6 @@ #include "jacobian_tests/test_sum.h" #include "jacobian_tests/test_trace.h" #include "problem/test_problem.h" -#include "utils/test_blas.h" #include "utils/test_csc_matrix.h" #include "utils/test_csr_matrix.h" #include "wsum_hess/elementwise/test_entr.h" @@ -258,10 +257,7 @@ int main(void) mu_run_test(test_problem_constraint_forward, tests_run); mu_run_test(test_problem_hessian, tests_run); - printf("\n--- BLAS Tests ---\n"); - mu_run_test(test_matrix_matrix_mult, tests_run); - - printf("\n=== All %d tests passed ===\n", tests_run); + printf("\n=== All %d tests passed ===\n", tests_run); return 0; } diff --git a/tests/utils/test_blas.h b/tests/utils/test_blas.h deleted file mode 100644 index 52e07e7..0000000 --- a/tests/utils/test_blas.h +++ /dev/null @@ -1,64 +0,0 @@ -#ifndef TEST_BLAS_H -#define TEST_BLAS_H - -#include "../minunit.h" -#include "blas.h" -#include "utils/blas_wrappers.h" -#include -#include - -/* BLAS dgemm prototype: C = alpha*A*B + beta*C */ -extern void BLAS(gemm)(const char *transa, const char *transb, blas_int *m, - blas_int *n, blas_int *k, const double *alpha, - const double *a, blas_int *lda, const double *b, - blas_int *ldb, const double *beta, double *c, blas_int *ldc); - -static const char *test_matrix_matrix_mult(void) -{ - /* Test: C = A * B where - * A is 3x2: [1 4] B is 2x3: [7 10 13] - * [2 5] [8 11 14] - * [3 6] - * - * Expected C (3x3): [39 54 69] - * [54 75 96] - * [69 96 123] - */ - - /* Matrices stored in column-major order */ - double A[6] = {1.0, 2.0, 3.0, /* first column */ - 4.0, 5.0, 6.0}; /* second column */ - - double B[6] = {7.0, 8.0, /* first column */ - 10.0, 11.0, /* second column */ - 13.0, 14.0}; /* third column */ - - double C[9] = {0.0, 0.0, 0.0, /* initialize to zero */ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; - - /* Expected result (column-major) */ - double expected[9] = {39.0, 54.0, 69.0, /* first column */ - 54.0, 75.0, 96.0, /* second column */ - 69.0, 96.0, 123.0}; /* third column */ - - /* Compute C = A * B using wrapper */ - /* A is 3x2, B is 2x3, C is 3x3 */ - mat_mat_mult(A, B, C, 3, 2, 3); - - /* Verify result */ - double tol = 1e-10; - for (int i = 0; i < 9; i++) - { - double error = fabs(C[i] - expected[i]); - if (error > tol) - { - printf("Error at index %d: got %f, expected %f (diff: %e)\n", i, C[i], - expected[i], error); - return "Matrix multiplication result incorrect"; - } - } - - return NULL; /* Test passed */ -} - -#endif /* TEST_BLAS_H */ From 66ebb8608a368963881cb9c5ee404e1cb5c8e2d5 Mon Sep 17 00:00:00 2001 From: Daniel Date: Mon, 19 Jan 2026 05:23:50 -0800 Subject: [PATCH 3/6] minor to matmul --- pyproject.toml | 2 +- src/bivariate/matmul.c | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e798591..65d92c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ wheel.packages = ["src/dnlp_diff_engine"] build-dir = "build/{wheel_tag}" [tool.scikit-build.cmake.define] -CMAKE_BUILD_TYPE = "Release" +CMAKE_BUILD_TYPE = "Debug" [tool.pytest.ini_options] testpaths = ["python/tests"] diff --git a/src/bivariate/matmul.c b/src/bivariate/matmul.c index 0b946cb..f1ad432 100644 --- a/src/bivariate/matmul.c +++ b/src/bivariate/matmul.c @@ -114,10 +114,10 @@ static void eval_jacobian(expr *node) if (x->var_id < y->var_id) { - /* Y^T contribution: contiguous column of Y at 'col' */ + /* contribution to this row from Y^T */ memcpy(Jx + pos, y->value + col * k, k * sizeof(double)); - /* X row contribution: stride m across columns */ + /* contribution to this row from X */ for (int j = 0; j < k; j++) { Jx[pos + k + j] = x->value[row + j * m]; @@ -125,13 +125,13 @@ static void eval_jacobian(expr *node) } else { - /* X row contribution: stride m across columns */ + /* contribution to this row from X */ for (int j = 0; j < k; j++) { Jx[pos + j] = x->value[row + j * m]; } - /* Y^T contribution: contiguous column of Y at 'col' */ + /* contribution to this row from Y^T */ memcpy(Jx + pos + k, y->value + col * k, k * sizeof(double)); } } @@ -242,6 +242,7 @@ static void eval_wsum_hess(expr *node, const double *w) int offset = 0; double *Hx = node->wsum_hess->x; + const double *w_temp; if (x->var_id < y->var_id) { @@ -260,9 +261,10 @@ static void eval_wsum_hess(expr *node, const double *w) /* rows corresponding to y */ for (int col = 0; col < n; col++) { + w_temp = w + col * m; for (int k_idx = 0; k_idx < k; k_idx++) { - memcpy(Hx + offset, w + col * m, m * sizeof(double)); + memcpy(Hx + offset, w_temp, m * sizeof(double)); offset += m; } } @@ -272,9 +274,10 @@ static void eval_wsum_hess(expr *node, const double *w) /* rows corresponding to y */ for (int col = 0; col < n; col++) { + w_temp = w + col * m; for (int k_idx = 0; k_idx < k; k_idx++) { - memcpy(Hx + offset, w + col * m, m * sizeof(double)); + memcpy(Hx + offset, w_temp, m * sizeof(double)); offset += m; } } From 26ba23cffec359f4c5be4569d7d75b9013f27bfe Mon Sep 17 00:00:00 2001 From: Daniel Date: Mon, 19 Jan 2026 05:27:29 -0800 Subject: [PATCH 4/6] removed blas file --- CMakeLists.txt | 28 ---------------------------- include/blas.h | 40 ---------------------------------------- pyproject.toml | 2 +- src/bivariate/matmul.c | 8 ++++---- 4 files changed, 5 insertions(+), 73 deletions(-) delete mode 100644 include/blas.h diff --git a/CMakeLists.txt b/CMakeLists.txt index d8503ff..d729ab0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,33 +22,6 @@ else() message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") endif() - -# ============================================================================= -# BLAS/LAPACK Configuration (DISABLED) -# ============================================================================= -# find_package(BLAS REQUIRED) -# find_package(LAPACK REQUIRED) -# -# message(STATUS "Found BLAS: ${BLAS_LIBRARIES}") -# message(STATUS "Found LAPACK: ${LAPACK_LIBRARIES}") -# -# set(BLAS_LIBRARIES ${BLAS_LIBRARIES} ${LAPACK_LIBRARIES}) -# add_compile_definitions(USE_BLAS) -# -# # Optional BLAS configuration flags -# option(BLAS_64 "Use 64-bit integers for BLAS" OFF) -# option(NO_BLAS_SUFFIX "BLAS functions have no suffix" OFF) -# -# if(BLAS_64) -# add_compile_definitions(BLAS_64) -# message(STATUS "Using 64-bit BLAS integers") -# endif() -# -# if(NO_BLAS_SUFFIX) -# add_compile_definitions(NO_BLAS_SUFFIX) -# message(STATUS "Using BLAS with no function suffix") -# endif() - # Warning flags (always enabled) add_compile_options( -Wall # Enable most warnings @@ -72,7 +45,6 @@ file(GLOB_RECURSE SOURCES "src/*.c") # Create core library add_library(dnlp_diff ${SOURCES}) -# target_link_libraries(dnlp_diff m ${BLAS_LIBRARIES}) target_link_libraries(dnlp_diff m) # Config-specific compile options diff --git a/include/blas.h b/include/blas.h deleted file mode 100644 index 11d78be..0000000 --- a/include/blas.h +++ /dev/null @@ -1,40 +0,0 @@ -#ifndef DNLP_BLAS_H_GUARD -#define DNLP_BLAS_H_GUARD - -#ifdef __cplusplus -extern "C" -{ -#endif - -/* Default to underscore suffix for BLAS/LAPACK function names */ -#ifndef BLAS_SUFFIX -#define BLAS_SUFFIX _ -#endif - -/* Handle BLAS function name mangling */ -#if defined(NO_BLAS_SUFFIX) && NO_BLAS_SUFFIX > 0 -/* No suffix version */ -#define BLAS(x) d##x -#define BLASI(x) id##x -#else -/* Macro indirection needed for BLAS_SUFFIX to work correctly */ -#define _stitch(pre, x, post) pre##x##post -#define _stitch2(pre, x, post) _stitch(pre, x, post) -/* Add suffix (default: underscore) */ -#define BLAS(x) _stitch2(d, x, BLAS_SUFFIX) -#define BLASI(x) _stitch2(id, x, BLAS_SUFFIX) -#endif - -/* BLAS integer type */ -#ifdef BLAS_64 -#include - typedef int64_t blas_int; -#else -typedef int blas_int; -#endif - -#ifdef __cplusplus -} -#endif - -#endif /* DNLP_BLAS_H_GUARD */ diff --git a/pyproject.toml b/pyproject.toml index 65d92c3..37c735f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ description = "Low-level C autodiff engine for nonlinear optimization" readme = "README.md" requires-python = ">=3.10" dependencies = [ - "numpy" + "numpy", ] [project.optional-dependencies] diff --git a/src/bivariate/matmul.c b/src/bivariate/matmul.c index f1ad432..2168ca6 100644 --- a/src/bivariate/matmul.c +++ b/src/bivariate/matmul.c @@ -64,7 +64,7 @@ static void jacobian_init(expr *node) /* X has lower var_id */ if (x->var_id < y->var_id) { - /* sparsity pattern of kron(Y^T, I) for this row */ + /* sparsity pattern of kron(YT, I) for this row */ for (int j = 0; j < k; j++) { node->jacobian->i[nnz_idx++] = x->var_id + row + j * m; @@ -84,7 +84,7 @@ static void jacobian_init(expr *node) node->jacobian->i[nnz_idx++] = y->var_id + col * k + j; } - /* sparsity pattern of kron(Y^T, I) for this row */ + /* sparsity pattern of kron(YT, I) for this row */ for (int j = 0; j < k; j++) { node->jacobian->i[nnz_idx++] = x->var_id + row + j * m; @@ -114,7 +114,7 @@ static void eval_jacobian(expr *node) if (x->var_id < y->var_id) { - /* contribution to this row from Y^T */ + /* contribution to this row from YT */ memcpy(Jx + pos, y->value + col * k, k * sizeof(double)); /* contribution to this row from X */ @@ -131,7 +131,7 @@ static void eval_jacobian(expr *node) Jx[pos + j] = x->value[row + j * m]; } - /* contribution to this row from Y^T */ + /* contribution to this row from YT */ memcpy(Jx + pos + k, y->value + col * k, k * sizeof(double)); } } From cbb8f687a218b728ccc2cae72a075c04724806f1 Mon Sep 17 00:00:00 2001 From: Daniel Date: Mon, 19 Jan 2026 05:28:00 -0800 Subject: [PATCH 5/6] clean up blas --- CMakeLists.txt | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index d729ab0..26c2cdd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,16 +1,6 @@ cmake_minimum_required(VERSION 3.15) project(DNLP_Diff_Engine C) - - -# Prefer OpenBLAS unless user overrides BLA_VENDOR -# if(NOT DEFINED BLA_VENDOR) -# set(BLA_VENDOR "OpenBLAS" CACHE STRING "Preferred BLAS vendor" FORCE) -# message(STATUS "BLA_VENDOR not set, defaulting to OpenBLAS.") -# else() -# message(STATUS "BLA_VENDOR set to ${BLA_VENDOR}") -# endif() - set(CMAKE_C_STANDARD 99) #Set default build type to Release if not specified From cfc7db070959e31985ee0e28796cda4cf180ab19 Mon Sep 17 00:00:00 2001 From: Daniel Date: Mon, 19 Jan 2026 05:29:27 -0800 Subject: [PATCH 6/6] minor --- CMakeLists.txt | 1 - src/bivariate/matmul.c | 12 ++---------- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 26c2cdd..f00ac25 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,3 @@ - cmake_minimum_required(VERSION 3.15) project(DNLP_Diff_Engine C) set(CMAKE_C_STANDARD 99) diff --git a/src/bivariate/matmul.c b/src/bivariate/matmul.c index 2168ca6..617e539 100644 --- a/src/bivariate/matmul.c +++ b/src/bivariate/matmul.c @@ -21,16 +21,8 @@ static void forward(expr *node, const double *u) x->forward(x, u); y->forward(y, u); - /* local forward pass: Z = X @ Y - * X is m x k (d1 x d2) - * Y is k x n (d1 x d2) - * Z is m x n (d1 x d2) - */ - int m = x->d1; - int k = x->d2; - int n = y->d2; - - mat_mat_mult(x->value, y->value, node->value, m, k, n); + /* local forward pass */ + mat_mat_mult(x->value, y->value, node->value, x->d1, x->d2, y->d2); } static bool is_affine(const expr *node)