diff --git a/src/festim/__init__.py b/src/festim/__init__.py index 3dee1e6f5..da68df1ab 100644 --- a/src/festim/__init__.py +++ b/src/festim/__init__.py @@ -40,6 +40,7 @@ ExportBaseClass, VTXSpeciesExport, VTXTemperatureExport, + ReactionRate, CustomFieldExport, ) from .exports.xdmf import XDMFExport diff --git a/src/festim/exports/__init__.py b/src/festim/exports/__init__.py index e18892474..3cab0f300 100644 --- a/src/festim/exports/__init__.py +++ b/src/festim/exports/__init__.py @@ -14,6 +14,7 @@ ExportBaseClass, VTXSpeciesExport, VTXTemperatureExport, + ReactionRate, CustomFieldExport, ) from .xdmf import XDMFExport @@ -28,6 +29,7 @@ "MinimumSurface", "MinimumVolume", "Profile1DExport", + "ReactionRate", "SurfaceFlux", "SurfaceQuantity", "TotalSurface", diff --git a/src/festim/exports/vtx.py b/src/festim/exports/vtx.py index c64cb35a7..7b1e096d5 100644 --- a/src/festim/exports/vtx.py +++ b/src/festim/exports/vtx.py @@ -8,8 +8,10 @@ from dolfinx import fem, io from festim.helpers import get_interpolation_points +from festim import k_B as _k_B from festim.species import Species, ImplicitSpecies from festim.subdomain.volume_subdomain import VolumeSubdomain +from festim.reaction import Reaction class ExportBaseClass: @@ -235,6 +237,25 @@ def __init__( self.checkpoint = checkpoint self.subdomain = subdomain + @property + def mixed_domain(self) -> bool: + """ + Check if we are in a mixed domain/discontinuous case. This is the case if at least + one of the species in species_dependent_value is defined on a subdomain or if the + custom field is defined on a subdomain. + + Returns: + True if we are in a mixed domain/discontinuous case, False otherwise. + """ + all_explicit_species = [ + spe + for spe in self.species_dependent_value.values() + if isinstance(spe, Species) + ] + return any( + spe.subdomain_to_post_processing_solution for spe in all_explicit_species + ) or (self.subdomain.sub_T if self.subdomain else None) + def set_dolfinx_expression( self, temperature: fem.Constant | fem.Function, @@ -249,12 +270,6 @@ def set_dolfinx_expression( temperature: The temperature field to use in the expression time: The time to use in the expression """ - # check if we are in a mixed domain/discontinuous case - mixed_domain = any( - spe.subdomain_to_post_processing_solution - for spe in self.species_dependent_value.values() - ) or (self.subdomain.sub_T if self.subdomain else None) - # get the arguments of the user-provided expression arguments = inspect.signature(self.expression).parameters @@ -266,42 +281,46 @@ def set_dolfinx_expression( x = ufl.SpatialCoordinate(self.function.function_space.mesh) kwargs["x"] = x if "T" in arguments: - if isinstance(temperature, fem.Function) and mixed_domain: + if isinstance(temperature, fem.Function) and self.mixed_domain: # fem.Function in mixed domain/discontinuous case, use sub_T # NOTE I'm not sure that sub_T is updated at every time step kwargs["T"] = self.subdomain.sub_T else: # else use the provided temperature kwargs["T"] = temperature + # check if there are other arguments and if they are in species_dependent_value for arg in arguments: if arg in self.species_dependent_value: - spe = self.species_dependent_value[arg] - if isinstance(spe, ImplicitSpecies): - raise NotImplementedError( - "Custom fields depending on implicit species are not" - "implemented yet." - ) - if mixed_domain: - kwargs[arg] = spe.subdomain_to_post_processing_solution[ - self.subdomain - ] - else: - kwargs[arg] = spe.post_processing_solution + kwargs[arg] = self._get_species_function( + self.species_dependent_value[arg] + ) assert kwargs[arg] is not None, ( f"Argument {arg} not found in species_dependent_value" ) - self.check_valid_inputs(kwargs, mixed_domain) + self.check_valid_inputs(kwargs) - # evaluate the user-provided expression with the appropriate arguments and create a - # dolfinx.fem.Expression + # evaluate the user-provided expression with the appropriate arguments and + # create a dolfinx.fem.Expression self.dolfinx_expression = fem.Expression( self.expression(**kwargs), get_interpolation_points(self.function.function_space.element), ) - def check_valid_inputs(self, kwargs: dict, mixed_domain: bool): + def _get_species_function(self, spe: Species): + if isinstance(spe, ImplicitSpecies): + if self.mixed_domain: + return spe.concentration_submesh(self.subdomain) + else: + return spe.concentration + else: + if self.mixed_domain: + return spe.subdomain_to_post_processing_solution[self.subdomain] + else: + return spe.post_processing_solution + + def check_valid_inputs(self, kwargs: dict): """ Check if we are in the mixed domain/discontinuous case and if the user-provided expression is valid in this case. @@ -315,7 +334,7 @@ def check_valid_inputs(self, kwargs: dict, mixed_domain: bool): # check the domain of all kwargs and check that they are the same - if mixed_domain and "t" in kwargs: + if self.mixed_domain and "t" in kwargs: raise NotImplementedError( "Time-dependent custom fields are not implemented in the case of a " "mixed domain/discontinuous case." @@ -323,3 +342,93 @@ def check_valid_inputs(self, kwargs: dict, mixed_domain: bool): "defined on the parent mesh." "See https://github.com/FEniCS/dolfinx/issues/3207 for more details." ) + + +class ReactionRate(CustomFieldExport): + def __init__( + self, + reaction: Reaction, + filename: str | Path, + direction: str = "both", + times: list[float] | None = None, + subdomain: VolumeSubdomain | None = None, + checkpoint: bool = False, + ): + + reactant_names = [reactant.name for reactant in reaction.reactant] + if isinstance(reaction.product, list): + product_names = [product.name for product in reaction.product] + else: + product_names = [reaction.product.name] + + def expression(T, **kwargs): + _reactant_names = [kwargs[name] for name in reactant_names] + _product_names = [kwargs[name] for name in product_names] + k = reaction.k_0 * ufl.exp(-reaction.E_k / (_k_B * T)) + if reaction.p_0 and reaction.E_p: + p = reaction.p_0 * ufl.exp(-reaction.E_p / (_k_B * T)) + elif reaction.p_0: + p = reaction.p_0 + else: + p = 0.0 + + forward = k * ufl.product(_reactant_names) + backward = p * ufl.product(_product_names) + + if direction == "forward": + return forward + elif direction == "backward": + return backward + else: + return forward - backward + + self.override_signature(expression, reactant_names, product_names) + + reaction_products = ( + reaction.product + if isinstance(reaction.product, list) + else [reaction.product] + ) + + super().__init__( + filename=filename, + expression=expression, + species_dependent_value={ + spe.name: spe for spe in reaction.reactant + reaction_products + }, + times=times, + subdomain=subdomain, + checkpoint=checkpoint, + ) + + def override_signature( + self, expression: Callable, reactant_names: list[str], product_names: list[str] + ): + """ + Override the signature of the expression function. This is needed to ensure that + the expression has the correct arguments for set_dolfinx_expression(). + + Args: + expression: The user-provided expression for the reaction rate. The arguments + of the expression must be T (temperature) and the names of the reactants + and products. + """ + sig_params = [inspect.Parameter("T", inspect.Parameter.POSITIONAL_OR_KEYWORD)] + # Use dict.fromkeys to preserve order and remove duplicates + for name in dict.fromkeys(reactant_names + product_names): + sig_params.append( + inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD) + ) + expression.__signature__ = inspect.Signature(sig_params) + + assert inspect.signature(expression).parameters.keys() == { + "T", + *reactant_names, + *product_names, + }, ( + "The expression for the reaction rate is automatically generated based on the " + "reaction provided. The arguments of the expression must be T (temperature) and " + "the names of the reactants and products. The current expression has arguments " + f"{inspect.signature(expression).parameters.keys()} but should have arguments " + f"T and {reactant_names + product_names}." + ) diff --git a/src/festim/hydrogen_transport_problem.py b/src/festim/hydrogen_transport_problem.py index 18e9b3d1d..a3a4455f0 100644 --- a/src/festim/hydrogen_transport_problem.py +++ b/src/festim/hydrogen_transport_problem.py @@ -1482,9 +1482,6 @@ def create_subdomain_formulation(self, subdomain: _subdomain.VolumeSubdomain): if reaction.volume != subdomain: continue - # TODO remove - # temporarily overide the solution to the one of the subdomain - self.override_solution_attributes(reaction) # reactant for reactant in reaction.reactant: if isinstance(reactant, _species.Species): @@ -1607,28 +1604,6 @@ def create_formulation(self): }, ) - def override_solution_attributes(self, reaction: _reaction.Reaction): - """Reaction.reaction_term() relies on the .solution attribute of the species - however, in the discontinuous class, this attribute doesn't really make sense - since there is one solution per subdomain. - - Therefore we temporarily override the .solution attribute based on the - reactants, - products, and `others` if there are implicit species - """ - list_of_species_to_override = reaction.reactant + reaction.product - - # check if we have implicit species: - for reactant in reaction.reactant: - if isinstance(reactant, _species.ImplicitSpecies): - for other_spe in reactant.others: - if other_spe not in list_of_species_to_override: - list_of_species_to_override.append(other_spe) - - for species in list_of_species_to_override: - if isinstance(species, _species.Species): - species.solution = species.subdomain_to_solution[reaction.volume] - def create_solver(self): if Version(dolfinx.__version__) == Version("0.9.0"): self.solver = BlockedNewtonSolver( diff --git a/src/festim/reaction.py b/src/festim/reaction.py index 946bfc7e2..588fef7ea 100644 --- a/src/festim/reaction.py +++ b/src/festim/reaction.py @@ -137,6 +137,24 @@ def reaction_term( The reaction term to be used in a formulation. """ + # make sure products is a list + products = self.product if isinstance(self.product, list) else [self.product] + + # detect if mixed_domain + mixed_domain = any( + isinstance(reactant, _Species) and reactant.subdomain_to_solution != {} + for reactant in self.reactant + ) or any( + isinstance(product, _Species) and product.subdomain_to_solution != {} + for product in products + ) + + def get_concentration(species): + if mixed_domain: + return species.concentration_submesh(self.volume) + else: + return species.concentration + if self.product == []: if self.p_0 is not None: raise ValueError( @@ -158,8 +176,6 @@ def reaction_term( "E_p cannot be None when reaction products are present." ) - products = self.product if isinstance(self.product, list) else [self.product] - # reaction rates k = self.k_0 * exp(-self.E_k / (_k_B * temperature)) @@ -176,18 +192,22 @@ def reaction_term( assert len(reactant_concentrations) == len(reactants) for i, reactant in enumerate(reactants): if reactant_concentrations[i] is None: - reactant_concentrations[i] = reactant.concentration + reactant_concentrations[i] = get_concentration(reactant) else: - reactant_concentrations = [reactant.concentration for reactant in reactants] + reactant_concentrations = [ + get_concentration(reactant) for reactant in reactants + ] # if product_concentrations is provided, use these concentrations if product_concentrations is not None: assert len(product_concentrations) == len(products) for i, product in enumerate(products): if product_concentrations[i] is None: - product_concentrations[i] = product.concentration + product_concentrations[i] = get_concentration(product) else: - product_concentrations = [product.concentration for product in products] + product_concentrations = [ + get_concentration(product) for product in products + ] # multiply all concentrations to be used in the term product_of_reactants = reactant_concentrations[0] diff --git a/src/festim/species.py b/src/festim/species.py index 8f8d8babb..162bfbe08 100644 --- a/src/festim/species.py +++ b/src/festim/species.py @@ -108,6 +108,12 @@ def __str__(self) -> str: def concentration(self): return self.solution + def concentration_submesh(self, subdomain: _VolumeSubdomain): + assert subdomain in self.subdomains, ( + f"Species {self.name} has no solution on subdomain {subdomain}." + ) + return self.subdomain_to_solution[subdomain] + @property def legacy(self) -> bool: """Check if we are using FESTIM 1.0 implementation or FESTIM 2.0.""" @@ -169,6 +175,17 @@ def concentration(self): ) return self.value_fenics - sum([other.solution for other in self.others]) + def concentration_submesh(self, subdomain: _VolumeSubdomain): + if len(self.others) > 0: + for other in self.others: + assert other.subdomain_to_solution[subdomain], ( + f"Cannot compute concentration of {self.name} because {other.name}" + + f" has no solution on subdomain {subdomain}." + ) + return self.value_fenics - sum( + [other.subdomain_to_solution[subdomain] for other in self.others] + ) + def create_value_fenics(self, mesh, t: fem.Constant): """Creates the value of the density as a fenics object and sets it to self.value_fenics. If the value is a constant, it is converted to a diff --git a/test/test_reaction.py b/test/test_reaction.py index 850214a58..77c1dccd3 100644 --- a/test/test_reaction.py +++ b/test/test_reaction.py @@ -433,20 +433,3 @@ def test_product_setter_raise_error_E_p_no_product(): for subdomain in my_model.volume_subdomains: my_model.define_function_spaces(subdomain) - - -@pytest.mark.parametrize("reaction", [reac1, reac2]) -def test_override_solution_attributes(reaction): - """Tests the HydrogenTransportProblemDiscontinuous.override_solution_attributes - method Checks that the .solution attribute is the expected one based on the volume - of the reaction.""" - - # RUN - my_model.override_solution_attributes(reaction) - - # TEST - relevant_species = reaction.reactant + reaction.product + empty_traps.others - for species in relevant_species: - if isinstance(species, F.Species): - expected_solution = species.subdomain_to_solution[reaction.volume] - assert species.solution == expected_solution diff --git a/test/test_vtx.py b/test/test_vtx.py index 1e147573c..1c39493f5 100644 --- a/test/test_vtx.py +++ b/test/test_vtx.py @@ -404,3 +404,110 @@ def test_custom_field_not_implemented_error(expression): with pytest.raises(NotImplementedError): my_model.initialise() + + +@pytest.mark.parametrize("direction", ["both", "forward", "backward"]) +@pytest.mark.parametrize("product_type", ["list", "single"]) +@pytest.mark.parametrize("p_0, E_p", [(0.01, 0.05), (0.01, 0.0), (0.0, 0.0)]) +def test_reaction_rate_export(tmp_path, direction, product_type, p_0, E_p): + """ + Test ReactionRate export functionality for different directions, product formats, + and reaction configurations. + """ + if p_0 == 0.0 and direction == "backward": + pytest.skip( + "Backward direction export not supported when backward reaction is disabled" + ) + my_model = F.HydrogenTransportProblem() + mat = F.Material(D_0=1, E_D=0, K_S_0=1, E_K_S=0) + vol = F.VolumeSubdomain(id=1, material=mat) + top = F.SurfaceSubdomain(id=1, locator=lambda x: np.isclose(x[1], 1)) + bottom = F.SurfaceSubdomain(id=2, locator=lambda x: np.isclose(x[1], 0)) + left = F.SurfaceSubdomain(id=3, locator=lambda x: np.isclose(x[0], 0)) + right = F.SurfaceSubdomain(id=4, locator=lambda x: np.isclose(x[0], 1)) + + my_model.subdomains = [vol, top, bottom, left, right] + + dolfinx_mesh = dolfinx.mesh.create_unit_square(MPI.COMM_WORLD, 10, 10) + my_model.mesh = F.Mesh(dolfinx_mesh) + + A = F.Species("A") + B = F.Species("B") + C = F.Species("C") + + my_model.species = [A, B, C] + + my_model.boundary_conditions = [ + F.FixedConcentrationBC(species=A, subdomain=top, value=1), + F.FixedConcentrationBC(species=B, subdomain=left, value=1), + F.FixedConcentrationBC(species=C, subdomain=bottom, value=0), + ] + + reaction = F.Reaction( + reactant=[A, B], + product=[C] if product_type == "list" else C, + k_0=1, + E_k=0.1, + p_0=p_0, + E_p=E_p, + volume=vol, + ) + + my_model.reactions = [reaction] + + my_model.temperature = 300 + + my_model.settings = F.Settings(transient=False, atol=1e-9, rtol=1e-9) + + reaction_rate_export = F.ReactionRate( + filename=tmp_path / f"reaction_rate_{direction}.bp", + reaction=reaction, + direction=direction, + ) + + my_model.exports = [reaction_rate_export] + + my_model.initialise() + my_model.run() + + +def test_reaction_rate_override_signature(): + """ + Test that ReactionRate signature override correctly updates signatures. + """ + mat = F.Material(D_0=1, E_D=0) + vol = F.VolumeSubdomain(id=1, material=mat) + A = F.Species("A") + B = F.Species("B") + reaction = F.Reaction( + reactant=[A], product=[B], k_0=1, E_k=0, p_0=0, E_p=0, volume=vol + ) + + rr = F.ReactionRate(reaction=reaction, filename="dummy.bp") + + def my_expression(**kwargs): + return kwargs.get("x", 0) + kwargs.get("y", 0) + + rr.override_signature(my_expression, ["A"], ["B"]) + import inspect + + sig = inspect.signature(my_expression) + assert set(sig.parameters.keys()) == {"T", "A", "B"} + + +def test_export_base_class_times_and_extension(tmp_path): + """ + Test that ExportBaseClass sorts times and warns when wrong extension is given. + """ + with pytest.warns(UserWarning, match="does not have .bp extension"): + export = F.ExportBaseClass( + filename=tmp_path / "wrong_extension.txt", ext=".bp", times=[3.0, 1.0, 2.0] + ) + + assert export.filename.suffix == ".bp" + assert export.times == [1.0, 2.0, 3.0] + + +def test_export_base_class_no_times(tmp_path): + export = F.ExportBaseClass(filename=tmp_path / "correct.bp", ext=".bp", times=None) + assert export.times is None