Skip to content

JAX compatibility#12

Draft
TTMoursy wants to merge 33 commits intoNihanPol:mainfrom
TTMoursy:main
Draft

JAX compatibility#12
TTMoursy wants to merge 33 commits intoNihanPol:mainfrom
TTMoursy:main

Conversation

@TTMoursy
Copy link
Copy Markdown

@TTMoursy TTMoursy commented Mar 8, 2026

I added JAX functionality through the anis_pta_jax.py file and an example notebook. Everything is documented. I tried installing my fork in a fresh conda environment (with enterprise) and the example notebook ran.

This is not the most efficient way to do JAX-based anisotropy analyses, but it was difficult to make something both general and highly-optimized, so I went for generality in this module. I also did not use object-oriented programming because functional programming is easier for me.

pip installing this package will not give GPU compatibility by default, but will if the user installs the CUDA version of JAX.

Still need to implement SNR method in class
JAX is a dependency of Optimistix already
@TTMoursy TTMoursy marked this pull request as draft April 11, 2026 18:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants