-
Notifications
You must be signed in to change notification settings - Fork 35
Description
🚀 The feature, motivation and pitch
When allocating the shared memory for DSLRegionOp, we need to assign a shared memory encoding attribute. It's currently from getPlainMemDesc in lib/Dialect/FlagTree/Transforms/ConvertArgToMemDesc.cpp, which just make a reverse range of the given rank.
We plan to reset the SwizzledSharedEncodingAttr getter back to triton::getSharedEncoding, but this may let DSLRegion get the wrong strides of shared memory and lead to the final failure. So we also need to update ExtractStridesOpConversion in lib/Conversion/FlagTreeToLLVM/ExtractOpToLLVM.cpp. It currently assumes the shared memory is row-major and starts from the left side to the right side. With the expected modifications, it should check the order of memdesc at first, reorder these axes and calculate the correct stride for every axis.
Alternatives
No response
Additional context
No response