JAX now supports saved activation offloading. These policies should be expressible in praxis and thus downstream libraries like paxml.
Maxtext already has such integration.
https://github.com/google/maxtext/blob/ebd39aa64d670fa13a313b6f776e01ad9e450321/MaxText/layers/models.py#L231