diff --git a/pyproject.toml b/pyproject.toml index 535ccdb..6e89dda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ build-dir = "build/{wheel_tag}" CMAKE_BUILD_TYPE = "Release" [tool.pytest.ini_options] -testpaths = ["tests/python"] +testpaths = ["python/tests"] python_files = ["test_*.py"] python_functions = ["test_*"] diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt deleted file mode 100644 index 6a71c2d..0000000 --- a/python/CMakeLists.txt +++ /dev/null @@ -1,57 +0,0 @@ -cmake_minimum_required(VERSION 3.20) -project(DNLP_diff_engine LANGUAGES C VERSION 0.1.0 - DESCRIPTION "Python bindings for DNLP diff engine") - -if(DEFINED ENV{VIRTUAL_ENV}) - set(Python3_EXECUTABLE "$ENV{VIRTUAL_ENV}/bin/python3") -endif() - -# Find Python3 and NumPy -find_package(Python3 REQUIRED COMPONENTS Interpreter Development NumPy) -message(STATUS "Python3 executable: ${Python3_EXECUTABLE}") -message(STATUS "Python3 version: ${Python3_VERSION}") -message(STATUS "NumPy include dirs: ${Python3_NumPy_INCLUDE_DIRS}") - -# Verify NumPy version -execute_process( - COMMAND ${Python3_EXECUTABLE} -c "import numpy; print(numpy.__version__)" - OUTPUT_VARIABLE NUMPY_VERSION - OUTPUT_STRIP_TRAILING_WHITESPACE -) -message(STATUS "NumPy version: ${NUMPY_VERSION}") - -# Build core project from source -add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/.. ${CMAKE_BINARY_DIR}/core_build) - -# create Python module -add_library(bindings MODULE ${CMAKE_CURRENT_LIST_DIR}/bindings.c) -target_include_directories(bindings - PRIVATE - ${Python3_INCLUDE_DIRS} - ${Python3_NumPy_INCLUDE_DIRS} - ${CMAKE_CURRENT_LIST_DIR}/../include - ${CMAKE_CURRENT_LIST_DIR}/../src - ${CMAKE_CURRENT_LIST_DIR}/../src/utils -) - -# Link against core and Python libraries -target_link_libraries(bindings - PRIVATE - dnlp_diff - Python3::Python - Python3::NumPy -) - -# Place the module in the python source folder alongside example.py -set(PY_MODULE_OUTPUT_DIR ${CMAKE_CURRENT_LIST_DIR}) -file(MAKE_DIRECTORY ${PY_MODULE_OUTPUT_DIR}) - -set_target_properties(bindings PROPERTIES - PREFIX "" - OUTPUT_NAME "DNLP_diff_engine" - SUFFIX ".${Python3_SOABI}.so" - - # All configs → same folder - LIBRARY_OUTPUT_DIRECTORY ${PY_MODULE_OUTPUT_DIR} - RUNTIME_OUTPUT_DIRECTORY ${PY_MODULE_OUTPUT_DIR} # Critical for Windows .pyd -) diff --git a/python/atoms/asinh.h b/python/atoms/asinh.h new file mode 100644 index 0000000..a09934e --- /dev/null +++ b/python/atoms/asinh.h @@ -0,0 +1,31 @@ +#ifndef ATOM_ASINH_H +#define ATOM_ASINH_H + +#include "common.h" +#include "elementwise_univariate.h" + +static PyObject *py_make_asinh(PyObject *self, PyObject *args) +{ + PyObject *child_capsule; + if (!PyArg_ParseTuple(args, "O", &child_capsule)) + { + return NULL; + } + expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); + if (!child) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + expr *node = new_asinh(child); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create asinh node"); + return NULL; + } + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_ASINH_H */ diff --git a/python/atoms/atanh.h b/python/atoms/atanh.h new file mode 100644 index 0000000..fd0568d --- /dev/null +++ b/python/atoms/atanh.h @@ -0,0 +1,31 @@ +#ifndef ATOM_ATANH_H +#define ATOM_ATANH_H + +#include "common.h" +#include "elementwise_univariate.h" + +static PyObject *py_make_atanh(PyObject *self, PyObject *args) +{ + PyObject *child_capsule; + if (!PyArg_ParseTuple(args, "O", &child_capsule)) + { + return NULL; + } + expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); + if (!child) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + expr *node = new_atanh(child); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create atanh node"); + return NULL; + } + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_ATANH_H */ diff --git a/python/atoms/cos.h b/python/atoms/cos.h new file mode 100644 index 0000000..3f798be --- /dev/null +++ b/python/atoms/cos.h @@ -0,0 +1,31 @@ +#ifndef ATOM_COS_H +#define ATOM_COS_H + +#include "common.h" +#include "elementwise_univariate.h" + +static PyObject *py_make_cos(PyObject *self, PyObject *args) +{ + PyObject *child_capsule; + if (!PyArg_ParseTuple(args, "O", &child_capsule)) + { + return NULL; + } + expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); + if (!child) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + expr *node = new_cos(child); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create cos node"); + return NULL; + } + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_COS_H */ diff --git a/python/atoms/entr.h b/python/atoms/entr.h new file mode 100644 index 0000000..9a37668 --- /dev/null +++ b/python/atoms/entr.h @@ -0,0 +1,31 @@ +#ifndef ATOM_ENTR_H +#define ATOM_ENTR_H + +#include "common.h" +#include "elementwise_univariate.h" + +static PyObject *py_make_entr(PyObject *self, PyObject *args) +{ + PyObject *child_capsule; + if (!PyArg_ParseTuple(args, "O", &child_capsule)) + { + return NULL; + } + expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); + if (!child) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + expr *node = new_entr(child); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create entr node"); + return NULL; + } + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_ENTR_H */ diff --git a/python/atoms/left_matmul.h b/python/atoms/left_matmul.h new file mode 100644 index 0000000..0155c49 --- /dev/null +++ b/python/atoms/left_matmul.h @@ -0,0 +1,63 @@ +#ifndef ATOM_LEFT_MATMUL_H +#define ATOM_LEFT_MATMUL_H + +#include "common.h" +#include "bivariate.h" + +/* Left matrix multiplication: A @ f(x) where A is a constant matrix */ +static PyObject *py_make_left_matmul(PyObject *self, PyObject *args) +{ + PyObject *child_capsule; + PyObject *data_obj, *indices_obj, *indptr_obj; + int m, n; + if (!PyArg_ParseTuple(args, "OOOOii", &child_capsule, &data_obj, &indices_obj, + &indptr_obj, &m, &n)) + { + return NULL; + } + + expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); + if (!child) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + PyArrayObject *data_array = + (PyArrayObject *) PyArray_FROM_OTF(data_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + PyArrayObject *indices_array = (PyArrayObject *) PyArray_FROM_OTF( + indices_obj, NPY_INT32, NPY_ARRAY_IN_ARRAY); + PyArrayObject *indptr_array = (PyArrayObject *) PyArray_FROM_OTF( + indptr_obj, NPY_INT32, NPY_ARRAY_IN_ARRAY); + + if (!data_array || !indices_array || !indptr_array) + { + Py_XDECREF(data_array); + Py_XDECREF(indices_array); + Py_XDECREF(indptr_array); + return NULL; + } + + int nnz = (int) PyArray_SIZE(data_array); + CSR_Matrix *A = new_csr_matrix(m, n, nnz); + memcpy(A->x, PyArray_DATA(data_array), nnz * sizeof(double)); + memcpy(A->i, PyArray_DATA(indices_array), nnz * sizeof(int)); + memcpy(A->p, PyArray_DATA(indptr_array), (m + 1) * sizeof(int)); + + Py_DECREF(data_array); + Py_DECREF(indices_array); + Py_DECREF(indptr_array); + + expr *node = new_left_matmul(child, A); + free_csr_matrix(A); + + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create left_matmul node"); + return NULL; + } + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_LEFT_MATMUL_H */ diff --git a/python/atoms/logistic.h b/python/atoms/logistic.h new file mode 100644 index 0000000..e903661 --- /dev/null +++ b/python/atoms/logistic.h @@ -0,0 +1,31 @@ +#ifndef ATOM_LOGISTIC_H +#define ATOM_LOGISTIC_H + +#include "common.h" +#include "elementwise_univariate.h" + +static PyObject *py_make_logistic(PyObject *self, PyObject *args) +{ + PyObject *child_capsule; + if (!PyArg_ParseTuple(args, "O", &child_capsule)) + { + return NULL; + } + expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); + if (!child) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + expr *node = new_logistic(child); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create logistic node"); + return NULL; + } + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_LOGISTIC_H */ diff --git a/python/atoms/multiply.h b/python/atoms/multiply.h new file mode 100644 index 0000000..a5a5ccd --- /dev/null +++ b/python/atoms/multiply.h @@ -0,0 +1,32 @@ +#ifndef ATOM_MULTIPLY_H +#define ATOM_MULTIPLY_H + +#include "common.h" +#include "bivariate.h" + +static PyObject *py_make_multiply(PyObject *self, PyObject *args) +{ + PyObject *left_capsule, *right_capsule; + if (!PyArg_ParseTuple(args, "OO", &left_capsule, &right_capsule)) + { + return NULL; + } + expr *left = (expr *) PyCapsule_GetPointer(left_capsule, EXPR_CAPSULE_NAME); + expr *right = (expr *) PyCapsule_GetPointer(right_capsule, EXPR_CAPSULE_NAME); + if (!left || !right) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + expr *node = new_elementwise_mult(left, right); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create multiply node"); + return NULL; + } + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_MULTIPLY_H */ diff --git a/python/atoms/power.h b/python/atoms/power.h new file mode 100644 index 0000000..da9b7bf --- /dev/null +++ b/python/atoms/power.h @@ -0,0 +1,32 @@ +#ifndef ATOM_POWER_H +#define ATOM_POWER_H + +#include "common.h" +#include "elementwise_univariate.h" + +static PyObject *py_make_power(PyObject *self, PyObject *args) +{ + PyObject *child_capsule; + double p; + if (!PyArg_ParseTuple(args, "Od", &child_capsule, &p)) + { + return NULL; + } + expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); + if (!child) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + expr *node = new_power(child, p); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create power node"); + return NULL; + } + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_POWER_H */ diff --git a/python/atoms/right_matmul.h b/python/atoms/right_matmul.h new file mode 100644 index 0000000..b3307ac --- /dev/null +++ b/python/atoms/right_matmul.h @@ -0,0 +1,63 @@ +#ifndef ATOM_RIGHT_MATMUL_H +#define ATOM_RIGHT_MATMUL_H + +#include "common.h" +#include "bivariate.h" + +/* Right matrix multiplication: f(x) @ A where A is a constant matrix */ +static PyObject *py_make_right_matmul(PyObject *self, PyObject *args) +{ + PyObject *child_capsule; + PyObject *data_obj, *indices_obj, *indptr_obj; + int m, n; + if (!PyArg_ParseTuple(args, "OOOOii", &child_capsule, &data_obj, &indices_obj, + &indptr_obj, &m, &n)) + { + return NULL; + } + + expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); + if (!child) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + PyArrayObject *data_array = + (PyArrayObject *) PyArray_FROM_OTF(data_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + PyArrayObject *indices_array = (PyArrayObject *) PyArray_FROM_OTF( + indices_obj, NPY_INT32, NPY_ARRAY_IN_ARRAY); + PyArrayObject *indptr_array = (PyArrayObject *) PyArray_FROM_OTF( + indptr_obj, NPY_INT32, NPY_ARRAY_IN_ARRAY); + + if (!data_array || !indices_array || !indptr_array) + { + Py_XDECREF(data_array); + Py_XDECREF(indices_array); + Py_XDECREF(indptr_array); + return NULL; + } + + int nnz = (int) PyArray_SIZE(data_array); + CSR_Matrix *A = new_csr_matrix(m, n, nnz); + memcpy(A->x, PyArray_DATA(data_array), nnz * sizeof(double)); + memcpy(A->i, PyArray_DATA(indices_array), nnz * sizeof(int)); + memcpy(A->p, PyArray_DATA(indptr_array), (m + 1) * sizeof(int)); + + Py_DECREF(data_array); + Py_DECREF(indices_array); + Py_DECREF(indptr_array); + + expr *node = new_right_matmul(child, A); + free_csr_matrix(A); + + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create right_matmul node"); + return NULL; + } + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_RIGHT_MATMUL_H */ diff --git a/python/atoms/sin.h b/python/atoms/sin.h new file mode 100644 index 0000000..426d302 --- /dev/null +++ b/python/atoms/sin.h @@ -0,0 +1,31 @@ +#ifndef ATOM_SIN_H +#define ATOM_SIN_H + +#include "common.h" +#include "elementwise_univariate.h" + +static PyObject *py_make_sin(PyObject *self, PyObject *args) +{ + PyObject *child_capsule; + if (!PyArg_ParseTuple(args, "O", &child_capsule)) + { + return NULL; + } + expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); + if (!child) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + expr *node = new_sin(child); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create sin node"); + return NULL; + } + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_SIN_H */ diff --git a/python/atoms/sinh.h b/python/atoms/sinh.h new file mode 100644 index 0000000..be86cf5 --- /dev/null +++ b/python/atoms/sinh.h @@ -0,0 +1,31 @@ +#ifndef ATOM_SINH_H +#define ATOM_SINH_H + +#include "common.h" +#include "elementwise_univariate.h" + +static PyObject *py_make_sinh(PyObject *self, PyObject *args) +{ + PyObject *child_capsule; + if (!PyArg_ParseTuple(args, "O", &child_capsule)) + { + return NULL; + } + expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); + if (!child) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + expr *node = new_sinh(child); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create sinh node"); + return NULL; + } + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_SINH_H */ diff --git a/python/atoms/tan.h b/python/atoms/tan.h new file mode 100644 index 0000000..3ec83b7 --- /dev/null +++ b/python/atoms/tan.h @@ -0,0 +1,31 @@ +#ifndef ATOM_TAN_H +#define ATOM_TAN_H + +#include "common.h" +#include "elementwise_univariate.h" + +static PyObject *py_make_tan(PyObject *self, PyObject *args) +{ + PyObject *child_capsule; + if (!PyArg_ParseTuple(args, "O", &child_capsule)) + { + return NULL; + } + expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); + if (!child) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + expr *node = new_tan(child); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create tan node"); + return NULL; + } + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_TAN_H */ diff --git a/python/atoms/tanh.h b/python/atoms/tanh.h new file mode 100644 index 0000000..e60d420 --- /dev/null +++ b/python/atoms/tanh.h @@ -0,0 +1,31 @@ +#ifndef ATOM_TANH_H +#define ATOM_TANH_H + +#include "common.h" +#include "elementwise_univariate.h" + +static PyObject *py_make_tanh(PyObject *self, PyObject *args) +{ + PyObject *child_capsule; + if (!PyArg_ParseTuple(args, "O", &child_capsule)) + { + return NULL; + } + expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); + if (!child) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + expr *node = new_tanh(child); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create tanh node"); + return NULL; + } + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_TANH_H */ diff --git a/python/atoms/xexp.h b/python/atoms/xexp.h new file mode 100644 index 0000000..1dda312 --- /dev/null +++ b/python/atoms/xexp.h @@ -0,0 +1,31 @@ +#ifndef ATOM_XEXP_H +#define ATOM_XEXP_H + +#include "common.h" +#include "elementwise_univariate.h" + +static PyObject *py_make_xexp(PyObject *self, PyObject *args) +{ + PyObject *child_capsule; + if (!PyArg_ParseTuple(args, "O", &child_capsule)) + { + return NULL; + } + expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); + if (!child) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + expr *node = new_xexp(child); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create xexp node"); + return NULL; + } + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_XEXP_H */ diff --git a/python/bindings.c b/python/bindings.c index ff68c0e..4fb952b 100644 --- a/python/bindings.c +++ b/python/bindings.c @@ -8,8 +8,22 @@ #include "atoms/exp.h" #include "atoms/linear.h" #include "atoms/log.h" +#include "atoms/multiply.h" #include "atoms/neg.h" +#include "atoms/power.h" #include "atoms/promote.h" +#include "atoms/sin.h" +#include "atoms/cos.h" +#include "atoms/tan.h" +#include "atoms/sinh.h" +#include "atoms/tanh.h" +#include "atoms/asinh.h" +#include "atoms/atanh.h" +#include "atoms/entr.h" +#include "atoms/logistic.h" +#include "atoms/xexp.h" +#include "atoms/left_matmul.h" +#include "atoms/right_matmul.h" #include "atoms/sum.h" #include "atoms/variable.h" @@ -42,6 +56,20 @@ static PyMethodDef DNLPMethods[] = { {"make_sum", py_make_sum, METH_VARARGS, "Create sum node"}, {"make_neg", py_make_neg, METH_VARARGS, "Create neg node"}, {"make_promote", py_make_promote, METH_VARARGS, "Create promote node"}, + {"make_multiply", py_make_multiply, METH_VARARGS, "Create elementwise multiply node"}, + {"make_power", py_make_power, METH_VARARGS, "Create power node"}, + {"make_sin", py_make_sin, METH_VARARGS, "Create sin node"}, + {"make_cos", py_make_cos, METH_VARARGS, "Create cos node"}, + {"make_tan", py_make_tan, METH_VARARGS, "Create tan node"}, + {"make_sinh", py_make_sinh, METH_VARARGS, "Create sinh node"}, + {"make_tanh", py_make_tanh, METH_VARARGS, "Create tanh node"}, + {"make_asinh", py_make_asinh, METH_VARARGS, "Create asinh node"}, + {"make_atanh", py_make_atanh, METH_VARARGS, "Create atanh node"}, + {"make_entr", py_make_entr, METH_VARARGS, "Create entr node"}, + {"make_logistic", py_make_logistic, METH_VARARGS, "Create logistic node"}, + {"make_xexp", py_make_xexp, METH_VARARGS, "Create xexp node"}, + {"make_left_matmul", py_make_left_matmul, METH_VARARGS, "Create left matmul node (A @ f(x))"}, + {"make_right_matmul", py_make_right_matmul, METH_VARARGS, "Create right matmul node (f(x) @ A)"}, {"make_problem", py_make_problem, METH_VARARGS, "Create problem from objective and constraints"}, {"problem_init_derivatives", py_problem_init_derivatives, METH_VARARGS, diff --git a/src/dnlp_diff_engine/__init__.py b/src/dnlp_diff_engine/__init__.py index e9eee0e..596f831 100644 --- a/src/dnlp_diff_engine/__init__.py +++ b/src/dnlp_diff_engine/__init__.py @@ -23,6 +23,46 @@ def _chain_add(children): return result +def _convert_matmul(expr, children): + """Convert matrix multiplication A @ f(x) or f(x) @ A.""" + # MulExpression has args: [left, right] + # One of them should be a Constant, the other a variable expression + left_arg, right_arg = expr.args + + if isinstance(left_arg, cp.Constant): + # A @ f(x) -> left_matmul + A = np.asarray(left_arg.value, dtype=np.float64) + if A.ndim == 1: + A = A.reshape(1, -1) # Convert 1D to row vector + A_csr = sparse.csr_matrix(A) + m, n = A_csr.shape + return _diffengine.make_left_matmul( + children[1], # right child is the variable expression + A_csr.data.astype(np.float64), + A_csr.indices.astype(np.int32), + A_csr.indptr.astype(np.int32), + m, + n, + ) + elif isinstance(right_arg, cp.Constant): + # f(x) @ A -> right_matmul + A = np.asarray(right_arg.value, dtype=np.float64) + if A.ndim == 1: + A = A.reshape(-1, 1) # Convert 1D to column vector + A_csr = sparse.csr_matrix(A) + m, n = A_csr.shape + return _diffengine.make_right_matmul( + children[0], # left child is the variable expression + A_csr.data.astype(np.float64), + A_csr.indices.astype(np.int32), + A_csr.indptr.astype(np.int32), + m, + n, + ) + else: + raise NotImplementedError("MulExpression with two non-constant args not supported") + + # Mapping from CVXPY atom names to C diff engine functions # Converters receive (expr, children) where expr is the CVXPY expression ATOM_CONVERTERS = { @@ -43,6 +83,31 @@ def _chain_add(children): # Reductions "Sum": lambda _expr, children: _diffengine.make_sum(children[0], -1), + + # Bivariate + "multiply": lambda _expr, children: _diffengine.make_multiply(children[0], children[1]), + + # Matrix multiplication + "MulExpression": _convert_matmul, + + # Elementwise univariate with parameter + "power": lambda expr, children: _diffengine.make_power(children[0], float(expr.p.value)), + + # Trigonometric + "sin": lambda _expr, children: _diffengine.make_sin(children[0]), + "cos": lambda _expr, children: _diffengine.make_cos(children[0]), + "tan": lambda _expr, children: _diffengine.make_tan(children[0]), + + # Hyperbolic + "sinh": lambda _expr, children: _diffengine.make_sinh(children[0]), + "tanh": lambda _expr, children: _diffengine.make_tanh(children[0]), + "asinh": lambda _expr, children: _diffengine.make_asinh(children[0]), + "atanh": lambda _expr, children: _diffengine.make_atanh(children[0]), + + # Other elementwise + "entr": lambda _expr, children: _diffengine.make_entr(children[0]), + "logistic": lambda _expr, children: _diffengine.make_logistic(children[0]), + "xexp": lambda _expr, children: _diffengine.make_xexp(children[0]), }