Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions firedrake/ml/jax/fem_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ def to_jax(x: Union[Function, Constant], gather: Optional[bool] = False, batched
x_P = jnp.array(np.ravel(x.dat.global_data), **kwargs)
else:
# Use local data
x_P = jnp.array(np.ravel(x.dat.data_ro), **kwargs)
with x.dat.vec_ro as vec:
x_P = jnp.array(np.ravel(vec.buffer_r), **kwargs)
if batched:
# Default behaviour: add batch dimension after converting to JAX
return x_P[None, :]
Expand Down Expand Up @@ -222,5 +223,7 @@ def from_jax(x: "jax.Array", V: Optional[WithGeometry] = None) -> Union[Function
val = val[0]
return Constant(val)
else:
x_F = Function(V, val=np.asarray(x))
x_F = Function(V)
with x_F.dat.vec_wo as vec:
vec.array_w = np.asarray(x)
return x_F
7 changes: 5 additions & 2 deletions firedrake/ml/pytorch/fem_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ def to_torch(x, gather=False, batched=True, **kwargs):
x_P = torch.tensor(np.ravel(x.dat.global_data), **kwargs)
else:
# Use local data
x_P = torch.tensor(np.ravel(x.dat.data_ro), **kwargs)
with x.dat.vec_ro as vec:
x_P = torch.tensor(np.ravel(vec.buffer_r), **kwargs)
if batched:
# Default behaviour: add batch dimension after converting to PyTorch
return x_P[None, :]
Expand Down Expand Up @@ -218,5 +219,7 @@ def from_torch(x, V=None):
val = val[0]
return Constant(val)
else:
x_F = Function(V, val=x.detach().numpy())
x_F = Function(V)
with x_F.dat.vec_wo as vec:
vec.array_w = x.detach().numpy()
return x_F
34 changes: 34 additions & 0 deletions tests/firedrake/external_operators/test_jax_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,40 @@ def test_forward(u, nn):
assert np.allclose(y_F.dat.data_ro, assembled_N.dat.data_ro)


@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done
@pytest.mark.skipjax # Skip if JAX is not installed
def test_forward_mixed(V, nn):

W = V * V
u = Function(W)
u1, u2 = u.subfunctions
x, y = SpatialCoordinate(V.mesh())
u1.interpolate(sin(pi * x) * sin(pi * y))
u2.interpolate(sin(2 * pi * x) * sin(2 * pi * y))

# Set JaxOperator
n = W.dim()
model = Linear(n, n)

N = ml_operator(model, function_space=W)(u)
# Get model
model = N.model

# Assemble NeuralNet
assembled_N = assemble(N)
assert isinstance(assembled_N, Function)

# Convert from Firedrake to JAX
x_P = to_jax(u)
# Forward pass
y_P = model(x_P)
# Convert from JAX to Firedrake
y_F = from_jax(y_P, u.function_space())

# Check
assert np.allclose(y_F.dat.data_ro, assembled_N.dat.data_ro)


@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done
@pytest.mark.skipjax # Skip if JAX is not installed
def test_jvp(u, nn, rg):
Expand Down
35 changes: 35 additions & 0 deletions tests/firedrake/external_operators/test_pytorch_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,41 @@ def test_forward(u, nn):
assert np.allclose(y_F.dat.data_ro, assembled_N.dat.data_ro)


@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done
@pytest.mark.skiptorch # Skip if PyTorch is not installed
def test_forward_mixed(V, nn):

W = V * V
u = Function(W)
u1, u2 = u.subfunctions
x, y = SpatialCoordinate(V.mesh())
u1.interpolate(sin(pi * x) * sin(pi * y))
u2.interpolate(sin(2 * pi * x) * sin(2 * pi * y))

# Set PytorchOperator
n = W.dim()
model = Linear(n, n)

N = ml_operator(model, function_space=W)(u)
# Get model
model = N.model

# Assemble NeuralNet
assembled_N = assemble(N)

assert isinstance(assembled_N, Function)

# Convert from Firedrake to PyTorch
x_P = to_torch(u)
# Forward pass
y_P = model(x_P)
# Convert from PyTorch to Firedrake
y_F = from_torch(y_P, u.function_space())

# Check
assert np.allclose(y_F.dat.data_ro, assembled_N.dat.data_ro)


@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done
@pytest.mark.skiptorch # Skip if PyTorch is not installed
def test_jvp(u, nn, rg):
Expand Down
Loading