@@ -1946,7 +1946,11 @@ def upcast_to_f8e4m3fn(reg: ir.Value, part: int):
19461946 if cur_dtype == i4 and self .is_signed and new_dtype == bf16 and vector_len % 2 == 0 :
19471947 new_registers = np .empty_like (self .registers )
19481948 out_vec_ty = ir .VectorType .get ((vector_len ,), new_dtype )
1949- for idx , reg in np .ndenumerate (self .registers ):
1949+ # We use packed_registers for consistency, even though the packing is not
1950+ # really profitable here: the PTX below begins by an op dependent on the
1951+ # extracted part and so there are no ops that can be shared across packed
1952+ # parts.
1953+ for indices , reg in packed_registers (2 , if_not_sliced = True ):
19501954 # The algorithm here is largely the same as CUTLASS's
19511955 # NumericArrayConverter specialization for int4 -> bf16 casts.
19521956 # We modify it slightly, because we only extract 2 values.
@@ -1962,7 +1966,7 @@ def upcast_to_f8e4m3fn(reg: ir.Value, part: int):
19621966 # bias coming from flipping the sign bit which is 136 (0x4308 as bits).
19631967 def upcast_i4_to_bf16 (reg : ir .Value , reg_shr : ir .Value , part : int ):
19641968 assert 0 <= part < 4
1965- return llvm .inline_asm (
1969+ int_reg = llvm .inline_asm (
19661970 i32 ,
19671971 [reg , reg_shr ],
19681972 f"""
@@ -1976,43 +1980,43 @@ def upcast_i4_to_bf16(reg: ir.Value, reg_shr: ir.Value, part: int):
19761980 """ ,
19771981 "=r,r,r" ,
19781982 )
1979- offset = 0
1983+ return utils .bitcast (int_reg , ir .VectorType .get ((2 ,), bf16 ))
1984+ [group_size ] = ir .VectorType (reg .type ).shape
1985+ assert group_size % vector_len == 0
1986+ assert group_size * 4 <= 32
1987+ int_ty = ir .IntegerType .get_signless (group_size * 4 )
1988+ # If the vector originates from a slice (common after relayouts), we
1989+ # can fuse the slicing into the conversion and prevent LLVM from
1990+ # generating a bunch of shifts to align the vector data to the LSB.
1991+ # This also lets us share the right shift among more vectors.
19801992 out_int_regs : list [ir .Value ] = []
1981- # TODO(apaszke): Use packed_registers here.
1982- for group_size in (8 , 4 , 2 ):
1983- int_ty = ir .IntegerType .get_signless (group_size * 4 )
1984- while vector_len - offset >= group_size :
1985- # If the vector originates from a slice (common after relayouts), we
1986- # can fuse the slicing into the conversion and prevent LLVM from
1987- # generating a bunch of shifts to align the vector data to the LSB.
1988- # This also lets us share the right shift among more vectors.
1989- if (isinstance (slice_op := reg .owner .opview , vector .ExtractStridedSliceOp )
1990- and utils .bitwidth (slice_op .source .type ) == 32
1991- and slice_op .strides [0 ].value == 1 ):
1992- slice_offset = slice_op .offsets [0 ].value + offset
1993- reg_int = utils .bitcast (slice_op .source , i32 )
1994- reg_int_shr = arith .shrui (reg_int , c (4 , i32 ))
1995- out_int_regs .extend (
1996- upcast_i4_to_bf16 (reg_int , reg_int_shr , part = (slice_offset // 2 + part ))
1997- for part in range (group_size // 2 )
1998- )
1999- else :
2000- reg_slice = utils .vector_slice (reg , slice (offset , offset + group_size ))
2001- reg_slice_int = utils .bitcast (reg_slice , int_ty )
2002- if int_ty != i32 :
2003- reg_slice_int = arith .extsi (i32 , reg_slice_int )
2004- reg_slice_int_shr = arith .shrui (reg_slice_int , c (4 , i32 ))
2005- out_int_regs .extend (
2006- upcast_i4_to_bf16 (reg_slice_int , reg_slice_int_shr , part = part )
2007- for part in range (group_size // 2 )
2008- )
2009- offset += group_size
2010- assert offset == vector_len
2011- out_vec_int = utils .vector_concat ([
2012- vector .broadcast (ir .VectorType .get ((1 ,), i32 ), reg )
2013- for reg in out_int_regs
2014- ])
2015- new_registers [idx ] = utils .bitcast (out_vec_int , out_vec_ty )
1993+ if regs_from_32bit_slice :
1994+ slice_op = reg .owner .opview
1995+ slice_offset = slice_op .offsets [0 ].value
1996+ reg_int = utils .bitcast (slice_op .source , i32 )
1997+ reg_int_shr = arith .shrui (reg_int , c (4 , i32 ))
1998+ assert slice_offset % 2 == 0
1999+ out_int_regs .extend (
2000+ upcast_i4_to_bf16 (reg_int , reg_int_shr , part = slice_offset // 2 + part )
2001+ for part in range (group_size // 2 )
2002+ )
2003+ else :
2004+ reg_slice_int = utils .bitcast (reg , int_ty )
2005+ if int_ty != i32 :
2006+ reg_slice_int = arith .extsi (i32 , reg_slice_int )
2007+ reg_slice_int_shr = arith .shrui (reg_slice_int , c (4 , i32 ))
2008+ out_int_regs .extend (
2009+ upcast_i4_to_bf16 (reg_slice_int , reg_slice_int_shr , part = part )
2010+ for part in range (group_size // 2 )
2011+ )
2012+ out_reg = utils .vector_concat (out_int_regs )
2013+ offset = 0
2014+ for idx in indices :
2015+ new_registers [idx ] = new_reg = utils .vector_slice (
2016+ out_reg , slice (offset , offset + vector_len )
2017+ )
2018+ offset += vector_len
2019+ assert new_reg .type == out_vec_ty
20162020 return FragmentedArray (
20172021 _registers = new_registers , _layout = self .layout , _is_signed = None
20182022 )
@@ -2058,6 +2062,7 @@ def upcast_i4_to_i8(reg: ir.Value, first_valid_nibble: int = 0):
20582062 ])
20592063 [group_size ] = ir .VectorType (reg .type ).shape
20602064 assert group_size % vector_len == 0
2065+ assert group_size * 4 <= 32
20612066 int_ty = ir .IntegerType .get_signless (group_size * 4 )
20622067 if regs_from_32bit_slice :
20632068 slice_op = reg .owner .opview
0 commit comments