From 68f09547708bf358155ddeba4683af0490305977 Mon Sep 17 00:00:00 2001 From: r0hansaxena Date: Sat, 25 Oct 2025 15:15:09 +0530 Subject: [PATCH 1/7] feat(solvers): add max_wall_time for solver timeout handling --- src/pybamm/simulation.py | 11 ++++ src/pybamm/solvers/base_solver.py | 18 +++++++ src/pybamm/solvers/casadi_solver.py | 2 + src/pybamm/solvers/idaklu_solver.py | 2 + tests/unit/test_solvers/test_base_solver.py | 60 +++++++++++++++++++++ 5 files changed, 93 insertions(+) diff --git a/src/pybamm/simulation.py b/src/pybamm/simulation.py index 1fb362f115..815e8e08d1 100644 --- a/src/pybamm/simulation.py +++ b/src/pybamm/simulation.py @@ -1,6 +1,7 @@ from __future__ import annotations import pickle +import time import warnings from copy import copy from datetime import timedelta @@ -380,6 +381,7 @@ def solve( inputs=None, t_interp=None, initial_conditions=None, + max_wall_time=None, **kwargs, ): """ @@ -448,6 +450,12 @@ def solve( if solver is None: solver = self._solver + if max_wall_time is not None: + solver.max_wall_time = max_wall_time + + if solver.max_wall_time is not None and solver._wall_time_start is None: + solver._wall_time_start = time.time() + if calc_esoh is None: calc_esoh = self._model.calc_esoh else: @@ -764,6 +772,9 @@ def solve( **kwargs, ) except pybamm.SolverError as error: + if "Wall time limit" in str(error): + raise + if ( "non-positive at initial conditions" in error.message and "[experiment]" in error.message diff --git a/src/pybamm/solvers/base_solver.py b/src/pybamm/solvers/base_solver.py index ce0599b230..7baa4c565c 100644 --- a/src/pybamm/solvers/base_solver.py +++ b/src/pybamm/solvers/base_solver.py @@ -4,6 +4,7 @@ import numbers import platform import sys +import time import warnings import casadi @@ -58,6 +59,7 @@ def __init__( on_extrapolation=None, on_failure=None, output_variables=None, + max_wall_time=None, ): self.method = method self.rtol = rtol @@ -69,6 +71,8 @@ def __init__( self._on_extrapolation = on_extrapolation or "warn" self._on_failure = on_failure or "raise" self._model_set_up = {} + self.max_wall_time = max_wall_time + self._wall_time_start = None # Defaults, can be overwritten by specific solver self.name = "Base solver" @@ -795,6 +799,11 @@ def solve( """ pybamm.logger.info(f"Start solving {model.name} with {self.name}") + import time + + if self.max_wall_time is not None: + self._wall_time_start = time.time() + # Make sure model isn't empty self._check_empty_model(model) @@ -1273,6 +1282,15 @@ def step( `model.variables = {}`) """ + + if self.max_wall_time is not None and self._wall_time_start is not None: + elapsed = time.time() - self._wall_time_start + if elapsed > self.max_wall_time: + raise pybamm.SolverError( + f"Wall time limit ({self.max_wall_time}s) exceeded " + f"(elapsed: {elapsed:.2f}s)" + ) + if old_solution is None: old_solution = pybamm.EmptySolution() diff --git a/src/pybamm/solvers/casadi_solver.py b/src/pybamm/solvers/casadi_solver.py index efea953901..ebf908b77d 100644 --- a/src/pybamm/solvers/casadi_solver.py +++ b/src/pybamm/solvers/casadi_solver.py @@ -93,6 +93,7 @@ def __init__( return_solution_if_failed_early=False, perturb_algebraic_initial_conditions=None, integrators_maxcount=100, + max_wall_time=None, ): on_extrapolation = on_extrapolation or "error" super().__init__( @@ -103,6 +104,7 @@ def __init__( root_tol=root_tol, extrap_tol=extrap_tol, on_extrapolation=on_extrapolation, + max_wall_time=max_wall_time, ) if mode in ["safe", "fast", "fast with events", "safe without grid"]: self.mode = mode diff --git a/src/pybamm/solvers/idaklu_solver.py b/src/pybamm/solvers/idaklu_solver.py index 4a0d9edb07..fc54b72770 100644 --- a/src/pybamm/solvers/idaklu_solver.py +++ b/src/pybamm/solvers/idaklu_solver.py @@ -160,6 +160,7 @@ def __init__( output_variables=None, on_failure=None, options=None, + max_wall_time=None, ): # set default options, # (only if user does not supply) @@ -220,6 +221,7 @@ def __init__( output_variables=output_variables, on_extrapolation=on_extrapolation, on_failure=on_failure, + max_wall_time=max_wall_time, ) self.name = "IDA KLU solver" self._supports_interp = True diff --git a/tests/unit/test_solvers/test_base_solver.py b/tests/unit/test_solvers/test_base_solver.py index c0bcbf3ce2..5412e7ff8f 100644 --- a/tests/unit/test_solvers/test_base_solver.py +++ b/tests/unit/test_solvers/test_base_solver.py @@ -2,6 +2,9 @@ # Tests for the Base Solver class # +import time +import unittest + import casadi import numpy as np import pytest @@ -10,6 +13,63 @@ import pybamm +class TestWallTimeTimeout(unittest.TestCase): + """Tests for max_wall_time functionality""" + + def test_no_timeout_when_not_specified(self): + """Test that solver works normally without timeout (experiment terminates on event)""" + experiment = pybamm.Experiment([("Discharge at 1C until 2.5V",)]) + + model = pybamm.lithium_ion.SPM() + sim = pybamm.Simulation(model, experiment=experiment) + + solution = sim.solve() + assert solution is not None + assert len(solution.t) > 0 # Basic check: some data + + def test_timeout_with_experiment(self): + """Test timeout with IDAKLUSolver and long experiment""" + experiment = pybamm.Experiment( + [ + ( + "Discharge at 1C until 2.5V", + "Rest for 1 hour", + "Charge at 1C until 4.2V", + "Rest for 1 hour", + ) + ] + * 500 + ) + model = pybamm.lithium_ion.SPM() + solver = pybamm.CasadiSolver(max_wall_time=1) + sim = pybamm.Simulation(model, solver=solver, experiment=experiment) + + start = time.time() + with pytest.raises(pybamm.SolverError) as context: + sim.solve() + + elapsed = time.time() - start + assert elapsed < 3 + assert "Wall time limit" in str(context.exception) + + def test_timeout_via_simulation_solve(self): + """Test passing max_wall_time through Simulation.solve()""" + experiment = pybamm.Experiment( + [("Discharge at 1C until 2.5V", "Rest for 1 hour")] * 500 + ) + + model = pybamm.lithium_ion.SPM() + sim = pybamm.Simulation(model, experiment=experiment) + + start = time.time() + with pytest.raises(pybamm.SolverError) as context: + sim.solve(max_wall_time=1) # 1s limit + + elapsed = time.time() - start + assert elapsed < 3 + assert "Wall time limit" in str(context.exception) + + class TestBaseSolver: def test_base_solver_init(self): solver = pybamm.BaseSolver(rtol=1e-2, atol=1e-4) From f89b5c42349553c335a63f7525bdb2ecb3c6d6d6 Mon Sep 17 00:00:00 2001 From: r0hansaxena Date: Sat, 25 Oct 2025 15:58:44 +0530 Subject: [PATCH 2/7] added changelog entry --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 010b26d600..7289fd251b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## Features +- Added the max_wall_time parameter to BaseSolver (and child classes) for optional wall-clock timeout handling during solver steps, raising SolverError on exceedance while preserving partial solutions. ([#5240](https://github.com/pybamm-team/PyBaMM/pull/5240)) - Added the `electrode_phases` kwarg to `plot_voltage_components()` which allows choosing between plotting primary or secondary phase overpotentials. ([#5229](https://github.com/pybamm-team/PyBaMM/pull/5229)) - Added the `num_steps_no_progress` and `t_no_progress` options in the `IDAKLUSolver` to early terminate the simulation if little progress is detected. ([#5201](https://github.com/pybamm-team/PyBaMM/pull/5201)) - EvaluateAt symbol: add support for children evaluated at edges ([#5190](https://github.com/pybamm-team/PyBaMM/pull/5190)) From 1d10ba25665320bb78b22a0c1fa221e121a6fa6d Mon Sep 17 00:00:00 2001 From: r0hansaxena Date: Thu, 30 Oct 2025 09:28:58 +0530 Subject: [PATCH 3/7] moved time import Removed unnecessary import of the time module. --- src/pybamm/solvers/base_solver.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/pybamm/solvers/base_solver.py b/src/pybamm/solvers/base_solver.py index 7baa4c565c..d531e48b36 100644 --- a/src/pybamm/solvers/base_solver.py +++ b/src/pybamm/solvers/base_solver.py @@ -799,8 +799,6 @@ def solve( """ pybamm.logger.info(f"Start solving {model.name} with {self.name}") - import time - if self.max_wall_time is not None: self._wall_time_start = time.time() From 1bdea3f9ef4d043962ca644cf0227af394bdc0c3 Mon Sep 17 00:00:00 2001 From: r0hansaxena Date: Thu, 30 Oct 2025 09:38:01 +0530 Subject: [PATCH 4/7] removed max_wall_time from BaseSolver init --- src/pybamm/solvers/base_solver.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pybamm/solvers/base_solver.py b/src/pybamm/solvers/base_solver.py index d531e48b36..862b3fe87a 100644 --- a/src/pybamm/solvers/base_solver.py +++ b/src/pybamm/solvers/base_solver.py @@ -71,7 +71,6 @@ def __init__( self._on_extrapolation = on_extrapolation or "warn" self._on_failure = on_failure or "raise" self._model_set_up = {} - self.max_wall_time = max_wall_time self._wall_time_start = None # Defaults, can be overwritten by specific solver From a2b147256d766002b47c0599900be60b447e3a7b Mon Sep 17 00:00:00 2001 From: r0hansaxena Date: Thu, 30 Oct 2025 10:08:22 +0530 Subject: [PATCH 5/7] Implemented runtime max_wall_time timeout in BaseSolver.solve with propagation --- src/pybamm/solvers/base_solver.py | 201 ++++++++++++++++-------------- 1 file changed, 110 insertions(+), 91 deletions(-) diff --git a/src/pybamm/solvers/base_solver.py b/src/pybamm/solvers/base_solver.py index 862b3fe87a..ab136572c8 100644 --- a/src/pybamm/solvers/base_solver.py +++ b/src/pybamm/solvers/base_solver.py @@ -735,82 +735,85 @@ def _solve_process_calculate_sensitivities_arg( return calculate_sensitivities_list, sensitivities_have_changed def solve( - self, - model, - t_eval=None, - inputs=None, - nproc=None, - calculate_sensitivities=False, - t_interp=None, - initial_conditions=None, - ): - """ - Execute the solver setup and calculate the solution of the model at - specified times. - - Parameters - ---------- - model : :class:`pybamm.BaseModel` - The model whose solution to calculate. Must have attributes rhs and - initial_conditions. All calls to solve must pass in the same model or - an error is raised - t_eval : None, list or ndarray, optional - The times (in seconds) at which to compute the solution. Defaults to None. - inputs : dict or list, optional - A dictionary or list of dictionaries describing any input parameters to - pass to the model when solving - nproc : int, optional - Number of processes to use when solving for more than one set of input - parameters. Defaults to value returned by "os.cpu_count()". - calculate_sensitivities : list of str or bool, optional - Whether the solver calculates sensitivities of all input parameters. Defaults to False. - If only a subset of sensitivities are required, can also pass a - list of input parameter names. **Limitations**: sensitivities are not calculated up to numerical tolerances - so are not guarenteed to be within the tolerances set by the solver, please raise an issue if you - require this functionality. Also, when using this feature with `pybamm.Experiment`, the sensitivities - do not take into account the movement of step-transitions wrt input parameters, so do not use this feature - if the timings of your experimental protocol change rapidly with respect to your input parameters. - t_interp : None, list or ndarray, optional - The times (in seconds) at which to interpolate the solution. Defaults to None. - Only valid for solvers that support intra-solve interpolation (`IDAKLUSolver`). - initial_conditions : dict, numpy.ndarray, or list, optional - Override the model’s default `y0`. Can be: - - - a dict mapping variable names → values - - a 1D array of length `n_states` - - a list of such overrides (one per parallel solve) - - Only valid for IDAKLU solver. - Returns - ------- - :class:`pybamm.Solution` or list of :class:`pybamm.Solution` objects. - If type of `inputs` is `list`, return a list of corresponding - :class:`pybamm.Solution` objects. + self, + model, + t_eval=None, + inputs=None, + nproc=None, + calculate_sensitivities=False, + t_interp=None, + initial_conditions=None, + max_wall_time=None, +): + """ + Execute the solver setup and calculate the solution of the model at + specified times. - Raises - ------ - :class:`pybamm.ModelError` - If an empty model is passed (`model.rhs = {}` and `model.algebraic={}` and - `model.variables = {}`) - :class:`RuntimeError` - If multiple calls to `solve` pass in different models + Parameters + ---------- + model : :class:`pybamm.BaseModel` + The model whose solution to calculate. Must have attributes rhs and + initial_conditions. All calls to solve must pass in the same model or + an error is raised + t_eval : None, list or ndarray, optional + The times (in seconds) at which to compute the solution. Defaults to None. + inputs : dict or list, optional + A dictionary or list of dictionaries describing any input parameters to + pass to the model when solving + nproc : int, optional + Number of processes to use when solving for more than one set of input + parameters. Defaults to value returned by "os.cpu_count()". + calculate_sensitivities : list of str or bool, optional + Whether the solver calculates sensitivities of all input parameters. Defaults to False. + If only a subset of sensitivities are required, can also pass a + list of input parameter names. **Limitations**: sensitivities are not calculated up to numerical tolerances + so are not guarenteed to be within the tolerances set by the solver, please raise an issue if you + require this functionality. Also, when using this feature with `pybamm.Experiment`, the sensitivities + do not take into account the movement of step-transitions wrt input parameters, so do not use this feature + if the timings of your experimental protocol change rapidly with respect to your input parameters. + t_interp : None, list or ndarray, optional + The times (in seconds) at which to interpolate the solution. Defaults to None. + Only valid for solvers that support intra-solve interpolation (`IDAKLUSolver`). + initial_conditions : dict, numpy.ndarray, or list, optional + Override the model’s default `y0`. Can be: + + - a dict mapping variable names → values + - a 1D array of length `n_states` + - a list of such overrides (one per parallel solve) + + Only valid for IDAKLU solver. + Returns + ------- + :class:`pybamm.Solution` or list of :class:`pybamm.Solution` objects. + If type of `inputs` is `list`, return a list of corresponding + :class:`pybamm.Solution` objects. + + Raises + ------ + :class:`pybamm.ModelError` + If an empty model is passed (`model.rhs = {}` and `model.algebraic={}` and + `model.variables = {}`) + :class:`RuntimeError` + If multiple calls to `solve` pass in different models - """ + """ + start_time = time.time() + try: pybamm.logger.info(f"Start solving {model.name} with {self.name}") - - if self.max_wall_time is not None: + + if max_wall_time is not None: self._wall_time_start = time.time() - + # Make sure model isn't empty self._check_empty_model(model) - + # t_eval can only be None if the solver is an algebraic solver. In that case # set it to 0 if t_eval is None: if self._algebraic_solver is False: raise ValueError("t_eval cannot be None") t_eval = np.array([0]) - + # If t_eval is provided as [t0, tf] return the solution at 100 points elif isinstance(t_eval, list): if len(t_eval) == 1 and self._algebraic_solver is True: @@ -824,13 +827,13 @@ def solve( ) elif not self.supports_interp: t_eval = np.linspace(t_eval[0], t_eval[-1], 100) - + # Make sure t_eval is monotonic if (np.diff(t_eval) < 0).any(): raise pybamm.SolverError("t_eval must increase monotonically") - + t_interp = self.process_t_interp(t_interp) - + # Set up inputs # # Argument "inputs" can be either a list of input dicts or @@ -841,18 +844,18 @@ def solve( model_inputs_list = [ self._set_up_model_inputs(model, inputs) for inputs in inputs_list ] - + calculate_sensitivities_list, sensitivities_have_changed = ( BaseSolver._solve_process_calculate_sensitivities_arg( model_inputs_list[0], model, calculate_sensitivities ) ) - + # (Re-)calculate consistent initialization # Assuming initial conditions do not depend on input parameters # when len(inputs_list) > 1, only `model_inputs_list[0]` # is passed to `_set_consistent_initialization`. - # See https://github.com/pybamm-team/PyBaMM/pull/1261 + # See [https://github.com/pybamm-team/PyBaMM/pull/1261](https://github.com/pybamm-team/PyBaMM/pull/1261) if len(model_inputs_list) > 1: all_inputs_names = set( itertools.chain.from_iterable( @@ -871,18 +874,18 @@ def solve( "Input parameters cannot appear in expression " "for initial conditions." ) - + # if any setup configuration has changed, we need to re-set up if sensitivities_have_changed: self._model_set_up.pop(model, None) # CasadiSolver caches its integrators using model, so delete this too if isinstance(self, pybamm.CasadiSolver): self.integrators.pop(model, None) - + # save sensitivity parameters so we can identify them later on # (FYI: this is used in the Solution class) model.calculate_sensitivities = calculate_sensitivities_list - + # Set up (if not done already) timer = pybamm.Timer() # Set the initial conditions @@ -898,7 +901,7 @@ def solve( # up (initial condition, time-scale and length-scale) does # not depend on input parameters. Therefore, only `model_inputs[0]` # is passed to `set_up`. - # See https://github.com/pybamm-team/PyBaMM/pull/1261 + # See [https://github.com/pybamm-team/PyBaMM/pull/1261](https://github.com/pybamm-team/PyBaMM/pull/1261) self.set_up(model, model_inputs_list[0], t_eval) self._model_set_up.update( {model: {"initial conditions": model.concatenated_initial_conditions}} @@ -922,23 +925,23 @@ def solve( else: # Set the standard initial conditions self._set_initial_conditions(model, t_eval[0], model_inputs_list[0]) - + # Solve for the consistent initialization self._set_consistent_initialization(model, t_eval[0], model_inputs_list[0]) - + set_up_time = timer.time() timer.reset() - + # Check initial conditions don't violate events self._check_events_with_initialization(t_eval, model, model_inputs_list[0]) - + # Process discontinuities ( start_indices, end_indices, t_eval, ) = self._get_discontinuity_start_end_indices(model, inputs, t_eval) - + # Integrate separately over each time segment and accumulate into the solution # object, restarting the solver at each discontinuity (and recalculating a # consistent state afterwards if a DAE) @@ -956,6 +959,7 @@ def solve( model_inputs_list, t_interp, initial_conditions, + max_wall_time=max_wall_time ) else: ninputs = len(model_inputs_list) @@ -965,6 +969,7 @@ def solve( t_eval[start_index:end_index], model_inputs_list[0], t_interp=t_interp, + max_wall_time=max_wall_time ) new_solutions = [new_solution] else: @@ -977,7 +982,7 @@ def solve( model_inputs_list, [t_interp] * ninputs, strict=False, - ), + ) + ((max_wall_time,) * ninputs,) ) p.close() p.join() @@ -991,10 +996,10 @@ def solve( else: for i, new_solution in enumerate(new_solutions): solutions[i] = solutions[i] + new_solution - + if solutions[0].termination != "final time": break - + if end_index != len(t_eval): # setup for next integration subsection last_state = solutions[0].y[:, -1] @@ -1006,7 +1011,7 @@ def solve( model, t_eval[end_index], model_inputs_list[0] ) solve_time = timer.time() - + for i, solution in enumerate(solutions): # Check if extrapolation occurred self.check_extrapolation(solution, model.events) @@ -1020,10 +1025,10 @@ def solve( # all solutions get the same solve time, but their integration time # will be different (see https://github.com/pybamm-team/PyBaMM/pull/1261) solutions[i].solve_time = solve_time - + # Restore old y0 model.y0 = old_y0 - + # Report times if len(solutions) == 1: pybamm.logger.info(f"Finish solving {model.name} ({termination})") @@ -1036,7 +1041,7 @@ def solve( pybamm.logger.info( f"Set-up time: {solutions[0].set_up_time}, Solve time: {solutions[0].solve_time}, Total time: {solutions[0].total_time}" ) - + # Raise error if solutions[0] only contains one timestep (except for algebraic # solvers, where we may only expect one time in the solution) if ( @@ -1049,11 +1054,25 @@ def solve( "Check whether simulation terminated too early." ) - # Return solution(s) - if len(solutions) == 1: - return solutions[0] - else: - return solutions + except pybamm.SolverError as e: + if max_wall_time is not None and "Wall time limit exceeded" in str(e): + raise + raise e + + elapsed = time.time() - start_time + if max_wall_time is not None and elapsed > max_wall_time: + partial_sol = solutions[0] if 'solutions' in locals() else None + raise pybamm.SolverError( + f"Wall time limit ({max_wall_time}s) exceeded during solve (took {elapsed:.2f}s)", + solution=partial_sol + ) + + # Return solution(s) + if len(solutions) == 1: + return solutions[0] + else: + return solutions + @staticmethod def filter_discontinuities(t_discon: list, t_eval: list) -> np.ndarray: From 3daebcd59d8e8e7c85edd596fc3daf3de8db3018 Mon Sep 17 00:00:00 2001 From: r0hansaxena Date: Sat, 1 Nov 2025 13:10:56 +0530 Subject: [PATCH 6/7] added entry in the docstring for solve method Add max_wall_time parameter to solver options --- src/pybamm/simulation.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/pybamm/simulation.py b/src/pybamm/simulation.py index 815e8e08d1..04a880b632 100644 --- a/src/pybamm/simulation.py +++ b/src/pybamm/simulation.py @@ -437,6 +437,10 @@ def solve( t_interp : None, list or ndarray, optional The times (in seconds) at which to interpolate the solution. Defaults to None. Only valid for solvers that support intra-solve interpolation (`IDAKLUSolver`). + max_wall_time : float, optional + Maximum wall-clock time (in seconds) for the entire solve process. If + exceeded during solver steps, raises a :class:`pybamm.SolverError` while + preserving partial solution data. Defaults to None (no limit). **kwargs Additional key-word arguments passed to `solver.solve`. See :meth:`pybamm.BaseSolver.solve`. From aeddfcf7f35969e2d5ece9f137ef30e26625c77d Mon Sep 17 00:00:00 2001 From: r0hansaxena Date: Sat, 1 Nov 2025 13:19:07 +0530 Subject: [PATCH 7/7] docstring update --- src/pybamm/simulation.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/pybamm/simulation.py b/src/pybamm/simulation.py index 04a880b632..b7937e48f7 100644 --- a/src/pybamm/simulation.py +++ b/src/pybamm/simulation.py @@ -427,6 +427,8 @@ def solve( Initial State of Charge (SOC) for the simulation. Must be between 0 and 1. If given, overwrites the initial concentrations provided in the parameter set. + direction : str, optional + Direction of the solve ("forward" or "backward"). Defaults to "forward". callbacks : list of callbacks, optional A list of callbacks to be called at each time step. Each callback must implement all the methods defined in :class:`pybamm.callbacks.BaseCallback`. @@ -434,9 +436,15 @@ def solve( Whether to show a progress bar for cycling. If true, shows a progress bar for cycles. Has no effect when not used with an experiment. Default is False. + inputs : dict, optional + Dictionary of input values to override model defaults. If None, uses built-in + values from the parameter set. t_interp : None, list or ndarray, optional The times (in seconds) at which to interpolate the solution. Defaults to None. Only valid for solvers that support intra-solve interpolation (`IDAKLUSolver`). + initial_conditions : dict, optional + Dictionary of initial conditions for the variables. If None, the initial + conditions are inferred from the model. max_wall_time : float, optional Maximum wall-clock time (in seconds) for the entire solve process. If exceeded during solver steps, raises a :class:`pybamm.SolverError` while