diff --git a/suncal/common/matrix.py b/suncal/common/matrix.py index 20c8b06..aaf4809 100644 --- a/suncal/common/matrix.py +++ b/suncal/common/matrix.py @@ -8,6 +8,22 @@ import sympy +_lambdify_cache = {} + + +def _lambdify(keys, expr): + ''' Return a cached lambdified function for expr over keys. + + Args: + keys: tuple of variable name strings (argument order for lambdify) + expr: sympy expression + ''' + cache_key = (expr, keys) + if cache_key not in _lambdify_cache: + _lambdify_cache[cache_key] = sympy.lambdify(keys, expr, 'numpy') + return _lambdify_cache[cache_key] + + def matmul(a, b): ''' Matrix multiply. Manually looped to preserve units since Pint doesn't allow matrices with different units on each element. @@ -64,11 +80,12 @@ def eval_matrix(U, values): Returns: list of list of floats ''' + keys = tuple(values.keys()) U_eval = [] for row in U: U_row = [] for expr in row: - df = sympy.lambdify(values.keys(), expr, 'numpy') # Can't subs() with pint Quantities + df = _lambdify(keys, expr) U_row.append(df(**values)) U_eval.append(U_row) return U_eval @@ -84,9 +101,10 @@ def eval_list(U, values): Returns: list of floats ''' + keys = tuple(values.keys()) U_eval = [] for expr in U: - df = sympy.lambdify(values.keys(), expr, 'numpy') # Can't subs() with pint Quantities + df = _lambdify(keys, expr) U_eval.append(df(**values)) return U_eval @@ -101,8 +119,9 @@ def eval_dict(U, values): Returns: dictionary of {name:float} ''' + keys = tuple(values.keys()) U_eval = {} for name, expr in U.items(): - df = sympy.lambdify(values.keys(), expr, 'numpy') # Can't subs() with pint Quantities + df = _lambdify(keys, expr) U_eval[name] = df(**values) return U_eval