From a62324d54c62c40427d2fb3e85b41142a77c65d4 Mon Sep 17 00:00:00 2001 From: jshipton Date: Fri, 19 Dec 2025 18:39:08 +0000 Subject: [PATCH 1/4] fix from_torch and to_torch to work with mixed function spaces --- firedrake/ml/pytorch/fem_operator.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/firedrake/ml/pytorch/fem_operator.py b/firedrake/ml/pytorch/fem_operator.py index 868c4cbb34..22b57501c5 100644 --- a/firedrake/ml/pytorch/fem_operator.py +++ b/firedrake/ml/pytorch/fem_operator.py @@ -178,7 +178,8 @@ def to_torch(x, gather=False, batched=True, **kwargs): x_P = torch.tensor(np.ravel(x.dat.global_data), **kwargs) else: # Use local data - x_P = torch.tensor(np.ravel(x.dat.data_ro), **kwargs) + with x.dat.vec_ro as vec: + x_P = torch.tensor(np.ravel(vec.buffer_r), **kwargs) if batched: # Default behaviour: add batch dimension after converting to PyTorch return x_P[None, :] @@ -218,5 +219,7 @@ def from_torch(x, V=None): val = val[0] return Constant(val) else: - x_F = Function(V, val=x.detach().numpy()) + x_F = Function(V) + with x_F.dat.vec_wo as vec: + vec.array_w = x.detach().numpy() return x_F From 6da886ca95e2b92bdbd16e482041178eeb06e5c9 Mon Sep 17 00:00:00 2001 From: jshipton Date: Sun, 21 Dec 2025 18:52:44 +0000 Subject: [PATCH 2/4] remove unnecessary dtype arg --- .../test_pytorch_operator.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/firedrake/external_operators/test_pytorch_operator.py b/tests/firedrake/external_operators/test_pytorch_operator.py index d9782d1644..7e0ffef7ba 100644 --- a/tests/firedrake/external_operators/test_pytorch_operator.py +++ b/tests/firedrake/external_operators/test_pytorch_operator.py @@ -109,6 +109,41 @@ def test_forward(u, nn): assert np.allclose(y_F.dat.data_ro, assembled_N.dat.data_ro) +@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done +@pytest.mark.skiptorch # Skip if PyTorch is not installed +def test_forward_mixed(V, nn): + + W = V * V + u = Function(W) + u1, u2 = u.subfunctions + x, y = SpatialCoordinate(V.mesh()) + u1.interpolate(sin(pi * x) * sin(pi * y)) + u2.interpolate(sin(2 * pi * x) * sin(2 * pi * y)) + + # Set PytorchOperator + n = W.dim() + model = Linear(n, n) + + N = ml_operator(model, function_space=W)(u) + # Get model + model = N.model + + # Assemble NeuralNet + assembled_N = assemble(N) + + assert isinstance(assembled_N, Function) + + # Convert from Firedrake to PyTorch + x_P = to_torch(u) + # Forward pass + y_P = model(x_P) + # Convert from PyTorch to Firedrake + y_F = from_torch(y_P, u.function_space()) + + # Check + assert np.allclose(y_F.dat.data_ro, assembled_N.dat.data_ro) + + @pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done @pytest.mark.skiptorch # Skip if PyTorch is not installed def test_jvp(u, nn, rg): From 4eb682890dd31549cdd80cbc9a64784c59154bc8 Mon Sep 17 00:00:00 2001 From: jshipton Date: Sun, 21 Dec 2025 18:57:16 +0000 Subject: [PATCH 3/4] fix from_jax and to_jax for mixed functions --- firedrake/ml/jax/fem_operator.py | 7 ++-- .../external_operators/test_jax_operator.py | 34 +++++++++++++++++++ 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/firedrake/ml/jax/fem_operator.py b/firedrake/ml/jax/fem_operator.py index 94f2875202..cee0d23a86 100644 --- a/firedrake/ml/jax/fem_operator.py +++ b/firedrake/ml/jax/fem_operator.py @@ -173,7 +173,8 @@ def to_jax(x: Union[Function, Constant], gather: Optional[bool] = False, batched x_P = jnp.array(np.ravel(x.dat.global_data), **kwargs) else: # Use local data - x_P = jnp.array(np.ravel(x.dat.data_ro), **kwargs) + with x.dat.vec_ro as vec: + x_P = jnp.array(np.ravel(vec.buffer_r), **kwargs) if batched: # Default behaviour: add batch dimension after converting to JAX return x_P[None, :] @@ -222,5 +223,7 @@ def from_jax(x: "jax.Array", V: Optional[WithGeometry] = None) -> Union[Function val = val[0] return Constant(val) else: - x_F = Function(V, val=np.asarray(x)) + x_F = Function(V) + with x_F.dat.vec_wo as vec: + vec.array_w = np.asarray(x) return x_F diff --git a/tests/firedrake/external_operators/test_jax_operator.py b/tests/firedrake/external_operators/test_jax_operator.py index 55d2e752d4..51ca0ced85 100644 --- a/tests/firedrake/external_operators/test_jax_operator.py +++ b/tests/firedrake/external_operators/test_jax_operator.py @@ -117,6 +117,40 @@ def test_forward(u, nn): assert np.allclose(y_F.dat.data_ro, assembled_N.dat.data_ro) +@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done +@pytest.mark.skiptorch # Skip if PyTorch is not installed +def test_forward_mixed(V, nn): + + W = V * V + u = Function(W) + u1, u2 = u.subfunctions + x, y = SpatialCoordinate(V.mesh()) + u1.interpolate(sin(pi * x) * sin(pi * y)) + u2.interpolate(sin(2 * pi * x) * sin(2 * pi * y)) + + # Set JaxOperator + n = W.dim() + model = Linear(n, n) + + N = ml_operator(model, function_space=W)(u) + # Get model + model = N.model + + # Assemble NeuralNet + assembled_N = assemble(N) + assert isinstance(assembled_N, Function) + + # Convert from Firedrake to JAX + x_P = to_jax(u) + # Forward pass + y_P = model(x_P) + # Convert from JAX to Firedrake + y_F = from_jax(y_P, u.function_space()) + + # Check + assert np.allclose(y_F.dat.data_ro, assembled_N.dat.data_ro) + + @pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done @pytest.mark.skipjax # Skip if JAX is not installed def test_jvp(u, nn, rg): From 82d9a6c5a64ef41486b6d26bfadf6a36ca8bc8b3 Mon Sep 17 00:00:00 2001 From: Dr Jemma Shipton Date: Tue, 6 Jan 2026 17:09:44 +0000 Subject: [PATCH 4/4] Update tests/firedrake/external_operators/test_jax_operator.py Co-authored-by: David A. Ham --- tests/firedrake/external_operators/test_jax_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/firedrake/external_operators/test_jax_operator.py b/tests/firedrake/external_operators/test_jax_operator.py index 51ca0ced85..4ee66ebb42 100644 --- a/tests/firedrake/external_operators/test_jax_operator.py +++ b/tests/firedrake/external_operators/test_jax_operator.py @@ -118,7 +118,7 @@ def test_forward(u, nn): @pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done -@pytest.mark.skiptorch # Skip if PyTorch is not installed +@pytest.mark.skipjax # Skip if JAX is not installed def test_forward_mixed(V, nn): W = V * V