Skip to content

Conversation

@alexallmont
Copy link
Contributor

No description provided.

alexallmont and others added 30 commits September 2, 2025 17:03
- JAX requirements based on their FFI example code using nanobind and cuda 12.
- Experimental: trying docker/podman run args to use CUDA through container.
- Also required clamping pybind11<3 version as in pyproject.toml.
- Removed explicit cmake in devcontainer as VSCode seems to be happier without it (after getting popups that cmake was not found).
- Build issues found in devcontainer on linux host. Firstly, the compiler could not find PyLong_AsInt and suggested switch to PyLong_AsLong, which then compiles OK.
- Secondly, the `char*` conversion using PyArg_ParseTupleAndKeywords was causing build to fail. This seems to be a common problem, perhaps here caused by my using a newer Python (3.12)? For now I'm committing with `-fpermissive` (an overkill suggested workaround from online discussion) with FIXME to discuss in review.
- Reworking rms_norm example in JAX FFI example along with python registration from lncde project.
- CMake options for ROUGHPY_JAX (default on) and ROUGHPY_JAX_CUDA (default off). C++ files are stubs only at the moment.
- No C++ registrations yet, this is primarily testing degree_begin is correct in TensorBasis and adding DenseFreeTensor which is based on it. These are pytree objects for interop with JAX.
- Note TensorBasis is pure `lax.scan` equivalent of tensor_basis_init in tensor_basis.c.
- Rename _roughpy_jax_cpu to _rpy_jax_internals, to be more inline with roughpy compute module. For now removing cpu/cuda split from __init__.py as is also more inline with roughpy compute; consolidation of CMakeLists.txt reflects this, using one file referencing files in subdirs directly.
- _rpy_jax_internals so file copied into roughpy_jax for idiomatic roughpy testing.
- For now, adding RmsNorm example from FFI examples whilst working out linkage wrinkles. FIXME to replace with dense_ft_fma.
- Separate include for XLA c_api and ffi because these generate a lot of compiler warnings. FIXME note to review/replace later to best roughpy practices.
- Note usage of XLA_FFI_DECLARE_HANDLER_SYMBOL to expose symbol to py_module. Again need to check if this is idiomatic.
- Note this is not currently working, but parking work whilst registration is building and loading correctly in __init__.py.
- Arguments to cpu_dense_ft_fma are incorrect; rather than taking all bases, this needs to be changed to work as roughpy compute.
- Note, needed to set CXX_VISIBILITY_PRESET for symbols to get picked up unlike in the in JAX FFI example, may need investigation.
- test_dense_ft_fma is just stub/printf code for now whilst getting up and running.
- This is a quick fix so TensorBasis (in roughpy_compute/common/basis.hpp) can be directly constructed from XLA buffer. Many FIXME comments regarding JAX_ENABLE_X64 whilst clarifying requirements.
- Note to run these tests, one must `export JAX_ENABLE_X64=True`.
- Raises the question which floating point types should be supported too. Types currently toggleable with XlaIndexType and XlaFloatType.
- Also has side effect of cpu_dense_ft_fma_impl widths being long ints, needs further investigations.
- Note that 64 bit makes the lax.scan creation of degree_begin rather ugly as now jnp.array needed as initial value to force int64; may be best instead as a C function/
- Reimplement out_depth, lhs_depth and rhs_depth and basis from original compute code, with some FIXME comments for review where defaults are used.
- Note this is still buggy as python throws a malloc error when test_ft_fma is run. I have checked that out[idx] is not overrunning here which looks fine but most likely I'm passing different args than compute.
- Docstrings in for base roughpy_jax.
- Reverting indexing to S32 type so users are not forced into enabling JAX 64 bit features by default. For now this is implemented by copying values into a temp 64 bit array.
- Added TensorBasis size() method as in compute.
- This incorrect size was generating the wrong size result array for XLA, overrunning buffer in ft_fma causing segfault.
- Beforehand, using a python int for these XLA Attrs was generating a size cast warning. This can be avoided using numpy's int32 constructor.
- Also adding unit test for test_dense_ft_fma, which strictly should have been with previous commit.
- Also added some comments to pick up on in TensorBasis before merge to main.
- Addressing a few points picked up in review: copy_n should be used instead of memcpy, and we did discuss using out_max_degree but this is commented out for now as it would require changing calling code (where result array is alloc-ed).
- Replaced basic::v1 namespace with just basic as this in an inline namespace and not necessary.
- Also replaced use of back inserter with regular copy destination for degree_begin_i64.
- Review feedback, JAX is optional in build so shouldn't be forced for all users.
- Just a note to pick this up before merge to main as it's duplicating cmake code in roughpy folder.
- Review feedback, makes the end-user experience a little nicer.
- Copied from roughpy equivalent cmake; also renamed project to mirror source.
- Rather than nanobind to capsule function and module, use python ctypes to load directly from .so file.
- The .so is no longer a python module so doesn't have the nice cpython versioning/safety rails for various platforms. However, this is based on the starter example in the JAX FFI docs so good starting point.
- PyModule_Add was added in python 3.13 so will not build in current ubuntu 24.04 (3.12)
- Remove nanobind.
- Avoid localisation prompts (DEBIAN_FRONTEND)
- Don't install jax cuda by default.
- Not needed with symbols being loaded directly from .so in `__init__.py`.
- Common fma/exp functions moved into `xla_common.hpp`.
- fma moved update_algebra_params call into separate out/lhs/rhs degree methods.
- Note commit message for prev prev (939fae3) was incorrect: this new code is dense_ft_exp not dense_ft_fma
- After design discussion we are not planning to have to set in the API for now so safe to assume minimum is 0. It can still be set in the C++ API, but not needed in python and JAX.
- Also tidied xla_common methods into cpp file rather than inlining.
- Reducing code duplication for more JAX methods to be added soon.
- Bug found whilst adding more comprehensive tests; old python wrappers were just returning data so chained operations were failing on type (array rather than DenseFreeTensor class) being passed into next call.
- Removed old placeholder exp/log tests and refactored tests from tests/compute instead.
- Tests using `ft_fmexp` and `ft_mul` still in progress until methods added.
- This change mirrors the naming convention of related compute code; the underlying dense/shuffle dispatch will be done through python.
- For discussion: now uses `c` (the equivalent of `rhs` in compute API) in output basis. Previously was using a mix of `a`, `b` and `c` worked from original, which is confusing. The jax shape struct not using `a` may be a problem but currnetly we are considering all basis the same size.
jackleland and others added 30 commits November 27, 2025 15:19
- Integrating foundational reorg from 220-roughpy-jax-streams into current work, preserving previous in-progress lie tensor work.
- ffi.py renamed to ops.py and loaded in __init__.py.
- Merge over remainder of Sam's changes into algebra.py.
- lie tensor tests disabled temporarily to get this code in as soon as possible.
- Should have been removed in d5111d1
Functioning benchmarks and JAX build/packaging fixes
* feat: l2t and t2l (#220)

- Use native JAX for CSC matrix vector mul computation instead of compute dense_lie_to_tensor and dense_tensor_to_lie.
- All l2t and t2l tests updated to test with batched variants.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add RoughPy JAX support

3 participants