@@ -4869,41 +4869,45 @@ def matmul(
48694869 rtol = 0 ,
48704870 )
48714871
4872- @parameterized .parameters (jnp .int8 , jnp .uint8 )
4873- def test_integer_wgmma (self , dtype ):
4872+ @parameterized .product (
4873+ dtype = (jnp .int8 , jnp .uint8 ),
4874+ lhs_in_smem = (False , True ),
4875+ )
4876+ def test_integer_wgmma (self , dtype , lhs_in_smem ):
48744877 m , k , n = 64 , 128 , 64
48754878
48764879 def body (ctx , lhs_gmem , rhs_gmem , result_gmem , scratch ):
48774880 del ctx
4878- lhs , rhs , tma_barrier = scratch
4881+ lhs_smem , rhs_smem , tma_barrier = scratch
48794882
48804883 i32 = ir .IntegerType .get_signless (32 )
48814884 zero = arith .constant (i32 , 0 )
48824885
48834886 tma_barrier .arrive_expect_tx (m * k + k * n )
48844887 mgpu_dialect .async_load (
48854888 source = lhs_gmem ,
4886- destination = lhs ,
4889+ destination = lhs_smem ,
48874890 barrier = tma_barrier .as_barrier_memref (),
48884891 indices = [zero , zero ],
4889- slice_lengths = lhs .type .shape ,
4892+ slice_lengths = lhs_smem .type .shape ,
48904893 collective = ir .ArrayAttr .get ([]),
48914894 )
48924895 mgpu_dialect .async_load (
48934896 source = rhs_gmem ,
4894- destination = rhs ,
4897+ destination = rhs_smem ,
48954898 barrier = tma_barrier .as_barrier_memref (),
48964899 indices = [zero , zero ],
4897- slice_lengths = rhs .type .shape ,
4900+ slice_lengths = rhs_smem .type .shape ,
48984901 collective = ir .ArrayAttr .get ([]),
48994902 )
49004903 tma_barrier .wait ()
49014904
49024905 acc_type = ir .VectorType .get ((m , n ), i32 )
49034906 acc = vector .broadcast (acc_type , zero )
4907+ lhs = lhs_smem if lhs_in_smem else mgpu_dialect .vector_load (lhs_smem )
49044908 # Only f16 WGMMA supports transposes
4905- rhs = utils .memref_transpose (rhs , (1 , 0 ))
4906- result = mgpu_dialect .wgmma (acc , lhs , rhs )
4909+ rhs_smem = utils .memref_transpose (rhs_smem , (1 , 0 ))
4910+ result = mgpu_dialect .wgmma (acc , lhs , rhs_smem )
49074911 nvvm .wgmma_commit_group_sync_aligned ()
49084912 nvvm .wgmma_wait_group_sync_aligned (0 )
49094913 mgpu_dialect .vector_store (result , result_gmem )
0 commit comments