Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions suncal/common/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,22 @@
import sympy


_lambdify_cache = {}


def _lambdify(keys, expr):
''' Return a cached lambdified function for expr over keys.

Args:
keys: tuple of variable name strings (argument order for lambdify)
expr: sympy expression
'''
cache_key = (expr, keys)
if cache_key not in _lambdify_cache:
_lambdify_cache[cache_key] = sympy.lambdify(keys, expr, 'numpy')
return _lambdify_cache[cache_key]


def matmul(a, b):
''' Matrix multiply. Manually looped to preserve units since Pint
doesn't allow matrices with different units on each element.
Expand Down Expand Up @@ -64,11 +80,12 @@ def eval_matrix(U, values):
Returns:
list of list of floats
'''
keys = tuple(values.keys())
U_eval = []
for row in U:
U_row = []
for expr in row:
df = sympy.lambdify(values.keys(), expr, 'numpy') # Can't subs() with pint Quantities
df = _lambdify(keys, expr)
U_row.append(df(**values))
U_eval.append(U_row)
return U_eval
Expand All @@ -84,9 +101,10 @@ def eval_list(U, values):
Returns:
list of floats
'''
keys = tuple(values.keys())
U_eval = []
for expr in U:
df = sympy.lambdify(values.keys(), expr, 'numpy') # Can't subs() with pint Quantities
df = _lambdify(keys, expr)
U_eval.append(df(**values))
return U_eval

Expand All @@ -101,8 +119,9 @@ def eval_dict(U, values):
Returns:
dictionary of {name:float}
'''
keys = tuple(values.keys())
U_eval = {}
for name, expr in U.items():
df = sympy.lambdify(values.keys(), expr, 'numpy') # Can't subs() with pint Quantities
df = _lambdify(keys, expr)
U_eval[name] = df(**values)
return U_eval