Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/scripts/SPM_compare_particle_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Compare different discretisations in the particle
#
import argparse
from typing import Any
import numpy as np
import pybamm
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -48,7 +49,7 @@
disc.process_model(model)

# solve model
solutions = [None] * len(models)
solutions: list[Any] = [None] * len(models)
t_eval = np.linspace(0, 3600, 100)
for i, model in enumerate(models):
solutions[i] = model.default_solver.solve(model, t_eval)
Expand Down
21 changes: 11 additions & 10 deletions examples/scripts/SPMe_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,14 @@
time += dt

# plot
time_in_seconds = solution["Time [s]"].entries
step_time_in_seconds = step_solution["Time [s]"].entries
voltage = solution["Voltage [V]"].entries
step_voltage = step_solution["Voltage [V]"].entries
plt.plot(time_in_seconds, voltage, "b-", label="SPMe (continuous solve)")
plt.plot(step_time_in_seconds, step_voltage, "ro", label="SPMe (stepped solve)")
plt.xlabel(r"$t$")
plt.ylabel("Voltage [V]")
plt.legend()
plt.show()
if step_solution is not None:
time_in_seconds = solution["Time [s]"].entries
step_time_in_seconds = step_solution["Time [s]"].entries
voltage = solution["Voltage [V]"].entries
step_voltage = step_solution["Voltage [V]"].entries
plt.plot(time_in_seconds, voltage, "b-", label="SPMe (continuous solve)")
plt.plot(step_time_in_seconds, step_voltage, "ro", label="SPMe (stepped solve)")
plt.xlabel(r"$t$")
plt.ylabel("Voltage [V]")
plt.legend()
plt.show()
3 changes: 2 additions & 1 deletion examples/scripts/heat_equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pybamm
import numpy as np
import matplotlib.pyplot as plt
import numpy.typing as npt

# Numerical solution ----------------------------------------------------------

Expand Down Expand Up @@ -106,7 +107,7 @@ def T_exact(x, t):
# Plot ------------------------------------------------------------------------
x_nodes = mesh["rod"].nodes # numerical gridpoints
xx = np.linspace(0, 2, 101) # fine mesh to plot exact solution
plot_times = np.linspace(0, 1, 5)
plot_times: npt.NDArray = np.linspace(0, 1, 5)

plt.figure(figsize=(15, 8))
cmap = plt.get_cmap("inferno")
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/minimal_example_of_lookup_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ def process_2D(name, data):
D_s_n_data = process_2D("Negative particle diffusivity [m2.s-1]", df)


def D_s_n(sto, T):
def D_s_n_func(sto, T):
name, (x, y) = D_s_n_data
return pybamm.Interpolant(x, y, [T, sto], name)


parameter_values["Negative particle diffusivity [m2.s-1]"] = D_s_n
parameter_values["Negative particle diffusivity [m2.s-1]"] = D_s_n_func

k_n = parameter_values["Negative electrode exchange-current density [A.m-2]"]

Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,10 @@ concurrency = ["multiprocessing"]
ignore_missing_imports = true
allow_redefinition = true
disable_error_code = ["call-overload", "operator"]
strict = false
Copy link
Member

@Saransh-cpp Saransh-cpp Mar 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given how big this PR is, it would actually be better to split it into multiple PRs, each one adding a new config option in pyproject.toml.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok so should I go ahead then and make a PR on one of the config first? or edit this one accordingly?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can keep this PR for one config, and add other configs in subsequent PRs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I'll create seperate PRs for different configs and then keep this one for the end, I think that would be faster for me

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed many of the warn_unreachable error depends on errors from enable_error_code config, they're related to each other and their are total of 77 errors out of which 57 are from enable_error_code so I think creating a seperate PR would still be almost as big as this one, so should I still proceed with it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think anything that reduces the diff and keeps this PR scoped to a specific change (or a few of them) would be great. Thanks for investigating!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've created the PR with just enable_error_code config: #4891, sorry for the delay lab tests going on

warn_unreachable = true
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
exclude = "^(build/|docs/conf\\.py)$"

[[tool.mypy.overrides]]
module = [
Expand Down
4 changes: 3 additions & 1 deletion src/pybamm/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ class Experiment:

def __init__(
self,
operating_conditions: list[str | tuple[str] | BaseStep],
operating_conditions: list[
str | tuple[str, ...] | tuple[str | BaseStep] | BaseStep
],
period: str | None = None,
temperature: float | None = None,
termination: list[str] | None = None,
Expand Down
2 changes: 1 addition & 1 deletion src/pybamm/experiment/step/base_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(
self.value = pybamm.Interpolant(
t,
y,
pybamm.t - pybamm.InputParameter("start time"),
pybamm.t - pybamm.InputParameter("start time"), # type: ignore[arg-type]
name="Drive Cycle",
)
self.period = np.diff(t).min()
Expand Down
47 changes: 43 additions & 4 deletions src/pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Binary operator classes
#
from __future__ import annotations
import numbers

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -35,7 +34,7 @@ def _preprocess_binary(
right = pybamm.Vector(right)

# Check both left and right are pybamm Symbols
if not (isinstance(left, pybamm.Symbol) and isinstance(right, pybamm.Symbol)):
if not (isinstance(left, pybamm.Symbol) and isinstance(right, pybamm.Symbol)): # type: ignore[redundant-expr]
raise NotImplementedError(
f"BinaryOperator not implemented for symbols of type {type(left)} and {type(right)}"
)
Expand Down Expand Up @@ -114,6 +113,9 @@ def __str__(self):
right_str = f"{self.right!s}"
return f"{left_str} {self.name} {right_str}"

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return self.__class__(self.name, left, right) # pragma: no cover

def create_copy(
self,
new_children: list[pybamm.Symbol] | None = None,
Expand All @@ -128,7 +130,7 @@ def create_copy(
children = self._children_for_copying(new_children)

if not perform_simplifications:
out = self.__class__(children[0], children[1])
out = self._new_instance(children[0], children[1])
else:
# creates a new instance using the overloaded binary operator to perform
# additional simplifications, rather than just calling the constructor
Expand Down Expand Up @@ -225,6 +227,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("**", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Power(left, right)

Comment on lines +230 to +232
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need new _new_instance methods? Could you document this in an inline comment somewhere, or add it in the PR description (better)?

Copy link
Member Author

@Rishab87 Rishab87 Mar 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added new_instance because earlier when we were using __ class __ it showed a third arg was not getting passed:

error: Missing positional argument "right_child" in call to "BinaryOperator"  [call-arg]

but this function was always getting called from child classes of BinaryOperator which don't need to pass 3 arguments, so i thought it was better to make a new_instance method which can be overirded in child classes

so should I update the PR description or think of some other approach?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And for broadcasts.py self.broadcast_domain was not defined in parent class they were params of child classes and _unary_new_copy function was getting called from child class instance always so I just overrided this function in child classes, apart from these changes all other changes are mostly type changes

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
# apply chain rule and power rule
Expand Down Expand Up @@ -274,6 +279,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("+", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Addition(left, right)

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
return self.left.diff(variable) + self.right.diff(variable)
Expand Down Expand Up @@ -301,6 +309,9 @@ def __init__(

super().__init__("-", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Subtraction(left, right)

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
return self.left.diff(variable) - self.right.diff(variable)
Expand Down Expand Up @@ -330,6 +341,9 @@ def __init__(

super().__init__("*", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Multiplication(left, right)

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
# apply product rule
Expand Down Expand Up @@ -370,6 +384,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("@", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return MatrixMultiplication(left, right) # pragma: no cover

def diff(self, variable):
"""See :meth:`pybamm.Symbol.diff()`."""
# We shouldn't need this
Expand Down Expand Up @@ -419,6 +436,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("/", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Division(left, right)

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
# apply quotient rule
Expand Down Expand Up @@ -467,6 +487,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("inner product", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Inner(left, right) # pragma: no cover

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
# apply product rule
Expand Down Expand Up @@ -544,6 +567,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("==", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Equality(left, right)

def diff(self, variable):
"""See :meth:`pybamm.Symbol.diff()`."""
# Equality should always be multiplied by something else so hopefully don't
Expand Down Expand Up @@ -601,6 +627,10 @@ def __init__(
):
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__(name, left, right)
self.name = name

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return _Heaviside(self.name, left, right) # pragma: no cover

def diff(self, variable):
"""See :meth:`pybamm.Symbol.diff()`."""
Expand Down Expand Up @@ -679,6 +709,9 @@ def __init__(
):
super().__init__("%", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Modulo(left, right)

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
# apply chain rule and power rule
Expand Down Expand Up @@ -721,6 +754,9 @@ def __init__(
):
super().__init__("minimum", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Minimum(left, right)

def __str__(self):
"""See :meth:`pybamm.Symbol.__str__()`."""
return f"minimum({self.left!s}, {self.right!s})"
Expand Down Expand Up @@ -765,6 +801,9 @@ def __init__(
):
super().__init__("maximum", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Maximum(left, right)

def __str__(self):
"""See :meth:`pybamm.Symbol.__str__()`."""
return f"maximum({self.left!s}, {self.right!s})"
Expand Down Expand Up @@ -1539,7 +1578,7 @@ def source(
corresponding to a source term in the bulk.
"""
# Broadcast if left is number
if isinstance(left, numbers.Number):
if isinstance(left, (int, float)):
left = pybamm.PrimaryBroadcast(left, "current collector")

# force type cast for mypy
Expand Down
17 changes: 14 additions & 3 deletions src/pybamm/expression_tree/broadcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ def _from_json(cls, snippet):
)

def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True):
"""See :meth:`pybamm.UnaryOperator._unary_new_copy()`."""
return self.__class__(child, self.broadcast_domain)
pass # pragma: no cover


class PrimaryBroadcast(Broadcast):
Expand Down Expand Up @@ -191,6 +190,10 @@ def reduce_one_dimension(self):
"""Reduce the broadcast by one dimension."""
return self.orphans[0]

def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True):
"""See :meth:`pybamm.UnaryOperator._unary_new_copy()`."""
return self.__class__(child, self.broadcast_domain)


class PrimaryBroadcastToEdges(PrimaryBroadcast):
"""A primary broadcast onto the edges of the domain."""
Expand Down Expand Up @@ -321,6 +324,10 @@ def reduce_one_dimension(self):
"""Reduce the broadcast by one dimension."""
return self.orphans[0]

def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True):
"""See :meth:`pybamm.UnaryOperator._unary_new_copy()`."""
return self.__class__(child, self.broadcast_domain)


class SecondaryBroadcastToEdges(SecondaryBroadcast):
"""A secondary broadcast onto the edges of a domain."""
Expand Down Expand Up @@ -438,6 +445,10 @@ def reduce_one_dimension(self):
"""Reduce the broadcast by one dimension."""
raise NotImplementedError

def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True):
"""See :meth:`pybamm.UnaryOperator._unary_new_copy()`."""
return self.__class__(child, self.broadcast_domain)


class TertiaryBroadcastToEdges(TertiaryBroadcast):
"""A tertiary broadcast onto the edges of a domain."""
Expand All @@ -463,7 +474,7 @@ def __init__(
self,
child_input: Numeric | pybamm.Symbol,
broadcast_domain: DomainType = None,
auxiliary_domains: AuxiliaryDomainType = None,
auxiliary_domains: AuxiliaryDomainType | str = None,
broadcast_domains: DomainsType = None,
name: str | None = None,
):
Expand Down
6 changes: 3 additions & 3 deletions src/pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def __init__(self, *children, name: Optional[str] = None):
if name is None:
# Name is the intersection of the children names (should usually make sense
# if the children have been named consistently)
name = intersect(children[0].name, children[1].name)
name = intersect(children[0].name, children[1].name) or ""
for child in children[2:]:
name = intersect(name, child.name)
if len(name) == 0:
Expand Down Expand Up @@ -515,7 +515,7 @@ def substrings(s: str):
yield s[i : j + 1]


def intersect(s1: str, s2: str):
def intersect(s1: str, s2: str) -> str:
# find all the common strings between two strings
all_intersects = set(substrings(s1)) & set(substrings(s2))
# intersect is the longest such intercept
Expand All @@ -526,7 +526,7 @@ def intersect(s1: str, s2: str):
return intersect.lstrip().rstrip()


def simplified_concatenation(*children, name: Optional[str] = None):
def simplified_concatenation(*children, name=None):
"""Perform simplifications on a concatenation."""
# remove children that are None
children = list(filter(lambda x: x is not None, children))
Expand Down
5 changes: 3 additions & 2 deletions src/pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy.typing as npt
from scipy import special
import sympy
from typing import Callable
from typing import Callable, cast
from collections.abc import Sequence
from typing_extensions import TypeVar

Expand All @@ -33,7 +33,7 @@ class Function(pybamm.Symbol):
def __init__(
self,
function: Callable,
*children: pybamm.Symbol,
*children: pybamm.Symbol | float | int,
name: str | None = None,
differentiated_function: Callable | None = None,
):
Expand All @@ -43,6 +43,7 @@ def __init__(
if isinstance(child, (float, int, np.number)):
children[idx] = pybamm.Scalar(child)

children = cast(Sequence[pybamm.Symbol], children)
if name is not None:
self.name = name
else:
Expand Down
Loading
Loading