Skip to content

Commit 788d8d3

Browse files
authored
[UR][L0] Fix allocation lookup in EnqueuedPool (#19638)
This PR updates the allocation management algorithm to address cases where it fails to find suitable allocations due to how it performs lower bound searches. Example: ``` Freelist { Allocation(align=64, size=128), Allocation(align=64, size=256), Allocation(align=4096, size=128), Allocation(align=4096, size=1024), } ``` If we request `align=64`, `size=512`, the current code looks at `Allocation(align=4096, size=128)` and skips the rest, even though `Allocation(align=4096, size=1024)` would work. This PR introduces grouping the allocations by queue and alignment.
1 parent 1af67b1 commit 788d8d3

File tree

3 files changed

+163
-51
lines changed

3 files changed

+163
-51
lines changed

unified-runtime/source/adapters/level_zero/enqueued_pool.cpp

Lines changed: 88 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -13,32 +13,71 @@
1313

1414
#include <ur_api.h>
1515

16-
EnqueuedPool::~EnqueuedPool() { cleanup(); }
16+
namespace {
1717

1818
std::optional<EnqueuedPool::Allocation>
19-
EnqueuedPool::getBestFit(size_t Size, size_t Alignment, void *Queue) {
20-
auto Lock = std::lock_guard(Mutex);
19+
getBestFitHelper(size_t Size, size_t Alignment, void *Queue,
20+
EnqueuedPool::AllocationGroupMap &Freelist) {
21+
// Iterate over the alignments for a given queue.
22+
auto GroupIt = Freelist.lower_bound({Queue, Alignment});
23+
for (; GroupIt != Freelist.end() && GroupIt->first.Queue == Queue;
24+
++GroupIt) {
25+
auto &AllocSet = GroupIt->second;
26+
// Find the first allocation that is large enough.
27+
auto AllocIt = AllocSet.lower_bound({nullptr, Size, nullptr, nullptr, 0});
28+
if (AllocIt != AllocSet.end()) {
29+
auto BestFit = *AllocIt;
30+
AllocSet.erase(AllocIt);
31+
if (AllocSet.empty()) {
32+
Freelist.erase(GroupIt);
33+
}
34+
return BestFit;
35+
}
36+
}
37+
return std::nullopt;
38+
}
2139

22-
Allocation Alloc = {nullptr, Size, nullptr, Queue, Alignment};
40+
void removeFromFreelist(const EnqueuedPool::Allocation &Alloc,
41+
EnqueuedPool::AllocationGroupMap &Freelist,
42+
bool IsGlobal) {
43+
const EnqueuedPool::AllocationGroupKey Key = {
44+
IsGlobal ? nullptr : Alloc.Queue, Alloc.Alignment};
2345

24-
auto It = Freelist.lower_bound(Alloc);
25-
if (It != Freelist.end() && It->Size >= Size && It->Queue == Queue &&
26-
It->Alignment >= Alignment) {
27-
Allocation BestFit = *It;
28-
Freelist.erase(It);
46+
auto GroupIt = Freelist.find(Key);
47+
assert(GroupIt != Freelist.end() && "Allocation group not found in freelist");
2948

30-
return BestFit;
49+
auto &AllocSet = GroupIt->second;
50+
auto AllocIt = AllocSet.find(Alloc);
51+
assert(AllocIt != AllocSet.end() && "Allocation not found in group");
52+
53+
AllocSet.erase(AllocIt);
54+
if (AllocSet.empty()) {
55+
Freelist.erase(GroupIt);
3156
}
57+
}
3258

33-
// To make sure there's no match on other queues, we need to reset it to
34-
// nullptr and try again.
35-
Alloc.Queue = nullptr;
36-
It = Freelist.lower_bound(Alloc);
59+
} // namespace
3760

38-
if (It != Freelist.end() && It->Size >= Size && It->Alignment >= Alignment) {
39-
Allocation BestFit = *It;
40-
Freelist.erase(It);
61+
EnqueuedPool::~EnqueuedPool() { cleanup(); }
4162

63+
std::optional<EnqueuedPool::Allocation>
64+
EnqueuedPool::getBestFit(size_t Size, size_t Alignment, void *Queue) {
65+
auto Lock = std::lock_guard(Mutex);
66+
67+
// First, try to find the best fit in the queue-specific freelist.
68+
auto BestFit = getBestFitHelper(Size, Alignment, Queue, FreelistByQueue);
69+
if (BestFit) {
70+
// Remove the allocation from the global freelist as well.
71+
removeFromFreelist(*BestFit, FreelistGlobal, true);
72+
return BestFit;
73+
}
74+
75+
// If no fit was found in the queue-specific freelist, try the global
76+
// freelist.
77+
BestFit = getBestFitHelper(Size, Alignment, nullptr, FreelistGlobal);
78+
if (BestFit) {
79+
// Remove the allocation from the queue-specific freelist.
80+
removeFromFreelist(*BestFit, FreelistByQueue, false);
4281
return BestFit;
4382
}
4483

@@ -52,45 +91,54 @@ void EnqueuedPool::insert(void *Ptr, size_t Size, ur_event_handle_t Event,
5291
uintptr_t Address = (uintptr_t)Ptr;
5392
size_t Alignment = Address & (~Address + 1);
5493

55-
Freelist.emplace(Allocation{Ptr, Size, Event, Queue, Alignment});
94+
Allocation Alloc = {Ptr, Size, Event, Queue, Alignment};
95+
FreelistByQueue[{Queue, Alignment}].emplace(Alloc);
96+
FreelistGlobal[{nullptr, Alignment}].emplace(Alloc);
5697
}
5798

5899
bool EnqueuedPool::cleanup() {
59100
auto Lock = std::lock_guard(Mutex);
60-
auto FreedAllocations = !Freelist.empty();
101+
auto FreedAllocations = !FreelistGlobal.empty();
61102

62103
auto Ret [[maybe_unused]] = UR_RESULT_SUCCESS;
63-
for (auto It : Freelist) {
64-
Ret = MemFreeFn(It.Ptr);
65-
assert(Ret == UR_RESULT_SUCCESS);
66-
67-
if (It.Event)
68-
EventReleaseFn(It.Event);
104+
for (const auto &[GroupKey, AllocSet] : FreelistGlobal) {
105+
for (const auto &Alloc : AllocSet) {
106+
Ret = MemFreeFn(Alloc.Ptr);
107+
assert(Ret == UR_RESULT_SUCCESS);
108+
109+
if (Alloc.Event) {
110+
EventReleaseFn(Alloc.Event);
111+
}
112+
}
69113
}
70-
Freelist.clear();
114+
115+
FreelistGlobal.clear();
116+
FreelistByQueue.clear();
71117

72118
return FreedAllocations;
73119
}
74120

75121
bool EnqueuedPool::cleanupForQueue(void *Queue) {
76122
auto Lock = std::lock_guard(Mutex);
77-
78-
Allocation Alloc = {nullptr, 0, nullptr, Queue, 0};
79-
// first allocation on the freelist with the specific queue
80-
auto It = Freelist.lower_bound(Alloc);
81-
82123
bool FreedAllocations = false;
83124

84125
auto Ret [[maybe_unused]] = UR_RESULT_SUCCESS;
85-
while (It != Freelist.end() && It->Queue == Queue) {
86-
Ret = MemFreeFn(It->Ptr);
87-
assert(Ret == UR_RESULT_SUCCESS);
88-
89-
if (It->Event)
90-
EventReleaseFn(It->Event);
91-
92-
// Erase the current allocation and move to the next one
93-
It = Freelist.erase(It);
126+
auto GroupIt = FreelistByQueue.lower_bound({Queue, 0});
127+
while (GroupIt != FreelistByQueue.end() && GroupIt->first.Queue == Queue) {
128+
auto &AllocSet = GroupIt->second;
129+
for (const auto &Alloc : AllocSet) {
130+
Ret = MemFreeFn(Alloc.Ptr);
131+
assert(Ret == UR_RESULT_SUCCESS);
132+
133+
if (Alloc.Event) {
134+
EventReleaseFn(Alloc.Event);
135+
}
136+
137+
removeFromFreelist(Alloc, FreelistGlobal, true);
138+
}
139+
140+
// Move to the next group.
141+
GroupIt = FreelistByQueue.erase(GroupIt);
94142
FreedAllocations = true;
95143
}
96144

unified-runtime/source/adapters/level_zero/enqueued_pool.hpp

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include "ur_api.h"
1515
#include "ur_pool_manager.hpp"
16+
#include <map>
1617
#include <set>
1718
#include <umf_helpers.hpp>
1819

@@ -43,25 +44,44 @@ class EnqueuedPool {
4344
bool cleanup();
4445
bool cleanupForQueue(void *Queue);
4546

46-
private:
47-
struct Comparator {
48-
bool operator()(const Allocation &lhs, const Allocation &rhs) const {
47+
// Allocations are grouped by queue and alignment.
48+
struct AllocationGroupKey {
49+
void *Queue;
50+
size_t Alignment;
51+
};
52+
53+
struct GroupComparator {
54+
bool operator()(const AllocationGroupKey &lhs,
55+
const AllocationGroupKey &rhs) const {
4956
if (lhs.Queue != rhs.Queue) {
50-
return lhs.Queue < rhs.Queue; // Compare by queue handle first
51-
}
52-
if (lhs.Alignment != rhs.Alignment) {
53-
return lhs.Alignment < rhs.Alignment; // Then by alignment
57+
return lhs.Queue < rhs.Queue;
5458
}
59+
return lhs.Alignment < rhs.Alignment;
60+
}
61+
};
62+
63+
// Then, the allocations are sorted by size.
64+
struct SizeComparator {
65+
bool operator()(const Allocation &lhs, const Allocation &rhs) const {
5566
if (lhs.Size != rhs.Size) {
56-
return lhs.Size < rhs.Size; // Then by size
67+
return lhs.Size < rhs.Size;
5768
}
58-
return lhs.Ptr < rhs.Ptr; // Finally by pointer address
69+
return lhs.Ptr < rhs.Ptr;
5970
}
6071
};
6172

62-
using AllocationSet = std::set<Allocation, Comparator>;
73+
using AllocationGroup = std::set<Allocation, SizeComparator>;
74+
using AllocationGroupMap =
75+
std::map<AllocationGroupKey, AllocationGroup, GroupComparator>;
76+
77+
private:
6378
ur_mutex Mutex;
64-
AllocationSet Freelist;
79+
80+
// Freelist grouped by queue and alignment.
81+
AllocationGroupMap FreelistByQueue;
82+
// Freelist grouped by alignment only.
83+
AllocationGroupMap FreelistGlobal;
84+
6585
event_release_callback_t EventReleaseFn;
6686
memory_free_callback_t MemFreeFn;
6787
};

unified-runtime/test/adapters/level_zero/enqueue_alloc.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -758,3 +758,47 @@ TEST_P(urL0EnqueueAllocMultiQueueMultiDeviceTest,
758758
ASSERT_NE(freeEvent, nullptr);
759759
}
760760
}
761+
762+
using urL0EnqueueAllocStandaloneTest = uur::urQueueTest;
763+
UUR_INSTANTIATE_DEVICE_TEST_SUITE(urL0EnqueueAllocStandaloneTest);
764+
765+
TEST_P(urL0EnqueueAllocStandaloneTest, ReuseFittingAllocation) {
766+
ur_usm_pool_handle_t pool = nullptr;
767+
ur_usm_pool_desc_t pool_desc = {};
768+
ASSERT_SUCCESS(urUSMPoolCreate(context, &pool_desc, &pool));
769+
770+
auto makeAllocation = [&](uint32_t alignment, size_t size, void **ptr) {
771+
const ur_usm_device_desc_t usm_device_desc{
772+
UR_STRUCTURE_TYPE_USM_DEVICE_DESC, nullptr,
773+
/* device flags */ 0};
774+
775+
const ur_usm_desc_t usm_desc{UR_STRUCTURE_TYPE_USM_DESC, &usm_device_desc,
776+
UR_USM_ADVICE_FLAG_DEFAULT, alignment};
777+
778+
ASSERT_SUCCESS(
779+
urUSMDeviceAlloc(context, device, &usm_desc, pool, size, ptr));
780+
};
781+
782+
std::array<void *, 4> allocations = {};
783+
makeAllocation(64, 128, &allocations[0]);
784+
makeAllocation(64, 256, &allocations[1]);
785+
makeAllocation(4096, 512, &allocations[2]);
786+
makeAllocation(4096, 8192, &allocations[3]);
787+
788+
ASSERT_SUCCESS(
789+
urEnqueueUSMFreeExp(queue, pool, allocations[0], 0, nullptr, nullptr));
790+
ASSERT_SUCCESS(
791+
urEnqueueUSMFreeExp(queue, pool, allocations[1], 0, nullptr, nullptr));
792+
ASSERT_SUCCESS(
793+
urEnqueueUSMFreeExp(queue, pool, allocations[2], 0, nullptr, nullptr));
794+
ASSERT_SUCCESS(
795+
urEnqueueUSMFreeExp(queue, pool, allocations[3], 0, nullptr, nullptr));
796+
797+
void *ptr = nullptr;
798+
ASSERT_SUCCESS(urEnqueueUSMDeviceAllocExp(queue, pool, 8192, nullptr, 0,
799+
nullptr, &ptr, nullptr));
800+
801+
ASSERT_EQ(ptr, allocations[3]); // Fitting allocation should be reused.
802+
ASSERT_SUCCESS(urEnqueueUSMFreeExp(queue, pool, ptr, 0, nullptr, nullptr));
803+
ASSERT_SUCCESS(urQueueFinish(queue));
804+
}

0 commit comments

Comments
 (0)