diff --git a/grudge/array_context.py b/grudge/array_context.py index 0202333d..d170502e 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -6,6 +6,7 @@ .. autoclass:: MPINumpyArrayContext .. class:: MPIPytatoArrayContext .. autoclass:: MPIEagerJAXArrayContext +.. autoclass:: MPIPytatoJAXArrayContext .. autofunction:: get_reasonable_array_context_class """ @@ -76,12 +77,19 @@ _HAVE_FUSION_ACTX = False -from arraycontext import ArrayContext, EagerJAXArrayContext, NumpyArrayContext +from arraycontext import ( + ArrayContext, + EagerJAXArrayContext, + NumpyArrayContext, + PytatoJAXArrayContext, +) +from arraycontext.container import ArrayContainer from arraycontext.impl.pytato.compile import LazilyPyOpenCLCompilingFunctionCaller from arraycontext.pytest import ( _PytestEagerJaxArrayContextFactory, _PytestNumpyArrayContextFactory, _PytestPyOpenCLArrayContextFactoryWithClass, + _PytestPytatoJaxArrayContextFactory, _PytestPytatoPyOpenCLArrayContextFactory, register_pytest_array_context_factory, ) @@ -99,6 +107,7 @@ import pyopencl import pyopencl.array as cl_array from arraycontext.container import ArrayContainer + from pytools.tag import Tag @@ -465,6 +474,26 @@ def clone(self) -> Self: # }}} +# {{{ distributed + lazy jax + +class MPIPytatoJAXArrayContext(PytatoJAXArrayContext, MPIBasedArrayContext): + """An array context for using distributed computation with :mod:`jax` + lazy evaluation. + + .. autofunction:: __init__ + """ + + def __init__(self, mpi_communicator) -> None: + super().__init__() + + self.mpi_communicator = mpi_communicator + + def clone(self) -> Self: + return type(self)(self.mpi_communicator) + +# }}} + + # {{{ distributed + pytato array context subclasses class MPIBasePytatoPyOpenCLArrayContext( @@ -573,6 +602,15 @@ def __call__(self): return self.actx_class() +class PytestPytatoJAXArrayContextFactory(_PytestPytatoJaxArrayContextFactory): + actx_class = PytatoJAXArrayContext + + def __call__(self): + import jax + jax.config.update("jax_enable_x64", True) + return self.actx_class() + + register_pytest_array_context_factory("grudge.pyopencl", PytestPyOpenCLArrayContextFactory) register_pytest_array_context_factory("grudge.pytato-pyopencl", @@ -581,6 +619,8 @@ def __call__(self): PytestNumpyArrayContextFactory) register_pytest_array_context_factory("grudge.eager-jax", PytestEagerJAXArrayContextFactory) +register_pytest_array_context_factory("grudge.lazy-jax", + PytestPytatoJAXArrayContextFactory) # }}} diff --git a/test/test_dt_utils.py b/test/test_dt_utils.py index 41cc6806..ee5f0d7b 100644 --- a/test/test_dt_utils.py +++ b/test/test_dt_utils.py @@ -34,6 +34,7 @@ PytestEagerJAXArrayContextFactory, PytestNumpyArrayContextFactory, PytestPyOpenCLArrayContextFactory, + PytestPytatoJAXArrayContextFactory, PytestPytatoPyOpenCLArrayContextFactory, ) @@ -42,7 +43,8 @@ [PytestPyOpenCLArrayContextFactory, PytestPytatoPyOpenCLArrayContextFactory, PytestNumpyArrayContextFactory, - PytestEagerJAXArrayContextFactory]) + PytestEagerJAXArrayContextFactory, + PytestPytatoJAXArrayContextFactory]) import logging diff --git a/test/test_metrics.py b/test/test_metrics.py index cf3c3035..261b8ef4 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -39,6 +39,7 @@ PytestEagerJAXArrayContextFactory, PytestNumpyArrayContextFactory, PytestPyOpenCLArrayContextFactory, + PytestPytatoJAXArrayContextFactory, PytestPytatoPyOpenCLArrayContextFactory, ) from grudge.discretization import make_discretization_collection @@ -49,7 +50,8 @@ [PytestPyOpenCLArrayContextFactory, PytestPytatoPyOpenCLArrayContextFactory, PytestNumpyArrayContextFactory, - PytestEagerJAXArrayContextFactory]) + PytestEagerJAXArrayContextFactory, + PytestPytatoJAXArrayContextFactory]) # {{{ inverse metric