diff --git a/pyproject.toml b/pyproject.toml index 01357c2ebe..a921c4c068 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -183,7 +183,7 @@ extend-select = [ # "PGH", # pygrep-hooks # "PIE", # flake8-pie # "PL", # pylint - # "PT", # flake8-pytest-style + "PT", # flake8-pytest-style # "PTH", # flake8-use-pathlib # "RET", # flake8-return "RUF", # Ruff-specific diff --git a/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_full.py b/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_full.py index fb0e5f71d2..969d799734 100644 --- a/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_full.py +++ b/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_full.py @@ -13,7 +13,7 @@ def optimtest(): class TestLeadAcidFull: @pytest.mark.parametrize( - "options, t_eval", + ("options", "t_eval"), [ ({"thermal": "isothermal"}, np.linspace(0, 3600 * 17)), ( @@ -76,7 +76,7 @@ def test_set_up(self): optimtest.set_up_model(to_python=False) @pytest.mark.parametrize( - "options, param_update", + ("options", "param_update"), [ ({"thermal": "lumped"}, {"Current function [A]": 1.7}), ({"thermal": "x-full"}, None), diff --git a/tests/unit/test_batch_study.py b/tests/unit/test_batch_study.py index a7c00df16e..cf6bcf1727 100644 --- a/tests/unit/test_batch_study.py +++ b/tests/unit/test_batch_study.py @@ -38,7 +38,7 @@ def test_solve(self): # Tests for exceptions for name in pybamm.BatchStudy.INPUT_LIST: - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Either provide no"): pybamm.BatchStudy( models={"SPM": spm, "SPM uniform": spm_uniform}, **{name: {None}} ) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 980c51eb25..5fc1b4f4b6 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -35,7 +35,9 @@ def test_write_read_uuid(self, tmp_path, write_opt_in): else: assert config_dict["enable_telemetry"] is False - @pytest.mark.parametrize("user_opted_in, user_input", [(True, "y"), (False, "n")]) + @pytest.mark.parametrize( + ("user_opted_in", "user_input"), [(True, "y"), (False, "n")] + ) def test_ask_user_opt_in(self, monkeypatch, capsys, user_opted_in, user_input): # Mock select.select to simulate user input def mock_select(*args, **kwargs): diff --git a/tests/unit/test_experiments/test_base_step.py b/tests/unit/test_experiments/test_base_step.py index 0250c0622f..f2850512f0 100644 --- a/tests/unit/test_experiments/test_base_step.py +++ b/tests/unit/test_experiments/test_base_step.py @@ -3,7 +3,7 @@ @pytest.mark.parametrize( - "test_string, unit_string", + ("test_string", "unit_string"), [ ("123e-1 W", "W"), ("123K", "K"), diff --git a/tests/unit/test_expression_tree/test_coupled_variable.py b/tests/unit/test_expression_tree/test_coupled_variable.py index c46404fb97..339ade4a37 100644 --- a/tests/unit/test_expression_tree/test_coupled_variable.py +++ b/tests/unit/test_expression_tree/test_coupled_variable.py @@ -89,7 +89,7 @@ def test_setter(self): coupled_variables = {"a": a} model.coupled_variables = coupled_variables assert model.coupled_variables == coupled_variables + coupled_variables = {"b": a} with pytest.raises(ValueError, match="Coupled variable with name"): - coupled_variables = {"b": a} model.coupled_variables = coupled_variables diff --git a/tests/unit/test_expression_tree/test_independent_variable.py b/tests/unit/test_expression_tree/test_independent_variable.py index 79c5ab9ea2..16552aff3f 100644 --- a/tests/unit/test_expression_tree/test_independent_variable.py +++ b/tests/unit/test_expression_tree/test_independent_variable.py @@ -23,13 +23,13 @@ def test_time(self): t = pybamm.Time() assert t.name == "time" assert t.evaluate(4) == 4 - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="t must be provided"): t.evaluate(None) t = pybamm.t assert t.name == "time" assert t.evaluate(4) == 4 - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="t must be provided"): t.evaluate(None) assert t.evaluate_for_shape() == 0 diff --git a/tests/unit/test_expression_tree/test_interpolant.py b/tests/unit/test_expression_tree/test_interpolant.py index 277e2d8995..82c457e309 100644 --- a/tests/unit/test_expression_tree/test_interpolant.py +++ b/tests/unit/test_expression_tree/test_interpolant.py @@ -188,8 +188,8 @@ def f(x, y): value = interp._function_evaluate(evaluated_children) # Test evaluation fails with different child shapes - with pytest.raises(ValueError, match="All children must"): - evaluated_children = [np.array([[1, 1]]), np.array([7])] + evaluated_children = [np.array([[1, 1]]), np.array([7])] + with pytest.raises(ValueError, match="All children must have the same shape"): value = interp._function_evaluate(evaluated_children) # Test runs when all children are scalars @@ -295,8 +295,8 @@ def f(x, y, z): value = interp._function_evaluate(evaluated_children) # Test evaluation fails with different child shapes - with pytest.raises(ValueError, match="All children must"): - evaluated_children = [np.array([[1, 1]]), np.ones(()) * 4, np.array([[7]])] + evaluated_children = [np.array([[1, 1]]), np.ones(()) * 4, np.array([[7]])] + with pytest.raises(ValueError, match="All children must have the same shape"): value = interp._function_evaluate(evaluated_children) # Test runs when all children are scalsrs diff --git a/tests/unit/test_expression_tree/test_operations/test_convert_to_casadi.py b/tests/unit/test_expression_tree/test_operations/test_convert_to_casadi.py index 3012111dc8..ca06b5afe7 100644 --- a/tests/unit/test_expression_tree/test_operations/test_convert_to_casadi.py +++ b/tests/unit/test_expression_tree/test_operations/test_convert_to_casadi.py @@ -191,7 +191,7 @@ def test_interpolation(self): # error for not recognized interpolator with pytest.raises(ValueError, match="interpolator"): interp = pybamm.Interpolant(x, data, y, interpolator="idonotexist") - interp_casadi = interp.to_casadi(y=casadi_y) + interp_casadi = interp.to_casadi(y=casadi_y) # error for converted children count y4 = ( @@ -205,7 +205,7 @@ def test_interpolation(self): data4 = 2 * x4 # np.tile(2 * x3, (10, 1)).T with pytest.raises(ValueError, match="Invalid dimension of x"): interp = pybamm.Interpolant(x4_, data4, y4, interpolator="linear") - interp_casadi = interp.to_casadi(y=casadi_y) + interp_casadi = interp.to_casadi(y=casadi_y) def test_interpolation_2d(self): x_ = [np.linspace(0, 1), np.linspace(0, 1)] @@ -249,7 +249,7 @@ def test_interpolation_2d(self): # error for pchip interpolator with pytest.raises(ValueError, match="interpolator should be"): interp = pybamm.Interpolant(x_, Y, y, interpolator="pchip") - interp_casadi = interp.to_casadi(y=casadi_y) + interp_casadi = interp.to_casadi(y=casadi_y) def test_interpolation_3d(self): def f(x, y, z): diff --git a/tests/unit/test_expression_tree/test_parameter.py b/tests/unit/test_expression_tree/test_parameter.py index 127ac4d814..9b74e0df18 100644 --- a/tests/unit/test_expression_tree/test_parameter.py +++ b/tests/unit/test_expression_tree/test_parameter.py @@ -86,12 +86,12 @@ def test_set_input_names(self): assert func.input_names == new_input_names + new_input_names = {"wrong": "input type"} with pytest.raises(TypeError): - new_input_names = {"wrong": "input type"} func.input_names = new_input_names + new_input_names = [var] with pytest.raises(TypeError): - new_input_names = [var] func.input_names = new_input_names def test_print_name(self): diff --git a/tests/unit/test_expression_tree/test_symbol.py b/tests/unit/test_expression_tree/test_symbol.py index 735724ef11..93b8b870ac 100644 --- a/tests/unit/test_expression_tree/test_symbol.py +++ b/tests/unit/test_expression_tree/test_symbol.py @@ -417,7 +417,7 @@ def test_symbol_visualise(self, tmp_path): sym.visualise(str(temp_file)) assert temp_file.exists() - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Invalid file extension"): sym.visualise(str(temp_file.with_suffix(""))) def test_has_spatial_derivatives(self): @@ -559,5 +559,4 @@ def test_bool(self): bool(a) # if statement calls Boolean with pytest.raises(NotImplementedError, match="Boolean"): - if a > 1: - print("a is greater than 1") + bool(a > 1) diff --git a/tests/unit/test_expression_tree/test_unary_operators.py b/tests/unit/test_expression_tree/test_unary_operators.py index c8128c785b..e0b52556bf 100644 --- a/tests/unit/test_expression_tree/test_unary_operators.py +++ b/tests/unit/test_expression_tree/test_unary_operators.py @@ -660,8 +660,11 @@ def test_boundary_value(self): # error if boundary value on tabs and domain is not "current collector" var = pybamm.Variable("var", domain=["negative electrode"]) + with pytest.raises(pybamm.ModelError, match="Can only take boundary"): pybamm.boundary_value(var, "negative tab") + + with pytest.raises(pybamm.ModelError, match="Can only take boundary"): pybamm.boundary_value(var, "positive tab") # boundary value of symbol that evaluates on edges raises error diff --git a/tests/unit/test_logger.py b/tests/unit/test_logger.py index 06e2444c16..f42ad283e4 100644 --- a/tests/unit/test_logger.py +++ b/tests/unit/test_logger.py @@ -31,5 +31,5 @@ def test_logger(self): pybamm.set_logging_level("WARNING") def test_exceptions(self): - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="filename must be specified"): pybamm.get_new_logger("test", None) diff --git a/tests/unit/test_meshes/test_one_dimensional_submesh.py b/tests/unit/test_meshes/test_one_dimensional_submesh.py index 0f5aaafc74..b18beb5b11 100644 --- a/tests/unit/test_meshes/test_one_dimensional_submesh.py +++ b/tests/unit/test_meshes/test_one_dimensional_submesh.py @@ -3,7 +3,7 @@ import numpy as np -@pytest.fixture() +@pytest.fixture def r(): r = pybamm.SpatialVariable( "r", domain=["negative particle"], coord_sys="spherical polar" @@ -11,14 +11,14 @@ def r(): return r -@pytest.fixture() +@pytest.fixture def x(): return pybamm.SpatialVariable( "x", domain=["negative electrode"], coord_sys="cartesian" ) -@pytest.fixture() +@pytest.fixture def geometry(r): geometry = { "negative particle": {r: {"min": pybamm.Scalar(0), "max": pybamm.Scalar(1)}} diff --git a/tests/unit/test_meshes/test_scikit_fem_submesh.py b/tests/unit/test_meshes/test_scikit_fem_submesh.py index 30c45510e4..085a5ac9f5 100644 --- a/tests/unit/test_meshes/test_scikit_fem_submesh.py +++ b/tests/unit/test_meshes/test_scikit_fem_submesh.py @@ -7,7 +7,7 @@ import numpy as np -@pytest.fixture() +@pytest.fixture def param(): return pybamm.ParameterValues( { diff --git a/tests/unit/test_models/test_base_model.py b/tests/unit/test_models/test_base_model.py index 2b1f162455..c2c493d21d 100644 --- a/tests/unit/test_models/test_base_model.py +++ b/tests/unit/test_models/test_base_model.py @@ -191,7 +191,7 @@ def test_get_parameter_info(self, symbols): assert parameter_info["g"][1] == "Parameter" @pytest.mark.parametrize( - "sub, key, parameter_value", + ("sub", "key", "parameter_value"), [ ("sub1", "a", "InputParameter"), ("sub1", "w", "InputParameter"), diff --git a/tests/unit/test_models/test_full_battery_models/test_base_battery_model.py b/tests/unit/test_models/test_full_battery_models/test_base_battery_model.py index 00b0be42b1..e6549f8d31 100644 --- a/tests/unit/test_models/test_full_battery_models/test_base_battery_model.py +++ b/tests/unit/test_models/test_full_battery_models/test_base_battery_model.py @@ -478,7 +478,10 @@ def test_save_load_model(self): ) # raises error if variables are saved without mesh - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="Serialisation: Please provide the mesh if variables are required", + ): model.save_model( filename="test_base_battery_model", variables=model.variables ) diff --git a/tests/unit/test_models/test_full_battery_models/test_equivalent_circuit/test_thevenin.py b/tests/unit/test_models/test_full_battery_models/test_equivalent_circuit/test_thevenin.py index 4f54db7035..ecd897774d 100644 --- a/tests/unit/test_models/test_full_battery_models/test_equivalent_circuit/test_thevenin.py +++ b/tests/unit/test_models/test_full_battery_models/test_equivalent_circuit/test_thevenin.py @@ -46,8 +46,8 @@ def test_changing_number_of_rcs(self): model = pybamm.equivalent_circuit.Thevenin(options=options) model.check_well_posedness() - with pytest.raises(pybamm.OptionError, match="natural numbers"): - options = {"number of rc elements": -1} + options = {"number of rc elements": -1} + with pytest.raises(pybamm.OptionError, match="natural numbers"): # noqa: PT012 model = pybamm.equivalent_circuit.Thevenin(options=options) model.check_well_posedness() diff --git a/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_full.py b/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_full.py index a3de8dc4fe..b0ab15ec80 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_full.py +++ b/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_full.py @@ -39,7 +39,7 @@ def test_model_well_posedness(self): model.check_well_posedness() @pytest.mark.parametrize( - "options, expected_solver", + ("options", "expected_solver"), [ ( {"hydrolysis": "true", "surface form": "differential"}, diff --git a/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_loqs.py b/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_loqs.py index 098f8ef8a5..f122bf0bf4 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_loqs.py +++ b/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_loqs.py @@ -36,7 +36,7 @@ def test_default_geometry(self): ) @pytest.mark.parametrize( - "dimensionality, spatial_method, submesh_type", + ("dimensionality", "spatial_method", "submesh_type"), [ (1, pybamm.FiniteVolume, pybamm.Uniform1DSubMesh), (2, pybamm.ScikitFiniteElement, pybamm.ScikitUniform2DSubMesh), diff --git a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_electrode_soh.py b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_electrode_soh.py index 80ff155369..4ac0ab1421 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_electrode_soh.py +++ b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_electrode_soh.py @@ -7,7 +7,7 @@ # Fixture for TestElectrodeSOHMSMR, TestCalculateTheoreticalEnergy and TestGetInitialOCPMSMR class. -@pytest.fixture() +@pytest.fixture def options(): options = { "open-circuit potential": "MSMR", @@ -363,13 +363,11 @@ def test_error(self): parameter_values, known_value="something else" ) + param_MSMR = pybamm.lithium_ion.MSMR({"number of MSMR reactions": "3"}).param with pytest.raises( ValueError, match="Known value must be cell capacity or cyclable lithium capacity", ): - param_MSMR = pybamm.lithium_ion.MSMR( - {"number of MSMR reactions": "3"} - ).param pybamm.models.full_battery_models.lithium_ion.electrode_soh._ElectrodeSOHMSMR( param=param_MSMR, known_value="something else" ) diff --git a/tests/unit/test_parameters/test_parameter_values.py b/tests/unit/test_parameters/test_parameter_values.py index 3078e13ab4..9a66d863b4 100644 --- a/tests/unit/test_parameters/test_parameter_values.py +++ b/tests/unit/test_parameters/test_parameter_values.py @@ -325,16 +325,13 @@ def test_process_symbol(self): # not found with pytest.raises(KeyError): - x = pybamm.Parameter("x") - parameter_values.process_symbol(x) + parameter_values.process_symbol(pybamm.Parameter("x")) parameter_values = pybamm.ParameterValues({"x": np.nan}) with pytest.raises(ValueError, match="Parameter 'x' not found"): - x = pybamm.Parameter("x") - parameter_values.process_symbol(x) + parameter_values.process_symbol(pybamm.Parameter("x")) with pytest.raises(ValueError, match="possibly a function"): - x = pybamm.FunctionParameter("x", {}) - parameter_values.process_symbol(x) + parameter_values.process_symbol(pybamm.FunctionParameter("x", {})) def test_process_parameter_in_parameter(self): parameter_values = pybamm.ParameterValues( @@ -1003,7 +1000,9 @@ def test_evaluate(self): ) y = pybamm.StateVector(slice(0, 1)) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="symbol must evaluate to a constant scalar or array" + ): parameter_values.evaluate(y) def test_exchange_current_density_plating(self): diff --git a/tests/unit/test_pybamm_data.py b/tests/unit/test_pybamm_data.py index 6d73c633f8..308ea7c843 100644 --- a/tests/unit/test_pybamm_data.py +++ b/tests/unit/test_pybamm_data.py @@ -21,7 +21,9 @@ def test_fetch(): ) def test_fetch_fake(): # Try to fetch a fake file not present in the registry - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="File 'NotAfile.json' is not in the registry." + ): data_loader.get_data("NotAfile.json") diff --git a/tests/unit/test_serialisation/test_serialisation.py b/tests/unit/test_serialisation/test_serialisation.py index 2c38f47b6f..5a680b41a8 100644 --- a/tests/unit/test_serialisation/test_serialisation.py +++ b/tests/unit/test_serialisation/test_serialisation.py @@ -304,15 +304,16 @@ def test_get_pybamm_class(self, mocker): assert isinstance(mesh_class, pybamm.Mesh) + unrecognised_symbol = { + "py/id": mocker.ANY, + "py/object": "pybamm.expression_tree.scalar.Scale", + "name": "5.0", + "id": mocker.ANY, + "value": 5.0, + "children": [], + } + with pytest.raises(AttributeError): - unrecognised_symbol = { - "py/id": mocker.ANY, - "py/object": "pybamm.expression_tree.scalar.Scale", - "name": "5.0", - "id": mocker.ANY, - "value": 5.0, - "children": [], - } Serialise()._get_pybamm_class(unrecognised_symbol) def test_reconstruct_symbol(self, mocker): diff --git a/tests/unit/test_settings.py b/tests/unit/test_settings.py index 6573929ad9..0680b06394 100644 --- a/tests/unit/test_settings.py +++ b/tests/unit/test_settings.py @@ -34,11 +34,11 @@ def test_smoothing_parameters(self): pybamm.settings.set_smoothing_parameters("exact") # Test errors + pybamm.settings.min_max_mode = "smooth" with pytest.raises(ValueError, match="greater than 1"): - pybamm.settings.min_max_mode = "smooth" pybamm.settings.min_max_smoothing = 0.9 + pybamm.settings.min_max_mode = "soft" with pytest.raises(ValueError, match="positive number"): - pybamm.settings.min_max_mode = "soft" pybamm.settings.min_max_smoothing = -10 with pytest.raises(ValueError, match="positive number"): pybamm.settings.heaviside_smoothing = -10 diff --git a/tests/unit/test_simulation.py b/tests/unit/test_simulation.py index 87947edfb9..76f6203781 100644 --- a/tests/unit/test_simulation.py +++ b/tests/unit/test_simulation.py @@ -466,13 +466,19 @@ def oscillating(t): def f(t, x=x): return x + t - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="Input function must return a real number output at t = 0", + ): operating_mode(f) def g(t, y): return t - with pytest.raises(TypeError): + with pytest.raises( + TypeError, + match="Input function must have only 1 positional argument for time", + ): operating_mode(g) def test_save_load(self, tmp_path): @@ -575,7 +581,10 @@ def test_plot(self): sim = pybamm.Simulation(pybamm.lithium_ion.SPM()) # test exception if not solved - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="Model has not been solved, please solve the model before plotting.", + ): sim.plot() # now solve and plot diff --git a/tests/unit/test_solvers/test_idaklu_jax.py b/tests/unit/test_solvers/test_idaklu_jax.py index 9550884444..5d1cc2d910 100644 --- a/tests/unit/test_solvers/test_idaklu_jax.py +++ b/tests/unit/test_solvers/test_idaklu_jax.py @@ -199,7 +199,7 @@ def test_no_inputs(self): # Scalar evaluation @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_f_scalar(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper(f)(t_eval[k], inputs) @@ -208,7 +208,7 @@ def test_f_scalar(self, output_variables, idaklu_jax_solver, f, wrapper): ) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_f_vector(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper(f)(t_eval, inputs) @@ -217,7 +217,7 @@ def test_f_vector(self, output_variables, idaklu_jax_solver, f, wrapper): ) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_f_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper(jax.vmap(f, in_axes=in_axes))(t_eval, inputs) @@ -226,7 +226,7 @@ def test_f_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): ) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_f_batch_over_inputs(self, output_variables, idaklu_jax_solver, f, wrapper): inputs_mock = np.array([1.0, 2.0, 3.0]) @@ -236,22 +236,22 @@ def test_f_batch_over_inputs(self, output_variables, idaklu_jax_solver, f, wrapp # Get all vars (should mirror test_f_* [above]) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_getvars_call_signature( self, output_variables, idaklu_jax_solver, f, wrapper ): if wrapper == jax.jit: return # test does not involve a JAX expression - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Invalid call signature"): idaklu_jax_solver.get_vars() # no variable name specified idaklu_jax_solver.get_vars(output_variables) # (okay) idaklu_jax_solver.get_vars(f, output_variables) # (okay) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Invalid call signature"): idaklu_jax_solver.get_vars(1, 2, 3) # too many arguments @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_getvars_scalar(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper(idaklu_jax_solver.get_vars(output_variables))(t_eval[k], inputs) @@ -260,7 +260,7 @@ def test_getvars_scalar(self, output_variables, idaklu_jax_solver, f, wrapper): ) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_getvars_vector(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper(idaklu_jax_solver.get_vars(output_variables))(t_eval, inputs) @@ -269,7 +269,7 @@ def test_getvars_vector(self, output_variables, idaklu_jax_solver, f, wrapper): ) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_getvars_vector_array( self, output_variables, idaklu_jax_solver, f, wrapper @@ -281,7 +281,7 @@ def test_getvars_vector_array( np.testing.assert_allclose(out, array) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_getvars_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper( @@ -297,22 +297,22 @@ def test_getvars_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): # Isolate single output variable @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_getvar_call_signature( self, output_variables, idaklu_jax_solver, f, wrapper ): if wrapper == jax.jit: return # test does not involve a JAX expression - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Invalid call signature"): idaklu_jax_solver.get_var() # no variable name specified idaklu_jax_solver.get_var(output_variables[0]) # (okay) idaklu_jax_solver.get_var(f, output_variables[0]) # (okay) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Invalid call signature"): idaklu_jax_solver.get_var(1, 2, 3) # too many arguments @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_getvar_scalar_float_jaxpr( self, output_variables, idaklu_jax_solver, f, wrapper @@ -323,7 +323,7 @@ def test_getvar_scalar_float_jaxpr( np.testing.assert_allclose(out, sim[outvar](float(t_eval[k]))) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_getvar_scalar_float_f( self, output_variables, idaklu_jax_solver, f, wrapper @@ -336,7 +336,7 @@ def test_getvar_scalar_float_f( np.testing.assert_allclose(out, sim[outvar](float(t_eval[k]))) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_getvar_scalar_jaxpr(self, output_variables, idaklu_jax_solver, f, wrapper): # Per variable checks using the default JAX expression (self.jaxpr) @@ -345,7 +345,7 @@ def test_getvar_scalar_jaxpr(self, output_variables, idaklu_jax_solver, f, wrapp np.testing.assert_allclose(out, sim[outvar](t_eval[k])) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_getvar_scalar_f(self, output_variables, idaklu_jax_solver, f, wrapper): # Per variable checks using a provided JAX expression (f) @@ -354,7 +354,7 @@ def test_getvar_scalar_f(self, output_variables, idaklu_jax_solver, f, wrapper): np.testing.assert_allclose(out, sim[outvar](t_eval[k])) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_getvar_vector_jaxpr(self, output_variables, idaklu_jax_solver, f, wrapper): # Per variable checks using the default JAX expression (self.jaxpr) @@ -363,7 +363,7 @@ def test_getvar_vector_jaxpr(self, output_variables, idaklu_jax_solver, f, wrapp np.testing.assert_allclose(out, sim[outvar](t_eval)) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_getvar_vector_f(self, output_variables, idaklu_jax_solver, f, wrapper): # Per variable checks using a provided JAX expression (f) @@ -372,7 +372,7 @@ def test_getvar_vector_f(self, output_variables, idaklu_jax_solver, f, wrapper): np.testing.assert_allclose(out, sim[outvar](t_eval)) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_getvar_vector_array(self, output_variables, idaklu_jax_solver, f, wrapper): # Per variable checks using a provided np.ndarray @@ -384,7 +384,7 @@ def test_getvar_vector_array(self, output_variables, idaklu_jax_solver, f, wrapp np.testing.assert_allclose(out, sim[outvar](t_eval)) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_getvar_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): for outvar in output_variables: @@ -399,7 +399,7 @@ def test_getvar_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): # Differentiation rules (jacfwd) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_jacfwd_scalar(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper(jax.jacfwd(f, argnums=1))(t_eval[k], inputs) @@ -415,7 +415,7 @@ def test_jacfwd_scalar(self, output_variables, idaklu_jax_solver, f, wrapper): np.testing.assert_allclose(flat_out, check.flatten()) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_jacfwd_vector(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper(jax.jacfwd(f, argnums=1))(t_eval, inputs) @@ -434,7 +434,7 @@ def test_jacfwd_vector(self, output_variables, idaklu_jax_solver, f, wrapper): ) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_jacfwd_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper( @@ -455,7 +455,7 @@ def test_jacfwd_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): np.testing.assert_allclose(flat_out, check.flatten()) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_jacfwd_vmap_wrt_time( self, output_variables, idaklu_jax_solver, f, wrapper @@ -469,7 +469,7 @@ def test_jacfwd_vmap_wrt_time( )(t_eval, inputs) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_jacfwd_batch_over_inputs( self, output_variables, idaklu_jax_solver, f, wrapper @@ -486,7 +486,7 @@ def test_jacfwd_batch_over_inputs( # Differentiation rules (jacrev) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_jacrev_scalar(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper(jax.jacrev(f, argnums=1))(t_eval[k], inputs) @@ -502,7 +502,7 @@ def test_jacrev_scalar(self, output_variables, idaklu_jax_solver, f, wrapper): np.testing.assert_allclose(flat_out, check.flatten()) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_jacrev_vector(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper(jax.jacrev(f, argnums=1))(t_eval, inputs) @@ -518,7 +518,7 @@ def test_jacrev_vector(self, output_variables, idaklu_jax_solver, f, wrapper): np.testing.assert_allclose(flat_out, check.flatten()) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_jacrev_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper( @@ -539,7 +539,7 @@ def test_jacrev_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): np.testing.assert_allclose(flat_out, check.flatten()) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_jacrev_batch_over_inputs( self, output_variables, idaklu_jax_solver, f, wrapper @@ -556,7 +556,7 @@ def test_jacrev_batch_over_inputs( # Forward differentiation rules with get_vars (multiple) and get_var (singular) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_jacfwd_scalar_getvars( self, output_variables, idaklu_jax_solver, f, wrapper @@ -581,7 +581,7 @@ def test_jacfwd_scalar_getvars( np.testing.assert_allclose(flat_out, flat_check) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_jacfwd_scalar_getvar( self, output_variables, idaklu_jax_solver, f, wrapper @@ -602,7 +602,7 @@ def test_jacfwd_scalar_getvar( np.testing.assert_allclose(flat_out, flat_check) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_jacfwd_vector_getvars( self, output_variables, idaklu_jax_solver, f, wrapper @@ -628,7 +628,7 @@ def test_jacfwd_vector_getvars( np.testing.assert_allclose(flat_out, flat_check) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_jacfwd_vector_getvar( self, output_variables, idaklu_jax_solver, f, wrapper @@ -649,7 +649,7 @@ def test_jacfwd_vector_getvar( np.testing.assert_allclose(flat_out, flat_check) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_jacfwd_vmap_getvars(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper( @@ -670,7 +670,7 @@ def test_jacfwd_vmap_getvars(self, output_variables, idaklu_jax_solver, f, wrapp np.testing.assert_allclose(flat_out, check.flatten()) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_jacfwd_vmap_getvar(self, output_variables, idaklu_jax_solver, f, wrapper): for outvar in output_variables: @@ -691,7 +691,7 @@ def test_jacfwd_vmap_getvar(self, output_variables, idaklu_jax_solver, f, wrappe # Reverse differentiation rules with get_vars (multiple) and get_var (singular) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_jacrev_scalar_getvars( self, output_variables, idaklu_jax_solver, f, wrapper @@ -716,7 +716,7 @@ def test_jacrev_scalar_getvars( np.testing.assert_allclose(flat_out, flat_check) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_jacrev_scalar_getvar( self, output_variables, idaklu_jax_solver, f, wrapper @@ -739,7 +739,7 @@ def test_jacrev_scalar_getvar( ) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_jacrev_vector_getvars( self, output_variables, idaklu_jax_solver, f, wrapper @@ -765,7 +765,7 @@ def test_jacrev_vector_getvars( np.testing.assert_allclose(flat_out, flat_check) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_jacrev_vector_getvar( self, output_variables, idaklu_jax_solver, f, wrapper @@ -786,7 +786,7 @@ def test_jacrev_vector_getvar( np.testing.assert_allclose(flat_out, flat_check) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_jacrev_vmap_getvars(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper( @@ -807,7 +807,7 @@ def test_jacrev_vmap_getvars(self, output_variables, idaklu_jax_solver, f, wrapp np.testing.assert_allclose(flat_out, check.flatten()) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_jacrev_vmap_getvar(self, output_variables, idaklu_jax_solver, f, wrapper): for outvar in output_variables: @@ -828,7 +828,7 @@ def test_jacrev_vmap_getvar(self, output_variables, idaklu_jax_solver, f, wrappe # Gradient rule (takes single variable) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_grad_scalar_getvar(self, output_variables, idaklu_jax_solver, f, wrapper): for outvar in output_variables: @@ -844,7 +844,7 @@ def test_grad_scalar_getvar(self, output_variables, idaklu_jax_solver, f, wrappe np.testing.assert_allclose(flat_out, check.flatten()) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_grad_vmap_getvar(self, output_variables, idaklu_jax_solver, f, wrapper): for outvar in output_variables: @@ -865,7 +865,7 @@ def test_grad_vmap_getvar(self, output_variables, idaklu_jax_solver, f, wrapper) # Value and gradient (takes single variable) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_value_and_grad_scalar( self, output_variables, idaklu_jax_solver, f, wrapper @@ -887,7 +887,7 @@ def test_value_and_grad_scalar( np.testing.assert_allclose(flat_t, check.flatten()) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_value_and_grad_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): for outvar in output_variables: @@ -912,7 +912,7 @@ def test_value_and_grad_vmap(self, output_variables, idaklu_jax_solver, f, wrapp # Helper functions - These return values (not jaxexprs) so cannot be JITed @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_jax_vars(self, output_variables, idaklu_jax_solver, f, wrapper): if wrapper == jax.jit: @@ -929,7 +929,7 @@ def test_jax_vars(self, output_variables, idaklu_jax_solver, f, wrapper): ) @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_jax_grad(self, output_variables, idaklu_jax_solver, f, wrapper): if wrapper == jax.jit: @@ -948,7 +948,7 @@ def test_jax_grad(self, output_variables, idaklu_jax_solver, f, wrapper): # Wrap jaxified expression in another function and take the gradient @pytest.mark.parametrize( - "output_variables,idaklu_jax_solver,f,wrapper", make_test_cases() + ("output_variables", "idaklu_jax_solver", "f", "wrapper"), make_test_cases() ) def test_grad_wrapper_sse(self, output_variables, idaklu_jax_solver, f, wrapper): # Use surrogate for experimental data diff --git a/tests/unit/test_solvers/test_idaklu_solver.py b/tests/unit/test_solvers/test_idaklu_solver.py index daa3557d83..e9cf4265a2 100644 --- a/tests/unit/test_solvers/test_idaklu_solver.py +++ b/tests/unit/test_solvers/test_idaklu_solver.py @@ -689,7 +689,8 @@ def test_failures(self): solver = pybamm.IDAKLUSolver() t_eval = [0, 3] - with pytest.raises(ValueError): + # raises `std::exception` on UNIX and `IDA failed with flag -22` on Windows + with pytest.raises(ValueError): # noqa: PT011 solver.solve(model, t_eval) def test_dae_solver_algebraic_model(self): @@ -828,7 +829,7 @@ def test_setup_options(self): soln.y, soln_base.y, rtol=1e-5, atol=1e-4 ) else: - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Unknown"): soln = solver.solve(model, t_eval, t_interp=t_interp) def test_solver_options(self): @@ -894,7 +895,8 @@ def test_solver_options(self): options = {option: options_fail[option]} solver = pybamm.IDAKLUSolver(options=options) - with pytest.raises(ValueError): + # raises `std::exception` on UNIX and `IDA failed with flag -22` on Windows + with pytest.raises(ValueError): # noqa: PT011 solver.solve(model, t_eval) def test_with_output_variables(self): diff --git a/tests/unit/test_solvers/test_jax_solver.py b/tests/unit/test_solvers/test_jax_solver.py index 8f43eda3c7..d17712f1be 100644 --- a/tests/unit/test_solvers/test_jax_solver.py +++ b/tests/unit/test_solvers/test_jax_solver.py @@ -218,7 +218,7 @@ def test_get_solve(self): disc.process_model(model) # test that another method string gives error - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="method must be one of"): solver = pybamm.JaxSolver(method="not_real") # Solve diff --git a/tests/unit/test_solvers/test_solution.py b/tests/unit/test_solvers/test_solution.py index 95600a7965..0269b880d4 100644 --- a/tests/unit/test_solvers/test_solution.py +++ b/tests/unit/test_solvers/test_solution.py @@ -373,7 +373,10 @@ def test_save(self, tmp_path): solution = pybamm.ScipySolver().solve(model, np.linspace(0, 1)) # test save data - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="Solution does not have any data.", + ): solution.save_data(f"{test_stub}.pickle") # set variables first then save diff --git a/tests/unit/test_spatial_methods/test_scikit_finite_element.py b/tests/unit/test_spatial_methods/test_scikit_finite_element.py index 7ccb06502a..bc2a41fcb2 100644 --- a/tests/unit/test_spatial_methods/test_scikit_finite_element.py +++ b/tests/unit/test_spatial_methods/test_scikit_finite_element.py @@ -100,7 +100,10 @@ def test_discretise_equations(self): "positive tab": (pybamm.Scalar(1), "Other BC"), } } - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="boundary condition must be Dirichlet or Neumann, not 'Other BC'", + ): eqn_disc = disc.process_symbol(eqn) disc.bcs = { var: { @@ -108,7 +111,10 @@ def test_discretise_equations(self): "positive tab": (pybamm.Scalar(1), "Neumann"), } } - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="boundary condition must be Dirichlet or Neumann, not 'Other BC'", + ): eqn_disc = disc.process_symbol(eqn) # raise ModelError if no BCs provided diff --git a/tests/unit/test_spatial_methods/test_spectral_volume.py b/tests/unit/test_spatial_methods/test_spectral_volume.py index f333b10d57..c1959d3adc 100644 --- a/tests/unit/test_spatial_methods/test_spectral_volume.py +++ b/tests/unit/test_spatial_methods/test_spectral_volume.py @@ -90,7 +90,10 @@ def get_1p1d_mesh_for_testing( class TestSpectralVolume: def test_exceptions(self): sp_meth = pybamm.SpectralVolume() - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="Too many degrees of differentiation.", + ): sp_meth.chebyshev_differentiation_matrices(3, 3) mesh = get_mesh_for_testing()