Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion evoxels/inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def solve(self, parameters, y0, saveat, adjoint=dfx.ForwardMode(), dt0=0.1):
solver = PseudoSpectralIMEX_dfx(problem.fourier_symbol)

solution = dfx.diffeqsolve(
dfx.ODETerm(lambda t, y, args: problem.rhs(y, t)),
dfx.ODETerm(lambda t, y, args: problem.rhs(t, y)),
solver,
t0=saveat.subs.ts[0],
t1=saveat.subs.ts[-1],
Expand Down
50 changes: 25 additions & 25 deletions evoxels/problem_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,25 @@ def order(self):
pass

@abstractmethod
def rhs_analytic(self, u, t):
def rhs_analytic(self, t, u):
"""Sympy expression of the problem right-hand side.

Args:
u : Sympy function of current state.
t (float): Current time.
u : Sympy function of current state.

Returns:
Sympy function of problem right-hand side.
"""
pass

@abstractmethod
def rhs(self, u, t):
def rhs(self, t, u):
"""Numerical right-hand side of the ODE system.

Args:
u (array): Current state.
t (float): Current time.
u (array): Current state.

Returns:
Same type as ``u`` containing the time derivative.
Expand Down Expand Up @@ -122,35 +122,35 @@ def __post_init__(self):
k_squared = self.vg.fft_k_squared_nonperiodic()

self._fourier_symbol = -self.D * self.A * k_squared

@property
def order(self):
return 2

@property
def fourier_symbol(self):
return self._fourier_symbol
def _eval_f(self, c, t, lib):

def _eval_f(self, t, c, lib):
"""Evaluate source/forcing term using ``self.f``."""
try:
return self.f(c, t, lib)
return self.f(t, c, lib)
except TypeError:
return self.f(c, t)
return self.f(t, c)

@property
def bc_type(self):
return self.BC_type

def pad_bc(self, u):
return self.pad_boundary(u, self.bcs[0], self.bcs[1])

def rhs_analytic(self, u, t):
return self.D*spv.laplacian(u) + self._eval_f(u, t, sp)
def rhs(self, u, t):
def rhs_analytic(self, t, u):
return self.D*spv.laplacian(u) + self._eval_f(t, u, sp)

def rhs(self, t, u):
laplace = self.vg.laplace(self.pad_bc(u))
update = self.D * laplace + self._eval_f(u, t, self.vg.lib)
update = self.D * laplace + self._eval_f(t, u, self.vg.lib)
return update

@dataclass
Expand Down Expand Up @@ -185,15 +185,15 @@ def __post_init__(self):
def pad_bc(self, u):
return self.pad_boundary(u, self.bcs[0], self.bcs[1])

def rhs_analytic(self, mask, u, t):
def rhs_analytic(self, t, u, mask):
grad_m = spv.gradient(mask)
norm_grad_m = sp.sqrt(grad_m.dot(grad_m))

divergence = spv.divergence(self.D*(spv.gradient(u) - u/mask*grad_m))
du = divergence + norm_grad_m*self.bc_flux + mask*self._eval_f(u/mask, t, sp)
du = divergence + norm_grad_m*self.bc_flux + mask*self._eval_f(t, u/mask, sp)
return du

def rhs(self, u, t):
def rhs(self, t, u):
z = self.pad_bc(u)
divergence = self.vg.grad_x_face(self.vg.grad_x_face(z) -\
self.vg.to_x_face(z/self.mask) * self.vg.grad_x_face(self.mask)
Expand All @@ -207,7 +207,7 @@ def rhs(self, u, t):

update = self.D * divergence + \
self.norm*self.bc_flux + \
self.mask[:,1:-1,1:-1,1:-1]*self._eval_f(u/self.mask[:,1:-1,1:-1,1:-1], t, self.vg.lib)
self.mask[:,1:-1,1:-1,1:-1]*self._eval_f(t, u/self.mask[:,1:-1,1:-1,1:-1], self.vg.lib)
return update


Expand Down Expand Up @@ -249,13 +249,13 @@ def _eval_mu(self, c, lib):
except TypeError:
return self.mu_hom(c)

def rhs_analytic(self, c, t):
def rhs_analytic(self, t, c):
mu = self._eval_mu(c, sp) - 2*self.eps*spv.laplacian(c)
fluxes = self.D*c*(1-c)*spv.gradient(mu)
rhs = spv.divergence(fluxes)
return rhs

def rhs(self, c, t):
def rhs(self, t, c):
r"""Evaluate :math:`\partial c / \partial t` for the CH equation.

Numerical computation of
Expand Down Expand Up @@ -341,7 +341,7 @@ def _eval_potential(self, phi, lib):
except TypeError:
return self.potential(phi)

def rhs_analytic(self, phi, t):
def rhs_analytic(self, t, phi):
grad = spv.gradient(phi)
laplace = spv.laplacian(phi)
norm_grad = sp.sqrt(grad.dot(grad))
Expand All @@ -354,7 +354,7 @@ def rhs_analytic(self, phi, t):
+ 3/self.eps * phi * (1-phi) * self.force
return self.M * df_dphi

def rhs(self, phi, t):
def rhs(self, t, phi):
r"""Two-phase Allen-Cahn equation

Microstructural evolution of the order parameter ``\phi``
Expand Down Expand Up @@ -422,13 +422,13 @@ def _eval_interaction(self, u, lib):
except TypeError:
return self.interaction(u)

def rhs_analytic(self, u, t):
def rhs_analytic(self, t, u):
interaction = self._eval_interaction(u, sp)
dc_A = self.D_A*spv.laplacian(u[0]) - interaction + self.feed * (1-u[0])
dc_B = self.D_B*spv.laplacian(u[1]) + interaction - self.kill * u[1]
return (dc_A, dc_B)

def rhs(self, u, t):
def rhs(self, t, u):
r"""Two-component reaction-diffusion system

Use batch channels for multiple species:
Expand Down
3 changes: 2 additions & 1 deletion evoxels/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def solve(
self._handle_outputs(u, frame, time, slice_idx, vtk_out, verbose, plot_bounds, colormap)
frame += 1

u = step(u, time)
u = step(time, u)

end = timer()
time = max_iters * time_increment
Expand Down Expand Up @@ -129,6 +129,7 @@ def _handle_outputs(self, u, frame, time, slice_idx, vtk_out, verbose, plot_boun
filename = self.problem_cls.__name__ + "_" +\
self.fieldnames[0] + f"_{frame:03d}.vtk"
self.vf.export_to_vtk(filename=filename, field_names=self.fieldnames)

if verbose == 'plot':
clear_output(wait=True)
self.vf.plot_slice(self.fieldnames[0], slice_idx, time=time, colormap=colormap, value_bounds=plot_bounds)
12 changes: 6 additions & 6 deletions evoxels/timesteppers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ def order(self) -> int:
pass

@abstractmethod
def step(self, u: State, t: float) -> State:
def step(self, t: float, u: State) -> State:
"""
Take one timestep from t to (t+dt).

Args:
u : Current state
t : Current time
u : Current state
Returns:
Updated state at t + dt.
"""
Expand All @@ -39,8 +39,8 @@ class ForwardEuler(TimeStepper):
def order(self) -> int:
return 1

def step(self, u: State, t: float) -> State:
return u + self.dt * self.problem.rhs(u, t)
def step(self, t: float, u: State) -> State:
return u + self.dt * self.problem.rhs(t, u)


@dataclass
Expand Down Expand Up @@ -68,8 +68,8 @@ def __post_init__(self):
def order(self) -> int:
return 1

def step(self, u: State, t: float) -> State:
dc = self.pad(self.problem.rhs(u, t))
def step(self, t: float, u: State) -> State:
dc = self.pad(self.problem.rhs(t, u))
dc_fft = self._fft_prefac * self.problem.vg.rfftn(dc, dc.shape)
update = self.problem.vg.irfftn(dc_fft, dc.shape)[:,:u.shape[1]]
return u + update
Expand Down
10 changes: 5 additions & 5 deletions evoxels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,16 @@ def rhs_convergence_test(
problem_kwargs["mask"] = mask(*grid)

ODE = ODE_class(vg, **problem_kwargs)
rhs_numeric = ODE.rhs(u, 0)
rhs_numeric = ODE.rhs(0, u)

if n_funcs > 1 and mask is not None:
rhs_analytic = ODE.rhs_analytic(mask_function, test_functions, 0)
rhs_analytic = ODE.rhs_analytic(0, test_functions, mask_function)
elif n_funcs > 1 and mask is None:
rhs_analytic = ODE.rhs_analytic(test_functions, 0)
rhs_analytic = ODE.rhs_analytic(0, test_functions)
elif n_funcs == 1 and mask is not None:
rhs_analytic = [ODE.rhs_analytic(mask_function, test_functions[0], 0)]
rhs_analytic = [ODE.rhs_analytic(0, test_functions[0], mask_function)]
else:
rhs_analytic = [ODE.rhs_analytic(test_functions[0], 0)]
rhs_analytic = [ODE.rhs_analytic(0, test_functions[0])]

# Compute solutions
for j, func in enumerate(test_functions):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_time_solver_multiple_fields():
vf.add_field("a", np.ones(vf.shape))
vf.add_field("b", np.zeros(vf.shape))

def step(u, t):
def step(t, u):
return u + 1

solver = TimeDependentSolver(vf, ["a", "b"], backend="torch", step_fn=step, device="cpu")
Expand Down