@@ -50,104 +50,125 @@ namespace mlir::tpu {
5050
5151namespace {
5252
53- void optimizeLoadReshape (int hardware_generation,
54- std::array<int64_t , 2 > target_shape,
55- Operation& raw_op) {
56- // Below, we try to look for reshapes that flatten multiple dims into the
57- // lane dimension. If the source of the reshape originates from a load of a
58- // ref with 128 minor dimension (effectively untiled), we can replace the
59- // load/reshape sequence with an efficient strided load. In essence, the
60- // strided load creates vregs with a narrow slice along the target minor
61- // dimension, but with the 2nd minor dim after the reshape already in
62- // sublanes. The results of strided load can be concatenated to form the
63- // final vector result.
64- //
65- // A little extra care needs to be applied to packed types, which we handle by
66- // briefly extending to 32-bit and repacking them after concatenation.
67- TypedValue<VectorType> src;
68- VectorType tgt_ty;
69- if (auto op = dyn_cast<tpu::ReshapeOp>(&raw_op)) {
70- src = op.getSource ();
71- tgt_ty = op.getResult ().getType ();
72- } else if (auto op = dyn_cast<vector::ShapeCastOp>(&raw_op)) {
73- src = op.getSource ();
74- tgt_ty = op.getResult ().getType ();
75- } else {
76- return ;
77- }
78- VectorType src_ty = src.getType ();
79- if (src_ty.getRank () < 2 || tgt_ty.getRank () < 1 ) {
80- return ;
53+ std::optional<int64_t > canOptimizeReshapeMemory (
54+ int hardware_generation, std::array<int64_t , 2 > target_shape,
55+ TypedValue<MemRefType> ref, VectorType expanded_ty,
56+ VectorType collapsed_ty) {
57+ if (expanded_ty.getRank () < 2 || collapsed_ty.getRank () < 1 ) {
58+ return std::nullopt ;
8159 }
82- const int bitwidth = src_ty .getElementTypeBitWidth ();
60+ const int bitwidth = expanded_ty .getElementTypeBitWidth ();
8361 const int packing = 32 / bitwidth;
8462 if (hardware_generation < 4 && packing > 1 ) {
85- return ;
86- }
87-
88- auto load_op = dyn_cast_if_present<vector::LoadOp>(src.getDefiningOp ());
89- // This rewrite might not be profitable if the load has other users.
90- if (!load_op || !load_op.getBase ().hasOneUse ()) {
91- return ;
63+ return std::nullopt ;
9264 }
9365
94- TypedValue<MemRefType> ref = load_op.getBase ();
95- MemRefType ref_ty = getMemRefType (ref);
9666 // The reshape below might be invalid if the memref is not contiguous, but it
9767 // is an overly conservative check (we don't need all dims to be contiguous).
9868 if (!isContiguousMemref (ref)) {
99- return ;
69+ return std:: nullopt ;
10070 }
10171
10272 const int64_t lane = target_shape[1 ];
103- auto src_shape = src_ty .getShape ();
104- auto tgt_shape = tgt_ty .getShape ();
73+ int64_t collapsed_minor = collapsed_ty .getShape (). back ();
74+ int64_t expanded_minor = expanded_ty .getShape (). back ();
10575 // Only handle the cases where the minor dim starts out as the number of lanes
10676 // and we fold at least the second minor dim into it, in a way that changes
10777 // its shape.
108- if (src_shape. back () != lane ||
109- tgt_shape. back () % (packing * lane) != 0 ||
110- tgt_shape. back () == src_shape. back () ||
111- tgt_shape. back () < llvm::product_of (src_shape .take_back (2 ))) {
112- return ;
78+ if (expanded_minor != lane ||
79+ collapsed_minor % (packing * lane) != 0 ||
80+ collapsed_minor == expanded_minor ||
81+ collapsed_minor < llvm::product_of (expanded_ty. getShape () .take_back (2 ))) {
82+ return std:: nullopt ;
11383 }
11484
11585 // We don't handle memrefs with padding.
86+ MemRefType ref_ty = getMemRefType (ref);
11687 auto tiled_layout = dyn_cast<tpu::TiledLayoutAttr>(ref_ty.getLayout ());
11788 if (!tiled_layout || tiled_layout.getTiles ().empty ()) {
118- return ;
89+ return std:: nullopt ;
11990 }
12091 ArrayRef<int64_t > front_tile = tiled_layout.getTiles ().front ().dimensions ();
12192 ArrayRef<int64_t > ref_tiled_shape =
12293 ref_ty.getShape ().take_back (front_tile.size ());
12394 for (int i = 0 ; i < front_tile.size (); ++i) {
12495 if (ref_tiled_shape[i] % front_tile[i]) {
125- return ;
96+ return std:: nullopt ;
12697 }
12798 }
12899
129100 // NOTE: We could generalize this to allow only flattening part of a dimension
130101 int folded_dims = 0 ;
131102 {
132103 int suffix_size = 1 ;
133- auto sizes_it = src_shape .rbegin ();
134- while (suffix_size < tgt_shape. back () ) {
104+ auto sizes_it = expanded_ty. getShape () .rbegin ();
105+ while (suffix_size < collapsed_minor ) {
135106 suffix_size *= *(sizes_it++);
136107 }
137108 // Make sure that the minor dim is folded only from entire major dims, not
138109 // from a part of some minor dim.
139- if (suffix_size != tgt_shape. back () ) {
140- return ;
110+ if (suffix_size != collapsed_minor ) {
111+ return std:: nullopt ;
141112 }
142- folded_dims = sizes_it - src_shape .rbegin ();
113+ folded_dims = sizes_it - expanded_ty. getShape () .rbegin ();
143114 }
144115 DCHECK_GE (folded_dims, 2 ); // Should fold at least 2nd minor into minor.
145116
146117 // We don't handle slicing in the folded dims at the moment.
147118 if (ref_ty.getShape ().take_back (folded_dims) !=
148- src_ty.getShape ().take_back (folded_dims)) {
119+ expanded_ty.getShape ().take_back (folded_dims)) {
120+ return std::nullopt ;
121+ }
122+
123+ return folded_dims;
124+ }
125+
126+ void optimizeLoadReshape (int hardware_generation,
127+ std::array<int64_t , 2 > target_shape,
128+ Operation& raw_op) {
129+ // Below, we try to look for reshapes that flatten multiple dims into the
130+ // lane dimension. If the source of the reshape originates from a load of a
131+ // ref with 128 minor dimension (effectively untiled), we can replace the
132+ // load/reshape sequence with an efficient strided load. In essence, the
133+ // strided load creates vregs with a narrow slice along the target minor
134+ // dimension, but with the 2nd minor dim after the reshape already in
135+ // sublanes. The results of strided load can be concatenated to form the
136+ // final vector result.
137+ //
138+ // A little extra care needs to be applied to packed types, which we handle by
139+ // briefly extending to 32-bit and repacking them after concatenation.
140+ TypedValue<VectorType> src;
141+ VectorType tgt_ty;
142+ if (auto op = dyn_cast<tpu::ReshapeOp>(&raw_op)) {
143+ src = op.getSource ();
144+ tgt_ty = op.getResult ().getType ();
145+ } else if (auto op = dyn_cast<vector::ShapeCastOp>(&raw_op)) {
146+ src = op.getSource ();
147+ tgt_ty = op.getResult ().getType ();
148+ } else {
149149 return ;
150150 }
151+ VectorType src_ty = src.getType ();
152+ ArrayRef<int64_t > src_shape = src_ty.getShape ();
153+ ArrayRef<int64_t > tgt_shape = tgt_ty.getShape ();
154+ const int lane = target_shape[1 ];
155+ const int bitwidth = src_ty.getElementTypeBitWidth ();
156+ const int packing = 32 / bitwidth;
157+
158+ auto load_op = dyn_cast_if_present<vector::LoadOp>(src.getDefiningOp ());
159+ // This rewrite might not be profitable if the load has other users.
160+ if (!load_op || !load_op.getBase ().hasOneUse ()) {
161+ return ;
162+ }
163+ TypedValue<MemRefType> ref = load_op.getBase ();
164+ MemRefType ref_ty = getMemRefType (ref);
165+
166+ auto maybe_folded_dims = canOptimizeReshapeMemory (
167+ hardware_generation, target_shape, ref, src_ty, tgt_ty);
168+ if (!maybe_folded_dims.has_value ()) {
169+ return ;
170+ }
171+ int folded_dims = *maybe_folded_dims;
151172
152173 Location loc = raw_op.getLoc ();
153174 ImplicitLocOpBuilder b (loc, &raw_op);
@@ -277,68 +298,18 @@ void optimizeStore(int hardware_generation, std::array<int64_t, 2> target_shape,
277298 MemRefType ref_ty = getMemRefType (base);
278299 VectorType src_ty = shape_cast_op.getSource ().getType ();
279300 VectorType tgt_ty = shape_cast_op.getResult ().getType ();
280- if (src_ty.getRank () < 1 || tgt_ty.getRank () < 2 ) {
281- return ;
282- }
283301 auto src_shape = src_ty.getShape ();
284302 auto tgt_shape = tgt_ty.getShape ();
285-
303+ const int64_t lane = target_shape[ 1 ];
286304 const int bitwidth = src_ty.getElementTypeBitWidth ();
287305 const int packing = 32 / bitwidth;
288- if (hardware_generation < 4 && packing > 1 ) {
289- return ;
290- }
291-
292- // The reshape below might be invalid if the memref is not contiguous, but it
293- // is an overly conservative check (we don't need all dims to be contiguous).
294- if (!isContiguousMemref (base)) {
295- return ;
296- }
297- const int64_t lane = target_shape[1 ];
298- // Only handle the cases where the minor dim starts out as the number of lanes
299- // and we fold at least the second minor dim into it, in a way that changes
300- // its shape.
301- if (tgt_shape.back () != lane ||
302- src_shape.back () % (packing * lane) != 0 ||
303- src_shape.back () == tgt_shape.back () ||
304- src_shape.back () < llvm::product_of (tgt_shape.take_back (2 ))) {
305- return ;
306- }
307- // We don't handle memrefs with padding.
308- auto tiled_layout = dyn_cast<tpu::TiledLayoutAttr>(ref_ty.getLayout ());
309- if (!tiled_layout || tiled_layout.getTiles ().empty ()) {
310- return ;
311- }
312- ArrayRef<int64_t > front_tile = tiled_layout.getTiles ().front ().dimensions ();
313- ArrayRef<int64_t > ref_tiled_shape =
314- ref_ty.getShape ().take_back (front_tile.size ());
315- for (int i = 0 ; i < front_tile.size (); ++i) {
316- if (ref_tiled_shape[i] % front_tile[i]) {
317- return ;
318- }
319- }
320-
321- int expanded_dims = 0 ;
322- {
323- int suffix_size = 1 ;
324- auto sizes_it = tgt_shape.rbegin ();
325- while (suffix_size < src_shape.back ()) {
326- suffix_size *= *(sizes_it++);
327- }
328- // Make sure the minor dim is expanded into its own dims and not folded into
329- // other major dims.
330- if (suffix_size != src_shape.back ()) {
331- return ;
332- }
333- expanded_dims = sizes_it - tgt_shape.rbegin ();
334- }
335- DCHECK_GE (expanded_dims, 2 ); // Minor should expand at least into 2 dims.
336306
337- // We don't support slicing in the expanded dims at the moment.
338- if ( tgt_ty. getShape (). take_back (expanded_dims) !=
339- ref_ty. getShape (). take_back (expanded_dims )) {
307+ std::optional< int > maybe_expanded_dims = canOptimizeReshapeMemory (
308+ hardware_generation, target_shape, base, tgt_ty, src_ty);
309+ if (!maybe_expanded_dims. has_value ( )) {
340310 return ;
341311 }
312+ int expanded_dims = *maybe_expanded_dims;
342313
343314 ImplicitLocOpBuilder b (raw_op.getLoc (), &raw_op);
344315 auto loc = raw_op.getLoc ();
0 commit comments