POEM 106 - Expose JAX vectorization to users#213
Conversation
|
In order for this to be jax-tracable (for JIT and AD), this would require the entire chain of OpenMDAO calculations to be written in JAX. It would also break iteroperability through other non-jax things like OpenMDAO.jl. If you build your model entirely in Jax and interface to it through OpenMDAO's JaxExplicitComponant or JaxImplicitComponent, I think you can achieve what you're looking for.
Those shapes won't be set until you use Documentation on JaxExplicitComponent can be found here: https://openmdao.org/newdocs/versions/latest/features/experimental/jax_explicitcomp_api.html Dynamic shaping documentation is here: Let me know if you have any questions. I've used this technique very successfully with dymos where we need to do something like a Matrix-Vector product across N points in time simultaneously. |
|
Hi Rob, thanks for the response and suggestion. That's a good suggestion - I'll try it to see if it works for my use case. One tricky aspect to the approach is not knowing which arg indicies of your function are vampped a priori during an analysis sweep. There might be a way to dynamically figure out which variables are vmapped and then define the vmap in the Component, but from my experience there may be a performance hit. |
|
Merging this for now to mark as rejected. But keep the conversation going if there are issues or is anything else we can help with. |
Provide an interface so that OpenMDAO problems can be vectorized for analysis sweeps or used in other contexts where vectorization provides a speedup