From 668c623e5a811a4b8a1cfe6d13b11f09339aa313 Mon Sep 17 00:00:00 2001 From: daubners Date: Thu, 18 Sep 2025 16:42:57 +0100 Subject: [PATCH] change to step(t,y) convention --- evoxels/inversion.py | 2 +- evoxels/problem_definition.py | 50 +++++++++++++++++------------------ evoxels/solvers.py | 3 ++- evoxels/timesteppers.py | 12 ++++----- evoxels/utils.py | 10 +++---- tests/test_solvers.py | 2 +- 6 files changed, 40 insertions(+), 39 deletions(-) diff --git a/evoxels/inversion.py b/evoxels/inversion.py index 2e010cf..dc2bc09 100644 --- a/evoxels/inversion.py +++ b/evoxels/inversion.py @@ -72,7 +72,7 @@ def solve(self, parameters, y0, saveat, adjoint=dfx.ForwardMode(), dt0=0.1): solver = PseudoSpectralIMEX_dfx(problem.fourier_symbol) solution = dfx.diffeqsolve( - dfx.ODETerm(lambda t, y, args: problem.rhs(y, t)), + dfx.ODETerm(lambda t, y, args: problem.rhs(t, y)), solver, t0=saveat.subs.ts[0], t1=saveat.subs.ts[-1], diff --git a/evoxels/problem_definition.py b/evoxels/problem_definition.py index 133c170..59a4f02 100644 --- a/evoxels/problem_definition.py +++ b/evoxels/problem_definition.py @@ -18,12 +18,12 @@ def order(self): pass @abstractmethod - def rhs_analytic(self, u, t): + def rhs_analytic(self, t, u): """Sympy expression of the problem right-hand side. Args: - u : Sympy function of current state. t (float): Current time. + u : Sympy function of current state. Returns: Sympy function of problem right-hand side. @@ -31,12 +31,12 @@ def rhs_analytic(self, u, t): pass @abstractmethod - def rhs(self, u, t): + def rhs(self, t, u): """Numerical right-hand side of the ODE system. Args: - u (array): Current state. t (float): Current time. + u (array): Current state. Returns: Same type as ``u`` containing the time derivative. @@ -122,7 +122,7 @@ def __post_init__(self): k_squared = self.vg.fft_k_squared_nonperiodic() self._fourier_symbol = -self.D * self.A * k_squared - + @property def order(self): return 2 @@ -130,14 +130,14 @@ def order(self): @property def fourier_symbol(self): return self._fourier_symbol - - def _eval_f(self, c, t, lib): + + def _eval_f(self, t, c, lib): """Evaluate source/forcing term using ``self.f``.""" try: - return self.f(c, t, lib) + return self.f(t, c, lib) except TypeError: - return self.f(c, t) - + return self.f(t, c) + @property def bc_type(self): return self.BC_type @@ -145,12 +145,12 @@ def bc_type(self): def pad_bc(self, u): return self.pad_boundary(u, self.bcs[0], self.bcs[1]) - def rhs_analytic(self, u, t): - return self.D*spv.laplacian(u) + self._eval_f(u, t, sp) - - def rhs(self, u, t): + def rhs_analytic(self, t, u): + return self.D*spv.laplacian(u) + self._eval_f(t, u, sp) + + def rhs(self, t, u): laplace = self.vg.laplace(self.pad_bc(u)) - update = self.D * laplace + self._eval_f(u, t, self.vg.lib) + update = self.D * laplace + self._eval_f(t, u, self.vg.lib) return update @dataclass @@ -185,15 +185,15 @@ def __post_init__(self): def pad_bc(self, u): return self.pad_boundary(u, self.bcs[0], self.bcs[1]) - def rhs_analytic(self, mask, u, t): + def rhs_analytic(self, t, u, mask): grad_m = spv.gradient(mask) norm_grad_m = sp.sqrt(grad_m.dot(grad_m)) divergence = spv.divergence(self.D*(spv.gradient(u) - u/mask*grad_m)) - du = divergence + norm_grad_m*self.bc_flux + mask*self._eval_f(u/mask, t, sp) + du = divergence + norm_grad_m*self.bc_flux + mask*self._eval_f(t, u/mask, sp) return du - def rhs(self, u, t): + def rhs(self, t, u): z = self.pad_bc(u) divergence = self.vg.grad_x_face(self.vg.grad_x_face(z) -\ self.vg.to_x_face(z/self.mask) * self.vg.grad_x_face(self.mask) @@ -207,7 +207,7 @@ def rhs(self, u, t): update = self.D * divergence + \ self.norm*self.bc_flux + \ - self.mask[:,1:-1,1:-1,1:-1]*self._eval_f(u/self.mask[:,1:-1,1:-1,1:-1], t, self.vg.lib) + self.mask[:,1:-1,1:-1,1:-1]*self._eval_f(t, u/self.mask[:,1:-1,1:-1,1:-1], self.vg.lib) return update @@ -249,13 +249,13 @@ def _eval_mu(self, c, lib): except TypeError: return self.mu_hom(c) - def rhs_analytic(self, c, t): + def rhs_analytic(self, t, c): mu = self._eval_mu(c, sp) - 2*self.eps*spv.laplacian(c) fluxes = self.D*c*(1-c)*spv.gradient(mu) rhs = spv.divergence(fluxes) return rhs - def rhs(self, c, t): + def rhs(self, t, c): r"""Evaluate :math:`\partial c / \partial t` for the CH equation. Numerical computation of @@ -341,7 +341,7 @@ def _eval_potential(self, phi, lib): except TypeError: return self.potential(phi) - def rhs_analytic(self, phi, t): + def rhs_analytic(self, t, phi): grad = spv.gradient(phi) laplace = spv.laplacian(phi) norm_grad = sp.sqrt(grad.dot(grad)) @@ -354,7 +354,7 @@ def rhs_analytic(self, phi, t): + 3/self.eps * phi * (1-phi) * self.force return self.M * df_dphi - def rhs(self, phi, t): + def rhs(self, t, phi): r"""Two-phase Allen-Cahn equation Microstructural evolution of the order parameter ``\phi`` @@ -422,13 +422,13 @@ def _eval_interaction(self, u, lib): except TypeError: return self.interaction(u) - def rhs_analytic(self, u, t): + def rhs_analytic(self, t, u): interaction = self._eval_interaction(u, sp) dc_A = self.D_A*spv.laplacian(u[0]) - interaction + self.feed * (1-u[0]) dc_B = self.D_B*spv.laplacian(u[1]) + interaction - self.kill * u[1] return (dc_A, dc_B) - def rhs(self, u, t): + def rhs(self, t, u): r"""Two-component reaction-diffusion system Use batch channels for multiple species: diff --git a/evoxels/solvers.py b/evoxels/solvers.py index feb7f21..11b6890 100644 --- a/evoxels/solvers.py +++ b/evoxels/solvers.py @@ -99,7 +99,7 @@ def solve( self._handle_outputs(u, frame, time, slice_idx, vtk_out, verbose, plot_bounds, colormap) frame += 1 - u = step(u, time) + u = step(time, u) end = timer() time = max_iters * time_increment @@ -129,6 +129,7 @@ def _handle_outputs(self, u, frame, time, slice_idx, vtk_out, verbose, plot_boun filename = self.problem_cls.__name__ + "_" +\ self.fieldnames[0] + f"_{frame:03d}.vtk" self.vf.export_to_vtk(filename=filename, field_names=self.fieldnames) + if verbose == 'plot': clear_output(wait=True) self.vf.plot_slice(self.fieldnames[0], slice_idx, time=time, colormap=colormap, value_bounds=plot_bounds) diff --git a/evoxels/timesteppers.py b/evoxels/timesteppers.py index f921628..11e4807 100644 --- a/evoxels/timesteppers.py +++ b/evoxels/timesteppers.py @@ -16,13 +16,13 @@ def order(self) -> int: pass @abstractmethod - def step(self, u: State, t: float) -> State: + def step(self, t: float, u: State) -> State: """ Take one timestep from t to (t+dt). Args: - u : Current state t : Current time + u : Current state Returns: Updated state at t + dt. """ @@ -39,8 +39,8 @@ class ForwardEuler(TimeStepper): def order(self) -> int: return 1 - def step(self, u: State, t: float) -> State: - return u + self.dt * self.problem.rhs(u, t) + def step(self, t: float, u: State) -> State: + return u + self.dt * self.problem.rhs(t, u) @dataclass @@ -68,8 +68,8 @@ def __post_init__(self): def order(self) -> int: return 1 - def step(self, u: State, t: float) -> State: - dc = self.pad(self.problem.rhs(u, t)) + def step(self, t: float, u: State) -> State: + dc = self.pad(self.problem.rhs(t, u)) dc_fft = self._fft_prefac * self.problem.vg.rfftn(dc, dc.shape) update = self.problem.vg.irfftn(dc_fft, dc.shape)[:,:u.shape[1]] return u + update diff --git a/evoxels/utils.py b/evoxels/utils.py index bf6e89b..930e30a 100644 --- a/evoxels/utils.py +++ b/evoxels/utils.py @@ -88,16 +88,16 @@ def rhs_convergence_test( problem_kwargs["mask"] = mask(*grid) ODE = ODE_class(vg, **problem_kwargs) - rhs_numeric = ODE.rhs(u, 0) + rhs_numeric = ODE.rhs(0, u) if n_funcs > 1 and mask is not None: - rhs_analytic = ODE.rhs_analytic(mask_function, test_functions, 0) + rhs_analytic = ODE.rhs_analytic(0, test_functions, mask_function) elif n_funcs > 1 and mask is None: - rhs_analytic = ODE.rhs_analytic(test_functions, 0) + rhs_analytic = ODE.rhs_analytic(0, test_functions) elif n_funcs == 1 and mask is not None: - rhs_analytic = [ODE.rhs_analytic(mask_function, test_functions[0], 0)] + rhs_analytic = [ODE.rhs_analytic(0, test_functions[0], mask_function)] else: - rhs_analytic = [ODE.rhs_analytic(test_functions[0], 0)] + rhs_analytic = [ODE.rhs_analytic(0, test_functions[0])] # Compute solutions for j, func in enumerate(test_functions): diff --git a/tests/test_solvers.py b/tests/test_solvers.py index 36622f4..13c7f05 100644 --- a/tests/test_solvers.py +++ b/tests/test_solvers.py @@ -14,7 +14,7 @@ def test_time_solver_multiple_fields(): vf.add_field("a", np.ones(vf.shape)) vf.add_field("b", np.zeros(vf.shape)) - def step(u, t): + def step(t, u): return u + 1 solver = TimeDependentSolver(vf, ["a", "b"], backend="torch", step_fn=step, device="cpu")