22from hashlib import sha256
33from textwrap import dedent , indent
44
5- import numba
65import numpy as np
76from numba .core .extending import overload
87from numpy .lib .array_utils import normalize_axis_index , normalize_axis_tuple
1514)
1615from pytensor .link .numba .dispatch import basic as numba_basic
1716from pytensor .link .numba .dispatch .basic import (
17+ create_tuple_string ,
1818 numba_funcify_and_cache_key ,
1919 register_funcify_and_cache_key ,
2020 register_funcify_default_op_cache_key ,
@@ -126,10 +126,12 @@ def scalar_in_place_fn_Minimum(op, idx, res, arr):
126126
127127def create_multiaxis_reducer (
128128 scalar_op ,
129+ * ,
129130 identity ,
130131 axes ,
131132 ndim ,
132- dtype ,
133+ acc_dtype = None ,
134+ out_dtype ,
133135 keepdims : bool = False ,
134136):
135137 r"""Construct a function that reduces multiple axes.
@@ -139,17 +141,46 @@ def create_multiaxis_reducer(
139141 .. code-block:: python
140142
141143 def careduce_add(x):
142- # For x.ndim == 3 and axes == (0, 1) and scalar_op == "Add"
143144 x_shape = x.shape
144- res_shape = x_shape[2]
145- res = np.full(res_shape, numba_basic.to_scalar(0.0), dtype=out_dtype)
145+ res_shape = (x_shape[0], x_shape[1])
146+ # identity = 0.0
147+ res = np.full(res_shape, identity, dtype=np.float64)
148+ for i0 in range(x_shape[0]):
149+ for i1 in range(x_shape[1]):
150+ for i2 in range(x_shape[2]):
151+ res[i0, i1] += x[i0, i1, i2]
152+ return res
153+
154+ If accumulation dtype differs from output_dtype
155+
156+ .. code-block:: python
146157
158+ def careduce_add(x):
159+ x_shape = x.shape
160+ res_shape = (x_shape[0], x_shape[1])
161+ # identity = 0.0
162+ res = np.full(res_shape, identity, dtype=np.float64)
147163 for i0 in range(x_shape[0]):
148164 for i1 in range(x_shape[1]):
149165 for i2 in range(x_shape[2]):
150- res[i2] += x[i0, i1, i2]
166+ res[i0, i1] += x[i0, i1, i2]
167+ return res.astype(np.int32)
168+
169+ Full reductions accumulate on scalars
170+
171+ .. code-block:: python
172+
173+ def careduce_mul(x):
174+ x_shape = x.shape
175+ res_shape = ()
176+ # identity = 1.0
177+ res = identity
178+ for i0 in range(x_shape[0]):
179+ for i1 in range(x_shape[1]):
180+ for i2 in range(x_shape[2]):
181+ res *= x[i0, i1, i2]
182+ return np.array(res, dtype=np.int32)
151183
152- return res
153184
154185 Parameters
155186 ==========
@@ -161,7 +192,9 @@ def careduce_add(x):
161192 The axes to reduce.
162193 ndim:
163194 The number of dimensions of the input variable.
164- dtype:
195+ acc_dtype: dtype, optional
196+ The data type used during accumulation. Defaults to out_dtype if not provided
197+ out_dtype:
165198 The data type of the result.
166199 keepdims: boolean, default False
167200 Whether to keep the reduced dimensions.
@@ -179,19 +212,23 @@ def careduce_add(x):
179212 "Cannot keep multiple dimensions when reducing multiple axes"
180213 )
181214
215+ out_dtype = np .dtype (out_dtype )
216+ acc_dtype = out_dtype if acc_dtype is None else np .dtype (acc_dtype )
217+ # Numba doesn't allow converting complex to real with a simple `astype`
218+ complex_to_real = acc_dtype .kind == "c" and out_dtype .kind != "c"
219+ out_dtype_str = f"np.{ out_dtype .name } "
220+ acc_dtype_str = f"np.{ acc_dtype .name } "
182221 careduce_fn_name = f"careduce_{ scalar_op } "
183222
184- identity = str (identity )
185- if identity == "inf" :
186- identity = "np.inf"
187- elif identity == "-inf" :
188- identity = "-np.inf"
189-
190- global_env = {
191- "np" : np ,
192- "numba_basic" : numba_basic ,
193- "out_dtype" : dtype ,
194- }
223+ if acc_dtype .kind in "ui" and not np .isfinite (identity ):
224+ if np .isposinf (identity ):
225+ identity = np .iinfo (acc_dtype ).max
226+ else :
227+ identity = np .iinfo (acc_dtype ).min
228+
229+ # Make sure it has the correct dtype
230+ identity = getattr (np , acc_dtype .name )(identity )
231+
195232 complete_reduction = len (axes ) == ndim
196233 kept_axis = tuple (i for i in range (ndim ) if i not in axes )
197234
@@ -209,17 +246,23 @@ def careduce_add(x):
209246 scalar_op , res_indices , "res" , f"x[{ arr_indices } ]"
210247 )
211248
212- res_shape = f"( { ', ' . join ( f' x_shape[{ i } ]' for i in kept_axis ) } )"
249+ res_shape = create_tuple_string ([ f" x_shape[{ i } ]" for i in kept_axis ])
213250 if complete_reduction and ndim > 0 :
214251 # We accumulate on a scalar, not an array
215- res_creator = f"np.asarray( { identity } ).astype(out_dtype).item() "
252+ res_creator = " identity"
216253 inplace_update_stmt = inplace_update_stmt .replace ("res[()]" , "res" )
217- return_obj = "np.asarray(res)"
254+ if complex_to_real :
255+ return_obj = f"np.array(res).real.astype({ out_dtype_str } )"
256+ else :
257+ return_obj = f"np.array(res, dtype={ out_dtype_str } )"
218258 else :
219- res_creator = (
220- f"np.full({ res_shape } , np.asarray({ identity } ).item(), dtype=out_dtype)"
221- )
222- return_obj = "res"
259+ res_creator = f"np.full(res_shape, identity, dtype={ acc_dtype_str } )"
260+ if complex_to_real :
261+ return_obj = f"res.real.astype({ out_dtype_str } )"
262+ else :
263+ return_obj = (
264+ "res" if out_dtype == acc_dtype else f"res.astype({ out_dtype_str } )"
265+ )
223266
224267 if keepdims :
225268 [axis ] = axes
@@ -230,6 +273,7 @@ def careduce_add(x):
230273 def { careduce_fn_name } (x):
231274 x_shape = x.shape
232275 res_shape = { res_shape }
276+ # identity = { identity }
233277 res = { res_creator }
234278 """
235279 )
@@ -239,13 +283,12 @@ def {careduce_fn_name}(x):
239283 " " * (4 + 4 * axis ),
240284 )
241285 careduce_def_src += indent (inplace_update_stmt , " " * (4 + 4 * ndim ))
242- careduce_def_src += "\n \n "
286+ careduce_def_src += "\n "
243287 careduce_def_src += indent (f"return { return_obj } " , " " * 4 )
244288
245289 careduce_fn = compile_numba_function_src (
246- careduce_def_src , careduce_fn_name , { ** globals (), ** global_env }
290+ careduce_def_src , careduce_fn_name , globals () | { "np" : np , "identity" : identity }
247291 )
248-
249292 return careduce_fn
250293
251294
@@ -356,41 +399,45 @@ def numba_funcify_CAReduce(op, node, **kwargs):
356399 acc_dtype = op .acc_dtype
357400 else :
358401 acc_dtype = node .outputs [0 ].type .dtype
359- np_acc_dtype = np .dtype (acc_dtype )
360-
361- scalar_op_identity = op .scalar_op .identity
362- if np_acc_dtype .kind == "i" and not np .isfinite (scalar_op_identity ):
363- if np .isposinf (scalar_op_identity ):
364- scalar_op_identity = np .iinfo (np_acc_dtype ).max
365- else :
366- scalar_op_identity = np .iinfo (np_acc_dtype ).min
367- # Make sure it has the correct dtype
368- scalar_op_identity = np .array (scalar_op_identity , dtype = np_acc_dtype )
369402
370403 out_dtype = np .dtype (node .outputs [0 ].type .dtype )
371404
372- if isinstance (op , Sum ) and node .inputs [0 ].ndim == len (axes ):
405+ if (
406+ isinstance (op , Sum )
407+ and node .inputs [0 ].ndim == len (axes )
408+ and out_dtype == acc_dtype
409+ ):
373410 # Slightly faster for this case
374411 @numba_basic .numba_njit
375412 def impl_sum (array ):
376- return np .asarray (array .sum (), dtype = np_acc_dtype ). astype ( out_dtype )
413+ return np .array (array .sum ())
377414
378415 careduce_fn = impl_sum # Some tests look for this name
379416
380417 else :
381418 ndim = node .inputs [0 ].ndim
382419 careduce_py_fn = create_multiaxis_reducer (
383420 op .scalar_op ,
384- scalar_op_identity ,
385- axes ,
386- ndim ,
387- out_dtype ,
421+ identity = op .scalar_op .identity ,
422+ axes = axes ,
423+ ndim = ndim ,
424+ acc_dtype = acc_dtype ,
425+ out_dtype = out_dtype ,
388426 )
389427 careduce_fn = numba_basic .numba_njit (careduce_py_fn , boundscheck = False )
390428
429+ cache_version = 1
391430 careduce_key = sha256 (
392431 str (
393- (type (op ), type (op .scalar_op ), axes , acc_dtype , scalar_op_identity .item ())
432+ (
433+ type (op ),
434+ type (op .scalar_op ),
435+ axes ,
436+ out_dtype ,
437+ acc_dtype ,
438+ op .scalar_op .identity ,
439+ cache_version ,
440+ )
394441 ).encode ()
395442 ).hexdigest ()
396443 return careduce_fn , careduce_key
@@ -449,18 +496,26 @@ def dimshuffle(x):
449496
450497@register_funcify_default_op_cache_key (Softmax )
451498def numba_funcify_Softmax (op , node , ** kwargs ):
452- x_at = node .inputs [0 ]
453- x_dtype = x_at .type .numpy_dtype
454- x_dtype = numba .np .numpy_support .from_dtype (x_dtype )
499+ ndim = node .inputs [0 ].type .ndim
500+ inp_dtype = node .inputs [0 ].type .numpy_dtype
455501 axis = op .axis
456502
457- if axis is not None :
458- axis = normalize_axis_index (axis , x_at .ndim )
503+ if ndim > 1 and axis is not None :
459504 reduce_max_py = create_multiaxis_reducer (
460- maximum , - np .inf , axis , x_at .ndim , x_dtype , keepdims = True
505+ maximum ,
506+ identity = - np .inf ,
507+ axes = (axis ,),
508+ ndim = ndim ,
509+ out_dtype = inp_dtype ,
510+ keepdims = True ,
461511 )
462512 reduce_sum_py = create_multiaxis_reducer (
463- add_as , 0.0 , (axis ,), x_at .ndim , x_dtype , keepdims = True
513+ add_as ,
514+ identity = 0.0 ,
515+ axes = (axis ,),
516+ ndim = ndim ,
517+ out_dtype = inp_dtype ,
518+ keepdims = True ,
464519 )
465520
466521 jit_fn = numba_basic .numba_njit (boundscheck = False )
@@ -470,66 +525,72 @@ def numba_funcify_Softmax(op, node, **kwargs):
470525 reduce_max = np .max
471526 reduce_sum = np .sum
472527
473- def softmax_py_fn (x ):
528+ @numba_basic .numba_njit (boundscheck = False )
529+ def softmax (x ):
474530 z = reduce_max (x )
475531 e_x = np .exp (x - z )
476532 w = reduce_sum (e_x )
477533 sm = e_x / w
478534 return sm
479535
480- softmax = numba_basic .numba_njit (softmax_py_fn , boundscheck = False )
481-
482- return softmax
536+ cache_version = 1
537+ return softmax , cache_version
483538
484539
485540@register_funcify_default_op_cache_key (SoftmaxGrad )
486541def numba_funcify_SoftmaxGrad (op , node , ** kwargs ):
487- sm_at = node .inputs [1 ]
488- sm_dtype = sm_at .type .numpy_dtype
489- sm_dtype = numba .np .numpy_support .from_dtype (sm_dtype )
542+ ndim = node .inputs [0 ].type .ndim
543+ inp_dtype = node .inputs [0 ].type .numpy_dtype
490544
491545 axis = op .axis
492- if axis is not None :
493- axis = normalize_axis_index (axis , sm_at .ndim )
546+ if ndim > 1 and axis is not None :
494547 reduce_sum_py = create_multiaxis_reducer (
495- add_as , 0.0 , (axis ,), sm_at .ndim , sm_dtype , keepdims = True
548+ add_as ,
549+ identity = 0.0 ,
550+ axes = (axis ,),
551+ ndim = ndim ,
552+ out_dtype = inp_dtype ,
553+ keepdims = True ,
496554 )
497555
498556 jit_fn = numba_basic .numba_njit (boundscheck = False )
499557 reduce_sum = jit_fn (reduce_sum_py )
500558 else :
501559 reduce_sum = np .sum
502560
503- def softmax_grad_py_fn (dy , sm ):
561+ @numba_basic .numba_njit (boundscheck = False )
562+ def softmax_grad (dy , sm ):
504563 dy_times_sm = dy * sm
505564 sum_dy_times_sm = reduce_sum (dy_times_sm )
506565 dx = dy_times_sm - sum_dy_times_sm * sm
507566 return dx
508567
509- softmax_grad = numba_basic .numba_njit (softmax_grad_py_fn , boundscheck = False )
510-
511- return softmax_grad
568+ cache_version = 1
569+ return softmax_grad , cache_version
512570
513571
514572@register_funcify_default_op_cache_key (LogSoftmax )
515573def numba_funcify_LogSoftmax (op , node , ** kwargs ):
516- x_at = node .inputs [0 ]
517- x_dtype = x_at .type .numpy_dtype
518- x_dtype = numba .np .numpy_support .from_dtype (x_dtype )
574+ ndim = node .inputs [0 ].type .ndim
575+ inp_dtype = node .inputs [0 ].type .numpy_dtype
519576 axis = op .axis
520577
521- if axis is not None :
522- axis = normalize_axis_index (axis , x_at .ndim )
578+ if ndim > 1 and axis is not None :
523579 reduce_max_py = create_multiaxis_reducer (
524580 maximum ,
525- - np .inf ,
526- (axis ,),
527- x_at . ndim ,
528- x_dtype ,
581+ identity = - np .inf ,
582+ axes = (axis ,),
583+ ndim = ndim ,
584+ out_dtype = inp_dtype ,
529585 keepdims = True ,
530586 )
531587 reduce_sum_py = create_multiaxis_reducer (
532- add_as , 0.0 , (axis ,), x_at .ndim , x_dtype , keepdims = True
588+ add_as ,
589+ identity = 0.0 ,
590+ axes = (axis ,),
591+ ndim = ndim ,
592+ out_dtype = inp_dtype ,
593+ keepdims = True ,
533594 )
534595
535596 jit_fn = numba_basic .numba_njit (boundscheck = False )
@@ -539,13 +600,14 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
539600 reduce_max = np .max
540601 reduce_sum = np .sum
541602
542- def log_softmax_py_fn (x ):
603+ @numba_basic .numba_njit (boundscheck = False )
604+ def log_softmax (x ):
543605 xdev = x - reduce_max (x )
544606 lsm = xdev - np .log (reduce_sum (np .exp (xdev )))
545607 return lsm
546608
547- log_softmax = numba_basic . numba_njit ( log_softmax_py_fn , boundscheck = False )
548- return log_softmax
609+ cache_version = 1
610+ return log_softmax , cache_version
549611
550612
551613@register_funcify_default_op_cache_key (Argmax )
0 commit comments