-
Notifications
You must be signed in to change notification settings - Fork 149
Numba CAReduce: respect acc_dtype #1773
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+190
−83
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,7 +2,6 @@ | |
| from hashlib import sha256 | ||
| from textwrap import dedent, indent | ||
|
|
||
| import numba | ||
| import numpy as np | ||
| from numba.core.extending import overload | ||
| from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple | ||
|
|
@@ -14,6 +13,7 @@ | |
| ) | ||
| from pytensor.link.numba.dispatch import basic as numba_basic | ||
| from pytensor.link.numba.dispatch.basic import ( | ||
| create_tuple_string, | ||
| numba_funcify_and_cache_key, | ||
| register_funcify_and_cache_key, | ||
| register_funcify_default_op_cache_key, | ||
|
|
@@ -125,10 +125,12 @@ def scalar_in_place_fn_Minimum(op, idx, res, arr): | |
|
|
||
| def create_multiaxis_reducer( | ||
| scalar_op, | ||
| *, | ||
| identity, | ||
| axes, | ||
| ndim, | ||
| dtype, | ||
| acc_dtype=None, | ||
| out_dtype, | ||
| keepdims: bool = False, | ||
| ): | ||
| r"""Construct a function that reduces multiple axes. | ||
|
|
@@ -138,17 +140,46 @@ def create_multiaxis_reducer( | |
| .. code-block:: python | ||
|
|
||
| def careduce_add(x): | ||
| # For x.ndim == 3 and axes == (0, 1) and scalar_op == "Add" | ||
| x_shape = x.shape | ||
| res_shape = x_shape[2] | ||
| res = np.full(res_shape, numba_basic.to_scalar(0.0), dtype=out_dtype) | ||
| res_shape = (x_shape[0], x_shape[1]) | ||
| # identity = 0.0 | ||
| res = np.full(res_shape, identity, dtype=np.float64) | ||
| for i0 in range(x_shape[0]): | ||
| for i1 in range(x_shape[1]): | ||
| for i2 in range(x_shape[2]): | ||
| res[i0, i1] += x[i0, i1, i2] | ||
| return res | ||
|
|
||
| If accumulation dtype differs from output_dtype | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| def careduce_add(x): | ||
| x_shape = x.shape | ||
| res_shape = (x_shape[0], x_shape[1]) | ||
| # identity = 0.0 | ||
| res = np.full(res_shape, identity, dtype=np.float64) | ||
| for i0 in range(x_shape[0]): | ||
| for i1 in range(x_shape[1]): | ||
| for i2 in range(x_shape[2]): | ||
| res[i2] += x[i0, i1, i2] | ||
| res[i0, i1] += x[i0, i1, i2] | ||
| return res.astype(np.int32) | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| Full reductions accumulate on scalars | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| def careduce_mul(x): | ||
| x_shape = x.shape | ||
| res_shape = () | ||
| # identity = 1.0 | ||
| res = identity | ||
| for i0 in range(x_shape[0]): | ||
| for i1 in range(x_shape[1]): | ||
| for i2 in range(x_shape[2]): | ||
| res *= x[i0, i1, i2] | ||
| return np.array(res, dtype=np.int32) | ||
|
|
||
| return res | ||
|
|
||
| Parameters | ||
| ========== | ||
|
|
@@ -160,7 +191,9 @@ def careduce_add(x): | |
| The axes to reduce. | ||
| ndim: | ||
| The number of dimensions of the input variable. | ||
| dtype: | ||
| acc_dtype: dtype, optional | ||
| The data type used during accumulation. Defaults to out_dtype if not provided | ||
| out_dtype: | ||
| The data type of the result. | ||
| keepdims: boolean, default False | ||
| Whether to keep the reduced dimensions. | ||
|
|
@@ -178,19 +211,23 @@ def careduce_add(x): | |
| "Cannot keep multiple dimensions when reducing multiple axes" | ||
| ) | ||
|
|
||
| out_dtype = np.dtype(out_dtype) | ||
| acc_dtype = out_dtype if acc_dtype is None else np.dtype(acc_dtype) | ||
| # Numba doesn't allow converting complex to real with a simple `astype` | ||
| complex_to_real = acc_dtype.kind == "c" and out_dtype.kind != "c" | ||
| out_dtype_str = f"np.{out_dtype.name}" | ||
| acc_dtype_str = f"np.{acc_dtype.name}" | ||
| careduce_fn_name = f"careduce_{scalar_op}" | ||
|
|
||
| identity = str(identity) | ||
| if identity == "inf": | ||
| identity = "np.inf" | ||
| elif identity == "-inf": | ||
| identity = "-np.inf" | ||
|
|
||
| global_env = { | ||
| "np": np, | ||
| "numba_basic": numba_basic, | ||
| "out_dtype": dtype, | ||
| } | ||
| if acc_dtype.kind in "ui" and not np.isfinite(identity): | ||
| if np.isposinf(identity): | ||
| identity = np.iinfo(acc_dtype).max | ||
| else: | ||
| identity = np.iinfo(acc_dtype).min | ||
|
|
||
| # Make sure it has the correct dtype | ||
| identity = getattr(np, acc_dtype.name)(identity) | ||
|
|
||
| complete_reduction = len(axes) == ndim | ||
| kept_axis = tuple(i for i in range(ndim) if i not in axes) | ||
|
|
||
|
|
@@ -208,17 +245,23 @@ def careduce_add(x): | |
| scalar_op, res_indices, "res", f"x[{arr_indices}]" | ||
| ) | ||
|
|
||
| res_shape = f"({', '.join(f'x_shape[{i}]' for i in kept_axis)})" | ||
| res_shape = create_tuple_string([f"x_shape[{i}]" for i in kept_axis]) | ||
| if complete_reduction and ndim > 0: | ||
| # We accumulate on a scalar, not an array | ||
| res_creator = f"np.asarray({identity}).astype(out_dtype).item()" | ||
| res_creator = "identity" | ||
| inplace_update_stmt = inplace_update_stmt.replace("res[()]", "res") | ||
| return_obj = "np.asarray(res)" | ||
| if complex_to_real: | ||
| return_obj = f"np.array(res).real.astype({out_dtype_str})" | ||
| else: | ||
| return_obj = f"np.array(res, dtype={out_dtype_str})" | ||
| else: | ||
| res_creator = ( | ||
| f"np.full({res_shape}, np.asarray({identity}).item(), dtype=out_dtype)" | ||
| ) | ||
| return_obj = "res" | ||
| res_creator = f"np.full(res_shape, identity, dtype={acc_dtype_str})" | ||
| if complex_to_real: | ||
| return_obj = f"res.real.astype({out_dtype_str})" | ||
| else: | ||
| return_obj = ( | ||
| "res" if out_dtype == acc_dtype else f"res.astype({out_dtype_str})" | ||
| ) | ||
|
|
||
| if keepdims: | ||
| [axis] = axes | ||
|
|
@@ -229,6 +272,7 @@ def careduce_add(x): | |
| def {careduce_fn_name}(x): | ||
| x_shape = x.shape | ||
| res_shape = {res_shape} | ||
| # identity = {identity} | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't get it
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just for readability |
||
| res = {res_creator} | ||
| """ | ||
| ) | ||
|
|
@@ -238,13 +282,12 @@ def {careduce_fn_name}(x): | |
| " " * (4 + 4 * axis), | ||
| ) | ||
| careduce_def_src += indent(inplace_update_stmt, " " * (4 + 4 * ndim)) | ||
| careduce_def_src += "\n\n" | ||
| careduce_def_src += "\n" | ||
| careduce_def_src += indent(f"return {return_obj}", " " * 4) | ||
|
|
||
| careduce_fn = compile_numba_function_src( | ||
| careduce_def_src, careduce_fn_name, {**globals(), **global_env} | ||
| careduce_def_src, careduce_fn_name, globals() | {"np": np, "identity": identity} | ||
| ) | ||
|
|
||
| return careduce_fn | ||
|
|
||
|
|
||
|
|
@@ -356,41 +399,45 @@ def numba_funcify_CAReduce(op, node, **kwargs): | |
| acc_dtype = op.acc_dtype | ||
| else: | ||
| acc_dtype = node.outputs[0].type.dtype | ||
| np_acc_dtype = np.dtype(acc_dtype) | ||
|
|
||
| scalar_op_identity = op.scalar_op.identity | ||
| if np_acc_dtype.kind == "i" and not np.isfinite(scalar_op_identity): | ||
| if np.isposinf(scalar_op_identity): | ||
| scalar_op_identity = np.iinfo(np_acc_dtype).max | ||
| else: | ||
| scalar_op_identity = np.iinfo(np_acc_dtype).min | ||
| # Make sure it has the correct dtype | ||
| scalar_op_identity = np.array(scalar_op_identity, dtype=np_acc_dtype) | ||
|
|
||
| out_dtype = np.dtype(node.outputs[0].type.dtype) | ||
|
|
||
| if isinstance(op, Sum) and node.inputs[0].ndim == len(axes): | ||
| if ( | ||
| isinstance(op, Sum) | ||
| and node.inputs[0].ndim == len(axes) | ||
| and out_dtype == acc_dtype | ||
| ): | ||
| # Slightly faster for this case | ||
| @numba_basic.numba_njit | ||
| def impl_sum(array): | ||
| return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype) | ||
| return np.array(array.sum()) | ||
|
|
||
| careduce_fn = impl_sum # Some tests look for this name | ||
|
|
||
| else: | ||
| ndim = node.inputs[0].ndim | ||
| careduce_py_fn = create_multiaxis_reducer( | ||
| op.scalar_op, | ||
| scalar_op_identity, | ||
| axes, | ||
| ndim, | ||
| out_dtype, | ||
| identity=op.scalar_op.identity, | ||
| axes=axes, | ||
| ndim=ndim, | ||
| acc_dtype=acc_dtype, | ||
| out_dtype=out_dtype, | ||
| ) | ||
| careduce_fn = numba_basic.numba_njit(careduce_py_fn, boundscheck=False) | ||
|
|
||
| cache_version = 1 | ||
| careduce_key = sha256( | ||
| str( | ||
| (type(op), type(op.scalar_op), axes, acc_dtype, scalar_op_identity.item()) | ||
| ( | ||
| type(op), | ||
| type(op.scalar_op), | ||
| axes, | ||
| out_dtype, | ||
| acc_dtype, | ||
| op.scalar_op.identity, | ||
| cache_version, | ||
| ) | ||
| ).encode() | ||
| ).hexdigest() | ||
| return careduce_fn, careduce_key | ||
|
|
@@ -449,18 +496,26 @@ def dimshuffle(x): | |
|
|
||
| @register_funcify_default_op_cache_key(Softmax) | ||
| def numba_funcify_Softmax(op, node, **kwargs): | ||
| x_at = node.inputs[0] | ||
| x_dtype = x_at.type.numpy_dtype | ||
| x_dtype = numba.np.numpy_support.from_dtype(x_dtype) | ||
| ndim = node.inputs[0].type.ndim | ||
| inp_dtype = node.inputs[0].type.numpy_dtype | ||
| axis = op.axis | ||
|
|
||
| if axis is not None: | ||
| axis = normalize_axis_index(axis, x_at.ndim) | ||
| if ndim > 1 and axis is not None: | ||
| reduce_max_py = create_multiaxis_reducer( | ||
| maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True | ||
| maximum, | ||
| identity=-np.inf, | ||
| axes=(axis,), | ||
| ndim=ndim, | ||
| out_dtype=inp_dtype, | ||
| keepdims=True, | ||
| ) | ||
| reduce_sum_py = create_multiaxis_reducer( | ||
| add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True | ||
| add_as, | ||
| identity=0.0, | ||
| axes=(axis,), | ||
| ndim=ndim, | ||
| out_dtype=inp_dtype, | ||
| keepdims=True, | ||
| ) | ||
|
|
||
| jit_fn = numba_basic.numba_njit(boundscheck=False) | ||
|
|
@@ -470,66 +525,72 @@ def numba_funcify_Softmax(op, node, **kwargs): | |
| reduce_max = np.max | ||
| reduce_sum = np.sum | ||
|
|
||
| def softmax_py_fn(x): | ||
| @numba_basic.numba_njit(boundscheck=False) | ||
| def softmax(x): | ||
| z = reduce_max(x) | ||
| e_x = np.exp(x - z) | ||
| w = reduce_sum(e_x) | ||
| sm = e_x / w | ||
| return sm | ||
|
|
||
| softmax = numba_basic.numba_njit(softmax_py_fn, boundscheck=False) | ||
|
|
||
| return softmax | ||
| cache_version = 1 | ||
| return softmax, cache_version | ||
|
|
||
|
|
||
| @register_funcify_default_op_cache_key(SoftmaxGrad) | ||
| def numba_funcify_SoftmaxGrad(op, node, **kwargs): | ||
| sm_at = node.inputs[1] | ||
| sm_dtype = sm_at.type.numpy_dtype | ||
| sm_dtype = numba.np.numpy_support.from_dtype(sm_dtype) | ||
| ndim = node.inputs[0].type.ndim | ||
| inp_dtype = node.inputs[0].type.numpy_dtype | ||
|
|
||
| axis = op.axis | ||
| if axis is not None: | ||
| axis = normalize_axis_index(axis, sm_at.ndim) | ||
| if ndim > 1 and axis is not None: | ||
| reduce_sum_py = create_multiaxis_reducer( | ||
| add_as, 0.0, (axis,), sm_at.ndim, sm_dtype, keepdims=True | ||
| add_as, | ||
| identity=0.0, | ||
| axes=(axis,), | ||
| ndim=ndim, | ||
| out_dtype=inp_dtype, | ||
| keepdims=True, | ||
| ) | ||
|
|
||
| jit_fn = numba_basic.numba_njit(boundscheck=False) | ||
| reduce_sum = jit_fn(reduce_sum_py) | ||
| else: | ||
| reduce_sum = np.sum | ||
|
|
||
| def softmax_grad_py_fn(dy, sm): | ||
| @numba_basic.numba_njit(boundscheck=False) | ||
| def softmax_grad(dy, sm): | ||
| dy_times_sm = dy * sm | ||
| sum_dy_times_sm = reduce_sum(dy_times_sm) | ||
| dx = dy_times_sm - sum_dy_times_sm * sm | ||
| return dx | ||
|
|
||
| softmax_grad = numba_basic.numba_njit(softmax_grad_py_fn, boundscheck=False) | ||
|
|
||
| return softmax_grad | ||
| cache_version = 1 | ||
| return softmax_grad, cache_version | ||
|
|
||
|
|
||
| @register_funcify_default_op_cache_key(LogSoftmax) | ||
| def numba_funcify_LogSoftmax(op, node, **kwargs): | ||
| x_at = node.inputs[0] | ||
| x_dtype = x_at.type.numpy_dtype | ||
| x_dtype = numba.np.numpy_support.from_dtype(x_dtype) | ||
| ndim = node.inputs[0].type.ndim | ||
| inp_dtype = node.inputs[0].type.numpy_dtype | ||
| axis = op.axis | ||
|
|
||
| if axis is not None: | ||
| axis = normalize_axis_index(axis, x_at.ndim) | ||
| if ndim > 1 and axis is not None: | ||
| reduce_max_py = create_multiaxis_reducer( | ||
| maximum, | ||
| -np.inf, | ||
| (axis,), | ||
| x_at.ndim, | ||
| x_dtype, | ||
| identity=-np.inf, | ||
| axes=(axis,), | ||
| ndim=ndim, | ||
| out_dtype=inp_dtype, | ||
| keepdims=True, | ||
| ) | ||
| reduce_sum_py = create_multiaxis_reducer( | ||
| add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True | ||
| add_as, | ||
| identity=0.0, | ||
| axes=(axis,), | ||
| ndim=ndim, | ||
| out_dtype=inp_dtype, | ||
| keepdims=True, | ||
| ) | ||
|
|
||
| jit_fn = numba_basic.numba_njit(boundscheck=False) | ||
|
|
@@ -539,13 +600,14 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): | |
| reduce_max = np.max | ||
| reduce_sum = np.sum | ||
|
|
||
| def log_softmax_py_fn(x): | ||
| @numba_basic.numba_njit(boundscheck=False) | ||
| def log_softmax(x): | ||
| xdev = x - reduce_max(x) | ||
| lsm = xdev - np.log(reduce_sum(np.exp(xdev))) | ||
| return lsm | ||
|
|
||
| log_softmax = numba_basic.numba_njit(log_softmax_py_fn, boundscheck=False) | ||
| return log_softmax | ||
| cache_version = 1 | ||
| return log_softmax, cache_version | ||
|
|
||
|
|
||
| @register_funcify_default_op_cache_key(Argmax) | ||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this commented?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's to make codegen readable when debugging, otherwise you just see a global identity being used but won't know its value