Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion grudge/array_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
.. autoclass:: MPINumpyArrayContext
.. class:: MPIPytatoArrayContext
.. autoclass:: MPIEagerJAXArrayContext
.. autoclass:: MPIPytatoJAXArrayContext
.. autofunction:: get_reasonable_array_context_class
"""

Expand Down Expand Up @@ -76,12 +77,19 @@
_HAVE_FUSION_ACTX = False


from arraycontext import ArrayContext, EagerJAXArrayContext, NumpyArrayContext
from arraycontext import (
ArrayContext,
EagerJAXArrayContext,
NumpyArrayContext,
PytatoJAXArrayContext,
)
from arraycontext.container import ArrayContainer

Check warning on line 86 in grudge/array_context.py

View workflow job for this annotation

GitHub Actions / basedpyright

"ArrayContainer" is not exported from module "arraycontext.container"   Import from "arraycontext.typing" instead (reportPrivateImportUsage)
from arraycontext.impl.pytato.compile import LazilyPyOpenCLCompilingFunctionCaller
from arraycontext.pytest import (
_PytestEagerJaxArrayContextFactory,
_PytestNumpyArrayContextFactory,
_PytestPyOpenCLArrayContextFactoryWithClass,
_PytestPytatoJaxArrayContextFactory,
_PytestPytatoPyOpenCLArrayContextFactory,
register_pytest_array_context_factory,
)
Expand All @@ -99,6 +107,7 @@
import pyopencl
import pyopencl.array as cl_array
from arraycontext.container import ArrayContainer

from pytools.tag import Tag


Expand Down Expand Up @@ -447,24 +456,44 @@

# {{{ distributed + eager jax

class MPIEagerJAXArrayContext(EagerJAXArrayContext, MPIBasedArrayContext):

Check failure on line 459 in grudge/array_context.py

View workflow job for this annotation

GitHub Actions / basedpyright

Base classes for class "MPIEagerJAXArrayContext" define method "einsum" in incompatible way   Return type mismatch: base method returns type "Array", override returns type "Array"     "Array" is incompatible with protocol "Array"       "arraycontext.typing.Array" is not assignable to "jax._src.basearray.Array"       "arraycontext.typing.Array" is not assignable to "jax._src.basearray.Array"       "arraycontext.typing.Array" is not assignable to "jax._src.basearray.Array"       "__getitem__" is an incompatible type         Type "(key: Unknown) -> Array" is not assignable to type "(index: Any) -> Array"           Parameter name mismatch: "index" versus "key" ... (reportIncompatibleMethodOverride)

Check failure on line 459 in grudge/array_context.py

View workflow job for this annotation

GitHub Actions / basedpyright

Multiple inheritance is not allowed because the following base classes contain `__init__` or `__new__` methods that may not get called: ArrayContext (reportUnsafeMultipleInheritance)
"""An array context for using distributed computation with :mod:`jax`
eager evaluation.

.. autofunction:: __init__
"""

def __init__(self, mpi_communicator) -> None:

Check warning on line 466 in grudge/array_context.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type annotation is missing for parameter "mpi_communicator" (reportMissingParameterType)

Check warning on line 466 in grudge/array_context.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type of parameter "mpi_communicator" is unknown (reportUnknownParameterType)
super().__init__()

self.mpi_communicator = mpi_communicator

Check warning on line 469 in grudge/array_context.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type annotation for attribute `mpi_communicator` is required because this class is not decorated with `@final` (reportUnannotatedClassAttribute)

def clone(self) -> Self:

Check warning on line 471 in grudge/array_context.py

View workflow job for this annotation

GitHub Actions / basedpyright

Method "clone" is not marked as override but is overriding a method in class "EagerJAXArrayContext" (reportImplicitOverride)
return type(self)(self.mpi_communicator)

# }}}


# {{{ distributed + lazy jax

class MPIPytatoJAXArrayContext(PytatoJAXArrayContext, MPIBasedArrayContext):

Check failure on line 479 in grudge/array_context.py

View workflow job for this annotation

GitHub Actions / basedpyright

Multiple inheritance is not allowed because the following base classes contain `__init__` or `__new__` methods that may not get called: ArrayContext (reportUnsafeMultipleInheritance)
"""An array context for using distributed computation with :mod:`jax`
lazy evaluation.

.. autofunction:: __init__
"""

def __init__(self, mpi_communicator) -> None:

Check warning on line 486 in grudge/array_context.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type annotation is missing for parameter "mpi_communicator" (reportMissingParameterType)

Check warning on line 486 in grudge/array_context.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type of parameter "mpi_communicator" is unknown (reportUnknownParameterType)
super().__init__()

self.mpi_communicator = mpi_communicator

Check warning on line 489 in grudge/array_context.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type annotation for attribute `mpi_communicator` is required because this class is not decorated with `@final` (reportUnannotatedClassAttribute)

def clone(self) -> Self:

Check warning on line 491 in grudge/array_context.py

View workflow job for this annotation

GitHub Actions / basedpyright

Method "clone" is not marked as override but is overriding a method in class "PytatoJAXArrayContext" (reportImplicitOverride)
return type(self)(self.mpi_communicator)

# }}}


# {{{ distributed + pytato array context subclasses

class MPIBasePytatoPyOpenCLArrayContext(
Expand Down Expand Up @@ -565,7 +594,7 @@


class PytestEagerJAXArrayContextFactory(_PytestEagerJaxArrayContextFactory):
actx_class = EagerJAXArrayContext

Check warning on line 597 in grudge/array_context.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type annotation for attribute `actx_class` is required because this class is not decorated with `@final` (reportUnannotatedClassAttribute)

def __call__(self):
import jax
Expand All @@ -573,6 +602,15 @@
return self.actx_class()


class PytestPytatoJAXArrayContextFactory(_PytestPytatoJaxArrayContextFactory):
actx_class = PytatoJAXArrayContext

def __call__(self):
import jax
jax.config.update("jax_enable_x64", True)
return self.actx_class()


register_pytest_array_context_factory("grudge.pyopencl",
PytestPyOpenCLArrayContextFactory)
register_pytest_array_context_factory("grudge.pytato-pyopencl",
Expand All @@ -581,6 +619,8 @@
PytestNumpyArrayContextFactory)
register_pytest_array_context_factory("grudge.eager-jax",
PytestEagerJAXArrayContextFactory)
register_pytest_array_context_factory("grudge.lazy-jax",
PytestPytatoJAXArrayContextFactory)

# }}}

Expand Down
4 changes: 3 additions & 1 deletion test/test_dt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
PytestEagerJAXArrayContextFactory,
PytestNumpyArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
PytestPytatoJAXArrayContextFactory,
PytestPytatoPyOpenCLArrayContextFactory,
)

Expand All @@ -42,7 +43,8 @@
[PytestPyOpenCLArrayContextFactory,
PytestPytatoPyOpenCLArrayContextFactory,
PytestNumpyArrayContextFactory,
PytestEagerJAXArrayContextFactory])
PytestEagerJAXArrayContextFactory,
PytestPytatoJAXArrayContextFactory])

import logging

Expand Down
4 changes: 3 additions & 1 deletion test/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
PytestEagerJAXArrayContextFactory,
PytestNumpyArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
PytestPytatoJAXArrayContextFactory,
PytestPytatoPyOpenCLArrayContextFactory,
)
from grudge.discretization import make_discretization_collection
Expand All @@ -49,7 +50,8 @@
[PytestPyOpenCLArrayContextFactory,
PytestPytatoPyOpenCLArrayContextFactory,
PytestNumpyArrayContextFactory,
PytestEagerJAXArrayContextFactory])
PytestEagerJAXArrayContextFactory,
PytestPytatoJAXArrayContextFactory])


# {{{ inverse metric
Expand Down
Loading