Skip to content

Commit f350d2c

Browse files
authored
Add persistent cache implementation based on LMDB cache (#561)
* Add persistent cache implementation based on LMDB cache * Implement shader cache stats * Add test * Add sgl::platform::current_process_id() * Reuse LMDB environments per process * Fix DB sharing * Fix close_db * Fix bug when committing a transaction failed It's illegal to call `mdb_txn_abort` after calling `mdb_txn_commit` even if it fails.
1 parent 6439413 commit f350d2c

File tree

14 files changed

+417
-57
lines changed

14 files changed

+417
-57
lines changed
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
2+
3+
import pytest
4+
from pathlib import Path
5+
6+
import slangpy as spy
7+
from slangpy.testing import helpers
8+
9+
10+
@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES)
11+
def test_shader_cache(device_type: spy.DeviceType, tmpdir: str):
12+
cache_dir = tmpdir
13+
# Create device with a shader cache.
14+
device = spy.Device(
15+
type=device_type,
16+
enable_print=True,
17+
shader_cache_path=cache_dir,
18+
compiler_options={"include_paths": [Path(__file__).parent]},
19+
label=f"shader-cache-1-{device_type.name}",
20+
)
21+
# We expect the cache to be empty and untouched.
22+
stats = device.shader_cache_stats
23+
assert stats.entry_count == 0
24+
assert stats.hit_count == 0
25+
assert stats.miss_count == 0
26+
# Create and dispatch kernel, shader should be stored to the cache.
27+
program = device.load_program(
28+
module_name="test_shader_cache", entry_point_names=["compute_main"]
29+
)
30+
kernel = device.create_compute_kernel(program)
31+
kernel.dispatch(thread_count=[1, 1, 1])
32+
assert device.flush_print_to_string().strip() == "Hello shader cache!"
33+
# We expect at least one entry but potentially more than one
34+
# (pipelines can get cached in addition to the compiled shader binary).
35+
# We also expect at least one miss because the cache was empty.
36+
stats = device.shader_cache_stats
37+
assert stats.entry_count > 0
38+
assert stats.hit_count == 0
39+
assert stats.miss_count > 0
40+
# Close device.
41+
device.close()
42+
43+
# Re-create device using same shader cache location.
44+
device = spy.Device(
45+
type=device_type,
46+
enable_print=True,
47+
shader_cache_path=cache_dir,
48+
compiler_options={"include_paths": [Path(__file__).parent]},
49+
label=f"shader-cache-1-{device_type.name}",
50+
)
51+
# We expect at least one entry, but hit/miss count are reset.
52+
stats = device.shader_cache_stats
53+
assert stats.entry_count > 0
54+
assert stats.hit_count == 0
55+
assert stats.miss_count == 0
56+
entry_count_before = stats.entry_count
57+
# Create and dispatch kernel, shader should be loaded from cache.
58+
program = device.load_program(
59+
module_name="test_shader_cache", entry_point_names=["compute_main"]
60+
)
61+
kernel = device.create_compute_kernel(program)
62+
kernel.dispatch(thread_count=[1, 1, 1])
63+
assert device.flush_print_to_string().strip() == "Hello shader cache!"
64+
# We expect the same number of entries in the cache, but at least one hit.
65+
stats = device.shader_cache_stats
66+
assert stats.entry_count == entry_count_before
67+
assert stats.hit_count > 0
68+
assert stats.miss_count == 0
69+
# Close device.
70+
device.close()
71+
72+
73+
if __name__ == "__main__":
74+
pytest.main([__file__, "-v", "-s"])
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
2+
3+
import sgl.device.print;
4+
5+
[shader("compute")]
6+
[numthreads(1, 1, 1)]
7+
void compute_main()
8+
{
9+
print("Hello shader cache!");
10+
}

src/sgl/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ target_sources(sgl PRIVATE
104104
device/native_formats.h
105105
device/nvapi.slang
106106
device/nvapi.slangh
107+
device/persistent_cache.cpp
108+
device/persistent_cache.h
107109
device/pipeline.cpp
108110
device/pipeline.h
109111
device/print.cpp

src/sgl/core/lmdb_cache.cpp

Lines changed: 110 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "lmdb_cache.h"
44

55
#include "sgl/core/error.h"
6+
#include "sgl/core/platform.h"
67

78
#include <chrono>
89

@@ -50,9 +51,11 @@ class ScopedTransaction {
5051
void commit()
5152
{
5253
SGL_CHECK(m_txn != nullptr, "Transaction is already committed or aborted");
53-
if (int result = mdb_txn_commit(m_txn); result != MDB_SUCCESS)
54-
LMDB_THROW("Failed to commit transaction", result);
54+
// It's an error to call `mdb_txn_abort` after `mdb_txn_commit`, even if it fails.
55+
MDB_txn* txn = m_txn;
5556
m_txn = nullptr;
57+
if (int result = mdb_txn_commit(txn); result != MDB_SUCCESS)
58+
LMDB_THROW("Failed to commit transaction", result);
5659
}
5760

5861
operator MDB_txn*() { return m_txn; }
@@ -108,39 +111,17 @@ LMDBCache::LMDBCache(const std::filesystem::path& path, std::optional<Options> o
108111
if (!std::filesystem::create_directories(path, ec) && ec)
109112
SGL_THROW("Failed to create cache directory ({})", ec.message());
110113

111-
if (int result = mdb_env_create(&m_env); result != MDB_SUCCESS)
112-
LMDB_THROW("Failed to create environment", result);
113-
if (int result = mdb_env_set_maxreaders(m_env, 126); result != MDB_SUCCESS)
114-
LMDB_THROW("Failed to set max readers", result);
115-
if (int result = mdb_env_set_maxdbs(m_env, 2); result != MDB_SUCCESS)
116-
LMDB_THROW("Failed to set max DBs", result);
117-
if (int result = mdb_env_set_mapsize(m_env, options.max_size); result != MDB_SUCCESS)
118-
LMDB_THROW("Failed to set map size", result);
119-
120-
int flags = options.nosync ? MDB_NOSYNC : 0;
121-
if (int result = mdb_env_open(m_env, path.string().c_str(), flags, 0664); result != MDB_SUCCESS)
122-
LMDB_THROW("Failed to open environment", result);
114+
m_db = open_db(path, options);
123115

124-
m_max_key_size = mdb_env_get_maxkeysize(m_env);
116+
m_max_key_size = mdb_env_get_maxkeysize(m_db.env);
125117

126118
m_eviction_threshold_size = (options.eviction_threshold * options.max_size) / 100;
127119
m_eviction_target_size = (options.eviction_target * options.max_size) / 100;
128-
129-
ScopedTransaction txn(m_env);
130-
131-
if (int result = mdb_dbi_open(txn, "data", MDB_CREATE, &m_dbi_data); result != MDB_SUCCESS)
132-
LMDB_THROW("Failed to open data DB", result);
133-
if (int result = mdb_dbi_open(txn, "meta", MDB_CREATE, &m_dbi_meta); result != MDB_SUCCESS)
134-
LMDB_THROW("Failed to open meta DB", result);
135-
136-
txn.commit();
137120
}
138121

139122
LMDBCache::~LMDBCache()
140123
{
141-
mdb_dbi_close(m_env, m_dbi_data);
142-
mdb_dbi_close(m_env, m_dbi_meta);
143-
mdb_env_close(m_env);
124+
close_db(m_db);
144125
}
145126

146127
void LMDBCache::set(const void* key_data, size_t key_size, const void* value_data, size_t value_size)
@@ -152,16 +133,16 @@ void LMDBCache::set(const void* key_data, size_t key_size, const void* value_dat
152133
if (usage().used_size > m_eviction_threshold_size)
153134
evict();
154135

155-
ScopedTransaction txn(m_env);
136+
ScopedTransaction txn(m_db.env);
156137

157138
MDB_val mdb_key = {key_size, const_cast<void*>(key_data)};
158139
MDB_val mdb_val = {value_size, const_cast<void*>(value_data)};
159-
if (int result = mdb_put(txn, m_dbi_data, &mdb_key, &mdb_val, 0); result != MDB_SUCCESS)
140+
if (int result = mdb_put(txn, m_db.dbi_data, &mdb_key, &mdb_val, 0); result != MDB_SUCCESS)
160141
LMDB_THROW("Failed to write data", result);
161142

162143
MetaData meta_data{.last_access = get_current_time_ns()};
163144
MDB_val mdb_val_meta = {sizeof(MetaData), &meta_data};
164-
if (int result = mdb_put(txn, m_dbi_meta, &mdb_key, &mdb_val_meta, 0); result != MDB_SUCCESS)
145+
if (int result = mdb_put(txn, m_db.dbi_meta, &mdb_key, &mdb_val_meta, 0); result != MDB_SUCCESS)
165146
LMDB_THROW("Failed to write metadata", result);
166147

167148
txn.commit();
@@ -172,20 +153,20 @@ bool LMDBCache::get(const void* key_data, size_t key_size, WriteValueFunc write_
172153
SGL_CHECK(key_size > 0, "Key size must be greater than 0");
173154
SGL_CHECK(key_size <= m_max_key_size, "Key size exceeds maximum allowed size");
174155

175-
ScopedTransaction txn(m_env);
156+
ScopedTransaction txn(m_db.env);
176157

177158
MDB_val mdb_key = {key_size, const_cast<void*>(key_data)};
178159
MDB_val mdb_val;
179160

180-
int result = mdb_get(txn, m_dbi_data, &mdb_key, &mdb_val);
161+
int result = mdb_get(txn, m_db.dbi_data, &mdb_key, &mdb_val);
181162
if (result == MDB_NOTFOUND)
182163
return false;
183164
if (result != MDB_SUCCESS)
184165
LMDB_THROW("Failed to read data", result);
185166

186167
MetaData meta_data{.last_access = get_current_time_ns()};
187168
MDB_val mdb_val_meta = {sizeof(MetaData), &meta_data};
188-
result = mdb_put(txn, m_dbi_meta, &mdb_key, &mdb_val_meta, 0);
169+
result = mdb_put(txn, m_db.dbi_meta, &mdb_key, &mdb_val_meta, 0);
189170
if (result != MDB_SUCCESS)
190171
LMDB_THROW("Failed to write metadata", result);
191172

@@ -201,17 +182,17 @@ bool LMDBCache::del(const void* key_data, size_t key_size)
201182
SGL_CHECK(key_size > 0, "Key size must be greater than 0");
202183
SGL_CHECK(key_size <= m_max_key_size, "Key size exceeds maximum allowed size");
203184

204-
ScopedTransaction txn(m_env);
185+
ScopedTransaction txn(m_db.env);
205186

206187
MDB_val mdb_key = {key_size, const_cast<void*>(key_data)};
207188

208-
int result = mdb_del(txn, m_dbi_data, &mdb_key, nullptr);
189+
int result = mdb_del(txn, m_db.dbi_data, &mdb_key, nullptr);
209190
if (result == MDB_NOTFOUND)
210191
return false;
211192
if (result != MDB_SUCCESS)
212193
LMDB_THROW("Failed to delete data", result);
213194

214-
result = mdb_del(txn, m_dbi_meta, &mdb_key, nullptr);
195+
result = mdb_del(txn, m_db.dbi_meta, &mdb_key, nullptr);
215196
if (result != MDB_SUCCESS && result != MDB_NOTFOUND)
216197
LMDB_THROW("Failed to delete metadata", result);
217198

@@ -224,11 +205,11 @@ LMDBCache::Usage LMDBCache::usage() const
224205
{
225206
Usage usage;
226207

227-
ScopedTransaction txn(m_env, MDB_RDONLY);
208+
ScopedTransaction txn(m_db.env, MDB_RDONLY);
228209

229210
uint64_t used_pages = 0;
230211
uint64_t page_size = 0;
231-
for (MDB_dbi dbi : {MDB_dbi(0) /* FREE_DBI */, MDB_dbi(1) /* MAIN_DBI */, m_dbi_data, m_dbi_meta}) {
212+
for (MDB_dbi dbi : {MDB_dbi(0) /* FREE_DBI */, MDB_dbi(1) /* MAIN_DBI */, m_db.dbi_data, m_db.dbi_meta}) {
232213
MDB_stat stat = {};
233214
if (int result = mdb_stat(txn, dbi, &stat); result != MDB_SUCCESS)
234215
LMDB_THROW("Failed to get DB stats", result);
@@ -237,7 +218,7 @@ LMDBCache::Usage LMDBCache::usage() const
237218
}
238219

239220
MDB_envinfo info = {};
240-
if (int result = mdb_env_info(m_env, &info); result != MDB_SUCCESS)
221+
if (int result = mdb_env_info(m_db.env, &info); result != MDB_SUCCESS)
241222
LMDB_THROW("Failed to get environment info", result);
242223

243224
usage.reserved_size = info.me_mapsize;
@@ -253,8 +234,8 @@ LMDBCache::Stats LMDBCache::stats() const
253234

254235
stats.evictions = m_evictions.load();
255236

256-
ScopedTransaction txn(m_env, MDB_RDONLY);
257-
ScopedCursor cursor(txn, m_dbi_data);
237+
ScopedTransaction txn(m_db.env, MDB_RDONLY);
238+
ScopedCursor cursor(txn, m_db.dbi_data);
258239

259240
MDB_val key, val;
260241
while (mdb_cursor_get(cursor, &key, &val, MDB_NEXT) == MDB_SUCCESS) {
@@ -267,8 +248,6 @@ LMDBCache::Stats LMDBCache::stats() const
267248

268249
void LMDBCache::evict()
269250
{
270-
SGL_ASSERT(m_env != nullptr);
271-
272251
struct Entry {
273252
uint64_t last_access;
274253
MDB_val key;
@@ -281,8 +260,8 @@ void LMDBCache::evict()
281260
return;
282261
size_t required_free_size = used_size - m_eviction_target_size;
283262

284-
ScopedTransaction txn(m_env);
285-
ScopedCursor cursor(txn, m_dbi_meta);
263+
ScopedTransaction txn(m_db.env);
264+
ScopedCursor cursor(txn, m_db.dbi_meta);
286265

287266
// Scan all entries.
288267
MDB_val key, val;
@@ -302,12 +281,12 @@ void LMDBCache::evict()
302281
while (required_free_size > 0 && !entries.empty()) {
303282
std::pop_heap(entries.begin(), entries.end(), cmp);
304283
Entry& entry = entries.back();
305-
if (int result = mdb_get(txn, m_dbi_data, &entry.key, &val); result != MDB_SUCCESS)
284+
if (int result = mdb_get(txn, m_db.dbi_data, &entry.key, &val); result != MDB_SUCCESS)
306285
LMDB_THROW("Failed to get data during eviction", result);
307286
required_free_size -= std::min(required_free_size, val.mv_size);
308-
if (int result = mdb_del(txn, m_dbi_data, &entry.key, nullptr); result != MDB_SUCCESS)
287+
if (int result = mdb_del(txn, m_db.dbi_data, &entry.key, nullptr); result != MDB_SUCCESS)
309288
LMDB_THROW("Failed to delete data during eviction", result);
310-
if (int result = mdb_del(txn, m_dbi_meta, &entry.key, nullptr); result != MDB_SUCCESS)
289+
if (int result = mdb_del(txn, m_db.dbi_meta, &entry.key, nullptr); result != MDB_SUCCESS)
311290
LMDB_THROW("Failed to delete metadata during eviction", result);
312291
entries.pop_back();
313292
evictions++;
@@ -321,4 +300,87 @@ void LMDBCache::evict()
321300
m_evictions.fetch_add(evictions);
322301
}
323302

303+
// LMDB doesn't support opening the same DB environment multiple times in the same process.
304+
// To work around this, we keep a global list of open environments to reuse them if opened multiple times.
305+
306+
struct DBCacheItem {
307+
uint64_t ref_count;
308+
ProcessID pid;
309+
std::filesystem::path path;
310+
LMDBCache::DB db;
311+
};
312+
313+
std::vector<DBCacheItem> s_db_cache;
314+
std::mutex s_db_cache_mutex;
315+
316+
LMDBCache::DB LMDBCache::open_db(const std::filesystem::path& path, const Options& options)
317+
{
318+
ProcessID pid = platform::current_process_id();
319+
std::filesystem::path abs_path = std::filesystem::absolute(path);
320+
std::lock_guard lock(s_db_cache_mutex);
321+
auto it = std::find_if(
322+
s_db_cache.begin(),
323+
s_db_cache.end(),
324+
[pid, &abs_path](const DBCacheItem& e) { return e.pid == pid && e.path == abs_path; }
325+
);
326+
if (it != s_db_cache.end()) {
327+
it->ref_count++;
328+
return it->db;
329+
}
330+
331+
DB db = {};
332+
333+
if (int result = mdb_env_create(&db.env); result != MDB_SUCCESS)
334+
LMDB_THROW("Failed to create environment", result);
335+
if (int result = mdb_env_set_maxreaders(db.env, 126); result != MDB_SUCCESS)
336+
LMDB_THROW("Failed to set max readers", result);
337+
if (int result = mdb_env_set_maxdbs(db.env, 2); result != MDB_SUCCESS)
338+
LMDB_THROW("Failed to set max DBs", result);
339+
if (int result = mdb_env_set_mapsize(db.env, options.max_size); result != MDB_SUCCESS)
340+
LMDB_THROW("Failed to set map size", result);
341+
342+
int flags = options.nosync ? MDB_NOSYNC : 0;
343+
if (int result = mdb_env_open(db.env, abs_path.string().c_str(), flags, 0664); result != MDB_SUCCESS)
344+
LMDB_THROW("Failed to open environment", result);
345+
346+
ScopedTransaction txn(db.env);
347+
348+
if (int result = mdb_dbi_open(txn, "data", MDB_CREATE, &db.dbi_data); result != MDB_SUCCESS)
349+
LMDB_THROW("Failed to open data DB", result);
350+
if (int result = mdb_dbi_open(txn, "meta", MDB_CREATE, &db.dbi_meta); result != MDB_SUCCESS)
351+
LMDB_THROW("Failed to open meta DB", result);
352+
353+
txn.commit();
354+
355+
s_db_cache.push_back(
356+
DBCacheItem{
357+
.ref_count = 1,
358+
.pid = pid,
359+
.path = abs_path,
360+
.db = db,
361+
}
362+
);
363+
364+
return db;
365+
}
366+
367+
void LMDBCache::close_db(DB db)
368+
{
369+
ProcessID pid = platform::current_process_id();
370+
std::lock_guard lock(s_db_cache_mutex);
371+
auto it = std::find_if(
372+
s_db_cache.begin(),
373+
s_db_cache.end(),
374+
[&db](const DBCacheItem& item) { return item.db.env == db.env; }
375+
);
376+
SGL_ASSERT(it != s_db_cache.end());
377+
SGL_ASSERT(it->pid == pid);
378+
if (--it->ref_count == 0) {
379+
mdb_dbi_close(db.env, db.dbi_data);
380+
mdb_dbi_close(db.env, db.dbi_meta);
381+
mdb_env_close(db.env);
382+
s_db_cache.erase(it);
383+
}
384+
}
385+
324386
} // namespace sgl

0 commit comments

Comments
 (0)