diff --git a/firedrake/ml/jax/fem_operator.py b/firedrake/ml/jax/fem_operator.py index 94f2875202..cee0d23a86 100644 --- a/firedrake/ml/jax/fem_operator.py +++ b/firedrake/ml/jax/fem_operator.py @@ -173,7 +173,8 @@ def to_jax(x: Union[Function, Constant], gather: Optional[bool] = False, batched x_P = jnp.array(np.ravel(x.dat.global_data), **kwargs) else: # Use local data - x_P = jnp.array(np.ravel(x.dat.data_ro), **kwargs) + with x.dat.vec_ro as vec: + x_P = jnp.array(np.ravel(vec.buffer_r), **kwargs) if batched: # Default behaviour: add batch dimension after converting to JAX return x_P[None, :] @@ -222,5 +223,7 @@ def from_jax(x: "jax.Array", V: Optional[WithGeometry] = None) -> Union[Function val = val[0] return Constant(val) else: - x_F = Function(V, val=np.asarray(x)) + x_F = Function(V) + with x_F.dat.vec_wo as vec: + vec.array_w = np.asarray(x) return x_F diff --git a/firedrake/ml/pytorch/fem_operator.py b/firedrake/ml/pytorch/fem_operator.py index 868c4cbb34..22b57501c5 100644 --- a/firedrake/ml/pytorch/fem_operator.py +++ b/firedrake/ml/pytorch/fem_operator.py @@ -178,7 +178,8 @@ def to_torch(x, gather=False, batched=True, **kwargs): x_P = torch.tensor(np.ravel(x.dat.global_data), **kwargs) else: # Use local data - x_P = torch.tensor(np.ravel(x.dat.data_ro), **kwargs) + with x.dat.vec_ro as vec: + x_P = torch.tensor(np.ravel(vec.buffer_r), **kwargs) if batched: # Default behaviour: add batch dimension after converting to PyTorch return x_P[None, :] @@ -218,5 +219,7 @@ def from_torch(x, V=None): val = val[0] return Constant(val) else: - x_F = Function(V, val=x.detach().numpy()) + x_F = Function(V) + with x_F.dat.vec_wo as vec: + vec.array_w = x.detach().numpy() return x_F diff --git a/tests/firedrake/external_operators/test_jax_operator.py b/tests/firedrake/external_operators/test_jax_operator.py index 55d2e752d4..4ee66ebb42 100644 --- a/tests/firedrake/external_operators/test_jax_operator.py +++ b/tests/firedrake/external_operators/test_jax_operator.py @@ -117,6 +117,40 @@ def test_forward(u, nn): assert np.allclose(y_F.dat.data_ro, assembled_N.dat.data_ro) +@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done +@pytest.mark.skipjax # Skip if JAX is not installed +def test_forward_mixed(V, nn): + + W = V * V + u = Function(W) + u1, u2 = u.subfunctions + x, y = SpatialCoordinate(V.mesh()) + u1.interpolate(sin(pi * x) * sin(pi * y)) + u2.interpolate(sin(2 * pi * x) * sin(2 * pi * y)) + + # Set JaxOperator + n = W.dim() + model = Linear(n, n) + + N = ml_operator(model, function_space=W)(u) + # Get model + model = N.model + + # Assemble NeuralNet + assembled_N = assemble(N) + assert isinstance(assembled_N, Function) + + # Convert from Firedrake to JAX + x_P = to_jax(u) + # Forward pass + y_P = model(x_P) + # Convert from JAX to Firedrake + y_F = from_jax(y_P, u.function_space()) + + # Check + assert np.allclose(y_F.dat.data_ro, assembled_N.dat.data_ro) + + @pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done @pytest.mark.skipjax # Skip if JAX is not installed def test_jvp(u, nn, rg): diff --git a/tests/firedrake/external_operators/test_pytorch_operator.py b/tests/firedrake/external_operators/test_pytorch_operator.py index d9782d1644..7e0ffef7ba 100644 --- a/tests/firedrake/external_operators/test_pytorch_operator.py +++ b/tests/firedrake/external_operators/test_pytorch_operator.py @@ -109,6 +109,41 @@ def test_forward(u, nn): assert np.allclose(y_F.dat.data_ro, assembled_N.dat.data_ro) +@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done +@pytest.mark.skiptorch # Skip if PyTorch is not installed +def test_forward_mixed(V, nn): + + W = V * V + u = Function(W) + u1, u2 = u.subfunctions + x, y = SpatialCoordinate(V.mesh()) + u1.interpolate(sin(pi * x) * sin(pi * y)) + u2.interpolate(sin(2 * pi * x) * sin(2 * pi * y)) + + # Set PytorchOperator + n = W.dim() + model = Linear(n, n) + + N = ml_operator(model, function_space=W)(u) + # Get model + model = N.model + + # Assemble NeuralNet + assembled_N = assemble(N) + + assert isinstance(assembled_N, Function) + + # Convert from Firedrake to PyTorch + x_P = to_torch(u) + # Forward pass + y_P = model(x_P) + # Convert from PyTorch to Firedrake + y_F = from_torch(y_P, u.function_space()) + + # Check + assert np.allclose(y_F.dat.data_ro, assembled_N.dat.data_ro) + + @pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done @pytest.mark.skiptorch # Skip if PyTorch is not installed def test_jvp(u, nn, rg):