Skip to content

Commit 2f33f11

Browse files
committed
[mlir][NVVM] Add ldmatrix op to NVVM dialect
Differential Revision: https://reviews.llvm.org/D121347
1 parent c7f25b6 commit 2f33f11

File tree

6 files changed

+148
-0
lines changed

6 files changed

+148
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,4 +634,48 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
634634
let hasVerifier = 1;
635635
}
636636

637+
def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
638+
Results<(outs AnyType:$res)>,
639+
Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr:$num, MMALayoutAttr:$layout)> {
640+
641+
let summary = "cooperative matrix load";
642+
643+
string llvmBuilder = [{
644+
auto operands = moduleTranslation.lookupValues(opInst.getOperands());
645+
auto intId = getLdMatrixIntrinsicId($layout, $num);
646+
$res = createIntrinsicCall(builder, intId, operands, {operands[0]->getType()});
647+
}];
648+
649+
string baseDescription = [{
650+
The `nvvm.ldmatrix` operation collectively loads one or more matrices across
651+
all threads in a warp from the location indicated by the address operand
652+
`ptr` from shared memory.
653+
654+
The attribute `num` indicates how many 8x8 16-bit matrices are to be loaded.
655+
656+
All the threads in the warp must execute the same ldmatrix operations.
657+
658+
Each row of 8 elements needs to be consecutive in memory. Each lane of the
659+
warp contains the start address of a row of 8 elements laid out as below:
660+
661+
```
662+
num | lane 0--7 | Threads 8--15 | Threads 16--31
663+
1 | addr0--addr7 | |
664+
2 | addr0--addr7 | addr8--addr15 |
665+
4 | addr0--addr7 | addr8--addr15 | addr16--addr31
666+
```
667+
668+
Example:
669+
```mlir
670+
%l1 = nvvm.ldmatrix %ptr {num = 1 : i32, layout = #nvvm.mma_layout<row>} :
671+
(!llvm.ptr<i32, 3>) -> i32
672+
%l2 = nvvm.ldmatrix %ptr {num = 4 : i32, layout = #nvvm.mma_layout<row>} :
673+
(!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
674+
```
675+
}];
676+
677+
let assemblyFormat = "$ptr attr-dict `:` functional-type($ptr, $res)";
678+
let hasVerifier = 1;
679+
}
680+
637681
#endif // NVVMIR_OPS

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,28 @@ LogicalResult NVVM::WMMAMmaOp::verify() {
219219
return success();
220220
}
221221

222+
LogicalResult NVVM::LdMatrixOp::verify() {
223+
unsigned addressSpace =
224+
ptr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
225+
if (addressSpace != 3)
226+
return emitOpError("expected source pointer in memory space 3");
227+
228+
if (num() != 1 && num() != 2 && num() != 4)
229+
return emitOpError("expected num attribute to be 1, 2 or 4");
230+
231+
Type i32 = IntegerType::get(getContext(), 32);
232+
if (num() == 1 && getType() != i32)
233+
return emitOpError("expected destination type is i32");
234+
if (num() == 2 || num() == 4) {
235+
Type dstType = LLVM::LLVMStructType::getLiteral(
236+
getContext(), SmallVector<Type>(num(), i32));
237+
if (getType() != dstType)
238+
return emitOpError("expected destination type is a structure of ")
239+
<< num() << " elements of type i32";
240+
}
241+
return success();
242+
}
243+
222244
//===----------------------------------------------------------------------===//
223245
// NVVMDialect initialization, type parsing, and registration.
224246
//===----------------------------------------------------------------------===//

mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,35 @@ static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType,
6464
llvm_unreachable("unknown shuffle kind");
6565
}
6666

67+
/// Return the intrinsic ID associated with ldmatrix for the given paramters.
68+
static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
69+
int32_t num) {
70+
if (layout == NVVM::MMALayout::col) {
71+
switch (num) {
72+
case 1:
73+
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
74+
case 2:
75+
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
76+
case 4:
77+
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
78+
default:
79+
llvm_unreachable("unsupported number of matrix");
80+
}
81+
82+
} else {
83+
switch (num) {
84+
case 1:
85+
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
86+
case 2:
87+
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
88+
case 4:
89+
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
90+
default:
91+
llvm_unreachable("unsupported number of matrix");
92+
}
93+
}
94+
}
95+
6796
namespace {
6897
/// Implementation of the dialect interface that converts operations belonging
6998
/// to the NVVM dialect to LLVM IR.

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,6 +1191,38 @@ llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector<
11911191

11921192
// -----
11931193

1194+
llvm.func @wmmald_matrix(%arg0: !llvm.ptr<i32>) {
1195+
// expected-error@+1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}}
1196+
%l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32>) -> i32
1197+
llvm.return
1198+
}
1199+
1200+
// -----
1201+
1202+
llvm.func @wmmald_matrix(%arg0: !llvm.ptr<i32, 3>) {
1203+
// expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4}}
1204+
%l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> i32
1205+
llvm.return
1206+
}
1207+
1208+
// -----
1209+
1210+
llvm.func @wmmald_matrix(%arg0: !llvm.ptr<i32, 3>) {
1211+
// expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is i32}}
1212+
%l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32)>
1213+
llvm.return
1214+
}
1215+
1216+
// -----
1217+
1218+
llvm.func @wmmald_matrix(%arg0: !llvm.ptr<i32, 3>) {
1219+
// expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}}
1220+
%l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32)>
1221+
llvm.return
1222+
}
1223+
1224+
// -----
1225+
11941226
llvm.func @caller() {
11951227
// expected-error @below {{expected function call to produce a value}}
11961228
llvm.call @callee() : () -> ()

mlir/test/Dialect/LLVMIR/nvvm.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,16 @@ llvm.func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
105105
llvm.return
106106
}
107107

108+
// CHECK-LABEL: llvm.func @ld_matrix
109+
llvm.func @ld_matrix(%arg0: !llvm.ptr<i32, 3>) {
110+
// CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 1 : i32} : (!llvm.ptr<i32, 3>) -> i32
111+
%l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> i32
112+
// CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 2 : i32} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32)>
113+
%l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32)>
114+
// CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
115+
%l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
116+
llvm.return
117+
}
108118
// -----
109119

110120
// expected-error@below {{attribute attached to unexpected op}}

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,17 @@ llvm.func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
176176
llvm.return
177177
}
178178

179+
// CHECK-LABEL: @ld_matrix(
180+
llvm.func @ld_matrix(%arg0: !llvm.ptr<i32, 3>) {
181+
// CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3i32(i32 addrspace(3)* %{{.*}})
182+
%l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> i32
183+
// CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.trans.b16.p3i32(i32 addrspace(3)* %{{.*}})
184+
%l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32)>
185+
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.trans.b16.p3i32(i32 addrspace(3)* %{{.*}})
186+
%l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
187+
llvm.return
188+
}
189+
179190
// This function has the "kernel" attribute attached and should appear in the
180191
// NVVM annotations after conversion.
181192
llvm.func @kernel_func() attributes {nvvm.kernel} {

0 commit comments

Comments
 (0)