From ad17bc671da7957898e78e8be4820ef249592b07 Mon Sep 17 00:00:00 2001 From: rishab Date: Sat, 1 Mar 2025 19:35:03 +0530 Subject: [PATCH 1/7] changes in pyproject.toml --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index bf7e908acd..0dfb2d8c96 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -252,6 +252,9 @@ concurrency = ["multiprocessing"] ignore_missing_imports = true allow_redefinition = true disable_error_code = ["call-overload", "operator"] +strict = false +warn_unreachable = true +enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] [[tool.mypy.overrides]] module = [ From ba437d6b754e549619162595c8730a55e2a64c0a Mon Sep 17 00:00:00 2001 From: rishab Date: Sun, 2 Mar 2025 13:21:34 +0530 Subject: [PATCH 2/7] fixed around 40 mypy errors --- docs/conf.py | 10 +++---- examples/scripts/SPM_compare_particle_grid.py | 2 +- examples/scripts/SPMe_step.py | 21 +++++++------- examples/scripts/heat_equation.py | 2 +- .../minimal_example_of_lookup_tables.py | 4 +-- pyproject.toml | 1 + src/pybamm/experiment/experiment.py | 4 ++- src/pybamm/experiment/step/base_step.py | 2 +- .../expression_tree/binary_operators.py | 9 +++--- src/pybamm/expression_tree/broadcasts.py | 2 +- src/pybamm/expression_tree/concatenations.py | 2 +- .../expression_tree/operations/serialise.py | 7 ++--- src/pybamm/expression_tree/symbol.py | 8 +++-- src/pybamm/expression_tree/unary_operators.py | 4 +-- src/pybamm/expression_tree/variable.py | 10 +++---- src/pybamm/models/base_model.py | 2 ++ src/pybamm/plotting/quick_plot.py | 2 +- src/pybamm/solvers/base_solver.py | 10 +++---- src/pybamm/solvers/solution.py | 2 +- src/pybamm/solvers/summary_variable.py | 29 ++++++++++++------- src/pybamm/telemetry.py | 6 ++-- src/pybamm/util.py | 2 +- .../test_binary_operators.py | 2 +- tests/unit/test_solvers/test_solution.py | 2 +- 24 files changed, 80 insertions(+), 65 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 4e18b1bd92..88e863538d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -107,7 +107,7 @@ exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", ".ipynb_checkpoints"] # Suppress warnings generated by Sphinx and/or by Sphinx extensions -suppress_warnings = [] +suppress_warnings = [] # type: list[str] # -- Options for HTML output ------------------------------------------------- @@ -174,7 +174,7 @@ html_title = f"{project} v{version} Manual" html_last_updated_fmt = "%Y-%m-%d" html_css_files = ["pybamm.css"] -html_context = {"default_mode": "light"} +html_context = {"default_mode": "light"} # type: dict[str, str | bool | None | ParameterSets] html_use_modindex = True html_copy_source = False html_domain_indices = False @@ -195,7 +195,7 @@ ) # Set canonical URL from the Read the Docs Domain -html_baseurl = os.getenv("READTHEDOCS_CANONICAL_URL", "") +html_baseurl = os.getenv("READTHEDOCS_CANONICAL_URL", "") # type: str # Tell Jinja2 templates the build is running on Read the Docs if os.getenv("READTHEDOCS") == "True": @@ -231,7 +231,7 @@ # Latex figure (float) alignment # # 'figure_align': 'htbp', -} +} # type: dict[str, str] # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, @@ -321,7 +321,7 @@ # made to a notebook, if any. # On local builds, the version is not set, so we use "latest". -notebooks_version = version +notebooks_version = version # type: str | None append_to_url = f"blob/v{notebooks_version}" if (os.environ.get("READTHEDOCS_VERSION") == "latest") or ( diff --git a/examples/scripts/SPM_compare_particle_grid.py b/examples/scripts/SPM_compare_particle_grid.py index 44a6f84edb..f04da20c7a 100644 --- a/examples/scripts/SPM_compare_particle_grid.py +++ b/examples/scripts/SPM_compare_particle_grid.py @@ -48,7 +48,7 @@ disc.process_model(model) # solve model -solutions = [None] * len(models) +solutions = [None] * len(models) # type: Any t_eval = np.linspace(0, 3600, 100) for i, model in enumerate(models): solutions[i] = model.default_solver.solve(model, t_eval) diff --git a/examples/scripts/SPMe_step.py b/examples/scripts/SPMe_step.py index f277c0e790..56ff16d2b2 100644 --- a/examples/scripts/SPMe_step.py +++ b/examples/scripts/SPMe_step.py @@ -43,13 +43,14 @@ time += dt # plot -time_in_seconds = solution["Time [s]"].entries -step_time_in_seconds = step_solution["Time [s]"].entries -voltage = solution["Voltage [V]"].entries -step_voltage = step_solution["Voltage [V]"].entries -plt.plot(time_in_seconds, voltage, "b-", label="SPMe (continuous solve)") -plt.plot(step_time_in_seconds, step_voltage, "ro", label="SPMe (stepped solve)") -plt.xlabel(r"$t$") -plt.ylabel("Voltage [V]") -plt.legend() -plt.show() +if step_solution is not None: + time_in_seconds = solution["Time [s]"].entries + step_time_in_seconds = step_solution["Time [s]"].entries + voltage = solution["Voltage [V]"].entries + step_voltage = step_solution["Voltage [V]"].entries + plt.plot(time_in_seconds, voltage, "b-", label="SPMe (continuous solve)") + plt.plot(step_time_in_seconds, step_voltage, "ro", label="SPMe (stepped solve)") + plt.xlabel(r"$t$") + plt.ylabel("Voltage [V]") + plt.legend() + plt.show() diff --git a/examples/scripts/heat_equation.py b/examples/scripts/heat_equation.py index fd01b37f97..0a2655e502 100644 --- a/examples/scripts/heat_equation.py +++ b/examples/scripts/heat_equation.py @@ -106,7 +106,7 @@ def T_exact(x, t): # Plot ------------------------------------------------------------------------ x_nodes = mesh["rod"].nodes # numerical gridpoints xx = np.linspace(0, 2, 101) # fine mesh to plot exact solution -plot_times = np.linspace(0, 1, 5) +plot_times = np.linspace(0, 1, 5) # type: np.ndarray plt.figure(figsize=(15, 8)) cmap = plt.get_cmap("inferno") diff --git a/examples/scripts/minimal_example_of_lookup_tables.py b/examples/scripts/minimal_example_of_lookup_tables.py index 335e9961ac..8ceda74b23 100644 --- a/examples/scripts/minimal_example_of_lookup_tables.py +++ b/examples/scripts/minimal_example_of_lookup_tables.py @@ -34,12 +34,12 @@ def process_2D(name, data): D_s_n_data = process_2D("Negative particle diffusivity [m2.s-1]", df) -def D_s_n(sto, T): +def D_s_n_func(sto, T): name, (x, y) = D_s_n_data return pybamm.Interpolant(x, y, [T, sto], name) -parameter_values["Negative particle diffusivity [m2.s-1]"] = D_s_n +parameter_values["Negative particle diffusivity [m2.s-1]"] = D_s_n_func k_n = parameter_values["Negative electrode exchange-current density [A.m-2]"] diff --git a/pyproject.toml b/pyproject.toml index 0dfb2d8c96..f8b3ff66d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -255,6 +255,7 @@ disable_error_code = ["call-overload", "operator"] strict = false warn_unreachable = true enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] +exclude = 'build/' [[tool.mypy.overrides]] module = [ diff --git a/src/pybamm/experiment/experiment.py b/src/pybamm/experiment/experiment.py index ce44457cb2..b9cd4b4d9d 100644 --- a/src/pybamm/experiment/experiment.py +++ b/src/pybamm/experiment/experiment.py @@ -40,7 +40,9 @@ class Experiment: def __init__( self, - operating_conditions: list[str | tuple[str] | BaseStep], + operating_conditions: list[ + str | tuple[str, ...] | tuple[str | BaseStep] | BaseStep + ], period: str | None = None, temperature: float | None = None, termination: list[str] | None = None, diff --git a/src/pybamm/experiment/step/base_step.py b/src/pybamm/experiment/step/base_step.py index 0895cdfaa6..46a367f4f5 100644 --- a/src/pybamm/experiment/step/base_step.py +++ b/src/pybamm/experiment/step/base_step.py @@ -140,7 +140,7 @@ def __init__( self.value = pybamm.Interpolant( t, y, - pybamm.t - pybamm.InputParameter("start time"), + pybamm.t - pybamm.InputParameter("start time"), # type: ignore[arg-type] name="Drive Cycle", ) self.period = np.diff(t).min() diff --git a/src/pybamm/expression_tree/binary_operators.py b/src/pybamm/expression_tree/binary_operators.py index be3df653ad..19c7e0e0f4 100644 --- a/src/pybamm/expression_tree/binary_operators.py +++ b/src/pybamm/expression_tree/binary_operators.py @@ -2,7 +2,6 @@ # Binary operator classes # from __future__ import annotations -import numbers import numpy as np import sympy @@ -33,8 +32,8 @@ def _preprocess_binary( raise ValueError("right must be a 1D array") right = pybamm.Vector(right) - # Check both left and right are pybamm Symbols - if not (isinstance(left, pybamm.Symbol) and isinstance(right, pybamm.Symbol)): + # Check right is pybamm Symbol + if not isinstance(right, pybamm.Symbol): raise NotImplementedError( f"BinaryOperator not implemented for symbols of type {type(left)} and {type(right)}" ) @@ -127,7 +126,7 @@ def create_copy( children = self._children_for_copying(new_children) if not perform_simplifications: - out = self.__class__(children[0], children[1]) + out = self.__class__(*children) else: # creates a new instance using the overloaded binary operator to perform # additional simplifications, rather than just calling the constructor @@ -1538,7 +1537,7 @@ def source( corresponding to a source term in the bulk. """ # Broadcast if left is number - if isinstance(left, numbers.Number): + if isinstance(left, (int, float)): left = pybamm.PrimaryBroadcast(left, "current collector") # force type cast for mypy diff --git a/src/pybamm/expression_tree/broadcasts.py b/src/pybamm/expression_tree/broadcasts.py index 6045c3f3e8..b8fb0ac4fc 100644 --- a/src/pybamm/expression_tree/broadcasts.py +++ b/src/pybamm/expression_tree/broadcasts.py @@ -463,7 +463,7 @@ def __init__( self, child_input: Numeric | pybamm.Symbol, broadcast_domain: DomainType = None, - auxiliary_domains: AuxiliaryDomainType = None, + auxiliary_domains: AuxiliaryDomainType | str = None, broadcast_domains: DomainsType = None, name: str | None = None, ): diff --git a/src/pybamm/expression_tree/concatenations.py b/src/pybamm/expression_tree/concatenations.py index 62adac5265..b04eccc715 100644 --- a/src/pybamm/expression_tree/concatenations.py +++ b/src/pybamm/expression_tree/concatenations.py @@ -473,7 +473,7 @@ def __init__(self, *children, name: Optional[str] = None): if name is None: # Name is the intersection of the children names (should usually make sense # if the children have been named consistently) - name = intersect(children[0].name, children[1].name) + name = intersect(children[0].name, children[1].name) or "" for child in children[2:]: name = intersect(name, child.name) if len(name) == 0: diff --git a/src/pybamm/expression_tree/operations/serialise.py b/src/pybamm/expression_tree/operations/serialise.py index 0507b3304e..6b320f1e91 100644 --- a/src/pybamm/expression_tree/operations/serialise.py +++ b/src/pybamm/expression_tree/operations/serialise.py @@ -20,7 +20,7 @@ def __init__(self): class _SymbolEncoder(json.JSONEncoder): """Converts PyBaMM symbols into a JSON-serialisable format""" - def default(self, node: dict): + def default(self, node: dict | pybamm.Symbol): node_dict = {"py/object": str(type(node))[8:-2], "py/id": id(node)} if isinstance(node, pybamm.Symbol): node_dict.update(node.to_json()) # this doesn't include children @@ -46,7 +46,7 @@ def default(self, node: dict): class _MeshEncoder(json.JSONEncoder): """Converts PyBaMM meshes into a JSON-serialisable format""" - def default(self, node: pybamm.Mesh): + def default(self, node: pybamm.Mesh | pybamm.SubMesh): node_dict = {"py/object": str(type(node))[8:-2], "py/id": id(node)} if isinstance(node, pybamm.Mesh): node_dict.update(node.to_json()) @@ -64,9 +64,6 @@ def default(self, node: pybamm.Mesh): node_dict.update(node.to_json()) return node_dict - node_dict["json"] = json.JSONEncoder.default(self, node) # pragma: no cover - return node_dict # pragma: no cover - class _Empty: """A dummy class to aid deserialisation""" diff --git a/src/pybamm/expression_tree/symbol.py b/src/pybamm/expression_tree/symbol.py index 3f695b768e..bb9081a077 100644 --- a/src/pybamm/expression_tree/symbol.py +++ b/src/pybamm/expression_tree/symbol.py @@ -66,7 +66,9 @@ def create_object_of_size(size: int, typ="vector"): return np.nan * np.ones((size, size)) -def evaluate_for_shape_using_domain(domains: dict[str, list[str] | str], typ="vector"): +def evaluate_for_shape_using_domain( + domains: dict[str, list[str] | str] | list[str], typ="vector" +): """ Return a vector of the appropriate shape, based on the domains. Domain 'sizes' can clash, but are unlikely to, and won't cause failures if they do. @@ -964,7 +966,9 @@ def to_casadi( """ return pybamm.CasadiConverter(casadi_symbols).convert(self, t, y, y_dot, inputs) - def _children_for_copying(self, children: list[Symbol] | None = None) -> Symbol: + def _children_for_copying( + self, children: list[Symbol] | None = None + ) -> list[Symbol]: """ Gets existing children for a symbol being copied if they aren't provided. """ diff --git a/src/pybamm/expression_tree/unary_operators.py b/src/pybamm/expression_tree/unary_operators.py index 85c2bc5c54..b814f2467b 100644 --- a/src/pybamm/expression_tree/unary_operators.py +++ b/src/pybamm/expression_tree/unary_operators.py @@ -8,7 +8,7 @@ import sympy import pybamm from pybamm.util import import_optional_dependency -from pybamm.type_definitions import DomainsType +from pybamm.type_definitions import DomainsType, Numeric class UnaryOperator(pybamm.Symbol): @@ -31,7 +31,7 @@ class UnaryOperator(pybamm.Symbol): def __init__( self, name: str, - child: pybamm.Symbol, + child: pybamm.Symbol | Numeric, domains: DomainsType = None, ): if isinstance(child, (float, int, np.number)): diff --git a/src/pybamm/expression_tree/variable.py b/src/pybamm/expression_tree/variable.py index 4d08686245..062f10b6df 100644 --- a/src/pybamm/expression_tree/variable.py +++ b/src/pybamm/expression_tree/variable.py @@ -61,12 +61,12 @@ def __init__( domains: DomainsType = None, bounds: tuple[pybamm.Symbol] | None = None, print_name: str | None = None, - scale: float | pybamm.Symbol | None = 1, - reference: float | pybamm.Symbol | None = 0, + scale: float | int | pybamm.Symbol | None = 1, + reference: float | int | pybamm.Symbol | None = 0, ): - if isinstance(scale, numbers.Number): + if isinstance(scale, (float, int)): scale = pybamm.Scalar(scale) - if isinstance(reference, numbers.Number): + if isinstance(reference, (float, int)): reference = pybamm.Scalar(reference) self._scale = scale self._reference = reference @@ -88,7 +88,7 @@ def bounds(self): return self._bounds @bounds.setter - def bounds(self, values: tuple[Numeric, Numeric]): + def bounds(self, values: tuple[Numeric, Numeric] | None): if values is None: values = (-np.inf, np.inf) else: diff --git a/src/pybamm/models/base_model.py b/src/pybamm/models/base_model.py index b5670320b4..c8258869f9 100644 --- a/src/pybamm/models/base_model.py +++ b/src/pybamm/models/base_model.py @@ -77,6 +77,8 @@ def __init__(self, name="Unnamed model"): self.use_jacobian = True self.convert_to_format = "casadi" + self.calculate_sensitivities = [] + # Model is not initially discretised self.is_discretised = False self.y_slices = None diff --git a/src/pybamm/plotting/quick_plot.py b/src/pybamm/plotting/quick_plot.py index ee146a2002..6a4fce9a04 100644 --- a/src/pybamm/plotting/quick_plot.py +++ b/src/pybamm/plotting/quick_plot.py @@ -126,7 +126,7 @@ def __init__( # Set colors, linestyles, figsize, axis limits # call LoopList to make sure list index never runs out if colors is None: - self.colors = LoopList(colors or ["r", "b", "k", "g", "m", "c"]) + self.colors = LoopList(["r", "b", "k", "g", "m", "c"]) else: self.colors = LoopList(colors) self.linestyles = LoopList(linestyles or ["-", ":", "--", "-."]) diff --git a/src/pybamm/solvers/base_solver.py b/src/pybamm/solvers/base_solver.py index 49e9b928ae..87ab298ea0 100644 --- a/src/pybamm/solvers/base_solver.py +++ b/src/pybamm/solvers/base_solver.py @@ -94,8 +94,8 @@ def supports_parallel_solve(self): def requires_explicit_sensitivities(self): return True - @root_method.setter - def root_method(self, method): + @root_method.setter # type: ignore[attr-defined, no-redef] + def root_method(self, method) -> None: if method == "casadi": method = pybamm.CasadiAlgebraicSolver(self.root_tol) elif isinstance(method, str): @@ -1122,7 +1122,7 @@ def _set_sens_initial_conditions_from( """ ninputs = len(model.calculate_sensitivities) - initial_conditions = tuple([] for _ in range(ninputs)) + initial_conditions = tuple([] for _ in range(ninputs)) # type: tuple solution = solution.last_state for var in model.initial_conditions: final_state = solution[var.name] @@ -1143,10 +1143,10 @@ def _set_sens_initial_conditions_from( slices = [y_slices[symbol][0] for symbol in model.initial_conditions.keys()] # sort equations according to slices - concatenated_initial_conditions = [ + concatenated_initial_conditions = tuple( casadi.vertcat(*[eq for _, eq in sorted(zip(slices, init))]) for init in initial_conditions - ] + ) return concatenated_initial_conditions def process_t_interp(self, t_interp): diff --git a/src/pybamm/solvers/solution.py b/src/pybamm/solvers/solution.py index 4f17c60d94..d96a667344 100644 --- a/src/pybamm/solvers/solution.py +++ b/src/pybamm/solvers/solution.py @@ -160,7 +160,7 @@ def __init__( def has_sensitivities(self) -> bool: if isinstance(self._all_sensitivities, bool): return self._all_sensitivities - elif isinstance(self._all_sensitivities, dict): + else: return len(self._all_sensitivities) > 0 def extract_explicit_sensitivities(self): diff --git a/src/pybamm/solvers/summary_variable.py b/src/pybamm/solvers/summary_variable.py index 4c3da92a42..85a1d778f2 100644 --- a/src/pybamm/solvers/summary_variable.py +++ b/src/pybamm/solvers/summary_variable.py @@ -4,7 +4,7 @@ from __future__ import annotations import pybamm import numpy as np -from typing import Any +from typing import Any, cast class SummaryVariables: @@ -40,8 +40,11 @@ def __init__( ): self.user_inputs = user_inputs or {} self.esoh_solver = esoh_solver - self._variables = {} # Store computed variables + + # Store computed variables + self._variables = {} # type: dict[str, float | list[float]] self.cycle_number = np.array([]) + self.cycles: list[SummaryVariables] | None = None model = solution.all_models[0] self._possible_variables = model.summary_variables # minus esoh variables @@ -69,7 +72,7 @@ def _initialize_for_cycles(self, cycle_summary_variables: list[SummaryVariables] self.first_state = None self.last_state = None self.cycles = cycle_summary_variables - self.cycle_number = np.arange(1, len(self.cycles) + 1) + self.cycle_number = np.arange(1, len(self.cycles) + 1, dtype=float) first_cycle = self.cycles[0] self.calc_esoh = first_cycle.calc_esoh self.esoh_solver = first_cycle.esoh_solver @@ -99,7 +102,11 @@ def all_variables(self) -> list[str]: @property def esoh_variables(self) -> list[str] | None: """Return names of all eSOH variables.""" - if self.calc_esoh and self._esoh_variables is None: + if ( + self.esoh_solver is not None + and self.calc_esoh + and self._esoh_variables is None + ): esoh_model = self.esoh_solver._get_electrode_soh_sims_full().model esoh_vars = list(esoh_model.variables.keys()) self._esoh_variables = esoh_vars @@ -123,7 +130,7 @@ def __getitem__(self, key: str) -> float | list[float]: # return it if it exists return self._variables[key] elif key == "Cycle number": - return self.cycle_number + return cast(list[float], self.cycle_number.tolist()) elif key not in self.all_variables: # check it's listed as a summary variable raise KeyError(f"Variable '{key}' is not a summary variable.") @@ -148,10 +155,11 @@ def update(self, var: str): def _update_multiple_cycles(self, var: str, var_lowercase: str): """Creates aggregated summary variables for where more than one cycle exists.""" - var_cycle = [cycle[var] for cycle in self.cycles] - change_var_cycle = [ - cycle[f"Change in {var_lowercase}"] for cycle in self.cycles - ] + cycles = cast(list[SummaryVariables], self.cycles) + var_cycle = cast(list[float], [cycle[var] for cycle in cycles]) + change_var_cycle = cast( + list[float], [cycle[f"Change in {var_lowercase}"] for cycle in cycles] + ) self._variables[var] = var_cycle self._variables[f"Change in {var_lowercase}"] = change_var_cycle @@ -180,8 +188,9 @@ def _get_esoh_variables(self) -> dict[str, float]: Q_p = self.last_state["Positive electrode capacity [A.h]"].data[0] Q_Li = self.last_state["Total lithium capacity in particles [A.h]"].data[0] all_inputs = {**self.user_inputs, "Q_n": Q_n, "Q_p": Q_p, "Q_Li": Q_Li} + esoh_solver = cast(pybamm.lithium_ion.ElectrodeSOHSolver, self.esoh_solver) try: - esoh_sol = self.esoh_solver.solve(inputs=all_inputs) + esoh_sol = esoh_solver.solve(inputs=all_inputs) except pybamm.SolverError as error: # pragma: no cover raise pybamm.SolverError( "Could not solve for eSOH summary variables" diff --git a/src/pybamm/telemetry.py b/src/pybamm/telemetry.py index 3825738d47..d3806492c9 100644 --- a/src/pybamm/telemetry.py +++ b/src/pybamm/telemetry.py @@ -12,15 +12,15 @@ def capture(**kwargs): # pragma: no cover pass -if pybamm.config.check_opt_out(): - _posthog = MockTelemetry() -else: # pragma: no cover +if not pybamm.config.check_opt_out(): _posthog = Posthog( # this is the public, write only API key, so it's ok to include it here project_api_key="phc_acTt7KxmvBsAxaE0NyRd5WfJyNxGvBq1U9HnlQSztmb", host="https://us.i.posthog.com", ) _posthog.log.setLevel("CRITICAL") +else: # pragma: no cover + _posthog = MockTelemetry() def disable(): diff --git a/src/pybamm/util.py b/src/pybamm/util.py index dcab37b0dc..8143558248 100644 --- a/src/pybamm/util.py +++ b/src/pybamm/util.py @@ -154,7 +154,7 @@ def search( if not isinstance(keys, (str, list)) or not all( isinstance(k, str) for k in keys - ): + ): # type: ignore[redundant-expr] msg = f"'keys' must be a string or a list of strings, got {type(keys)}" raise TypeError(msg) diff --git a/tests/unit/test_expression_tree/test_binary_operators.py b/tests/unit/test_expression_tree/test_binary_operators.py index eba3ca1bbd..7bfcb9f83b 100644 --- a/tests/unit/test_expression_tree/test_binary_operators.py +++ b/tests/unit/test_expression_tree/test_binary_operators.py @@ -16,7 +16,7 @@ "secondary": [], "tertiary": [], "quaternary": [], -} +} # type: dict[str, list[str]] class TestBinaryOperators: diff --git a/tests/unit/test_solvers/test_solution.py b/tests/unit/test_solvers/test_solution.py index 6f7460dc4d..42ea0b011d 100644 --- a/tests/unit/test_solvers/test_solution.py +++ b/tests/unit/test_solvers/test_solution.py @@ -257,7 +257,7 @@ def test_copy_with_computed_variables(self): assert ( sol1._variables[k] == sol2._variables[k] for k in sol1._variables.keys() - ) + ) is not None assert sol2.variables_returned is True def test_last_state(self): From 77863b879058f298445e81d0ef7695f04a24d044 Mon Sep 17 00:00:00 2001 From: rishab Date: Mon, 3 Mar 2025 14:00:11 +0530 Subject: [PATCH 3/7] fixed remaining issues --- CHANGELOG.md | 1 + docs/conf.py | 10 ++--- examples/scripts/SPM_compare_particle_grid.py | 3 +- examples/scripts/heat_equation.py | 2 +- pyproject.toml | 2 +- .../expression_tree/binary_operators.py | 41 ++++++++++++++++++- src/pybamm/expression_tree/broadcasts.py | 15 ++++++- src/pybamm/expression_tree/concatenations.py | 4 +- src/pybamm/expression_tree/functions.py | 5 ++- .../expression_tree/operations/serialise.py | 8 +++- src/pybamm/expression_tree/parameter.py | 6 ++- src/pybamm/solvers/base_solver.py | 2 +- src/pybamm/solvers/idaklu_jax.py | 8 ++-- .../processed_variable_time_integral.py | 2 +- src/pybamm/solvers/summary_variable.py | 26 ++++++------ src/pybamm/telemetry.py | 9 ++-- src/pybamm/util.py | 4 +- .../test_binary_operators.py | 4 +- 18 files changed, 104 insertions(+), 48 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d2938a9781..ea9df8b7f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ - Moved concentration inside x-averaged when calculating LLI due to LAM variables ([#4858](https://github.com/pybamm-team/PyBaMM/pull/4858)) - Fixed a bug that caused the variable `"Loss of lithium due to {domain} lithium plating"`to have the domain `"current collector"` (should not have any domain at all) if the `"x-average side reactions"` option was set to `"true"`. ([#4844](https://github.com/pybamm-team/PyBaMM/pull/4844)) - Fixed interpolation bug in `pybamm.QuickPlot` with spatial variables. ([#4841](https://github.com/pybamm-team/PyBaMM/pull/4841)) +- Fixed mypy sp check guidelines ([#4887](https://github.com/pybamm-team/PyBaMM/pull/4887)) ## Optimizations diff --git a/docs/conf.py b/docs/conf.py index 88e863538d..4e18b1bd92 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -107,7 +107,7 @@ exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", ".ipynb_checkpoints"] # Suppress warnings generated by Sphinx and/or by Sphinx extensions -suppress_warnings = [] # type: list[str] +suppress_warnings = [] # -- Options for HTML output ------------------------------------------------- @@ -174,7 +174,7 @@ html_title = f"{project} v{version} Manual" html_last_updated_fmt = "%Y-%m-%d" html_css_files = ["pybamm.css"] -html_context = {"default_mode": "light"} # type: dict[str, str | bool | None | ParameterSets] +html_context = {"default_mode": "light"} html_use_modindex = True html_copy_source = False html_domain_indices = False @@ -195,7 +195,7 @@ ) # Set canonical URL from the Read the Docs Domain -html_baseurl = os.getenv("READTHEDOCS_CANONICAL_URL", "") # type: str +html_baseurl = os.getenv("READTHEDOCS_CANONICAL_URL", "") # Tell Jinja2 templates the build is running on Read the Docs if os.getenv("READTHEDOCS") == "True": @@ -231,7 +231,7 @@ # Latex figure (float) alignment # # 'figure_align': 'htbp', -} # type: dict[str, str] +} # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, @@ -321,7 +321,7 @@ # made to a notebook, if any. # On local builds, the version is not set, so we use "latest". -notebooks_version = version # type: str | None +notebooks_version = version append_to_url = f"blob/v{notebooks_version}" if (os.environ.get("READTHEDOCS_VERSION") == "latest") or ( diff --git a/examples/scripts/SPM_compare_particle_grid.py b/examples/scripts/SPM_compare_particle_grid.py index f04da20c7a..a27bb202c4 100644 --- a/examples/scripts/SPM_compare_particle_grid.py +++ b/examples/scripts/SPM_compare_particle_grid.py @@ -2,6 +2,7 @@ # Compare different discretisations in the particle # import argparse +from typing import Any import numpy as np import pybamm import matplotlib.pyplot as plt @@ -48,7 +49,7 @@ disc.process_model(model) # solve model -solutions = [None] * len(models) # type: Any +solutions: Any = [None] * len(models) t_eval = np.linspace(0, 3600, 100) for i, model in enumerate(models): solutions[i] = model.default_solver.solve(model, t_eval) diff --git a/examples/scripts/heat_equation.py b/examples/scripts/heat_equation.py index 0a2655e502..5c19fb3939 100644 --- a/examples/scripts/heat_equation.py +++ b/examples/scripts/heat_equation.py @@ -106,7 +106,7 @@ def T_exact(x, t): # Plot ------------------------------------------------------------------------ x_nodes = mesh["rod"].nodes # numerical gridpoints xx = np.linspace(0, 2, 101) # fine mesh to plot exact solution -plot_times = np.linspace(0, 1, 5) # type: np.ndarray +plot_times: np.ndarray = np.linspace(0, 1, 5) plt.figure(figsize=(15, 8)) cmap = plt.get_cmap("inferno") diff --git a/pyproject.toml b/pyproject.toml index f8b3ff66d0..2d3e28bc7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -255,7 +255,7 @@ disable_error_code = ["call-overload", "operator"] strict = false warn_unreachable = true enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] -exclude = 'build/' +exclude = "^(build/|docs/conf\\.py)$" [[tool.mypy.overrides]] module = [ diff --git a/src/pybamm/expression_tree/binary_operators.py b/src/pybamm/expression_tree/binary_operators.py index 19c7e0e0f4..1824ae0906 100644 --- a/src/pybamm/expression_tree/binary_operators.py +++ b/src/pybamm/expression_tree/binary_operators.py @@ -112,6 +112,9 @@ def __str__(self): right_str = f"{self.right!s}" return f"{left_str} {self.name} {right_str}" + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return self.__class__(self.name, left, right) + def create_copy( self, new_children: list[pybamm.Symbol] | None = None, @@ -126,7 +129,7 @@ def create_copy( children = self._children_for_copying(new_children) if not perform_simplifications: - out = self.__class__(*children) + out = self._new_instance(children[0], children[1]) else: # creates a new instance using the overloaded binary operator to perform # additional simplifications, rather than just calling the constructor @@ -223,6 +226,9 @@ def __init__( """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("**", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Power(left, right) + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" # apply chain rule and power rule @@ -272,6 +278,9 @@ def __init__( """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("+", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Addition(left, right) + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" return self.left.diff(variable) + self.right.diff(variable) @@ -299,6 +308,9 @@ def __init__( super().__init__("-", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Subtraction(left, right) + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" return self.left.diff(variable) - self.right.diff(variable) @@ -328,6 +340,9 @@ def __init__( super().__init__("*", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Multiplication(left, right) + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" # apply product rule @@ -368,6 +383,9 @@ def __init__( """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("@", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return MatrixMultiplication(left, right) + def diff(self, variable): """See :meth:`pybamm.Symbol.diff()`.""" # We shouldn't need this @@ -417,6 +435,9 @@ def __init__( """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("/", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Division(left, right) + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" # apply quotient rule @@ -465,6 +486,9 @@ def __init__( """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("inner product", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Inner(left, right) + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" # apply product rule @@ -542,6 +566,9 @@ def __init__( """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("==", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Equality(left, right) + def diff(self, variable): """See :meth:`pybamm.Symbol.diff()`.""" # Equality should always be multiplied by something else so hopefully don't @@ -600,6 +627,9 @@ def __init__( """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__(name, left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Equality(left, right) + def diff(self, variable): """See :meth:`pybamm.Symbol.diff()`.""" # Heaviside should always be multiplied by something else so hopefully don't @@ -677,6 +707,9 @@ def __init__( ): super().__init__("%", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Equality(left, right) + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" # apply chain rule and power rule @@ -719,6 +752,9 @@ def __init__( ): super().__init__("minimum", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Equality(left, right) + def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" return f"minimum({self.left!s}, {self.right!s})" @@ -763,6 +799,9 @@ def __init__( ): super().__init__("maximum", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Equality(left, right) + def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" return f"maximum({self.left!s}, {self.right!s})" diff --git a/src/pybamm/expression_tree/broadcasts.py b/src/pybamm/expression_tree/broadcasts.py index b8fb0ac4fc..8b1d02d3fc 100644 --- a/src/pybamm/expression_tree/broadcasts.py +++ b/src/pybamm/expression_tree/broadcasts.py @@ -78,8 +78,7 @@ def _from_json(cls, snippet): ) def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True): - """See :meth:`pybamm.UnaryOperator._unary_new_copy()`.""" - return self.__class__(child, self.broadcast_domain) + pass class PrimaryBroadcast(Broadcast): @@ -191,6 +190,10 @@ def reduce_one_dimension(self): """Reduce the broadcast by one dimension.""" return self.orphans[0] + def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True): + """See :meth:`pybamm.UnaryOperator._unary_new_copy()`.""" + return self.__class__(child, self.broadcast_domain) + class PrimaryBroadcastToEdges(PrimaryBroadcast): """A primary broadcast onto the edges of the domain.""" @@ -321,6 +324,10 @@ def reduce_one_dimension(self): """Reduce the broadcast by one dimension.""" return self.orphans[0] + def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True): + """See :meth:`pybamm.UnaryOperator._unary_new_copy()`.""" + return self.__class__(child, self.broadcast_domain) + class SecondaryBroadcastToEdges(SecondaryBroadcast): """A secondary broadcast onto the edges of a domain.""" @@ -438,6 +445,10 @@ def reduce_one_dimension(self): """Reduce the broadcast by one dimension.""" raise NotImplementedError + def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True): + """See :meth:`pybamm.UnaryOperator._unary_new_copy()`.""" + return self.__class__(child, self.broadcast_domain) + class TertiaryBroadcastToEdges(TertiaryBroadcast): """A tertiary broadcast onto the edges of a domain.""" diff --git a/src/pybamm/expression_tree/concatenations.py b/src/pybamm/expression_tree/concatenations.py index b04eccc715..de1761c2fa 100644 --- a/src/pybamm/expression_tree/concatenations.py +++ b/src/pybamm/expression_tree/concatenations.py @@ -514,7 +514,7 @@ def substrings(s: str): yield s[i : j + 1] -def intersect(s1: str, s2: str): +def intersect(s1: str, s2: str) -> str: # find all the common strings between two strings all_intersects = set(substrings(s1)) & set(substrings(s2)) # intersect is the longest such intercept @@ -525,7 +525,7 @@ def intersect(s1: str, s2: str): return intersect.lstrip().rstrip() -def simplified_concatenation(*children, name: Optional[str] = None): +def simplified_concatenation(*children, name=None): """Perform simplifications on a concatenation.""" # remove children that are None children = list(filter(lambda x: x is not None, children)) diff --git a/src/pybamm/expression_tree/functions.py b/src/pybamm/expression_tree/functions.py index 4e087e9725..80ba241a94 100644 --- a/src/pybamm/expression_tree/functions.py +++ b/src/pybamm/expression_tree/functions.py @@ -6,7 +6,7 @@ import numpy as np from scipy import special import sympy -from typing import Callable +from typing import Callable, cast from collections.abc import Sequence from typing_extensions import TypeVar @@ -32,7 +32,7 @@ class Function(pybamm.Symbol): def __init__( self, function: Callable, - *children: pybamm.Symbol, + *children: pybamm.Symbol | float | int, name: str | None = None, differentiated_function: Callable | None = None, ): @@ -42,6 +42,7 @@ def __init__( if isinstance(child, (float, int, np.number)): children[idx] = pybamm.Scalar(child) + children = cast(Sequence[pybamm.Symbol], children) if name is not None: self.name = name else: diff --git a/src/pybamm/expression_tree/operations/serialise.py b/src/pybamm/expression_tree/operations/serialise.py index 6b320f1e91..153a9f52f8 100644 --- a/src/pybamm/expression_tree/operations/serialise.py +++ b/src/pybamm/expression_tree/operations/serialise.py @@ -1,4 +1,5 @@ from __future__ import annotations +from typing import Any import pybamm from datetime import datetime @@ -21,7 +22,10 @@ class _SymbolEncoder(json.JSONEncoder): """Converts PyBaMM symbols into a JSON-serialisable format""" def default(self, node: dict | pybamm.Symbol): - node_dict = {"py/object": str(type(node))[8:-2], "py/id": id(node)} + node_dict: dict[str, Any] = { + "py/object": str(type(node))[8:-2], + "py/id": id(node), + } if isinstance(node, pybamm.Symbol): node_dict.update(node.to_json()) # this doesn't include children node_dict["children"] = [] @@ -61,7 +65,7 @@ def default(self, node: pybamm.Mesh | pybamm.SubMesh): return node_dict if isinstance(node, pybamm.SubMesh): - node_dict.update(node.to_json()) + node_dict.update(node.to_json()) # type: ignore[attr-defined] return node_dict class _Empty: diff --git a/src/pybamm/expression_tree/parameter.py b/src/pybamm/expression_tree/parameter.py index 14560da0b8..176a775443 100644 --- a/src/pybamm/expression_tree/parameter.py +++ b/src/pybamm/expression_tree/parameter.py @@ -5,7 +5,8 @@ import sys import numpy as np -from typing import Literal +from typing import Literal, cast +from collections.abc import Sequence import sympy @@ -97,7 +98,7 @@ class FunctionParameter(pybamm.Symbol): def __init__( self, name: str, - inputs: dict[str, pybamm.Symbol], + inputs: dict[str, pybamm.Symbol | float | int], diff_variable: pybamm.Symbol | None = None, print_name="calculate", ) -> None: @@ -110,6 +111,7 @@ def __init__( if isinstance(child, (float, int, np.number)): children_list[idx] = pybamm.Scalar(child) + children_list = cast(Sequence[pybamm.Symbol], children_list) domains = self.get_children_domains(children_list) super().__init__(name, children=children_list, domains=domains) diff --git a/src/pybamm/solvers/base_solver.py b/src/pybamm/solvers/base_solver.py index 87ab298ea0..3c4014323b 100644 --- a/src/pybamm/solvers/base_solver.py +++ b/src/pybamm/solvers/base_solver.py @@ -1122,7 +1122,7 @@ def _set_sens_initial_conditions_from( """ ninputs = len(model.calculate_sensitivities) - initial_conditions = tuple([] for _ in range(ninputs)) # type: tuple + initial_conditions: tuple = tuple([] for _ in range(ninputs)) solution = solution.last_state for var in model.initial_conditions: final_state = solution[var.name] diff --git a/src/pybamm/solvers/idaklu_jax.py b/src/pybamm/solvers/idaklu_jax.py index df6a056750..50208e1bf8 100644 --- a/src/pybamm/solvers/idaklu_jax.py +++ b/src/pybamm/solvers/idaklu_jax.py @@ -258,7 +258,7 @@ def f_isolated(*args, **kwargs): def jax_value( self, - t: np.ndarray = None, + t: Union[np.ndarray, None] = None, inputs: Union[dict, None] = None, output_variables: Union[list[str], None] = None, ): @@ -291,7 +291,7 @@ def jax_value( def jax_grad( self, - t: np.ndarray = None, + t: Union[np.ndarray, None] = None, inputs: Union[dict, None] = None, output_variables: Union[list[str], None] = None, ): @@ -464,13 +464,11 @@ def _jax_vjp_impl( logger.debug(f" py:invar: {type(invar)}, {invar}") logger.debug(f" py:primals: {type(primals)}, {primals}") - t = primals[0] + t = np.asarray(primals[0]) inputs = primals[1:] if isinstance(invar, float): invar = round(invar) - if isinstance(t, float): - t = np.array(t) if t.ndim == 0 or (t.ndim == 1 and t.shape[0] == 1): # scalar time input diff --git a/src/pybamm/solvers/processed_variable_time_integral.py b/src/pybamm/solvers/processed_variable_time_integral.py index 4fcdfb56ba..f801602d98 100644 --- a/src/pybamm/solvers/processed_variable_time_integral.py +++ b/src/pybamm/solvers/processed_variable_time_integral.py @@ -7,7 +7,7 @@ @dataclass class ProcessedVariableTimeIntegral: method: Literal["discrete", "continuous"] - initial_condition: np.ndarray + initial_condition: np.ndarray | float discrete_times: Optional[np.ndarray] @staticmethod diff --git a/src/pybamm/solvers/summary_variable.py b/src/pybamm/solvers/summary_variable.py index 85a1d778f2..2594ce3f67 100644 --- a/src/pybamm/solvers/summary_variable.py +++ b/src/pybamm/solvers/summary_variable.py @@ -42,13 +42,13 @@ def __init__( self.esoh_solver = esoh_solver # Store computed variables - self._variables = {} # type: dict[str, float | list[float]] + self._variables: dict[str, float | list[float]] = {} self.cycle_number = np.array([]) self.cycles: list[SummaryVariables] | None = None - + self._all_variables: list[str] | None = None model = solution.all_models[0] self._possible_variables = model.summary_variables # minus esoh variables - self._esoh_variables = None # Store eSOH variable names + self._esoh_variables: list[str] | None = None # Store eSOH variable names # Flag if eSOH calculations are needed self.calc_esoh = ( @@ -84,20 +84,18 @@ def all_variables(self) -> list[str]: Return names of all possible summary variables, including eSOH variables if appropriate. """ - try: + if self._all_variables is not None: return self._all_variables - except AttributeError: - base_vars = self._possible_variables.copy() - base_vars.extend( - f"Change in {var[0].lower() + var[1:]}" - for var in self._possible_variables - ) + base_vars = self._possible_variables.copy() + base_vars.extend( + f"Change in {var[0].lower() + var[1:]}" for var in self._possible_variables + ) - if self.calc_esoh: - base_vars.extend(self.esoh_variables) + if self.calc_esoh: + base_vars.extend(self.esoh_variables) - self._all_variables = base_vars - return self._all_variables + self._all_variables = cast(list[str], base_vars) + return self._all_variables @property def esoh_variables(self) -> list[str] | None: diff --git a/src/pybamm/telemetry.py b/src/pybamm/telemetry.py index d3806492c9..ac5103139b 100644 --- a/src/pybamm/telemetry.py +++ b/src/pybamm/telemetry.py @@ -1,3 +1,4 @@ +from typing import cast from posthog import Posthog import pybamm import sys @@ -12,15 +13,15 @@ def capture(**kwargs): # pragma: no cover pass -if not pybamm.config.check_opt_out(): +if pybamm.config.check_opt_out(): + _posthog = MockTelemetry() +else: # pragma: no cover _posthog = Posthog( # this is the public, write only API key, so it's ok to include it here project_api_key="phc_acTt7KxmvBsAxaE0NyRd5WfJyNxGvBq1U9HnlQSztmb", host="https://us.i.posthog.com", ) - _posthog.log.setLevel("CRITICAL") -else: # pragma: no cover - _posthog = MockTelemetry() + cast(Posthog, _posthog).log.setLevel("CRITICAL") def disable(): diff --git a/src/pybamm/util.py b/src/pybamm/util.py index 8143558248..b1a7c0dd80 100644 --- a/src/pybamm/util.py +++ b/src/pybamm/util.py @@ -152,9 +152,9 @@ def search( Default is 0.4 """ - if not isinstance(keys, (str, list)) or not all( + if not isinstance(keys, (str, list)) or not all( # type: ignore[redundant-expr] isinstance(k, str) for k in keys - ): # type: ignore[redundant-expr] + ): msg = f"'keys' must be a string or a list of strings, got {type(keys)}" raise TypeError(msg) diff --git a/tests/unit/test_expression_tree/test_binary_operators.py b/tests/unit/test_expression_tree/test_binary_operators.py index 7bfcb9f83b..ceed90fda2 100644 --- a/tests/unit/test_expression_tree/test_binary_operators.py +++ b/tests/unit/test_expression_tree/test_binary_operators.py @@ -11,12 +11,12 @@ import pybamm import sympy -EMPTY_DOMAINS = { +EMPTY_DOMAINS: dict[str, list[str]] = { "primary": [], "secondary": [], "tertiary": [], "quaternary": [], -} # type: dict[str, list[str]] +} class TestBinaryOperators: From 541382819deabcbb01a03fa1de60ee29345d6138 Mon Sep 17 00:00:00 2001 From: rishab Date: Mon, 3 Mar 2025 18:43:37 +0530 Subject: [PATCH 4/7] fixed issue in py 3.9 and modified tests --- src/pybamm/expression_tree/binary_operators.py | 14 +++++++------- src/pybamm/expression_tree/broadcasts.py | 2 +- .../solvers/processed_variable_time_integral.py | 2 +- .../test_operations/test_copy.py | 6 +++++- 4 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/pybamm/expression_tree/binary_operators.py b/src/pybamm/expression_tree/binary_operators.py index 1824ae0906..d58b211ae5 100644 --- a/src/pybamm/expression_tree/binary_operators.py +++ b/src/pybamm/expression_tree/binary_operators.py @@ -113,7 +113,7 @@ def __str__(self): return f"{left_str} {self.name} {right_str}" def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: - return self.__class__(self.name, left, right) + return self.__class__(self.name, left, right) # pragma: no cover def create_copy( self, @@ -384,7 +384,7 @@ def __init__( super().__init__("@", left, right) def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: - return MatrixMultiplication(left, right) + return MatrixMultiplication(left, right) # pragma: no cover def diff(self, variable): """See :meth:`pybamm.Symbol.diff()`.""" @@ -487,7 +487,7 @@ def __init__( super().__init__("inner product", left, right) def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: - return Inner(left, right) + return Inner(left, right) # pragma: no cover def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" @@ -628,7 +628,7 @@ def __init__( super().__init__(name, left, right) def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: - return Equality(left, right) + return _Heaviside(left, right) # pragma: no cover def diff(self, variable): """See :meth:`pybamm.Symbol.diff()`.""" @@ -708,7 +708,7 @@ def __init__( super().__init__("%", left, right) def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: - return Equality(left, right) + return Modulo(left, right) def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" @@ -753,7 +753,7 @@ def __init__( super().__init__("minimum", left, right) def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: - return Equality(left, right) + return Minimum(left, right) def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" @@ -800,7 +800,7 @@ def __init__( super().__init__("maximum", left, right) def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: - return Equality(left, right) + return Maximum(left, right) def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" diff --git a/src/pybamm/expression_tree/broadcasts.py b/src/pybamm/expression_tree/broadcasts.py index 8b1d02d3fc..1fabef127c 100644 --- a/src/pybamm/expression_tree/broadcasts.py +++ b/src/pybamm/expression_tree/broadcasts.py @@ -78,7 +78,7 @@ def _from_json(cls, snippet): ) def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True): - pass + pass # pragma: no cover class PrimaryBroadcast(Broadcast): diff --git a/src/pybamm/solvers/processed_variable_time_integral.py b/src/pybamm/solvers/processed_variable_time_integral.py index f801602d98..124cc3e407 100644 --- a/src/pybamm/solvers/processed_variable_time_integral.py +++ b/src/pybamm/solvers/processed_variable_time_integral.py @@ -7,7 +7,7 @@ @dataclass class ProcessedVariableTimeIntegral: method: Literal["discrete", "continuous"] - initial_condition: np.ndarray | float + initial_condition: Union[np.ndarray, float] discrete_times: Optional[np.ndarray] @staticmethod diff --git a/tests/unit/test_expression_tree/test_operations/test_copy.py b/tests/unit/test_expression_tree/test_operations/test_copy.py index f0d59a1fe1..67a884771c 100644 --- a/tests/unit/test_expression_tree/test_operations/test_copy.py +++ b/tests/unit/test_expression_tree/test_operations/test_copy.py @@ -79,6 +79,7 @@ def test_symbol_create_copy_new_children(self): a * b, a / b, a**b, + b % a, pybamm.minimum(a, b), pybamm.maximum(a, b), pybamm.Equality(a, b), @@ -89,12 +90,15 @@ def test_symbol_create_copy_new_children(self): b * a, b / a, b**a, + b % a, pybamm.minimum(b, a), pybamm.maximum(b, a), pybamm.Equality(b, a), ], ): - new_symbol = symbol_ab.create_copy(new_children=[b, a]) + new_symbol = symbol_ab.create_copy( + new_children=[b, a], perform_simplifications=False + ) assert new_symbol == symbol_ba assert new_symbol.print_name == symbol_ba.print_name From f7beff0c7e922799ef1d50f479304c27c103817f Mon Sep 17 00:00:00 2001 From: Rishab Kumar Jha Date: Mon, 3 Mar 2025 19:33:39 +0530 Subject: [PATCH 5/7] Update CHANGELOG.md Co-authored-by: Saransh Chopra --- CHANGELOG.md | 1 - 1 file changed, 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ea9df8b7f4..d2938a9781 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,7 +21,6 @@ - Moved concentration inside x-averaged when calculating LLI due to LAM variables ([#4858](https://github.com/pybamm-team/PyBaMM/pull/4858)) - Fixed a bug that caused the variable `"Loss of lithium due to {domain} lithium plating"`to have the domain `"current collector"` (should not have any domain at all) if the `"x-average side reactions"` option was set to `"true"`. ([#4844](https://github.com/pybamm-team/PyBaMM/pull/4844)) - Fixed interpolation bug in `pybamm.QuickPlot` with spatial variables. ([#4841](https://github.com/pybamm-team/PyBaMM/pull/4841)) -- Fixed mypy sp check guidelines ([#4887](https://github.com/pybamm-team/PyBaMM/pull/4887)) ## Optimizations From fba51e24db3474a86bd7e0d1a90db45f7585d407 Mon Sep 17 00:00:00 2001 From: rishab Date: Mon, 3 Mar 2025 20:07:04 +0530 Subject: [PATCH 6/7] minor changes --- src/pybamm/expression_tree/binary_operators.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pybamm/expression_tree/binary_operators.py b/src/pybamm/expression_tree/binary_operators.py index d58b211ae5..2452c47be8 100644 --- a/src/pybamm/expression_tree/binary_operators.py +++ b/src/pybamm/expression_tree/binary_operators.py @@ -626,9 +626,10 @@ def __init__( ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__(name, left, right) + self.name = name def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: - return _Heaviside(left, right) # pragma: no cover + return _Heaviside(self.name, left, right) # pragma: no cover def diff(self, variable): """See :meth:`pybamm.Symbol.diff()`.""" From e1b1aec26336786fe3afd7e9fd0d15c87e3aa372 Mon Sep 17 00:00:00 2001 From: rishab Date: Sun, 9 Mar 2025 10:39:25 +0530 Subject: [PATCH 7/7] minor changes --- examples/scripts/SPM_compare_particle_grid.py | 2 +- examples/scripts/heat_equation.py | 3 ++- src/pybamm/expression_tree/binary_operators.py | 4 ++-- src/pybamm/solvers/processed_variable_time_integral.py | 2 +- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/scripts/SPM_compare_particle_grid.py b/examples/scripts/SPM_compare_particle_grid.py index a27bb202c4..0fa7cdb6c1 100644 --- a/examples/scripts/SPM_compare_particle_grid.py +++ b/examples/scripts/SPM_compare_particle_grid.py @@ -49,7 +49,7 @@ disc.process_model(model) # solve model -solutions: Any = [None] * len(models) +solutions: list[Any] = [None] * len(models) t_eval = np.linspace(0, 3600, 100) for i, model in enumerate(models): solutions[i] = model.default_solver.solve(model, t_eval) diff --git a/examples/scripts/heat_equation.py b/examples/scripts/heat_equation.py index 5c19fb3939..4c2ac99ca4 100644 --- a/examples/scripts/heat_equation.py +++ b/examples/scripts/heat_equation.py @@ -5,6 +5,7 @@ import pybamm import numpy as np import matplotlib.pyplot as plt +import numpy.typing as npt # Numerical solution ---------------------------------------------------------- @@ -106,7 +107,7 @@ def T_exact(x, t): # Plot ------------------------------------------------------------------------ x_nodes = mesh["rod"].nodes # numerical gridpoints xx = np.linspace(0, 2, 101) # fine mesh to plot exact solution -plot_times: np.ndarray = np.linspace(0, 1, 5) +plot_times: npt.NDArray = np.linspace(0, 1, 5) plt.figure(figsize=(15, 8)) cmap = plt.get_cmap("inferno") diff --git a/src/pybamm/expression_tree/binary_operators.py b/src/pybamm/expression_tree/binary_operators.py index c0dd6017bd..4dd82a2c71 100644 --- a/src/pybamm/expression_tree/binary_operators.py +++ b/src/pybamm/expression_tree/binary_operators.py @@ -33,8 +33,8 @@ def _preprocess_binary( raise ValueError("right must be a 1D array") right = pybamm.Vector(right) - # Check right is pybamm Symbol - if not isinstance(right, pybamm.Symbol): + # Check both left and right are pybamm Symbols + if not (isinstance(left, pybamm.Symbol) and isinstance(right, pybamm.Symbol)): # type: ignore[redundant-expr] raise NotImplementedError( f"BinaryOperator not implemented for symbols of type {type(left)} and {type(right)}" ) diff --git a/src/pybamm/solvers/processed_variable_time_integral.py b/src/pybamm/solvers/processed_variable_time_integral.py index 37f76ae2be..077b079775 100644 --- a/src/pybamm/solvers/processed_variable_time_integral.py +++ b/src/pybamm/solvers/processed_variable_time_integral.py @@ -7,7 +7,7 @@ @dataclass class ProcessedVariableTimeIntegral: method: Literal["discrete", "continuous"] - initial_condition: Union[npt.NDArraynp, float] + initial_condition: Union[npt.NDArray, float] discrete_times: Optional[npt.NDArray] @staticmethod