Skip to content

Commit 0c4524a

Browse files
committed
Add support for global scale to scaled_matmul_wrapper
1 parent 75e95e4 commit 0c4524a

File tree

3 files changed

+52
-31
lines changed

3 files changed

+52
-31
lines changed

jax/_src/cudnn/scaled_matmul_stablehlo.py

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,18 @@ def element_type_to_backend_config_type(dtype):
6060
return _element_type_to_backend_config_type_mapping[dtype]
6161

6262

63-
def _scaled_matmul_impl(a, b, a_scale, b_scale, preferred_element_type):
63+
def _scaled_matmul_impl(a, b, a_scale, b_scale, global_scale,
64+
preferred_element_type, has_global_scale):
6465
return _scaled_matmul_p.bind(
65-
a, b, a_scale, b_scale, preferred_element_type=preferred_element_type
66+
a, b, a_scale, b_scale, global_scale,
67+
preferred_element_type=preferred_element_type,
68+
has_global_scale=has_global_scale
6669
)
6770

6871

6972
def _scaled_matmul_cuda_lowering(
70-
ctx, a, b, a_scales, b_scales, preferred_element_type
73+
ctx, a, b, a_scales, b_scales, global_scale, preferred_element_type,
74+
has_global_scale
7175
):
7276
lhs_type = ir.RankedTensorType(a.type)
7377
lhs_shape = lhs_type.shape
@@ -82,6 +86,8 @@ def _scaled_matmul_cuda_lowering(
8286
result_types = [ir.RankedTensorType.get(result_shape, out_type)]
8387

8488
operands = [a, b, a_scales, b_scales]
89+
if has_global_scale:
90+
operands.append(global_scale)
8591
backend_config = {
8692
"scaled_dot_backend_config": {
8793
"lhs_batch_dimensions": [0],
@@ -104,7 +110,8 @@ def _scaled_matmul_cuda_lowering(
104110
return [out.result]
105111

106112

107-
def _scaled_matmul_abstract(a, b, a_scale, b_scale, *, preferred_element_type):
113+
def _scaled_matmul_abstract(a, b, a_scale, b_scale, global_scale,
114+
*, preferred_element_type, has_global_scale):
108115
batch, non_contracting_lhs, contracting_lhs = a.shape
109116
_, non_contracting_rhs, _ = b.shape
110117
output_shape = (batch, non_contracting_lhs, non_contracting_rhs)
@@ -296,7 +303,7 @@ def _scaled_matmul_impl_partition(a, b, a_scale, b_scale):
296303

297304

298305
_scaled_matmul_lower = custom_partitioning(
299-
_scaled_matmul_impl, static_argnums=(4,)
306+
_scaled_matmul_impl, static_argnums=(5, 6)
300307
)
301308

302309
_scaled_matmul_lower.def_partition(
@@ -306,16 +313,13 @@ def _scaled_matmul_impl_partition(a, b, a_scale, b_scale):
306313
)
307314

308315

309-
def _scaled_matmul_batcher(batched_args, batch_dims, *, preferred_element_type):
310-
assert len(batch_dims) == 4
311-
assert (
312-
batch_dims[0] == batch_dims[1]
313-
and batch_dims[0] == batch_dims[2]
314-
and batch_dims[0] == batch_dims[3]
315-
)
316+
def _scaled_matmul_batcher(batched_args, batch_dims, *, preferred_element_type,
317+
has_global_scale):
318+
assert len(batch_dims) == 5
319+
assert len(set(batch_dims[:4])) == 1 and batch_dims[4] is None
316320
lhs_bdims = batch_dims[0]
317321
out_bdims = (batch_dims[0],)
318-
lhs, rhs, lhs_scales, rhs_scales = batched_args
322+
lhs, rhs, lhs_scales, rhs_scales, global_scale = batched_args
319323
*batch, lhs_non_contracting, contracting = lhs.shape
320324
*_, _, scales_contracting = lhs_scales.shape
321325
*_, rhs_non_contracting, _ = rhs.shape
@@ -336,7 +340,9 @@ def _scaled_matmul_batcher(batched_args, batch_dims, *, preferred_element_type):
336340
rhs,
337341
lhs_scales,
338342
rhs_scales,
343+
global_scale,
339344
preferred_element_type=preferred_element_type,
345+
has_global_scale=has_global_scale,
340346
)[0],
341347
(*batch, lhs_non_contracting, rhs_non_contracting),
342348
)
@@ -355,17 +361,20 @@ def _scaled_matmul_batcher(batched_args, batch_dims, *, preferred_element_type):
355361
batching.primitive_batchers[_scaled_matmul_p] = _scaled_matmul_batcher
356362

357363

358-
@api.jit(static_argnames=("preferred_element_type",))
364+
@api.jit(static_argnames=("preferred_element_type", "has_global_scale"))
359365
def _scaled_matmul(
360366
lhs: Array,
361367
rhs: Array,
362368
lhs_scales: Array,
363369
rhs_scales: Array,
370+
global_scale: Array,
364371
preferred_element_type: DTypeLike = np.dtype('float32'),
372+
has_global_scale: bool = False,
365373
) -> Array:
366374
output = _scaled_matmul_p_wrapper.bind(
367-
lhs, rhs, lhs_scales, rhs_scales,
368-
preferred_element_type=preferred_element_type
375+
lhs, rhs, lhs_scales, rhs_scales, global_scale,
376+
preferred_element_type=preferred_element_type,
377+
has_global_scale=has_global_scale
369378
)
370379
return output[0]
371380

@@ -374,7 +383,9 @@ def scaled_matmul_wrapper(
374383
rhs: Array,
375384
lhs_scales: Array,
376385
rhs_scales: Array,
386+
global_scale: Array,
377387
preferred_element_type: DTypeLike = np.dtype('float32'),
388+
has_global_scale: bool = False,
378389
) -> Array:
379390
"""
380391
Performs scaled matrix multiplication between two 3D arrays, with scaling
@@ -385,8 +396,10 @@ def scaled_matmul_wrapper(
385396
rhs (Array): A 3D array of shape (B, N, K).
386397
lhs_scales (Array): A 3D array of shape (B, M, K_block).
387398
rhs_scales (Array): A 3D array of shape (B, N, K_block).
399+
global_scale (Array): A 0D array (scalar).
388400
preferred_element_type (DTypeLike, optional): The preferred data type
389401
for the computation. Defaults to `jnp.float32`.
402+
has_global_scale (bool, optional): Whether to use a global scale.
390403
391404
Returns:
392405
Array: A 3D array of shape (B, M, N) representing the scaled matrix
@@ -416,7 +429,9 @@ def scaled_matmul_wrapper(
416429
rhs,
417430
lhs_scales,
418431
rhs_scales,
432+
global_scale,
419433
preferred_element_type=preferred_element_type,
434+
has_global_scale=has_global_scale,
420435
)
421436

422437
return out
@@ -577,18 +592,17 @@ def scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type,
577592
lhs_q, lhs_scales = quantize(lhs_3d, lhs_config)
578593
rhs_q, rhs_scales = quantize(rhs_3d, rhs_config)
579594

580-
out_dtype = preferred_element_type
581-
if configs[0].mode == 'nvfp4':
582-
out_dtype = np.float32
595+
has_global_scale = configs[0].mode == 'nvfp4'
596+
global_scale = jnp.array(
597+
configs[0].global_scale * configs[1].global_scale
598+
if has_global_scale else 0, dtype=preferred_element_type)
583599

584600
out = scaled_matmul_wrapper(
585-
lhs_q, rhs_q, lhs_scales, rhs_scales, preferred_element_type=out_dtype
601+
lhs_q, rhs_q, lhs_scales, rhs_scales, global_scale,
602+
preferred_element_type=preferred_element_type,
603+
has_global_scale=has_global_scale,
586604
)
587605

588-
if configs[0].mode == 'nvfp4':
589-
out *= (configs[0].global_scale * configs[1].global_scale)
590-
out = out.astype(preferred_element_type)
591-
592606
expanded_out_shape = compute_dot_output_shape(
593607
lhs.shape, rhs.shape, lhs_dn, rhs_dn
594608
)
@@ -625,7 +639,8 @@ def scaled_dot_general_transpose_lhs(
625639
y_q, y_scales = quantize(y_3d, y_config)
626640

627641
out = scaled_matmul_wrapper(
628-
g_q, y_q, g_scales, y_scales, preferred_element_type
642+
g_q, y_q, g_scales, y_scales, jnp.array(0),
643+
preferred_element_type, has_global_scale=False
629644
)
630645
else:
631646
out = jnp.matmul(g_3d, jnp.permute_dims(y_3d, (0, 2, 1)), preferred_element_type=preferred_element_type)

jax/_src/nn/functions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,6 +1249,7 @@ def scaled_matmul(
12491249
rhs: Array,
12501250
lhs_scales: Array,
12511251
rhs_scales: Array,
1252+
global_scale: Array | None = None,
12521253
preferred_element_type: DTypeLike = np.float32,
12531254
) -> Array:
12541255
r"""Scaled matrix multiplication function.
@@ -1269,6 +1270,7 @@ def scaled_matmul(
12691270
rhs (Array): Operand b, shape (B, N, K).
12701271
lhs_scales (Array): Shape (B, M, K_a), where `K % K_a == 0`.
12711272
rhs_scales (Array): Shape (B, N, K_b), where `K % K_b == 0`.
1273+
global_scale (Array, optional): Scalar scaling factor.
12721274
preferred_element_type (DTypeLike, optional): Defaults to `jnp.float32`.
12731275
12741276
Returns:
@@ -1348,7 +1350,9 @@ def scaled_matmul(
13481350
b,
13491351
a_scales,
13501352
b_scales,
1353+
global_scale or jnp.array(0, dtype=preferred_element_type),
13511354
preferred_element_type=preferred_element_type,
1355+
has_global_scale=global_scale is not None,
13521356
)
13531357
return out
13541358

tests/scaled_matmul_stablehlo_test.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def test_collectives(self, in_shardings, block_scale_configs):
303303
@jtu.sample_product(
304304
contract=[160, 96],
305305
lhs_non_contract=[240, 100],
306-
dtype=[jnp.float32, jnp.bfloat16, jnp.float16],
306+
dtype=[jnp.float32, jnp.bfloat16],
307307
)
308308
@jtu.run_on_devices("cuda")
309309
def test_scaled_matmul_nvfp4(
@@ -322,15 +322,15 @@ def test_scaled_matmul_nvfp4(
322322
b_gs = block_scale_configs[1].global_scale
323323

324324
def wrapper(lhs, rhs, lhs_scales, rhs_scales, out_type):
325-
out = scaled_matmul_wrapper(
325+
return scaled_matmul_wrapper(
326326
lhs,
327327
rhs,
328328
lhs_scales,
329329
rhs_scales,
330-
preferred_element_type=jnp.float32,
330+
jnp.array(a_gs * b_gs, dtype=out_type),
331+
preferred_element_type=out_type,
332+
has_global_scale=True,
331333
)
332-
gs = a_gs * b_gs
333-
return (out * gs).astype(out_type)
334334

335335
j_scaled_matmul = jax.jit(partial(wrapper, out_type=dtype))
336336
hlo_text = (
@@ -373,7 +373,9 @@ def wrapper(lhs, rhs, lhs_scales, rhs_scales, out_type):
373373
rhs,
374374
lhs_scales,
375375
rhs_scales,
376+
np.array(0),
376377
preferred_element_type=out_type,
378+
has_global_scale=False,
377379
)
378380

379381
j_scaled_matmul = jax.jit(partial(wrapper, out_type=dtype))
@@ -587,7 +589,7 @@ def fwd(a, b, use_normalized=False):
587589
True,
588590
),
589591
],
590-
output_type=[jnp.float32, jnp.float16, jnp.bfloat16],
592+
output_type=[jnp.float32, jnp.bfloat16],
591593
)
592594
@jtu.run_on_devices("cuda")
593595
def test_dot_general_nvfp4(self, configs, output_type):

0 commit comments

Comments
 (0)