Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 26 additions & 25 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,26 @@ DNLP-diff-engine is a C library with Python bindings that provides automatic dif
### Python Package (Recommended)

```bash
# Install in development mode (builds C library + Python bindings)
pip install -e .
# Install in development mode with uv (recommended)
uv pip install -e ".[test]"

# Run all Python tests
# Or with pip
pip install -e ".[test]"

# Run all Python tests (tests are in python/tests/)
pytest

# Run specific test file
pytest tests/python/test_unconstrained.py
pytest python/tests/test_unconstrained.py

# Run specific test
pytest tests/python/test_unconstrained.py::test_sum_log
pytest python/tests/test_unconstrained.py::test_sum_log

# Lint with ruff
ruff check src/

# Auto-fix lint issues
ruff check --fix src/
```

### Standalone C Library
Expand All @@ -35,16 +44,6 @@ cmake --build build
./build/all_tests
```

### Legacy Python Build (without pip)

```bash
# Build Python bindings manually (from project root)
cd python && cmake -B build -S . && cmake --build build && cd ..

# Run tests from python/ directory
cd python && python tests/test_problem_native.py
```

## Architecture

### Expression Tree System
Expand All @@ -59,10 +58,10 @@ The core abstraction is the `expr` struct (in `include/expr.h`) representing a n

Atoms are organized by mathematical properties in `src/`:

- **`affine/`** - Linear operations: `variable`, `constant`, `add`, `neg`, `sum`, `promote`, `hstack`, `trace`, `linear_op`
- **`elementwise_univariate/`** - Scalar functions applied elementwise: `log`, `exp`, `entr`, `power`, `logistic`, trigonometric, hyperbolic
- **`bivariate/`** - Two-argument operations: `multiply`, `quad_over_lin`, `rel_entr`
- **`other/`** - Special atoms not fitting above categories
- **`affine/`** - Linear operations: `variable`, `constant`, `add`, `neg`, `sum`, `promote`, `hstack`, `trace`, `linear_op`, `index`
- **`elementwise_univariate/`** - Scalar functions applied elementwise: `log`, `exp`, `entr`, `power`, `logistic`, `xexp`, trigonometric (`sin`, `cos`, `tan`), hyperbolic (`sinh`, `tanh`, `asinh`, `atanh`)
- **`bivariate/`** - Two-argument operations: `multiply`, `quad_over_lin`, `rel_entr`, `const_scalar_mult`, `const_vector_mult`, `left_matmul`, `right_matmul`
- **`other/`** - Special atoms: `quad_form`, `prod`

Each atom implements its own `forward`, `jacobian_init`, `eval_jacobian`, and `eval_wsum_hess` functions following a consistent pattern.

Expand All @@ -87,7 +86,8 @@ The Python package `dnlp_diff_engine` (in `src/dnlp_diff_engine/`) provides:
**High-level API** (`__init__.py`):
- `C_problem` class wraps the C problem struct
- `convert_problem()` builds expression trees from CVXPY Problem objects
- Atoms are mapped via `ATOM_CONVERTERS` dictionary
- Atoms are mapped via `ATOM_CONVERTERS` dictionary (maps CVXPY atom names → converter functions)
- Special converters handle: matrix multiplication (`_convert_matmul`), multiply with constants (`_convert_multiply`), indexing, reshape (Fortran order only)

**Low-level C extension** (`_core` module, built from `python/bindings.c`):
- Atom constructors: `make_variable`, `make_constant`, `make_log`, `make_exp`, `make_add`, etc.
Expand All @@ -108,14 +108,15 @@ Hessian computes weighted sum: `obj_w * H_obj + sum(lambda_i * H_constraint_i)`

## Key Directories

- `include/` - Header files defining public API
- `src/` - C implementation files
- `src/dnlp_diff_engine/` - Python package (installed via pip)
- `python/` - Python bindings C code and binding headers
- `include/` - Header files defining public API (`expr.h`, `problem.h`, atom headers)
- `src/` - C implementation files organized by atom category
- `src/dnlp_diff_engine/` - Python package with high-level API
- `python/` - Python bindings C code (`bindings.c`)
- `python/atoms/` - Python binding headers for each atom type
- `python/problem/` - Python binding headers for problem interface
- `python/tests/` - Python integration tests (run via pytest)
- `tests/` - C tests using minunit framework
- `tests/python/` - Python tests (run via pytest)
- `tests/forward_pass/` - Forward evaluation tests (C)
- `tests/jacobian_tests/` - Jacobian correctness tests (C)
- `tests/wsum_hess/` - Hessian correctness tests (C)

Expand Down
2 changes: 2 additions & 0 deletions include/affine.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@ expr *new_trace(expr *child);
expr *new_constant(int d1, int d2, int n_vars, const double *values);
expr *new_variable(int d1, int d2, int var_id, int n_vars);

expr *new_index(expr *child, const int *indices, int n_idxs);

#endif /* AFFINE_H */
9 changes: 9 additions & 0 deletions include/subexpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,13 @@ typedef struct const_vector_mult_expr
double *a; /* length equals node->size */
} const_vector_mult_expr;

/* Index/slicing: y = child[indices] where indices is a list of flat positions */
typedef struct index_expr
{
expr base;
int *indices; /* Flattened indices to select (owned, copied) */
int n_idxs; /* Number of selected elements */
bool has_duplicates; /* True if indices have duplicates (affects Hessian path) */
} index_expr;

#endif /* SUBEXPR_H */
53 changes: 53 additions & 0 deletions python/atoms/index.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// SPDX-License-Identifier: Apache-2.0

#ifndef ATOM_INDEX_H
#define ATOM_INDEX_H

#include "affine.h"
#include "common.h"

/* Index/slicing: y = child[indices] where indices is a list of flattened positions
*/
static PyObject *py_make_index(PyObject *self, PyObject *args)
{
PyObject *child_capsule;
PyObject *indices_obj;

if (!PyArg_ParseTuple(args, "OO", &child_capsule, &indices_obj))
{
return NULL;
}

expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME);
if (!child)
{
PyErr_SetString(PyExc_ValueError, "invalid child capsule");
return NULL;
}

/* Convert indices array to int32 */
PyArrayObject *indices_array = (PyArrayObject *) PyArray_FROM_OTF(
indices_obj, NPY_INT32, NPY_ARRAY_IN_ARRAY);

if (!indices_array)
{
return NULL;
}

int n_idxs = (int) PyArray_SIZE(indices_array);
int *indices_data = (int *) PyArray_DATA(indices_array);

expr *node = new_index(child, indices_data, n_idxs);

Py_DECREF(indices_array);

if (!node)
{
PyErr_SetString(PyExc_RuntimeError, "failed to create index node");
return NULL;
}
expr_retain(node); /* Capsule owns a reference */
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
}

#endif /* ATOM_INDEX_H */
36 changes: 36 additions & 0 deletions python/atoms/quad_over_lin.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// SPDX-License-Identifier: Apache-2.0

#ifndef ATOM_QUAD_OVER_LIN_H
#define ATOM_QUAD_OVER_LIN_H

#include "bivariate.h"
#include "common.h"

/* quad_over_lin: y = sum(x^2) / z where x is left, z is right (scalar) */
static PyObject *py_make_quad_over_lin(PyObject *self, PyObject *args)
{
(void)self;
PyObject *left_capsule, *right_capsule;
if (!PyArg_ParseTuple(args, "OO", &left_capsule, &right_capsule))
{
return NULL;
}
expr *left = (expr *)PyCapsule_GetPointer(left_capsule, EXPR_CAPSULE_NAME);
expr *right = (expr *)PyCapsule_GetPointer(right_capsule, EXPR_CAPSULE_NAME);
if (!left || !right)
{
PyErr_SetString(PyExc_ValueError, "invalid child capsule");
return NULL;
}

expr *node = new_quad_over_lin(left, right);
if (!node)
{
PyErr_SetString(PyExc_RuntimeError, "failed to create quad_over_lin node");
return NULL;
}
expr_retain(node); /* Capsule owns a reference */
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
}

#endif /* ATOM_QUAD_OVER_LIN_H */
36 changes: 36 additions & 0 deletions python/atoms/rel_entr.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// SPDX-License-Identifier: Apache-2.0

#ifndef ATOM_REL_ENTR_H
#define ATOM_REL_ENTR_H

#include "bivariate.h"
#include "common.h"

/* rel_entr: rel_entr(x, y) = x * log(x/y) elementwise */
static PyObject *py_make_rel_entr(PyObject *self, PyObject *args)
{
(void)self;
PyObject *left_capsule, *right_capsule;
if (!PyArg_ParseTuple(args, "OO", &left_capsule, &right_capsule))
{
return NULL;
}
expr *left = (expr *)PyCapsule_GetPointer(left_capsule, EXPR_CAPSULE_NAME);
expr *right = (expr *)PyCapsule_GetPointer(right_capsule, EXPR_CAPSULE_NAME);
if (!left || !right)
{
PyErr_SetString(PyExc_ValueError, "invalid child capsule");
return NULL;
}

expr *node = new_rel_entr_vector_args(left, right);
if (!node)
{
PyErr_SetString(PyExc_RuntimeError, "failed to create rel_entr node");
return NULL;
}
expr_retain(node); /* Capsule owns a reference */
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
}

#endif /* ATOM_REL_ENTR_H */
8 changes: 8 additions & 0 deletions python/bindings.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "atoms/cos.h"
#include "atoms/entr.h"
#include "atoms/exp.h"
#include "atoms/index.h"
#include "atoms/left_matmul.h"
#include "atoms/linear.h"
#include "atoms/log.h"
Expand All @@ -21,6 +22,8 @@
#include "atoms/power.h"
#include "atoms/promote.h"
#include "atoms/quad_form.h"
#include "atoms/quad_over_lin.h"
#include "atoms/rel_entr.h"
#include "atoms/right_matmul.h"
#include "atoms/sin.h"
#include "atoms/sinh.h"
Expand Down Expand Up @@ -55,6 +58,7 @@ static PyMethodDef DNLPMethods[] = {
{"make_linear", py_make_linear, METH_VARARGS, "Create linear op node"},
{"make_log", py_make_log, METH_VARARGS, "Create log node"},
{"make_exp", py_make_exp, METH_VARARGS, "Create exp node"},
{"make_index", py_make_index, METH_VARARGS, "Create index node"},
{"make_add", py_make_add, METH_VARARGS, "Create add node"},
{"make_sum", py_make_sum, METH_VARARGS, "Create sum node"},
{"make_neg", py_make_neg, METH_VARARGS, "Create neg node"},
Expand Down Expand Up @@ -82,6 +86,10 @@ static PyMethodDef DNLPMethods[] = {
"Create right matmul node (f(x) @ A)"},
{"make_quad_form", py_make_quad_form, METH_VARARGS,
"Create quadratic form node (x' * Q * x)"},
{"make_quad_over_lin", py_make_quad_over_lin, METH_VARARGS,
"Create quad_over_lin node (sum(x^2) / y)"},
{"make_rel_entr", py_make_rel_entr, METH_VARARGS,
"Create rel_entr node: x * log(x/y) elementwise"},
{"make_problem", py_make_problem, METH_VARARGS,
"Create problem from objective and constraints"},
{"problem_init_derivatives", py_problem_init_derivatives, METH_VARARGS,
Expand Down
Loading
Loading