@@ -131,6 +131,8 @@ std::vector<T> collectBranchMarkers(T root, T node) {
131
131
return findThreadSpecificMarkers (node);
132
132
}
133
133
134
+ struct FullSchedule ;
135
+
134
136
/*
135
137
* Transform schedule bands into a union_map.
136
138
* Takes all partial schedules at leaves as MUPAs (without accounting for
@@ -139,7 +141,8 @@ std::vector<T> collectBranchMarkers(T root, T node) {
139
141
* current leaves and transforms them into union maps.
140
142
* Mapping filters are ignored.
141
143
*/
142
- isl::union_map fullSchedule (const detail::ScheduleTree* root) {
144
+ isl::UnionMap<Domain, FullSchedule> fullSchedule (
145
+ const detail::ScheduleTree* root) {
143
146
using namespace tc ::polyhedral::detail;
144
147
145
148
if (!root->elemAs <ScheduleTreeElemDomain>()) {
@@ -182,7 +185,7 @@ isl::union_map fullSchedule(const detail::ScheduleTree* root) {
182
185
throw promotion::PromotionLogicError (ss.str ());
183
186
}
184
187
}
185
- return schedule;
188
+ return isl::UnionMap<Domain, FullSchedule>( schedule) ;
186
189
}
187
190
188
191
/*
@@ -263,7 +266,7 @@ bool promotionImprovesCoalescing(
263
266
const detail::ScheduleTree* root,
264
267
const detail::ScheduleTree* node,
265
268
const TensorReferenceGroup& group,
266
- isl::union_map schedule) {
269
+ isl::UnionMap<Domain, FullSchedule> schedule) {
267
270
auto originalAccesses = group.originalAccesses ();
268
271
269
272
auto markers = collectBranchMarkers (root, node);
@@ -313,6 +316,8 @@ isl::union_set collectMappingsTo(const Scop& scop) {
313
316
return mapping;
314
317
}
315
318
319
+ struct Unrolled ;
320
+
316
321
/*
317
322
* Check that only unrolled loops may appear in access subscripts.
318
323
* Because the scoping point can be above a branching tree, descend into each
@@ -343,11 +348,12 @@ isl::union_set collectMappingsTo(const Scop& scop) {
343
348
* different references may have different values, but all of them remain
344
349
* independent of non-unrolled loop iterators.
345
350
*/
351
+ template <typename Outer>
346
352
bool accessSubscriptsAreUnrolledLoops (
347
353
const TensorReferenceGroup& group,
348
354
const detail::ScheduleTree* root,
349
355
const detail::ScheduleTree* scope,
350
- isl::multi_union_pw_aff outerSchedule) {
356
+ isl::MultiUnionPwAff<Domain, Outer> outerSchedule) {
351
357
using namespace detail ;
352
358
353
359
auto nodes = ScheduleTree::collect (scope);
@@ -366,7 +372,7 @@ bool accessSubscriptsAreUnrolledLoops(
366
372
367
373
auto unrolledDims = isl::union_pw_aff_list (leaf->ctx_ , 1 );
368
374
for (auto node : ancestors) {
369
- auto band = node->elemAs <detail::ScheduleTreeElemBand>();
375
+ auto band = node->template elemAs <detail::ScheduleTreeElemBand>();
370
376
if (!band) {
371
377
continue ;
372
378
}
@@ -383,7 +389,8 @@ bool accessSubscriptsAreUnrolledLoops(
383
389
384
390
auto space = isl::space (leaf->ctx_ , 0 , unrolledDims.n ())
385
391
.align_params (subdomain.get_space ());
386
- auto unrolledDimsMupa = isl::multi_union_pw_aff (space, unrolledDims);
392
+ auto unrolledDimsMupa =
393
+ isl::MultiUnionPwAff<Domain, Unrolled>(space, unrolledDims);
387
394
388
395
// It is possible that no loops are unrolled, in which case
389
396
// unrolledDimsMupa is zero-dimensional and needs an explicit domain
@@ -392,10 +399,11 @@ bool accessSubscriptsAreUnrolledLoops(
392
399
unrolledDimsMupa.intersect_domain (group.originalAccesses ().domain ());
393
400
394
401
auto accesses = group.originalAccesses ();
395
- auto schedule = outerSchedule.flat_range_product (unrolledDimsMupa);
396
- accesses = accesses.apply_domain (isl::union_map::from (schedule));
402
+ auto schedule = outerSchedule.range_product (unrolledDimsMupa);
403
+ auto scheduleMap = schedule.asUnionMap ();
404
+ auto scheduledAccesses = accesses.apply_domain (scheduleMap);
397
405
398
- if (!accesses .is_single_valued ()) {
406
+ if (!scheduledAccesses .is_single_valued ()) {
399
407
return false ;
400
408
}
401
409
}
@@ -415,23 +423,25 @@ bool accessSubscriptsAreUnrolledLoops(
415
423
* thread associated to a given pair of tensor element and outer schedule
416
424
* iteration.
417
425
*/
426
+ template <typename Outer>
418
427
bool isPromotableToRegistersBelow (
419
428
const TensorReferenceGroup& group,
420
429
const detail::ScheduleTree* root,
421
430
const detail::ScheduleTree* scope,
422
- isl::multi_union_pw_aff outer,
423
- isl::multi_union_pw_aff thread) {
431
+ isl::MultiUnionPwAff<Domain, Outer> outer,
432
+ isl::MultiUnionPwAff<Domain, Thread> thread) {
424
433
if (!accessSubscriptsAreUnrolledLoops (
425
- group, root, scope, outer.flat_range_product (thread))) {
434
+ group, root, scope, outer.range_product (thread))) {
426
435
return false ;
427
436
}
428
437
429
438
auto originalAccesses = group.originalAccesses ();
430
- auto map = isl::union_map::from (outer);
431
- map = map.range_product (originalAccesses);
432
- map = map.apply_domain (isl::union_map::from (thread));
439
+ auto outerMap = isl::UnionMap<Domain, Outer>::from (outer);
440
+ auto pair = outerMap.range_product (originalAccesses);
441
+ auto threadMap = isl::UnionMap<Domain, Thread>::from (thread);
442
+ auto threadToPair = pair.apply_domain (threadMap);
433
443
434
- return map .is_injective ();
444
+ return threadToPair .is_injective ();
435
445
}
436
446
437
447
/*
@@ -654,15 +664,15 @@ void promoteToRegistersBelow(MappedScop& mscop, detail::ScheduleTree* scope) {
654
664
auto blockSchedule = mscop.blockMappingSchedule (mscop.schedule ());
655
665
656
666
// Pure affine schedule without (mapping) filters.
657
- auto partialSchedMupa = partialScheduleMupa (root, scope);
667
+ auto partialSchedMupa = partialScheduleMupa<Scope> (root, scope);
658
668
// Schedule with block mapping filter.
659
669
auto partialSched =
660
670
isl::union_map::from (partialSchedMupa).intersect_domain (blockMapping);
661
671
// The following promotion validity and profitability checks need to be
662
672
// performed with respect to the block mapping, so append the block schedule.
663
673
// If the partial schedule contains it already, it will just end up with
664
674
// identical dimensions without affecting the result of the checks.
665
- partialSchedMupa = partialSchedMupa.flat_range_product (blockSchedule);
675
+ auto partialSchedBlockMupa = partialSchedMupa.range_product (blockSchedule);
666
676
667
677
for (auto & tensorGroups : groupMap) {
668
678
auto tensorId = tensorGroups.first ;
@@ -676,11 +686,11 @@ void promoteToRegistersBelow(MappedScop& mscop, detail::ScheduleTree* scope) {
676
686
continue ;
677
687
}
678
688
if (!isPromotableToRegistersBelow (
679
- *group, root, scope, partialSchedMupa , threadSchedule)) {
689
+ *group, root, scope, partialSchedBlockMupa , threadSchedule)) {
680
690
continue ;
681
691
}
682
692
// Check reuse within threads.
683
- auto schedule = partialSchedMupa .flat_range_product (threadSchedule);
693
+ auto schedule = partialSchedBlockMupa .flat_range_product (threadSchedule);
684
694
if (!hasReuseWithin (*group, schedule)) {
685
695
continue ;
686
696
}
0 commit comments