@@ -60,14 +60,18 @@ def element_type_to_backend_config_type(dtype):
6060 return _element_type_to_backend_config_type_mapping [dtype ]
6161
6262
63- def _scaled_matmul_impl (a , b , a_scale , b_scale , preferred_element_type ):
63+ def _scaled_matmul_impl (a , b , a_scale , b_scale , global_scale ,
64+ preferred_element_type , has_global_scale ):
6465 return _scaled_matmul_p .bind (
65- a , b , a_scale , b_scale , preferred_element_type = preferred_element_type
66+ a , b , a_scale , b_scale , global_scale ,
67+ preferred_element_type = preferred_element_type ,
68+ has_global_scale = has_global_scale
6669 )
6770
6871
6972def _scaled_matmul_cuda_lowering (
70- ctx , a , b , a_scales , b_scales , preferred_element_type
73+ ctx , a , b , a_scales , b_scales , global_scale , preferred_element_type ,
74+ has_global_scale
7175 ):
7276 lhs_type = ir .RankedTensorType (a .type )
7377 lhs_shape = lhs_type .shape
@@ -82,6 +86,8 @@ def _scaled_matmul_cuda_lowering(
8286 result_types = [ir .RankedTensorType .get (result_shape , out_type )]
8387
8488 operands = [a , b , a_scales , b_scales ]
89+ if has_global_scale :
90+ operands .append (global_scale )
8591 backend_config = {
8692 "scaled_dot_backend_config" : {
8793 "lhs_batch_dimensions" : [0 ],
@@ -104,7 +110,8 @@ def _scaled_matmul_cuda_lowering(
104110 return [out .result ]
105111
106112
107- def _scaled_matmul_abstract (a , b , a_scale , b_scale , * , preferred_element_type ):
113+ def _scaled_matmul_abstract (a , b , a_scale , b_scale , global_scale ,
114+ * , preferred_element_type , has_global_scale ):
108115 batch , non_contracting_lhs , contracting_lhs = a .shape
109116 _ , non_contracting_rhs , _ = b .shape
110117 output_shape = (batch , non_contracting_lhs , non_contracting_rhs )
@@ -296,7 +303,7 @@ def _scaled_matmul_impl_partition(a, b, a_scale, b_scale):
296303
297304
298305_scaled_matmul_lower = custom_partitioning (
299- _scaled_matmul_impl , static_argnums = (4 , )
306+ _scaled_matmul_impl , static_argnums = (5 , 6 )
300307)
301308
302309_scaled_matmul_lower .def_partition (
@@ -306,16 +313,13 @@ def _scaled_matmul_impl_partition(a, b, a_scale, b_scale):
306313)
307314
308315
309- def _scaled_matmul_batcher (batched_args , batch_dims , * , preferred_element_type ):
310- assert len (batch_dims ) == 4
311- assert (
312- batch_dims [0 ] == batch_dims [1 ]
313- and batch_dims [0 ] == batch_dims [2 ]
314- and batch_dims [0 ] == batch_dims [3 ]
315- )
316+ def _scaled_matmul_batcher (batched_args , batch_dims , * , preferred_element_type ,
317+ has_global_scale ):
318+ assert len (batch_dims ) == 5
319+ assert len (set (batch_dims [:4 ])) == 1 and batch_dims [4 ] is None
316320 lhs_bdims = batch_dims [0 ]
317321 out_bdims = (batch_dims [0 ],)
318- lhs , rhs , lhs_scales , rhs_scales = batched_args
322+ lhs , rhs , lhs_scales , rhs_scales , global_scale = batched_args
319323 * batch , lhs_non_contracting , contracting = lhs .shape
320324 * _ , _ , scales_contracting = lhs_scales .shape
321325 * _ , rhs_non_contracting , _ = rhs .shape
@@ -336,7 +340,9 @@ def _scaled_matmul_batcher(batched_args, batch_dims, *, preferred_element_type):
336340 rhs ,
337341 lhs_scales ,
338342 rhs_scales ,
343+ global_scale ,
339344 preferred_element_type = preferred_element_type ,
345+ has_global_scale = has_global_scale ,
340346 )[0 ],
341347 (* batch , lhs_non_contracting , rhs_non_contracting ),
342348 )
@@ -355,17 +361,20 @@ def _scaled_matmul_batcher(batched_args, batch_dims, *, preferred_element_type):
355361batching .primitive_batchers [_scaled_matmul_p ] = _scaled_matmul_batcher
356362
357363
358- @api .jit (static_argnames = ("preferred_element_type" ,))
364+ @api .jit (static_argnames = ("preferred_element_type" , "has_global_scale" ))
359365def _scaled_matmul (
360366 lhs : Array ,
361367 rhs : Array ,
362368 lhs_scales : Array ,
363369 rhs_scales : Array ,
370+ global_scale : Array ,
364371 preferred_element_type : DTypeLike = np .dtype ('float32' ),
372+ has_global_scale : bool = False ,
365373 ) -> Array :
366374 output = _scaled_matmul_p_wrapper .bind (
367- lhs , rhs , lhs_scales , rhs_scales ,
368- preferred_element_type = preferred_element_type
375+ lhs , rhs , lhs_scales , rhs_scales , global_scale ,
376+ preferred_element_type = preferred_element_type ,
377+ has_global_scale = has_global_scale
369378 )
370379 return output [0 ]
371380
@@ -374,7 +383,9 @@ def scaled_matmul_wrapper(
374383 rhs : Array ,
375384 lhs_scales : Array ,
376385 rhs_scales : Array ,
386+ global_scale : Array ,
377387 preferred_element_type : DTypeLike = np .dtype ('float32' ),
388+ has_global_scale : bool = False ,
378389) -> Array :
379390 """
380391 Performs scaled matrix multiplication between two 3D arrays, with scaling
@@ -385,8 +396,10 @@ def scaled_matmul_wrapper(
385396 rhs (Array): A 3D array of shape (B, N, K).
386397 lhs_scales (Array): A 3D array of shape (B, M, K_block).
387398 rhs_scales (Array): A 3D array of shape (B, N, K_block).
399+ global_scale (Array): A 0D array (scalar).
388400 preferred_element_type (DTypeLike, optional): The preferred data type
389401 for the computation. Defaults to `jnp.float32`.
402+ has_global_scale (bool, optional): Whether to use a global scale.
390403
391404 Returns:
392405 Array: A 3D array of shape (B, M, N) representing the scaled matrix
@@ -416,7 +429,9 @@ def scaled_matmul_wrapper(
416429 rhs ,
417430 lhs_scales ,
418431 rhs_scales ,
432+ global_scale ,
419433 preferred_element_type = preferred_element_type ,
434+ has_global_scale = has_global_scale ,
420435 )
421436
422437 return out
@@ -577,18 +592,17 @@ def scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type,
577592 lhs_q , lhs_scales = quantize (lhs_3d , lhs_config )
578593 rhs_q , rhs_scales = quantize (rhs_3d , rhs_config )
579594
580- out_dtype = preferred_element_type
581- if configs [0 ].mode == 'nvfp4' :
582- out_dtype = np .float32
595+ has_global_scale = configs [0 ].mode == 'nvfp4'
596+ global_scale = jnp .array (
597+ configs [0 ].global_scale * configs [1 ].global_scale
598+ if has_global_scale else 0 , dtype = preferred_element_type )
583599
584600 out = scaled_matmul_wrapper (
585- lhs_q , rhs_q , lhs_scales , rhs_scales , preferred_element_type = out_dtype
601+ lhs_q , rhs_q , lhs_scales , rhs_scales , global_scale ,
602+ preferred_element_type = preferred_element_type ,
603+ has_global_scale = has_global_scale ,
586604 )
587605
588- if configs [0 ].mode == 'nvfp4' :
589- out *= (configs [0 ].global_scale * configs [1 ].global_scale )
590- out = out .astype (preferred_element_type )
591-
592606 expanded_out_shape = compute_dot_output_shape (
593607 lhs .shape , rhs .shape , lhs_dn , rhs_dn
594608 )
@@ -625,7 +639,8 @@ def scaled_dot_general_transpose_lhs(
625639 y_q , y_scales = quantize (y_3d , y_config )
626640
627641 out = scaled_matmul_wrapper (
628- g_q , y_q , g_scales , y_scales , preferred_element_type
642+ g_q , y_q , g_scales , y_scales , jnp .array (0 ),
643+ preferred_element_type , has_global_scale = False
629644 )
630645 else :
631646 out = jnp .matmul (g_3d , jnp .permute_dims (y_3d , (0 , 2 , 1 )), preferred_element_type = preferred_element_type )
0 commit comments