diff --git a/CMakeLists.txt b/CMakeLists.txt index 242129d..0de9154 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,14 +8,15 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON) include_directories(include third_party) set(PA_OVERRIDE ON) +set(PA_BUILD_TESTS OFF CACHE BOOL "Disable palloc tests for embedded pomai_cache" FORCE) add_subdirectory(third_party/palloc) -add_library(pomai_cache_core +add_library(pomaicache SHARED src/engine/engine.cpp src/engine/ssd_store.cpp src/policy/policies.cpp - src/server/http.cpp src/server/ai_cache.cpp + src/pomai_embedded.cc src/metrics/info_metrics.cpp src/util/time.cpp src/ds/bloom_filter.cpp @@ -23,54 +24,56 @@ add_library(pomai_cache_core src/ds/vector_index.cpp src/ds/dep_graph.cpp src/ds/compression.cpp + src/engine/prompt_cache.cpp + src/bindings/c_api.cc ) -target_link_libraries(pomai_cache_core PUBLIC palloc) - -if(NOT WIN32) - add_executable(pomai_cache_server src/server/server_main.cpp) - target_link_libraries(pomai_cache_server PRIVATE pomai_cache_core uring) - - add_executable(pomai_cache_cli apps/cli/main.cpp) - - add_executable(pomai_cache_netbench bench/pomai_cache_netbench.cpp) - target_link_libraries(pomai_cache_netbench PRIVATE pomai_cache_core) -endif() - -add_executable(pomai_cache_bench bench/pomai_cache_bench.cpp) -target_link_libraries(pomai_cache_bench PRIVATE pomai_cache_core) - -add_executable(pomai_cache_ai_bench bench/ai_artifact_bench.cpp) -target_link_libraries(pomai_cache_ai_bench PRIVATE pomai_cache_core) - -add_executable(pomai_cache_vector_bench bench/vector_cache_bench.cpp) -target_link_libraries(pomai_cache_vector_bench PRIVATE pomai_cache_core) +target_link_libraries(pomaicache PUBLIC palloc) include(CTest) if(BUILD_TESTING) add_library(mini_catch_main third_party/catch2/catch_main.cpp) add_executable(test_engine tests/test_engine.cpp) - target_link_libraries(test_engine PRIVATE pomai_cache_core mini_catch_main) - - add_executable(test_http tests/test_http.cpp) - target_link_libraries(test_http PRIVATE pomai_cache_core mini_catch_main) + target_link_libraries(test_engine PRIVATE pomaicache mini_catch_main) add_executable(test_ai_cache tests/test_ai_cache.cpp) - target_link_libraries(test_ai_cache PRIVATE pomai_cache_core mini_catch_main) + target_link_libraries(test_ai_cache PRIVATE pomaicache mini_catch_main) add_test(NAME test_engine COMMAND test_engine) - add_test(NAME test_http COMMAND test_http) add_test(NAME test_ai_cache COMMAND test_ai_cache) if(NOT WIN32) add_executable(test_integration tests/test_integration.cpp) - target_link_libraries(test_integration PRIVATE pomai_cache_core mini_catch_main) + target_link_libraries(test_integration PRIVATE pomaicache mini_catch_main) add_test(NAME test_integration COMMAND test_integration) endif() endif() if(NOT WIN32) add_executable(pomai_cache_crash_harness tests/crash_harness.cpp) - target_link_libraries(pomai_cache_crash_harness PRIVATE pomai_cache_core) + target_link_libraries(pomai_cache_crash_harness PRIVATE pomaicache) + + add_executable(pomai_cache_bench bench/pomai_cache_bench.cpp) + target_link_libraries(pomai_cache_bench PRIVATE pomaicache) + + add_executable(ai_artifact_bench bench/ai_artifact_bench.cpp) + target_link_libraries(ai_artifact_bench PRIVATE pomaicache) + + add_executable(vector_cache_bench bench/vector_cache_bench.cpp) + target_link_libraries(vector_cache_bench PRIVATE pomaicache) + + add_executable(prompt_cache_bench bench/prompt_cache_bench.cpp) + target_link_libraries(prompt_cache_bench PRIVATE pomaicache) +endif() + +option(BUILD_PYTHON_BINDINGS "Build Python bindings" ON) + +if(BUILD_PYTHON_BINDINGS) + find_package(pybind11 CONFIG REQUIRED) + add_library(pomaicache_python MODULE src/bindings/python_bindings.cc) + target_link_libraries(pomaicache_python PRIVATE pomaicache pybind11::module) + set_target_properties(pomaicache_python PROPERTIES + OUTPUT_NAME "pomaicache" + ) endif() diff --git a/Makefile b/Makefile index 3aac6c4..d092e09 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,6 @@ BUILD_DIR ?= build dev: cmake -S . -B $(BUILD_DIR) -DCMAKE_BUILD_TYPE=Debug cmake --build $(BUILD_DIR) -j - ./$(BUILD_DIR)/pomai_cache_server --port 6379 --policy pomai_cost --params config/policy_params.json release: cmake -S . -B $(BUILD_DIR)-release -DCMAKE_BUILD_TYPE=Release @@ -25,11 +24,9 @@ bench: cmake -S . -B $(BUILD_DIR)-release -DCMAKE_BUILD_TYPE=Release cmake --build $(BUILD_DIR)-release -j ./$(BUILD_DIR)-release/pomai_cache_bench - -netbench: - cmake -S . -B $(BUILD_DIR)-release -DCMAKE_BUILD_TYPE=Release - cmake --build $(BUILD_DIR)-release -j - ./$(BUILD_DIR)-release/pomai_cache_netbench + ./$(BUILD_DIR)-release/ai_artifact_bench + ./$(BUILD_DIR)-release/vector_cache_bench + ./$(BUILD_DIR)-release/prompt_cache_bench bench-all: ./scripts/bench_run.sh diff --git a/README.md b/README.md index adc7071..650a374 100644 --- a/README.md +++ b/README.md @@ -2,87 +2,83 @@ -Redis-compatible (subset) local cache core with RAM+SSD tiering, bounded TTL cleanup, crash-safe append-only SSD segments, selectable eviction policy (`lru`, `lfu`, `pomai_cost`), and an AI artifact cache layer for embeddings/prompts/RAG/rerank/response reuse. +Embedded local cache core with RAM+SSD tiering, bounded TTL cleanup, crash-safe append-only SSD segments, selectable eviction policy (`lru`, `lfu`, `pomai_cost`), and an AI artifact cache layer for embeddings/prompts/RAG/rerank/response reuse. ## Repo structure -- `src/server/` RESP parser + connection loop - `src/engine/` KV store, TTL heap, memory limit enforcement - `src/policy/` LRU, LFU, PomaiCostPolicy - `src/metrics/` INFO metrics module -- `apps/cli/` simple CLI helper -- `bench/` benchmark tool +- `bindings/` C and Python embedded bindings +- `bench/` embedded benchmarks - `tests/` correctness tests - `tuner/` offline python tuner -- `docker/` container artifacts -## Quickstart +## Quickstart (embedded library) -### Build + run locally +### Build library ```bash cmake -S . -B build -DCMAKE_BUILD_TYPE=Debug cmake --build build -j -./build/pomai_cache_server --port 6379 --policy pomai_cost --params config/policy_params.json --ssd-enabled --data-dir ./data --ssd-value-min-bytes 2048 --fsync everysec ``` -or: +This produces the shared library `libpomaicache` and optional Python module. -```bash -make dev -``` +### C++ usage -### Run with Docker +```cpp +#include -```bash -make docker-build -make docker-run -``` +int main() { + pomaicache::Config cfg; + cfg.memory_limit_bytes = 128 * 1024 * 1024; + cfg.data_dir = "./data"; -### Example redis-cli session + pomaicache::PomaiCache cache(cfg); + const std::string key = "demo"; + const std::string value = "hello"; -```bash -redis-cli -p 6379 SET demo hello EX 30 -redis-cli -p 6379 GET demo -redis-cli -p 6379 INFO -redis-cli -p 6379 CONFIG GET POLICY -redis-cli -p 6379 CONFIG SET POLICY lru -redis-cli -p 6379 CONFIG SET PARAMS /app/config/policy_params.json + cache.Set(key, std::as_bytes(std::span(value.data(), value.size())), + pomaicache::Ttl{300000}); + + auto got = cache.Get(key); + if (got) { + // use *got + } +} ``` +### C API usage + +```c +#include + +int main() { + pomai_config_t cfg = { .memory_limit_bytes = 128 * 1024 * 1024, + .data_dir = "./data" }; + pomai_t* db = pomai_create(&cfg); + const char* key = "demo"; + const char* val = "hello"; + pomai_set(db, key, strlen(key), val, strlen(val), 300000); + void* out = NULL; + size_t out_len = 0; + if (pomai_get(db, key, strlen(key), &out, &out_len) == 0) { + // use out / out_len + pomai_free(out); + } + pomai_destroy(db); +} +``` -### AI artifact quickstart (redis-cli) +### Python usage -```bash -redis-cli -p 6379 AI.PUT embedding emb:modelA:hashA:768:float16 '{"artifact_type":"embedding","owner":"vector","schema_version":"v1","model_id":"modelA","snapshot_epoch":"ix1"}' "abc" -redis-cli -p 6379 AI.GET emb:modelA:hashA:768:float16 -redis-cli -p 6379 AI.STATS -redis-cli -p 6379 AI.INVALIDATE EPOCH ix1 -``` +```python +import pomaicache -## Make targets - -- `make dev` debug build + run server -- `make release` release build -- `make test` tests -- `make bench` benchmarks -- `make crash-suite` short crash/recovery harness -- `make fmt` clang-format -- `make docker-build` -- `make docker-run` - -## Supported commands - -- `GET` -- `SET key value [EX seconds] [OWNER owner_name]` -- `DEL key [key ...]` -- `EXPIRE key seconds` -- `TTL key` -- `MGET key [key ...]` -- `INFO` -- `CONFIG GET POLICY` -- `CONFIG SET POLICY ` -- `CONFIG SET PARAMS ` +cache = pomaicache.Cache(data_dir="./data", memory_limit_bytes=128*1024*1024) +# prompt_put / prompt_get APIs are available for prompt prefix caching +``` ## Policy tuning @@ -92,35 +88,16 @@ Generate params from offline stats snapshot: python3 tuner/tune_policy.py --input stats_snapshot.json --output config/policy_params.json ``` -The server loads params on startup and can reload at runtime via: - -```bash -redis-cli -p 6379 CONFIG SET PARAMS /app/config/policy_params.json -``` - -Invalid/missing param files are handled safely with existing/default values. +Your application is responsible for loading updated params and calling the appropriate reload functions in the embedded API. ## Benchmarks -Run: - -```bash -make bench -``` - -Bench reports per workload and policy: - -- ops/s -- p50/p95/p99 latency -- hit rate -- memory used +Benchmarks under `bench/` exercise the embedded library in-process (no network). ## Security/stability constraints - max key length enforced - max value size enforced -- max concurrent connections enforced -- slow-client protection via bounded output buffer - bounded per-tick TTL cleanup ## SSD tier defaults (laptop-safe) diff --git a/apps/cli/main.cpp b/apps/cli/main.cpp deleted file mode 100644 index f5820cb..0000000 --- a/apps/cli/main.cpp +++ /dev/null @@ -1,38 +0,0 @@ -#include -#include -#include -#include -#include - -int main(int argc, char **argv) { - std::string host = "127.0.0.1"; - int port = 6379; - if (argc > 1) - port = std::stoi(argv[1]); - - int fd = socket(AF_INET, SOCK_STREAM, 0); - sockaddr_in addr{}; - addr.sin_family = AF_INET; - addr.sin_port = htons(port); - inet_pton(AF_INET, host.c_str(), &addr.sin_addr); - if (connect(fd, reinterpret_cast(&addr), sizeof(addr)) != 0) { - std::cerr << "connect failed\n"; - return 1; - } - std::string line; - while (std::getline(std::cin, line)) { - if (line == "quit") - break; - std::string payload = - "*1\r\n$" + std::to_string(line.size()) + "\r\n" + line + "\r\n"; - send(fd, payload.data(), payload.size(), 0); - char buf[4096]; - auto n = recv(fd, buf, sizeof(buf), 0); - if (n <= 0) - break; - std::cout.write(buf, n); - std::cout << std::endl; - } - close(fd); - return 0; -} diff --git a/bench/ai_artifact_bench.cpp b/bench/ai_artifact_bench.cpp index 0b79264..4c31da4 100644 --- a/bench/ai_artifact_bench.cpp +++ b/bench/ai_artifact_bench.cpp @@ -62,6 +62,8 @@ int main(int argc, char **argv) { if (argc > 1) out = argv[1]; + std::cout << "Embedded AI artifact cache benchmark (in-process, no network)\n"; + Engine e({64 * 1024 * 1024, 256, 4 * 1024 * 1024}, make_policy_by_name("pomai_cost")); AiArtifactCache ai(e); @@ -207,6 +209,19 @@ int main(int argc, char **argv) { std::chrono::steady_clock::now() - r0) .count(); + // Human-readable summary. + std::cout << std::fixed << std::setprecision(2); + std::cout << "\nworkload ops/s p50_us p99_us hit_rate\n"; + std::cout << "------------------------------------------------------------------\n"; + for (const auto &r : results) { + std::cout << std::left << std::setw(22) << r.name << std::right + << std::setw(10) << r.ops_s + << std::setw(12) << r.p50 + << std::setw(12) << r.p99 + << std::setw(11) << r.hit_rate << "\n"; + } + std::cout << "\nEngine constructor time (ms): " << warm_ms << "\n"; + std::ofstream os(out); os << "{\n \"workloads\": [\n"; for (std::size_t i = 0; i < results.size(); ++i) { @@ -221,7 +236,7 @@ int main(int argc, char **argv) { } os << " ],\n"; os << " \"ssd_mb_s\": 0.0,\n"; - os << " \"warm_restart_ms\": " << warm_ms << ",\n"; + os << " \"engine_ctor_ms\": " << warm_ms << ",\n"; os << " \"dedup_ratio\": 0.0\n"; os << "}\n"; diff --git a/bench/pomai_cache_bench.cpp b/bench/pomai_cache_bench.cpp index d1c0638..36fd6d0 100644 --- a/bench/pomai_cache_bench.cpp +++ b/bench/pomai_cache_bench.cpp @@ -1,23 +1,47 @@ #include "pomai_cache/engine.hpp" +#include "pomaicache.h" #include #include #include #include #include +#include +#include using namespace pomai_cache; +struct LatencyStats { + double p50_us{0}; + double p95_us{0}; + double p99_us{0}; +}; + +static LatencyStats compute_stats(std::vector &samples) { + if (samples.empty()) + return {}; + std::sort(samples.begin(), samples.end()); + auto at = [&](double p) { + return samples[static_cast(p * (samples.size() - 1))]; + }; + return {at(0.50), at(0.95), at(0.99)}; +} + int main() { const std::vector policies = {"lru", "lfu", "pomai_cost"}; const std::vector presets = {"hotset", "uniform", "writeheavy", "mixed"}; constexpr std::uint64_t seed = 424242; + + std::cout << "Embedded cache benchmark (in-process, no network)\n"; std::cout << "seed=" << seed << "\n"; - std::cout << "|workload|policy|ops/s|hit_rate|evictions|\n"; + std::cout << "\n|workload|policy|ops/s|hit_rate|evictions|\n"; std::cout << "|---|---:|---:|---:|---:|\n"; + std::cout << std::fixed << std::setprecision(2); + + // Engine-only throughput and hit rate. for (const auto &preset : presets) { for (const auto &pname : policies) { Engine engine({8 * 1024 * 1024, 256, 4 * 1024, 256}, @@ -49,10 +73,98 @@ int main() { const double hit_rate = gets > 0 ? static_cast(hits) / static_cast(gets) : 0.0; - std::cout << "|" << preset << "|" << pname << "|" << std::fixed - << std::setprecision(2) << (ops / seconds) << "|" << hit_rate - << "|" << engine.stats().evictions << "|\n"; + std::cout << "|" << preset << "|" << pname << "|" << (ops / seconds) + << "|" << hit_rate << "|" << engine.stats().evictions << "|\n"; + } + } + + // Engine get-latency percentiles for pomai_cost. + std::cout + << "\nEngine GET latency (pomai_cost, microseconds, p50/p95/p99):\n"; + for (const auto &preset : presets) { + Engine engine({8 * 1024 * 1024, 256, 4 * 1024, 256}, + make_policy_by_name("pomai_cost")); + std::mt19937_64 rng(seed); + std::uniform_int_distribution u(0, 999); + const int ops = 30000; + std::vector get_lat; + int gets = 0; + int hits = 0; + + for (int i = 0; i < ops; ++i) { + int k = u(rng); + if (preset == "hotset") + k = static_cast(std::pow((u(rng) % 100) + 1, 1.4)); + std::string key = "k" + std::to_string(k % 1000); + const bool do_write = + preset == "writeheavy" ? (i % 2 == 0) : (i % 5 == 0); + if (do_write) { + std::vector v(64, static_cast(i % 255)); + engine.set(key, v, std::nullopt, "default"); + } else { + ++gets; + auto t0 = std::chrono::steady_clock::now(); + if (engine.get(key).has_value()) + ++hits; + auto t1 = std::chrono::steady_clock::now(); + get_lat.push_back( + std::chrono::duration(t1 - t0).count()); + } + } + auto stats = compute_stats(get_lat); + double hit_rate = + gets > 0 ? static_cast(hits) / static_cast(gets) : 0.0; + std::cout << " " << std::setw(10) << preset << " p50=" << stats.p50_us + << " p95=" << stats.p95_us << " p99=" << stats.p99_us + << " hit_rate=" << hit_rate << "\n"; + } + + // PomaiCache (embedded API) benchmark with same workloads. + std::cout << "\nPomaiCache embedded API (Set/Get) latency, pomai_cost:\n"; + for (const auto &preset : presets) { + pomaicache::Config cfg; + cfg.memory_limit_bytes = 8 * 1024 * 1024; + cfg.data_dir = "./data_pomai_embedded"; + pomaicache::PomaiCache cache(cfg); + + std::mt19937_64 rng(seed); + std::uniform_int_distribution u(0, 999); + const int ops = 30000; + std::vector get_lat; + int gets = 0; + int hits = 0; + + for (int i = 0; i < ops; ++i) { + int k = u(rng); + if (preset == "hotset") + k = static_cast(std::pow((u(rng) % 100) + 1, 1.4)); + std::string key = "k" + std::to_string(k % 1000); + const bool do_write = + preset == "writeheavy" ? (i % 2 == 0) : (i % 5 == 0); + if (do_write) { + std::vector v(64, static_cast(i % 255)); + std::span val( + reinterpret_cast(v.data()), v.size()); + cache.Set(key, val, pomaicache::Ttl{0}); + } else { + ++gets; + auto t0 = std::chrono::steady_clock::now(); + auto got = cache.Get(key); + if (got.has_value()) + ++hits; + auto t1 = std::chrono::steady_clock::now(); + get_lat.push_back( + std::chrono::duration(t1 - t0).count()); + } } + + auto stats = compute_stats(get_lat); + double hit_rate = + gets > 0 ? static_cast(hits) / static_cast(gets) : 0.0; + std::cout << " " << std::setw(10) << preset << " p50=" << stats.p50_us + << " p95=" << stats.p95_us << " p99=" << stats.p99_us + << " hit_rate=" << hit_rate << "\n"; } + return 0; } diff --git a/bench/pomai_cache_netbench.cpp b/bench/pomai_cache_netbench.cpp deleted file mode 100644 index d53cfef..0000000 --- a/bench/pomai_cache_netbench.cpp +++ /dev/null @@ -1,333 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -struct Options { - std::string host{"127.0.0.1"}; - int port{6379}; - std::string workload{"mixed"}; - int threads{4}; - int clients{16}; - int pipeline{1}; - int duration_s{10}; - int warmup_s{2}; - int key_size{16}; - int value_size{128}; - int keyspace{10000}; - std::uint64_t seed{1337}; - std::string json_out{"netbench_summary.json"}; -}; - -struct SharedStats { - std::mutex mu; - std::vector latencies_us; - std::uint64_t ops{0}; - std::uint64_t get_ops{0}; - std::uint64_t get_hits{0}; - std::uint64_t set_ops{0}; -}; - -std::string make_cmd(const std::vector &args) { - if (args.empty()) return ""; - std::string cmd = args[0]; - if (cmd == "GET") { - return "GET /key/" + args[1] + " HTTP/1.1\r\n\r\n"; - } else if (cmd == "SET") { - std::string req = "POST /key/" + args[1]; - if (args.size() > 3 && args[3] == "PX") { - req += "?px=" + args[4]; - } - req += " HTTP/1.1\r\nContent-Length: " + std::to_string(args[2].size()) + "\r\n\r\n" + args[2]; - return req; - } else if (cmd == "INFO") { - return "GET /info HTTP/1.1\r\n\r\n"; - } - return ""; -} - -std::optional read_reply(int fd) { - std::string out; - char buf[4096]; - while (true) { - int r = recv(fd, buf, 4096, 0); - if (r <= 0) return std::nullopt; - out.append(buf, r); - if (out.find("\r\n\r\n") != std::string::npos) { - auto pos = out.find("Content-Length: "); - if (pos != std::string::npos) { - auto end = out.find("\r\n", pos); - int len = std::stoi(out.substr(pos + 16, end - pos - 16)); - auto header_end = out.find("\r\n\r\n") + 4; - if (out.size() >= header_end + len) { - return out; - } - } else { - return out; - } - } - } -} - -std::string fixed_key(int k, int key_size) { - std::string s = "k" + std::to_string(k); - if (static_cast(s.size()) < key_size) - s += std::string( - static_cast(key_size - static_cast(s.size())), 'x'); - return s; -} - -void parse_info(const std::string &info, std::uint64_t &memory_used, - std::uint64_t &evictions, std::uint64_t &admissions, - std::uint64_t &ram_hits, std::uint64_t &ssd_hits, - double &ssd_read_mb, double &ssd_write_mb, - std::uint64_t &ssd_bytes, double &fragmentation, - std::uint64_t &index_rebuild_ms) { - auto value_of = [&](const std::string &k) { - auto p = info.find(k + ":"); - if (p == std::string::npos) - return std::uint64_t{0}; - auto e = info.find('\n', p); - return static_cast( - std::stoull(info.substr(p + k.size() + 1, e - p - k.size() - 1))); - }; - memory_used = value_of("memory_used_bytes"); - evictions = value_of("evictions"); - admissions = value_of("admissions_rejected"); - ram_hits = value_of("hits"); - ssd_hits = value_of("ssd_hits"); - ssd_bytes = value_of("ssd_bytes"); - index_rebuild_ms = value_of("ssd_index_rebuild_ms"); - auto value_of_d = [&](const std::string &k) { - auto p = info.find(k + ":"); - if (p == std::string::npos) - return 0.0; - auto e = info.find('\n', p); - return std::stod(info.substr(p + k.size() + 1, e - p - k.size() - 1)); - }; - ssd_read_mb = value_of_d("ssd_read_mb"); - ssd_write_mb = value_of_d("ssd_write_mb"); - fragmentation = value_of_d("fragmentation_estimate"); -} - -int connect_server(const Options &opt) { - int fd = socket(AF_INET, SOCK_STREAM, 0); - sockaddr_in addr{}; - addr.sin_family = AF_INET; - addr.sin_port = htons(opt.port); - inet_pton(AF_INET, opt.host.c_str(), &addr.sin_addr); - if (connect(fd, reinterpret_cast(&addr), sizeof(addr)) < 0) - return -1; - return fd; -} - -int main(int argc, char **argv) { - Options opt; - for (int i = 1; i < argc; ++i) { - std::string a = argv[i]; - auto take = [&](int &v) { - if (i + 1 < argc) - v = std::stoi(argv[++i]); - }; - if (a == "--port") - take(opt.port); - else if (a == "--threads") - take(opt.threads); - else if (a == "--clients") - take(opt.clients); - else if (a == "--pipeline") - take(opt.pipeline); - else if (a == "--duration") - take(opt.duration_s); - else if (a == "--warmup") - take(opt.warmup_s); - else if (a == "--key-size") - take(opt.key_size); - else if (a == "--value-size") - take(opt.value_size); - else if (a == "--keyspace") - take(opt.keyspace); - else if (a == "--workload" && i + 1 < argc) - opt.workload = argv[++i]; - else if (a == "--json" && i + 1 < argc) - opt.json_out = argv[++i]; - } - - SharedStats shared; - std::atomic running{true}; - auto end_time = std::chrono::steady_clock::now() + - std::chrono::seconds(opt.duration_s + opt.warmup_s); - auto warmup_end = - std::chrono::steady_clock::now() + std::chrono::seconds(opt.warmup_s); - - std::vector workers; - for (int t = 0; t < opt.clients; ++t) { - workers.emplace_back([&, t] { - int fd = connect_server(opt); - if (fd < 0) - return; - std::mt19937_64 rng(opt.seed + static_cast(t)); - std::uniform_int_distribution uniform(0, opt.keyspace - 1); - std::uniform_real_distribution real(0.0, 1.0); - int value_size = opt.value_size; - if (opt.workload == "tier_on_large_values") - value_size = std::max(value_size, 64 * 1024); - std::string value(static_cast(value_size), 'v'); - - while (std::chrono::steady_clock::now() < end_time) { - std::vector batch; - std::vector expect_get; - batch.reserve(static_cast(opt.pipeline)); - for (int i = 0; i < opt.pipeline; ++i) { - int k = uniform(rng); - if (opt.workload == "hotset" || - opt.workload == "tier_on_large_values") { - const double x = std::pow(real(rng), 2.0); - k = static_cast(x * std::max(1, opt.keyspace / 10)); - } - bool do_set = false; - if (opt.workload == "writeheavy" || - opt.workload == "tier_on_pressure_demotion") - do_set = (real(rng) < 0.8); - else if (opt.workload == "mixed" || - opt.workload == "tier_off_ram_only") - do_set = (real(rng) < 0.35); - else if (opt.workload == "ttlheavy" || - opt.workload == "ttl_storm_with_tier") - do_set = true; - else if (opt.workload == "pipeline") - do_set = (i % 2 == 0); - std::string key = fixed_key(k, opt.key_size); - if (do_set) { - if (opt.workload == "ttlheavy") - batch.push_back(make_cmd({"SET", key, value, "PX", "200"})); - else - batch.push_back(make_cmd({"SET", key, value})); - expect_get.push_back(false); - } else { - batch.push_back(make_cmd({"GET", key})); - expect_get.push_back(true); - } - } - - auto t0 = std::chrono::steady_clock::now(); - for (const auto &cmd : batch) - send(fd, cmd.data(), cmd.size(), 0); - for (std::size_t i = 0; i < batch.size(); ++i) { - auto rep = read_reply(fd); - if (!rep) { - close(fd); - return; - } - auto t1 = std::chrono::steady_clock::now(); - if (std::chrono::steady_clock::now() >= warmup_end) { - std::lock_guard lk(shared.mu); - shared.latencies_us.push_back( - std::chrono::duration(t1 - t0).count()); - ++shared.ops; - if (expect_get[i]) { - ++shared.get_ops; - if (rep->find("200 OK") != std::string::npos) - ++shared.get_hits; - } else { - ++shared.set_ops; - } - } - } - } - close(fd); - }); - } - - for (auto &th : workers) - th.join(); - - int infofd = connect_server(opt); - std::uint64_t mem = 0, evictions = 0, admissions = 0, ram_hits = 0, - ssd_hits = 0, ssd_bytes = 0, index_rebuild_ms = 0; - double ssd_read_mb = 0.0, ssd_write_mb = 0.0, fragmentation = 0.0; - if (infofd >= 0) { - auto cmd = make_cmd({"INFO"}); - send(infofd, cmd.data(), cmd.size(), 0); - auto rep = read_reply(infofd); - if (rep) { - auto header_end = rep->find("\r\n\r\n"); - if (header_end != std::string::npos) { - std::string body = rep->substr(header_end + 4); - parse_info(body, mem, evictions, admissions, ram_hits, ssd_hits, - ssd_read_mb, ssd_write_mb, ssd_bytes, fragmentation, - index_rebuild_ms); - } else { - parse_info(*rep, mem, evictions, admissions, ram_hits, ssd_hits, - ssd_read_mb, ssd_write_mb, ssd_bytes, fragmentation, - index_rebuild_ms); - } - } - close(infofd); - } - - std::sort(shared.latencies_us.begin(), shared.latencies_us.end()); - auto pct = [&](double p) { - if (shared.latencies_us.empty()) - return 0.0; - return shared.latencies_us[static_cast( - p * (shared.latencies_us.size() - 1))]; - }; - const double run_secs = static_cast(opt.duration_s); - const double ops_s = - run_secs > 0 ? static_cast(shared.ops) / run_secs : 0.0; - const double hit_rate = shared.get_ops > 0 - ? static_cast(shared.get_hits) / - static_cast(shared.get_ops) - : 0.0; - - std::cout << std::fixed << std::setprecision(2) << "ops/s=" << ops_s - << " p50_us=" << pct(0.50) << " p95_us=" << pct(0.95) - << " p99_us=" << pct(0.99) << " p999_us=" << pct(0.999) - << " hit_rate=" << hit_rate << " ram_hits=" << ram_hits - << " ssd_hits=" << ssd_hits << " ssd_bytes=" << ssd_bytes - << " memory_used=" << mem << " evictions=" << evictions - << " admissions_rejected=" << admissions << "\n"; - - std::ofstream out(opt.json_out); - out << "{\n" - << " \"workload\": \"" << opt.workload << "\",\n" - << " \"ops_per_sec\": " << ops_s << ",\n" - << " \"p50_us\": " << pct(0.50) << ",\n" - << " \"p95_us\": " << pct(0.95) << ",\n" - << " \"p99_us\": " << pct(0.99) << ",\n" - << " \"p999_us\": " << pct(0.999) << ",\n" - << " \"hit_rate\": " << hit_rate << ",\n" - << " \"ram_hits\": " << ram_hits << ",\n" - << " \"ssd_hits\": " << ssd_hits << ",\n" - << " \"ssd_bytes\": " << ssd_bytes << ",\n" - << " \"ssd_read_mb\": " << ssd_read_mb << ",\n" - << " \"ssd_write_mb\": " << ssd_write_mb << ",\n" - << " \"ssd_index_rebuild_ms\": " << index_rebuild_ms << ",\n" - << " \"fragmentation_estimate\": " << fragmentation << ",\n" - << " \"memory_used_bytes\": " << mem << ",\n" - << " \"evictions_per_sec\": " - << (run_secs > 0 ? static_cast(evictions) / run_secs : 0.0) - << ",\n" - << " \"admissions_rejected_per_sec\": " - << (run_secs > 0 ? static_cast(admissions) / run_secs : 0.0) - << "\n" - << "}\n"; - return 0; -} diff --git a/bench/prompt_cache_bench.cpp b/bench/prompt_cache_bench.cpp new file mode 100644 index 0000000..fa9a5bd --- /dev/null +++ b/bench/prompt_cache_bench.cpp @@ -0,0 +1,305 @@ +#include "pomai_cache/ai_cache.hpp" +#include "pomai_cache/engine.hpp" +#include "pomai_cache/prompt_cache.hpp" +#include "pomaicache.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace pomai_cache; +using hrc = std::chrono::high_resolution_clock; + +struct LatencyStats { + double p50_us{0}; + double p95_us{0}; + double p99_us{0}; + double p999_us{0}; + double mean_us{0}; + double min_us{0}; + double max_us{0}; +}; + +static LatencyStats compute_stats(std::vector &samples) { + if (samples.empty()) + return {}; + std::sort(samples.begin(), samples.end()); + auto at = [&](double p) { + return samples[static_cast(p * (samples.size() - 1))]; + }; + double sum = std::accumulate(samples.begin(), samples.end(), 0.0); + return {at(0.50), at(0.95), at(0.99), at(0.999), + sum / samples.size(), samples.front(), samples.back()}; +} + +struct PromptBenchResult { + std::string name; + double ops_s{0}; + LatencyStats lat; + double hit_rate{0}; + double avg_savings_ratio{0}; +}; + +// Serialize a token sequence into a byte vector (little-endian u64s). +static std::vector +tokens_to_bytes(const std::vector &tokens) { + std::vector bytes(tokens.size() * sizeof(std::uint64_t)); + std::memcpy(bytes.data(), tokens.data(), bytes.size()); + return bytes; +} + +static PromptBenchResult +run_prompt_workload(const std::string &name, PromptCacheManager &pcm, + int ops, int hot_prefixes, int max_tokens, + double write_fraction) { + std::mt19937_64 rng(42); + std::uniform_int_distribution hot_dist(0, hot_prefixes - 1); + std::uniform_int_distribution len_dist(64, max_tokens); + std::bernoulli_distribution write_dist(write_fraction); + + std::vector latencies; + latencies.reserve(static_cast(ops)); + + std::uint64_t gets = 0; + std::uint64_t hits = 0; + double savings_sum = 0.0; + + const std::string tokenizer_id = "tok"; + + auto t_start = hrc::now(); + for (int i = 0; i < ops; ++i) { + const int prefix_id = hot_dist(rng); + const int total_tokens = len_dist(rng); + const int prefix_tokens = std::max(32, total_tokens / 2); + + // Build a deterministic "token" sequence so that all prompts sharing a + // prefix_id have a common prefix of prefix_tokens. + std::vector full_tokens(static_cast(total_tokens)); + for (int t = 0; t < total_tokens; ++t) { + if (t < prefix_tokens) + full_tokens[static_cast(t)] = + static_cast(prefix_id * 1000 + t); + else + full_tokens[static_cast(t)] = + static_cast(prefix_id * 1000 + 100 + t); + } + + std::vector prefix_tokens_vec( + full_tokens.begin(), + full_tokens.begin() + static_cast(prefix_tokens)); + + std::vector prefix_bytes = tokens_to_bytes(prefix_tokens_vec); + std::vector full_bytes = tokens_to_bytes(full_tokens); + + const std::string prefix_hash = "p" + std::to_string(prefix_id); + + auto t0 = hrc::now(); + + if (write_dist(rng)) { + // Store / refresh the prefix. + pcm.put_prefix(tokenizer_id, prefix_hash, prefix_bytes, + static_cast(prefix_tokens)); + } else { + // Attempt reuse for a full query that shares the same prefix bytes. + PromptReuseResult reuse = + pcm.reuse_for_query(tokenizer_id, prefix_hash, full_bytes); + ++gets; + if (reuse.hit) { + ++hits; + savings_sum += reuse.savings_ratio; + } + } + + auto t1 = hrc::now(); + latencies.push_back( + std::chrono::duration(t1 - t0).count()); + + // Periodic maintenance. + if ((i % 256) == 0) + pcm.tick(); + } + + double elapsed_sec = + std::chrono::duration(hrc::now() - t_start).count(); + + auto lat = compute_stats(latencies); + double hit_rate = gets ? static_cast(hits) / static_cast(gets) + : 0.0; + double avg_savings = gets ? (savings_sum / static_cast(gets)) : 0.0; + + return {name, + static_cast(ops) / elapsed_sec, + lat, + hit_rate, + avg_savings}; +} + +static PromptBenchResult +run_prompt_workload_embedded(const std::string &name, + pomaicache::PomaiCache &cache, + int ops, int hot_prefixes, int max_tokens, + double write_fraction) { + std::mt19937_64 rng(1337); + std::uniform_int_distribution hot_dist(0, hot_prefixes - 1); + std::uniform_int_distribution len_dist(32, max_tokens); + std::bernoulli_distribution write_dist(write_fraction); + + std::vector latencies; + latencies.reserve(static_cast(ops)); + + std::uint64_t gets = 0; + std::uint64_t hits = 0; + double savings_sum = 0.0; + + auto t_start = hrc::now(); + for (int i = 0; i < ops; ++i) { + const int prefix_id = hot_dist(rng); + const int total_tokens = len_dist(rng); + + std::vector tokens(static_cast(total_tokens)); + for (int t = 0; t < total_tokens; ++t) { + tokens[static_cast(t)] = + static_cast(prefix_id * 1000 + t); + } + + std::vector artifact_bytes( + static_cast(std::min(total_tokens * 2, 512))); + std::fill(artifact_bytes.begin(), artifact_bytes.end(), + static_cast(prefix_id)); + + auto t0 = hrc::now(); + + if (write_dist(rng)) { + std::span tok_span(tokens.data(), tokens.size()); + std::span art_span( + reinterpret_cast(artifact_bytes.data()), + artifact_bytes.size()); + cache.PromptPut(tok_span, art_span, pomaicache::Ttl{300000}); + } else { + std::span tok_span(tokens.data(), tokens.size()); + auto r = cache.PromptGet(tok_span); + ++gets; + if (r.hit) { + ++hits; + savings_sum += r.savings_ratio; + } + } + + auto t1 = hrc::now(); + latencies.push_back( + std::chrono::duration(t1 - t0).count()); + } + + double elapsed_sec = + std::chrono::duration(hrc::now() - t_start).count(); + + auto lat = compute_stats(latencies); + double hit_rate = gets ? static_cast(hits) / static_cast(gets) + : 0.0; + double avg_savings = gets ? (savings_sum / static_cast(gets)) : 0.0; + + return {name, + static_cast(ops) / elapsed_sec, + lat, + hit_rate, + avg_savings}; +} + +int main(int argc, char **argv) { + std::string json_out = "prompt_cache_bench.json"; + if (argc > 1) + json_out = argv[1]; + + EngineConfig cfg; + cfg.memory_limit_bytes = 64 * 1024 * 1024; + cfg.data_dir = "./data_prompt_bench"; + + auto policy = make_policy_by_name("pomai_cost"); + Engine engine(cfg, std::move(policy)); + AiArtifactCache ai(engine); + + PromptCacheConfig pcfg; + pcfg.enabled = true; + pcfg.default_ttl_ms = 5 * 60 * 1000; + pcfg.prefix_min_tokens = 32; + pcfg.max_cached_prefix_bytes = 16u * 1024u * 1024u; + + PromptCacheManager pcm(engine, ai, pcfg); + + std::cout << std::fixed << std::setprecision(2); + std::cout << "Embedded token/prompt cache benchmark (in-process, no network)\n"; + + std::vector results; + results.push_back(run_prompt_workload("chatty_short_sessions", + pcm, + 5000, + 200, // hot prefixes + 128, // max tokens per prompt + 0.30 // writes + )); + results.push_back(run_prompt_workload("long_lived_system_prompts", + pcm, + 5000, + 50, // fewer prefixes, more reuse + 256, // max tokens per prompt + 0.10 // mostly reads + )); + + // Embedded API workload using PomaiCache PromptPut/PromptGet. + pomaicache::Config embedded_cfg; + embedded_cfg.memory_limit_bytes = 64 * 1024 * 1024; + embedded_cfg.data_dir = "./data_prompt_bench_embedded"; + pomaicache::PomaiCache cache(embedded_cfg); + results.push_back(run_prompt_workload_embedded("embedded_api_hot_prompts", + cache, + 5000, + 100, + 256, + 0.25)); + + std::cout << "\n" + << std::left << std::setw(28) << "workload" + << std::right << std::setw(12) << "ops/s" + << std::setw(12) << "p50_us" + << std::setw(12) << "p95_us" + << std::setw(12) << "hit_rate" + << std::setw(16) << "avg_savings\n"; + std::cout << std::string(80, '-') << "\n"; + for (const auto &r : results) { + std::cout << std::left << std::setw(28) << r.name << std::right + << std::setw(12) << r.ops_s + << std::setw(12) << r.lat.p50_us + << std::setw(12) << r.lat.p95_us + << std::setw(12) << r.hit_rate + << std::setw(16) << r.avg_savings_ratio << "\n"; + } + + // JSON summary + std::ofstream jf(json_out); + jf << "{\n \"prompt_cache_workloads\": [\n"; + for (std::size_t i = 0; i < results.size(); ++i) { + const auto &r = results[i]; + jf << " {\"name\":\"" << r.name << "\"," + << "\"ops_s\":" << r.ops_s << "," + << "\"p50_us\":" << r.lat.p50_us << "," + << "\"p95_us\":" << r.lat.p95_us << "," + << "\"p99_us\":" << r.lat.p99_us << "," + << "\"hit_rate\":" << r.hit_rate << "," + << "\"avg_savings_ratio\":" << r.avg_savings_ratio << "}"; + if (i + 1 < results.size()) + jf << ","; + jf << "\n"; + } + jf << " ]\n}\n"; + + std::cout << "\nResults written to: " << json_out << "\n"; + return 0; +} + diff --git a/bench/vector_cache_bench.cpp b/bench/vector_cache_bench.cpp index 5cf5251..e7ed221 100644 --- a/bench/vector_cache_bench.cpp +++ b/bench/vector_cache_bench.cpp @@ -311,11 +311,10 @@ int main(int argc, char **argv) { std::cout << std::fixed << std::setprecision(2); // ======================================== - // SECTION 1: Raw Vector Index Performance + // SECTION 1: Raw Vector Index Performance (Embedded) // ======================================== - print_header("POMAI CACHE VECTOR BENCHMARK SUITE"); - std::cout << " Comparing against: Redis+RediSearch, Milvus, Qdrant, Weaviate, Pinecone\n"; - std::cout << " All measurements: single-thread, in-process (no network overhead)\n"; + print_header("POMAI CACHE EMBEDDED VECTOR BENCHMARK"); + std::cout << " All measurements: single-thread, in-process (no network, no RPC)\n"; print_separator(); std::vector configs = { @@ -349,7 +348,7 @@ int main(int argc, char **argv) { << r.bytes_per_vector << " bytes/vec)\n"; } - // Summary table + // Summary table (with reference-only estimates for networked systems) print_header("INSERT THROUGHPUT (vectors/sec)"); std::cout << std::left << std::setw(20) << "Config" << std::right << std::setw(14) << "Pomai Cache" << std::setw(14) << "Redis*" @@ -366,8 +365,9 @@ int main(int argc, char **argv) { << redis_est << std::setw(14) << milvus_est << std::setw(14) << qdrant_est << std::setw(14) << weaviate_est << "\n"; } - std::cout << "\n * Estimated from published benchmarks (network + indexing overhead).\n" - << " Pomai Cache: in-process flat index, zero network, zero serialization.\n"; + std::cout << "\n * Non-Pomai numbers are rough reference estimates for network-based systems\n" + << " (including network and indexing overhead). Pomai Cache runs embedded,\n" + << " co-located with the application (zero network, zero serialization).\n"; print_header("SEARCH LATENCY p50 (microseconds)"); std::cout << std::left << std::setw(20) << "Config" << std::right @@ -386,7 +386,7 @@ int main(int argc, char **argv) { } std::cout << "\n * Network-based systems add 200-5000us of network + serialization overhead.\n" << " Pinecone: managed cloud service, includes network RTT.\n" - << " Pomai Cache: co-located with application, sub-millisecond.\n"; + << " Pomai Cache: embedded in your process, typically sub-millisecond.\n"; // Memory efficiency print_header("MEMORY EFFICIENCY"); diff --git a/docs/AI_CACHE.md b/docs/AI_CACHE.md index 05808b1..3f118f4 100644 --- a/docs/AI_CACHE.md +++ b/docs/AI_CACHE.md @@ -44,7 +44,7 @@ Per-item TTL from metadata overrides defaults. SSD remains a warm cache tier: - async write-behind behavior in SSD store -- fsync default for server is `never` +- fsync default is `never` - queue pressure may drop writes - restart rebuild is best effort; corrupted tails are skipped diff --git a/docs/AI_COMMANDS.md b/docs/AI_COMMANDS.md index e781dab..fb14d86 100644 --- a/docs/AI_COMMANDS.md +++ b/docs/AI_COMMANDS.md @@ -1,35 +1,47 @@ -# AI Commands +# AI Commands (embedded) -All commands are RESP-compatible and usable with `redis-cli`. +The AI artifact cache is exposed as a C++ library (`pomai_cache::AiArtifactCache`) and via the higher-level `pomaicache::PomaiCache` / C / Python bindings. + +Below are examples using the C++ API directly. ## Store / fetch -```bash -redis-cli -p 6379 AI.PUT embedding emb:modelX:ih:768:float16 '{"artifact_type":"embedding","owner":"vector","schema_version":"v1","model_id":"modelX","snapshot_epoch":"ix42"}' "" -redis-cli -p 6379 AI.GET emb:modelX:ih:768:float16 -redis-cli -p 6379 AI.MGET emb:k1 emb:k2 emb:k3 -``` +```cpp +#include +#include + +pomai_cache::EngineConfig cfg; +cfg.memory_limit_bytes = 128 * 1024 * 1024; +cfg.data_dir = "./data"; -## Embedding helpers +pomai_cache::Engine engine(cfg, pomai_cache::make_policy_by_name("pomai_cost")); +pomai_cache::AiArtifactCache ai(engine); -```bash -redis-cli -p 6379 AI.EMB.PUT emb:modelX:ih:768:float16 modelX 768 float16 3600 "" -redis-cli -p 6379 AI.EMB.GET emb:modelX:ih:768:float16 +std::vector payload{/* bytes */}; +std::string meta = R"({"artifact_type":"embedding","owner":"vector","schema_version":"v1","model_id":"modelX","snapshot_epoch":"ix42"})"; + +ai.put("embedding", "emb:modelX:ih:768:float16", meta, payload); + +auto got = ai.get("emb:modelX:ih:768:float16"); +if (got) { + // use got->meta / got->payload +} ``` ## Invalidation -```bash -redis-cli -p 6379 AI.INVALIDATE EPOCH ix42 -redis-cli -p 6379 AI.INVALIDATE MODEL modelX -redis-cli -p 6379 AI.INVALIDATE PREFIX emb:modelX: +```cpp +ai.invalidate_epoch("ix42"); +ai.invalidate_model("modelX"); +ai.invalidate_prefix("emb:modelX:"); ``` ## Introspection -```bash -redis-cli -p 6379 AI.STATS -redis-cli -p 6379 AI.TOP HOT 20 -redis-cli -p 6379 AI.TOP COSTLY 20 -redis-cli -p 6379 AI.EXPLAIN emb:modelX:ih:768:float16 +```cpp +std::string stats = ai.stats(); +std::string hot = ai.top_hot(20); +std::string costly = ai.top_costly(20); +std::string explain = ai.explain("emb:modelX:ih:768:float16"); ``` + diff --git a/include/pomai_cache/engine_shard.hpp b/include/pomai_cache/engine_shard.hpp index eb93c2b..57d7624 100644 --- a/include/pomai_cache/engine_shard.hpp +++ b/include/pomai_cache/engine_shard.hpp @@ -3,6 +3,7 @@ #include "pomai_cache/ai_cache.hpp" #include "pomai_cache/engine.hpp" #include "pomai_cache/journal.hpp" +#include "pomai_cache/prompt_cache.hpp" #include #include @@ -24,8 +25,11 @@ namespace pomai_cache { class alignas(64) EngineShard { public: - EngineShard(std::uint32_t id, EngineConfig cfg, std::unique_ptr policy) - : id_(id), engine_(std::move(cfg), std::move(policy)) {} + EngineShard(std::uint32_t id, EngineConfig cfg, + std::unique_ptr policy, + PromptCacheConfig prompt_cfg = {}) + : id_(id), engine_(std::move(cfg), std::move(policy)), + ai_cache_(engine_), prompt_cache_(engine_, ai_cache_, prompt_cfg) {} // Forbidden copy/assignment to ensure memory stability EngineShard(const EngineShard&) = delete; @@ -35,9 +39,13 @@ class alignas(64) EngineShard { Engine& engine() { return engine_; } Journal& journal() { return journal_; } AiArtifactCache& ai_cache() { return ai_cache_; } + PromptCacheManager& prompt_cache() { return prompt_cache_; } - static void InitThreadLocal(std::uint32_t id, EngineConfig cfg, std::unique_ptr policy) { - tlocal_shard_ = new EngineShard(id, std::move(cfg), std::move(policy)); + static void InitThreadLocal(std::uint32_t id, EngineConfig cfg, + std::unique_ptr policy, + PromptCacheConfig prompt_cfg = {}) { + tlocal_shard_ = + new EngineShard(id, std::move(cfg), std::move(policy), prompt_cfg); } static void DestroyThreadLocal() { @@ -50,7 +58,8 @@ class alignas(64) EngineShard { private: std::uint32_t id_; Engine engine_; - AiArtifactCache ai_cache_{engine_}; + AiArtifactCache ai_cache_; + PromptCacheManager prompt_cache_; Journal journal_; static inline thread_local EngineShard* tlocal_shard_{nullptr}; diff --git a/include/pomai_cache/policy.hpp b/include/pomai_cache/policy.hpp index 48ab36e..f033f54 100644 --- a/include/pomai_cache/policy.hpp +++ b/include/pomai_cache/policy.hpp @@ -15,6 +15,7 @@ struct PolicyParams { double w_reuse{1.0}; double w_mem{1.0}; double w_risk{1.0}; + double prompt_reuse_weight{0.0}; double admit_threshold{0.0}; double evict_pressure{0.8}; std::uint64_t max_evictions_per_second{10000}; diff --git a/include/pomai_cache/prompt_cache.hpp b/include/pomai_cache/prompt_cache.hpp new file mode 100644 index 0000000..d1ac36b --- /dev/null +++ b/include/pomai_cache/prompt_cache.hpp @@ -0,0 +1,129 @@ +#pragma once + +#include "pomai_cache/ai_cache.hpp" +#include "pomai_cache/engine.hpp" + +#include +#include +#include +#include +#include +#include +#include + +namespace pomai_cache { + +// Configuration for prompt prefix caching. Tuned for edge / AI workloads. +struct PromptCacheConfig { + bool enabled{true}; + std::uint64_t default_ttl_ms{300'000}; // 5 minutes + std::size_t prefix_min_tokens{50}; // minimum tokens for reuse + std::size_t max_cached_prefix_bytes{16u * 1024u * 1024u}; // hard cap for RAM index +}; + +struct PromptReuseResult { + bool hit{false}; + std::string prompt_prefix_hash; + std::size_t cached_tokens{0}; + std::size_t suffix_tokens{0}; + double savings_ratio{0.0}; +}; + +struct PromptCacheStats { + std::uint64_t hits{0}; + std::uint64_t misses{0}; + std::uint64_t total_queries{0}; + std::uint64_t cached_prefix_bytes{0}; + double average_savings_ratio{0.0}; + std::uint64_t entry_count{0}; +}; + +// Manages prefix-based prompt caching for pre-tokenized prompts. +// +// Prompts are modeled as opaque byte sequences (e.g., serialized token ID +// arrays or partial embeddings). Callers are responsible for ensuring that +// the serialized representation obeys a prefix property: if a prefix P of a +// prompt Q is cached, then P's byte sequence is a prefix of Q's byte sequence. +// +// This manager uses AiArtifactCache + Engine for durable storage and SSD +// demotion, and keeps a compact in-memory index for fast longest-prefix +// lookup. It assumes single-threaded access per EngineShard. +class PromptCacheManager { +public: + PromptCacheManager(Engine &engine, AiArtifactCache &ai_cache, + PromptCacheConfig config); + + // Store a prompt prefix identified by (tokenizer_id, prompt_prefix_hash). + // `serialized_tokens` is the pre-tokenized representation, and + // `cached_tokens` is the logical token count (for metrics and thresholds). + // + // When ttl_ms is not provided, default_ttl_ms from config is used. + bool put_prefix(const std::string &tokenizer_id, + const std::string &prompt_prefix_hash, + const std::vector &serialized_tokens, + std::uint64_t cached_tokens, + std::optional ttl_ms = std::nullopt, + std::string *err = nullptr); + + // Attempt to reuse a cached prefix for a new query prompt identified by + // (tokenizer_id, prompt_full_hash) and its serialized token sequence. + // + // The manager scans cached prefixes for the same tokenizer and selects the + // longest prefix whose serialized bytes are a prefix of serialized_query. + // The effective minimum length is max(config.prefix_min_tokens, + // prefix_min_tokens_override.value_or(0)), compared against cached_tokens. + PromptReuseResult + reuse_for_query(const std::string &tokenizer_id, + const std::string &prompt_full_hash, + const std::vector &serialized_query, + std::optional prefix_min_tokens_override = + std::nullopt); + + // Invalidate a single cached prefix by its tokenizer_id + prefix hash. + std::size_t invalidate_prefix(const std::string &tokenizer_id, + const std::string &prompt_prefix_hash); + + // Periodic maintenance for TTL expiration and resource caps. + void tick(); + + PromptCacheStats stats() const; + +private: + struct PrefixEntry { + std::string canonical_key; + std::string tokenizer_id; + std::string prompt_prefix_hash; + std::uint64_t cached_tokens{0}; + std::uint64_t reuse_count{0}; + std::uint64_t size_bytes{0}; + std::uint64_t expiry_epoch_ms{0}; + }; + + struct ExpiryNode { + std::uint64_t expiry_epoch_ms; + std::string canonical_key; + std::uint64_t generation; + bool operator>(const ExpiryNode &other) const { + return expiry_epoch_ms > other.expiry_epoch_ms; + } + }; + + void maybe_expire(); + std::uint64_t now_ms() const; + std::size_t key_count_for_tokenizer(const std::string &tokenizer_id) const; + + Engine &engine_; + AiArtifactCache &ai_cache_; + PromptCacheConfig cfg_; + + std::unordered_map entries_; + std::unordered_map> by_tokenizer_; + std::priority_queue, std::greater> + expiry_heap_; + std::unordered_map expiry_generation_; + + PromptCacheStats stats_{}; +}; + +} // namespace pomai_cache + diff --git a/include/pomaicache.h b/include/pomaicache.h new file mode 100644 index 0000000..86e939c --- /dev/null +++ b/include/pomaicache.h @@ -0,0 +1,61 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace pomaicache { + +struct Config { + std::size_t memory_limit_bytes{128 * 1024 * 1024}; + std::string data_dir{"./data"}; +}; + +struct Ttl { + std::uint64_t ms{0}; +}; + +struct PromptResult { + bool hit{false}; + std::uint64_t cached_tokens{0}; + std::uint64_t suffix_tokens{0}; + double savings_ratio{0.0}; +}; + +class PomaiCacheImpl; + +class PomaiCache { +public: + explicit PomaiCache(const Config &cfg); + ~PomaiCache(); + + PomaiCache(const PomaiCache &) = delete; + PomaiCache &operator=(const PomaiCache &) = delete; + PomaiCache(PomaiCache &&) noexcept; + PomaiCache &operator=(PomaiCache &&) noexcept; + + // Core K/V + bool Set(std::string_view key, + std::span value, + Ttl ttl); + + std::optional> Get(std::string_view key); + + // AI Prompt Caching + bool PromptPut(std::span tokens, + std::span artifact, + Ttl ttl); + + PromptResult PromptGet(std::span tokens); + +private: + std::unique_ptr impl_; +}; + +} // namespace pomaicache + diff --git a/include/pomaicache_c.h b/include/pomaicache_c.h new file mode 100644 index 0000000..a5a6138 --- /dev/null +++ b/include/pomaicache_c.h @@ -0,0 +1,50 @@ +#pragma once + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct pomai_t pomai_t; + +typedef struct { + uint64_t memory_limit_bytes; + const char *data_dir; +} pomai_config_t; + +typedef struct { + uint8_t hit; + uint64_t cached_tokens; + uint64_t suffix_tokens; + double savings_ratio; +} pomai_prompt_result_t; + +pomai_t *pomai_create(const pomai_config_t *cfg); +void pomai_destroy(pomai_t *db); + +int pomai_set(pomai_t *db, + const char *key, size_t key_len, + const void *value, size_t value_len, + uint64_t ttl_ms); + +int pomai_get(pomai_t *db, + const char *key, size_t key_len, + void **out_value, size_t *out_len); + +void pomai_free(void *ptr); + +int pomai_prompt_put(pomai_t *db, + const uint64_t *tokens, size_t len, + const void *artifact, size_t artifact_len, + uint64_t ttl_ms); + +int pomai_prompt_get(pomai_t *db, + const uint64_t *tokens, size_t len, + pomai_prompt_result_t *out); + +#ifdef __cplusplus +} +#endif + diff --git a/sdk/python/pomai_cache/__init__.py b/sdk/python/pomai_cache/__init__.py deleted file mode 100644 index 05dad8b..0000000 --- a/sdk/python/pomai_cache/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Pomai Cache Python SDK — AI-first cache client.""" - -from pomai_cache.client import PomaiCache, AsyncPomaiCache -from pomai_cache.decorators import memoize - -__version__ = "0.1.0" -__all__ = ["PomaiCache", "AsyncPomaiCache", "memoize"] diff --git a/sdk/python/pomai_cache/client.py b/sdk/python/pomai_cache/client.py deleted file mode 100644 index 36b0a1c..0000000 --- a/sdk/python/pomai_cache/client.py +++ /dev/null @@ -1,309 +0,0 @@ -"""Pomai Cache client with AI-first semantics.""" - -from __future__ import annotations - -import json -import socket -import struct -from typing import Any, Dict, List, Optional, Sequence, Tuple - -from pomai_cache.resp import encode_command, read_reply - -try: - import numpy as np - - HAS_NUMPY = True -except ImportError: - HAS_NUMPY = False - -try: - from opentelemetry import trace - - HAS_OTEL = True -except ImportError: - HAS_OTEL = False - - -def _vector_to_str(vector) -> str: - """Convert a vector (list, numpy array, or torch tensor) to comma-separated string.""" - if HAS_NUMPY and isinstance(vector, np.ndarray): - vector = vector.astype(float).flatten().tolist() - elif hasattr(vector, "detach"): - vector = vector.detach().cpu().numpy().astype(float).flatten().tolist() - return ",".join(str(float(v)) for v in vector) - - -def _bytes_to_payload(data: bytes | str) -> str: - if isinstance(data, bytes): - return data.decode("utf-8", errors="replace") - return data - - -class PomaiCache: - """Synchronous Pomai Cache client.""" - - def __init__(self, host: str = "127.0.0.1", port: int = 6379, - timeout: float = 5.0, password: Optional[str] = None): - self._host = host - self._port = port - self._timeout = timeout - self._sock: Optional[socket.socket] = None - self._password = password - self._tracer = None - if HAS_OTEL: - self._tracer = trace.get_tracer("pomai_cache") - - def connect(self) -> "PomaiCache": - self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self._sock.settimeout(self._timeout) - self._sock.connect((self._host, self._port)) - if self._password: - self._execute("AUTH", self._password) - return self - - def close(self): - if self._sock: - self._sock.close() - self._sock = None - - def __enter__(self): - return self.connect() - - def __exit__(self, *args): - self.close() - - def _execute(self, *args: str) -> Any: - if not self._sock: - self.connect() - self._sock.sendall(encode_command(*args)) - return read_reply(self._sock) - - def _traced(self, op_name: str, fn, *args, **kwargs): - if self._tracer: - with self._tracer.start_as_current_span(f"pomai_cache.{op_name}"): - return fn(*args, **kwargs) - return fn(*args, **kwargs) - - # --- Standard KV --- - - def get(self, key: str) -> Optional[str]: - return self._traced("get", self._execute, "GET", key) - - def set(self, key: str, value: str, ttl_ms: Optional[int] = None) -> str: - if ttl_ms is not None: - return self._traced("set", self._execute, "SET", key, value, "PX", str(ttl_ms)) - return self._traced("set", self._execute, "SET", key, value) - - def delete(self, *keys: str) -> int: - return self._traced("delete", self._execute, "DEL", *keys) - - # --- AI Artifact Operations --- - - def put_artifact(self, artifact_type: str, key: str, meta: Dict[str, Any], - payload: bytes | str, depends_on: Optional[List[str]] = None) -> str: - meta.setdefault("artifact_type", artifact_type) - meta.setdefault("owner", "default") - meta.setdefault("schema_version", "v1") - meta_json = json.dumps(meta) - payload_str = _bytes_to_payload(payload) - args = ["AI.PUT", artifact_type, key, meta_json, payload_str] - if depends_on: - args.append("DEPENDS_ON") - args.extend(depends_on) - return self._traced("put_artifact", self._execute, *args) - - def get_artifact(self, key: str) -> Optional[Tuple[Dict[str, Any], bytes]]: - result = self._traced("get_artifact", self._execute, "AI.GET", key) - if result is None: - return None - meta = json.loads(result[0]) - payload = result[1].encode() if isinstance(result[1], str) else result[1] - return meta, payload - - def put_embedding(self, model_id: str, input_hash: str, vector, - payload: bytes | str = b"", dim: Optional[int] = None, - dtype: str = "float32", ttl_ms: Optional[int] = None, - **extra_meta) -> str: - if dim is None: - if HAS_NUMPY and isinstance(vector, np.ndarray): - dim = vector.shape[-1] - elif hasattr(vector, "shape"): - dim = vector.shape[-1] - else: - dim = len(vector) - - key = f"emb:{model_id}:{input_hash}:{dim}:{dtype}" - meta = { - "artifact_type": "embedding", - "owner": "vector", - "schema_version": "v1", - "model_id": model_id, - **extra_meta, - } - if ttl_ms: - meta["ttl_deadline"] = ttl_ms - return self.put_artifact("embedding", key, meta, payload) - - def put_response(self, prompt_hash: str, params_hash: str, model_id: str, - response: str, inference_tokens: int = 0, - dollar_cost: float = 0.0, **extra_meta) -> str: - key = f"rsp:{prompt_hash}:{params_hash}:{model_id}" - meta = { - "artifact_type": "response", - "owner": "response", - "schema_version": "v1", - "model_id": model_id, - "inference_tokens": inference_tokens, - "dollar_cost": dollar_cost, - **extra_meta, - } - return self.put_artifact("response", key, meta, response) - - # --- Similarity Search --- - - def sim_put(self, key: str, vector, payload: bytes | str, - meta: Optional[Dict[str, Any]] = None) -> str: - vec_str = _vector_to_str(vector) - payload_str = _bytes_to_payload(payload) - meta_json = json.dumps(meta) if meta else "" - args = ["AI.SIM.PUT", key, vec_str, payload_str] - if meta_json: - args.append(meta_json) - return self._traced("sim_put", self._execute, *args) - - def sim_get(self, vector, top_k: int = 1, - threshold: float = 0.9) -> List[Dict[str, Any]]: - vec_str = _vector_to_str(vector) - result = self._traced( - "sim_get", self._execute, - "AI.SIM.GET", vec_str, "TOPK", str(top_k), "THRESHOLD", str(threshold), - ) - if not result: - return [] - results = [] - for i in range(0, len(result), 4): - results.append({ - "key": result[i], - "score": float(result[i + 1]), - "meta": json.loads(result[i + 2]), - "payload": result[i + 3], - }) - return results - - # --- Streaming --- - - def stream_begin(self, key: str, meta: Dict[str, Any]) -> str: - meta.setdefault("artifact_type", "response") - meta.setdefault("owner", "response") - meta.setdefault("schema_version", "v1") - return self._traced("stream_begin", self._execute, - "AI.STREAM.BEGIN", key, json.dumps(meta)) - - def stream_append(self, key: str, chunk: str | bytes) -> str: - return self._traced("stream_append", self._execute, - "AI.STREAM.APPEND", key, _bytes_to_payload(chunk)) - - def stream_end(self, key: str) -> str: - return self._traced("stream_end", self._execute, "AI.STREAM.END", key) - - def stream_get(self, key: str) -> Optional[Tuple[Dict[str, Any], bytes]]: - result = self._traced("stream_get", self._execute, "AI.STREAM.GET", key) - if result is None: - return None - meta = json.loads(result[0]) - payload = result[1].encode() if isinstance(result[1], str) else result[1] - return meta, payload - - # --- Invalidation --- - - def invalidate_epoch(self, epoch: str) -> int: - return self._traced("invalidate", self._execute, - "AI.INVALIDATE", "EPOCH", epoch) - - def invalidate_model(self, model_id: str) -> int: - return self._traced("invalidate", self._execute, - "AI.INVALIDATE", "MODEL", model_id) - - def invalidate_cascade(self, key: str) -> int: - return self._traced("invalidate_cascade", self._execute, - "AI.INVALIDATE", "CASCADE", key) - - # --- Cost & Budget --- - - def cost_report(self) -> Dict[str, Any]: - result = self._traced("cost_report", self._execute, "AI.COST.REPORT") - report = {} - for line in result.strip().split("\n"): - if ":" in line: - k, v = line.split(":", 1) - try: - report[k] = float(v) if "." in v else int(v) - except ValueError: - report[k] = v - return report - - def set_budget(self, max_dollar_per_hour: float) -> str: - return self._traced("set_budget", self._execute, - "AI.BUDGET", str(max_dollar_per_hour)) - - # --- Info --- - - def stats(self) -> str: - return self._traced("stats", self._execute, "AI.STATS") - - def info(self) -> str: - return self._traced("info", self._execute, "INFO") - - -class AsyncPomaiCache: - """Async wrapper using asyncio. Requires an event loop.""" - - def __init__(self, host: str = "127.0.0.1", port: int = 6379, - timeout: float = 5.0, password: Optional[str] = None): - self._sync = PomaiCache(host, port, timeout, password) - - async def connect(self) -> "AsyncPomaiCache": - import asyncio - loop = asyncio.get_event_loop() - await loop.run_in_executor(None, self._sync.connect) - return self - - async def close(self): - import asyncio - loop = asyncio.get_event_loop() - await loop.run_in_executor(None, self._sync.close) - - async def __aenter__(self): - return await self.connect() - - async def __aexit__(self, *args): - await self.close() - - async def _run(self, method, *args, **kwargs): - import asyncio - loop = asyncio.get_event_loop() - return await loop.run_in_executor(None, lambda: method(*args, **kwargs)) - - async def get(self, key: str): - return await self._run(self._sync.get, key) - - async def set(self, key: str, value: str, ttl_ms: Optional[int] = None): - return await self._run(self._sync.set, key, value, ttl_ms) - - async def put_artifact(self, artifact_type: str, key: str, meta: Dict, payload, **kw): - return await self._run(self._sync.put_artifact, artifact_type, key, meta, payload, **kw) - - async def get_artifact(self, key: str): - return await self._run(self._sync.get_artifact, key) - - async def sim_put(self, key: str, vector, payload, meta=None): - return await self._run(self._sync.sim_put, key, vector, payload, meta) - - async def sim_get(self, vector, top_k: int = 1, threshold: float = 0.9): - return await self._run(self._sync.sim_get, vector, top_k, threshold) - - async def cost_report(self): - return await self._run(self._sync.cost_report) - - async def stats(self): - return await self._run(self._sync.stats) diff --git a/sdk/python/pomai_cache/decorators.py b/sdk/python/pomai_cache/decorators.py deleted file mode 100644 index c791bcb..0000000 --- a/sdk/python/pomai_cache/decorators.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Decorator support for Pomai Cache memoization.""" - -from __future__ import annotations - -import functools -import hashlib -import json -from typing import Any, Callable, Optional - - -def memoize(cache=None, artifact_type: str = "response", - model_id: str = "", ttl_ms: Optional[int] = None): - """Decorator that caches function results in Pomai Cache. - - Usage: - cache = PomaiCache("localhost", 6379) - cache.connect() - - @memoize(cache=cache, artifact_type="response", model_id="gpt-4") - def generate(prompt: str) -> str: - return call_llm(prompt) - """ - def decorator(fn: Callable) -> Callable: - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if cache is None: - return fn(*args, **kwargs) - - sig = json.dumps({"args": [str(a) for a in args], - "kwargs": {k: str(v) for k, v in sorted(kwargs.items())}}, - sort_keys=True) - key_hash = hashlib.sha256(sig.encode()).hexdigest()[:16] - cache_key = f"memo:{fn.__name__}:{key_hash}" - - existing = cache.get_artifact(cache_key) - if existing is not None: - _, payload = existing - return payload.decode() if isinstance(payload, bytes) else payload - - result = fn(*args, **kwargs) - - meta = { - "artifact_type": artifact_type, - "owner": "default", - "schema_version": "v1", - } - if model_id: - meta["model_id"] = model_id - - payload = result if isinstance(result, (str, bytes)) else json.dumps(result) - try: - cache.put_artifact(artifact_type, cache_key, meta, payload) - except Exception: - pass - - return result - return wrapper - return decorator diff --git a/sdk/python/pomai_cache/resp.py b/sdk/python/pomai_cache/resp.py deleted file mode 100644 index 2fe60f2..0000000 --- a/sdk/python/pomai_cache/resp.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Minimal RESP protocol encoder/decoder for Pomai Cache.""" - -from __future__ import annotations - -import socket -from typing import Any - - -def encode_command(*args: str) -> bytes: - """Encode a command as a RESP array of bulk strings.""" - parts = [f"*{len(args)}\r\n"] - for a in args: - encoded = a if isinstance(a, str) else str(a) - parts.append(f"${len(encoded.encode())}\r\n{encoded}\r\n") - return "".join(parts).encode() - - -def read_line(sock: socket.socket) -> str: - buf = b"" - while not buf.endswith(b"\r\n"): - ch = sock.recv(1) - if not ch: - raise ConnectionError("Connection closed") - buf += ch - return buf[:-2].decode() - - -def read_reply(sock: socket.socket) -> Any: - """Read a single RESP reply from the socket.""" - line = read_line(sock) - prefix = line[0] - body = line[1:] - - if prefix == "+": - return body - elif prefix == "-": - raise RuntimeError(f"RESP error: {body}") - elif prefix == ":": - return int(body) - elif prefix == "$": - length = int(body) - if length < 0: - return None - data = b"" - while len(data) < length + 2: - chunk = sock.recv(length + 2 - len(data)) - if not chunk: - raise ConnectionError("Connection closed") - data += chunk - return data[:-2].decode() - elif prefix == "*": - count = int(body) - if count < 0: - return None - return [read_reply(sock) for _ in range(count)] - else: - return line diff --git a/sdk/python/pyproject.toml b/sdk/python/pyproject.toml deleted file mode 100644 index 449da14..0000000 --- a/sdk/python/pyproject.toml +++ /dev/null @@ -1,22 +0,0 @@ -[build-system] -requires = ["setuptools>=68.0", "wheel"] -build-backend = "setuptools.build_meta" - -[project] -name = "pomai-cache" -version = "0.1.0" -description = "Python SDK for Pomai Cache — the AI-first cache" -readme = "README.md" -requires-python = ">=3.9" -license = {text = "MIT"} -authors = [{name = "Pomai Cache Team"}] -dependencies = [] - -[project.optional-dependencies] -numpy = ["numpy>=1.24"] -torch = ["torch>=2.0"] -otel = ["opentelemetry-api>=1.20"] -all = ["numpy>=1.24", "opentelemetry-api>=1.20"] - -[tool.setuptools.packages.find] -include = ["pomai_cache*"] diff --git a/src/bindings/c_api.cc b/src/bindings/c_api.cc new file mode 100644 index 0000000..3a33b8c --- /dev/null +++ b/src/bindings/c_api.cc @@ -0,0 +1,97 @@ +#include "pomaicache_c.h" +#include "pomaicache.h" + +#include +#include +#include + +struct pomai_t { + pomaicache::PomaiCache impl; + + explicit pomai_t(const pomaicache::Config &cfg) : impl(cfg) {} +}; + +extern "C" { + +pomai_t *pomai_create(const pomai_config_t *cfg) { + if (!cfg) + return nullptr; + pomaicache::Config c; + c.memory_limit_bytes = static_cast(cfg->memory_limit_bytes); + if (cfg->data_dir) + c.data_dir = cfg->data_dir; + try { + return new pomai_t(c); + } catch (...) { + return nullptr; + } +} + +void pomai_destroy(pomai_t *db) { + delete db; +} + +int pomai_set(pomai_t *db, + const char *key, size_t key_len, + const void *value, size_t value_len, + uint64_t ttl_ms) { + if (!db || !key || !value) + return 0; + std::string_view k(key, key_len); + auto *bytes = static_cast(value); + std::span v(bytes, value_len); + pomaicache::Ttl ttl{ttl_ms}; + return db->impl.Set(k, v, ttl) ? 1 : 0; +} + +int pomai_get(pomai_t *db, + const char *key, size_t key_len, + void **out_value, size_t *out_len) { + if (!db || !key || !out_value || !out_len) + return 0; + std::string_view k(key, key_len); + auto v = db->impl.Get(k); + if (!v.has_value()) + return 0; + auto &vec = *v; + auto *buf = new std::byte[vec.size()]; + std::memcpy(buf, vec.data(), vec.size()); + *out_value = buf; + *out_len = vec.size(); + return 1; +} + +void pomai_free(void *ptr) { + auto *b = static_cast(ptr); + delete[] b; +} + +int pomai_prompt_put(pomai_t *db, + const uint64_t *tokens, size_t len, + const void *artifact, size_t artifact_len, + uint64_t ttl_ms) { + if (!db || !tokens || !artifact) + return 0; + std::span t(tokens, len); + auto *bytes = static_cast(artifact); + std::span a(bytes, artifact_len); + pomaicache::Ttl ttl{ttl_ms}; + return db->impl.PromptPut(t, a, ttl) ? 1 : 0; +} + +int pomai_prompt_get(pomai_t *db, + const uint64_t *tokens, size_t len, + pomai_prompt_result_t *out) { + if (!db || !tokens || !out) + return 0; + std::span t(tokens, len); + auto r = db->impl.PromptGet(t); + out->hit = r.hit ? 1 : 0; + out->cached_tokens = r.cached_tokens; + out->suffix_tokens = r.suffix_tokens; + out->savings_ratio = r.savings_ratio; + return 1; +} + +} // extern "C" + diff --git a/src/bindings/python_bindings.cc b/src/bindings/python_bindings.cc new file mode 100644 index 0000000..a4e822a --- /dev/null +++ b/src/bindings/python_bindings.cc @@ -0,0 +1,75 @@ +#include "pomaicache_c.h" + +#include +#include +#include + +namespace py = pybind11; + +namespace { + +class PyCache { +public: + explicit PyCache(const std::string &data_dir, std::uint64_t memory_limit_bytes) { + pomai_config_t cfg{}; + cfg.memory_limit_bytes = memory_limit_bytes; + cfg.data_dir = data_dir.c_str(); + handle_ = pomai_create(&cfg); + if (!handle_) { + throw std::runtime_error("pomai_create failed"); + } + } + + ~PyCache() { + if (handle_) + pomai_destroy(handle_); + handle_ = nullptr; + } + + void prompt_put(const std::vector &tokens, + py::buffer artifact, + std::uint64_t ttl_ms) { + py::buffer_info info = artifact.request(); + if (info.ndim != 1) { + throw std::runtime_error("artifact must be 1D buffer"); + } + auto *data = static_cast(info.ptr); + const auto len = static_cast(info.size * info.itemsize); + if (!pomai_prompt_put(handle_, tokens.data(), tokens.size(), data, len, + ttl_ms)) { + throw std::runtime_error("pomai_prompt_put failed"); + } + } + + py::dict prompt_get(const std::vector &tokens) { + pomai_prompt_result_t out{}; + if (!pomai_prompt_get(handle_, tokens.data(), tokens.size(), &out)) { + throw std::runtime_error("pomai_prompt_get failed"); + } + py::dict d; + d["hit"] = static_cast(out.hit); + d["cached_tokens"] = out.cached_tokens; + d["suffix_tokens"] = out.suffix_tokens; + d["savings_ratio"] = out.savings_ratio; + return d; + } + +private: + pomai_t *handle_{nullptr}; +}; + +} // namespace + +PYBIND11_MODULE(pomaicache, m) { + py::class_(m, "Cache") + .def(py::init(), + py::arg("data_dir") = "./data", + py::arg("memory_limit_bytes") = 128 * 1024 * 1024) + .def("prompt_put", &PyCache::prompt_put, + py::arg("tokens"), + py::arg("artifact"), + py::arg("ttl_ms") = 300000) + .def("prompt_get", &PyCache::prompt_get, + py::arg("tokens")); +} + diff --git a/src/engine/engine.cpp b/src/engine/engine.cpp index 7e4b7b3..976eae4 100644 --- a/src/engine/engine.cpp +++ b/src/engine/engine.cpp @@ -414,6 +414,8 @@ bool Engine::reload_params(const std::string &path, std::string *err) { p.w_mem = clamp_d(d, 0.0, 1000.0); if (extract_double(text, "w_risk", d)) p.w_risk = clamp_d(d, 0.0, 1000.0); + if (extract_double(text, "prompt_reuse_weight", d)) + p.prompt_reuse_weight = clamp_d(d, 0.0, 1000.0); if (extract_double(text, "admit_threshold", d)) p.admit_threshold = clamp_d(d, -1e9, 1e9); if (extract_double(text, "evict_pressure", d)) diff --git a/src/engine/prompt_cache.cpp b/src/engine/prompt_cache.cpp new file mode 100644 index 0000000..b7a7b16 --- /dev/null +++ b/src/engine/prompt_cache.cpp @@ -0,0 +1,267 @@ +#include "pomai_cache/prompt_cache.hpp" + +#include +#include +#include + +namespace pomai_cache { +namespace { + +std::uint64_t to_ms(TimePoint tp) { + return static_cast( + std::chrono::duration_cast( + tp.time_since_epoch()) + .count()); +} + +} // namespace + +PromptCacheManager::PromptCacheManager(Engine &engine, AiArtifactCache &ai_cache, + PromptCacheConfig config) + : engine_(engine), ai_cache_(ai_cache), cfg_(std::move(config)) {} + +bool PromptCacheManager::put_prefix( + const std::string &tokenizer_id, const std::string &prompt_prefix_hash, + const std::vector &serialized_tokens, + std::uint64_t cached_tokens, std::optional ttl_ms, + std::string *err) { + if (!cfg_.enabled) + return false; + if (serialized_tokens.empty() || cached_tokens == 0) + return false; + + const auto now = Clock::now(); + const auto ttl = ttl_ms.value_or(cfg_.default_ttl_ms); + const auto expiry = to_ms(now) + ttl; + + const auto canonical_key = + canonical_prompt_key(tokenizer_id, prompt_prefix_hash); + + // Prepare AI metadata for the prefix payload. We model cached prompt + // prefixes as regular `prompt` artifacts so they participate in the same + // eviction and TTL behaviour as other prompt entries. + ArtifactMeta meta; + meta.artifact_type = "prompt"; + meta.owner = "prompt"; + meta.schema_version = "v1"; + meta.tokenizer_id = tokenizer_id; + meta.created_at_ms = to_ms(now); + meta.ttl_ms = ttl; + meta.size_bytes = serialized_tokens.size(); + meta.content_hash = AiArtifactCache::fast_hash_hex(serialized_tokens); + + // Treat cached prefix tokens as 10x cheaper than re-generating them from the + // model. This is exposed through AI.STATS and cost_report. + meta.inference_tokens = cached_tokens; + meta.miss_cost = 2.0; // logical miss cost for a prompt + meta.dollar_cost = (meta.miss_cost * static_cast(cached_tokens) * + 0.001) / + 10.0; + + const auto meta_json = AiArtifactCache::meta_to_json(meta); + + std::string local_err; + std::string *put_err = err ? err : &local_err; + if (!ai_cache_.put("prompt", canonical_key, meta_json, serialized_tokens, + put_err)) { + return false; + } + + auto &entry = entries_[canonical_key]; + entry.canonical_key = canonical_key; + entry.tokenizer_id = tokenizer_id; + entry.prompt_prefix_hash = prompt_prefix_hash; + entry.cached_tokens = cached_tokens; + entry.size_bytes = serialized_tokens.size(); + entry.expiry_epoch_ms = expiry; + + auto &vec = by_tokenizer_[tokenizer_id]; + if (std::find(vec.begin(), vec.end(), canonical_key) == vec.end()) + vec.push_back(canonical_key); + + const auto gen = ++expiry_generation_[canonical_key]; + expiry_heap_.push({expiry, canonical_key, gen}); + + stats_.cached_prefix_bytes += serialized_tokens.size(); + stats_.entry_count = entries_.size(); + return true; +} + +PromptReuseResult PromptCacheManager::reuse_for_query( + const std::string &tokenizer_id, const std::string &prompt_full_hash, + const std::vector &serialized_query, + std::optional prefix_min_tokens_override) { + PromptReuseResult result; + if (!cfg_.enabled || serialized_query.empty()) + return result; + + maybe_expire(); + + ++stats_.total_queries; + + auto it_vec = by_tokenizer_.find(tokenizer_id); + if (it_vec == by_tokenizer_.end()) { + ++stats_.misses; + return result; + } + + const std::size_t min_tokens = std::max( + cfg_.prefix_min_tokens, prefix_min_tokens_override.value_or(0)); + + std::size_t best_tokens = 0; + std::string best_key; + + for (const auto &key : it_vec->second) { + auto it = entries_.find(key); + if (it == entries_.end()) + continue; + auto &e = it->second; + if (e.cached_tokens < min_tokens) + continue; + + auto val = ai_cache_.get(key); + if (!val.has_value()) + continue; + const auto &prefix_bytes = val->payload; + if (prefix_bytes.size() > serialized_query.size()) + continue; + + // Byte-wise prefix check against the cached payload. Callers guarantee that + // the serialized representation preserves prompt-prefix relationships at + // the byte level, so this is sufficient to validate reuse. + bool is_prefix = std::equal(prefix_bytes.begin(), prefix_bytes.end(), + serialized_query.begin()); + if (!is_prefix) + continue; + if (e.cached_tokens > best_tokens) { + best_tokens = static_cast(e.cached_tokens); + best_key = key; + } + } + + if (best_key.empty()) { + ++stats_.misses; + return result; + } + + auto &entry = entries_[best_key]; + entry.reuse_count += 1; + + const std::size_t suffix_tokens = + best_tokens >= serialized_query.size() + ? 0 + : static_cast(serialized_query.size() - best_tokens); + + result.hit = true; + result.prompt_prefix_hash = entry.prompt_prefix_hash; + result.cached_tokens = best_tokens; + result.suffix_tokens = suffix_tokens; + const auto denom = + static_cast(best_tokens + std::max(1, suffix_tokens)); + result.savings_ratio = static_cast(best_tokens) / denom; + + ++stats_.hits; + // Track cached token reuse as discounted cost in savings ratio. + const double total_q = + static_cast(std::max(1, stats_.total_queries)); + const double prev_sum = + stats_.average_savings_ratio * (total_q - 1.0); + stats_.average_savings_ratio = (prev_sum + result.savings_ratio) / total_q; + + // Touch underlying engine entry so eviction policy sees reuse. + engine_.get(best_key); + + return result; +} + +std::size_t PromptCacheManager::invalidate_prefix( + const std::string &tokenizer_id, const std::string &prompt_prefix_hash) { + const auto canonical_key = + canonical_prompt_key(tokenizer_id, prompt_prefix_hash); + + auto it = entries_.find(canonical_key); + if (it == entries_.end()) + return 0; + + auto &entry = it->second; + if (stats_.cached_prefix_bytes >= entry.size_bytes) + stats_.cached_prefix_bytes -= entry.size_bytes; + + auto vec_it = by_tokenizer_.find(tokenizer_id); + if (vec_it != by_tokenizer_.end()) { + auto &v = vec_it->second; + v.erase(std::remove(v.begin(), v.end(), canonical_key), v.end()); + if (v.empty()) + by_tokenizer_.erase(vec_it); + } + + entries_.erase(it); + stats_.entry_count = entries_.size(); + + // Best-effort invalidation of the backing AI artifact. + ai_cache_.invalidate_prefix(canonical_key); + return 1; +} + +void PromptCacheManager::tick() { + if (!cfg_.enabled) + return; + maybe_expire(); +} + +PromptCacheStats PromptCacheManager::stats() const { return stats_; } + +void PromptCacheManager::maybe_expire() { + const auto now = now_ms(); + std::size_t cleaned = 0; + constexpr std::size_t kMaxPerTick = 256; + + while (!expiry_heap_.empty() && cleaned < kMaxPerTick) { + const auto &node = expiry_heap_.top(); + if (node.expiry_epoch_ms > now) + break; + const auto key = node.canonical_key; + const auto gen = node.generation; + expiry_heap_.pop(); + + auto it_gen = expiry_generation_.find(key); + if (it_gen == expiry_generation_.end() || it_gen->second != gen) + continue; + + auto it = entries_.find(key); + if (it == entries_.end()) + continue; + + auto &entry = it->second; + if (stats_.cached_prefix_bytes >= entry.size_bytes) + stats_.cached_prefix_bytes -= entry.size_bytes; + + auto vec_it = by_tokenizer_.find(entry.tokenizer_id); + if (vec_it != by_tokenizer_.end()) { + auto &v = vec_it->second; + v.erase(std::remove(v.begin(), v.end(), key), v.end()); + if (v.empty()) + by_tokenizer_.erase(vec_it); + } + + entries_.erase(it); + expiry_generation_.erase(key); + ++cleaned; + } + stats_.entry_count = entries_.size(); +} + +std::uint64_t PromptCacheManager::now_ms() const { + return to_ms(Clock::now()); +} + +std::size_t +PromptCacheManager::key_count_for_tokenizer(const std::string &tokenizer_id) const { + auto it = by_tokenizer_.find(tokenizer_id); + if (it == by_tokenizer_.end()) + return 0; + return it->second.size(); +} + +} // namespace pomai_cache + diff --git a/src/policy/policies.cpp b/src/policy/policies.cpp index ce1b556..8ec6ce9 100644 --- a/src/policy/policies.cpp +++ b/src/policy/policies.cpp @@ -144,9 +144,12 @@ class PomaiCostPolicy final : public IEvictionPolicy { 1.0, std::chrono::duration(now - e.last_access).count()); const double freq_signal = cms_freq > 0 ? static_cast(cms_freq) : 0.0; - const double p_reuse = std::min( + double p_reuse = std::min( 1.0, (static_cast(e.hit_count) + freq_signal + 1.0) / (age_s + 1.0)); + if (e.owner == "prompt" && params_.prompt_reuse_weight > 0.0) { + p_reuse *= (1.0 + params_.prompt_reuse_weight); + } const double mem_cost = static_cast(e.size_bytes) / 1024.0 + static_cast(e.size_bytes % 64) * 0.01; const double risk = diff --git a/src/pomai_embedded.cc b/src/pomai_embedded.cc new file mode 100644 index 0000000..3885268 --- /dev/null +++ b/src/pomai_embedded.cc @@ -0,0 +1,117 @@ +#include "pomaicache.h" + +#include "pomai_cache/ai_cache.hpp" +#include "pomai_cache/engine.hpp" +#include "pomai_cache/prompt_cache.hpp" + +#include + +namespace pomaicache { + +class PomaiCacheImpl { +public: + explicit PomaiCacheImpl(const Config &cfg) + : engine_cfg_(), + policy_(pomai_cache::make_policy_by_name("pomai_cost")), + engine_(engine_cfg_, std::move(policy_)), + ai_cache_(engine_), + prompt_cfg_(), + prompt_cache_(engine_, ai_cache_, prompt_cfg_) { + engine_cfg_.memory_limit_bytes = cfg.memory_limit_bytes; + engine_cfg_.data_dir = cfg.data_dir; + } + + bool set(std::string_view key, + std::span value, + Ttl ttl) { + std::vector v(value.size()); + std::memcpy(v.data(), value.data(), value.size()); + std::optional ttl_ms; + if (ttl.ms > 0) + ttl_ms = ttl.ms; + std::string err; + return engine_.set(std::string(key), v, ttl_ms, "default", &err); + } + + std::optional> get(std::string_view key) { + auto v = engine_.get(std::string(key)); + if (!v.has_value()) + return std::nullopt; + std::vector out(v->size()); + std::memcpy(out.data(), v->data(), v->size()); + return out; + } + + bool prompt_put(std::span tokens, + std::span artifact, + Ttl ttl) { + if (tokens.empty()) + return false; + std::vector serialized(artifact.size()); + std::memcpy(serialized.data(), artifact.data(), artifact.size()); + std::vector token_bytes(tokens.size() * sizeof(std::uint64_t)); + std::memcpy(token_bytes.data(), tokens.data(), token_bytes.size()); + const auto hash = pomai_cache::AiArtifactCache::fast_hash_hex(token_bytes); + std::optional ttl_ms; + if (ttl.ms > 0) + ttl_ms = ttl.ms; + std::string err; + return prompt_cache_.put_prefix("tok", hash, serialized, + static_cast(tokens.size()), + ttl_ms, &err); + } + + PromptResult prompt_get(std::span tokens) { + PromptResult r; + if (tokens.empty()) + return r; + std::vector token_bytes(tokens.size() * sizeof(std::uint64_t)); + std::memcpy(token_bytes.data(), tokens.data(), token_bytes.size()); + const auto hash = pomai_cache::AiArtifactCache::fast_hash_hex(token_bytes); + auto reuse = prompt_cache_.reuse_for_query("tok", hash, token_bytes); + r.hit = reuse.hit; + r.cached_tokens = reuse.cached_tokens; + r.suffix_tokens = reuse.suffix_tokens; + r.savings_ratio = reuse.savings_ratio; + return r; + } + +private: + pomai_cache::EngineConfig engine_cfg_; + std::unique_ptr policy_; + pomai_cache::Engine engine_; + pomai_cache::AiArtifactCache ai_cache_; + pomai_cache::PromptCacheConfig prompt_cfg_; + pomai_cache::PromptCacheManager prompt_cache_; +}; + +PomaiCache::PomaiCache(const Config &cfg) + : impl_(std::make_unique(cfg)) {} + +PomaiCache::~PomaiCache() = default; + +PomaiCache::PomaiCache(PomaiCache &&) noexcept = default; +PomaiCache &PomaiCache::operator=(PomaiCache &&) noexcept = default; + +bool PomaiCache::Set(std::string_view key, + std::span value, + Ttl ttl) { + return impl_->set(key, value, ttl); +} + +std::optional> PomaiCache::Get(std::string_view key) { + return impl_->get(key); +} + +bool PomaiCache::PromptPut(std::span tokens, + std::span artifact, + Ttl ttl) { + return impl_->prompt_put(tokens, artifact, ttl); +} + +PromptResult PomaiCache::PromptGet(std::span tokens) { + return impl_->prompt_get(tokens); +} + +} // namespace pomaicache + diff --git a/src/server/http.cpp b/src/server/http.cpp deleted file mode 100644 index b8119d0..0000000 --- a/src/server/http.cpp +++ /dev/null @@ -1,141 +0,0 @@ -#include "pomai_cache/http.hpp" -#include -#include - -namespace pomai_cache { - -void HttpParser::feed(std::string_view data) { - if (view_.empty()) { - buffer_.assign(data); - view_ = buffer_; - } else { - size_t offset = view_.data() - buffer_.data(); - buffer_.append(data); - view_ = std::string_view(buffer_).substr(offset); - } -} - -bool HttpParser::parse_request_line() { - auto pos = view_.find("\r\n"); - if (pos == std::string_view::npos) return false; - - auto line = view_.substr(0, pos); - view_.remove_prefix(pos + 2); - - auto sp1 = line.find(' '); - if (sp1 == std::string_view::npos) { state_ = State::ERROR; return false; } - current_req_.method = std::string(line.substr(0, sp1)); - - auto sp2 = line.find(' ', sp1 + 1); - if (sp2 == std::string_view::npos) { state_ = State::ERROR; return false; } - - auto full_path = line.substr(sp1 + 1, sp2 - sp1 - 1); - auto q_pos = full_path.find('?'); - if (q_pos != std::string_view::npos) { - current_req_.path = std::string(full_path.substr(0, q_pos)); - auto query_str = full_path.substr(q_pos + 1); - - // Parse query params (simple parsing) - size_t start = 0; - while (start < query_str.size()) { - auto amp = query_str.find('&', start); - auto pair = query_str.substr(start, amp == std::string_view::npos ? std::string_view::npos : amp - start); - auto eq = pair.find('='); - if (eq != std::string_view::npos) { - current_req_.query_params[std::string(pair.substr(0, eq))] = std::string(pair.substr(eq + 1)); - } else { - current_req_.query_params[std::string(pair)] = ""; - } - if (amp == std::string_view::npos) break; - start = amp + 1; - } - } else { - current_req_.path = std::string(full_path); - } - - state_ = State::HEADERS; - return true; -} - -bool HttpParser::parse_headers() { - while (true) { - auto pos = view_.find("\r\n"); - if (pos == std::string_view::npos) return false; - - if (pos == 0) { - view_.remove_prefix(2); - auto it = current_req_.headers.find("Content-Length"); - if (it != current_req_.headers.end()) { - expected_body_len_ = std::stoi(it->second); - state_ = expected_body_len_ > 0 ? State::BODY : State::COMPLETE; - } else { - expected_body_len_ = 0; - state_ = State::COMPLETE; - } - return true; - } - - auto line = view_.substr(0, pos); - view_.remove_prefix(pos + 2); - - auto colon = line.find(':'); - if (colon != std::string_view::npos) { - auto key = std::string(line.substr(0, colon)); - auto val = line.substr(colon + 1); - while (!val.empty() && (val[0] == ' ' || val[0] == '\t')) val.remove_prefix(1); - current_req_.headers[key] = std::string(val); - } - } -} - -std::optional HttpParser::next_request() { - while (!view_.empty()) { - switch (state_) { - case State::REQUEST_LINE: - if (!parse_request_line()) return std::nullopt; - break; - case State::HEADERS: - if (!parse_headers()) return std::nullopt; - break; - case State::BODY: - if (view_.size() >= static_cast(expected_body_len_)) { - current_req_.body = std::string(view_.substr(0, expected_body_len_)); - view_.remove_prefix(expected_body_len_); - state_ = State::COMPLETE; - } else { - return std::nullopt; - } - break; - case State::COMPLETE: { - auto req = std::move(current_req_); - current_req_ = HttpRequest(); - state_ = State::REQUEST_LINE; - if (view_.empty()) buffer_.clear(); - return req; - } - case State::ERROR: - return std::nullopt; - } - } - if (state_ == State::COMPLETE) { - auto req = std::move(current_req_); - current_req_ = HttpRequest(); - state_ = State::REQUEST_LINE; - if (view_.empty()) buffer_.clear(); - return req; - } - return std::nullopt; -} - -std::string http_response(int status_code, const std::string& status_text, const std::string& body, const std::string& content_type) { - std::ostringstream oss; - oss << "HTTP/1.1 " << status_code << " " << status_text << "\r\n"; - oss << "Content-Length: " << body.size() << "\r\n"; - oss << "Content-Type: " << content_type << "\r\n"; - oss << "Connection: keep-alive\r\n"; - oss << "\r\n"; - oss << body; - return oss.str(); -} - -} // namespace pomai_cache diff --git a/src/server/server_main.cpp b/src/server/server_main.cpp deleted file mode 100644 index 12677a7..0000000 --- a/src/server/server_main.cpp +++ /dev/null @@ -1,464 +0,0 @@ -#include "pomai_cache/ai_cache.hpp" -#include "pomai_cache/engine.hpp" -#include "pomai_cache/engine_shard.hpp" -#include "pomai_cache/http.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace { -volatile std::sig_atomic_t running = 1; -void on_sigint(int) { running = 0; } - -std::string upper(std::string_view s) { - std::string res; - res.reserve(s.size()); - for (char c : s) res += static_cast(std::toupper(static_cast(c))); - return res; -} - -bool parse_u64(std::string_view s, std::uint64_t &out) { - auto [ptr, ec] = std::from_chars(s.data(), s.data() + s.size(), out); - return ec == std::errc{}; -} - -struct ClientState { - pomai_cache::HttpParser parser; - std::string out; - std::string sending; - int fd; - bool is_sending{false}; - char buf[4096]; -}; - -class UringWorker { -public: - UringWorker(int port, const pomai_cache::EngineConfig& cfg, int id) - : port_(port), cfg_(cfg), id_(id) {} - - void run() { - auto policy = pomai_cache::make_policy_by_name("pomai_cost"); - pomai_cache::EngineShard::InitThreadLocal(id_, cfg_, std::move(policy)); - auto* shard = pomai_cache::EngineShard::tlocal(); - pomai_cache::ShardSet::instance().add_shard(shard); - - int listen_fd = socket(AF_INET, SOCK_STREAM, 0); - int one = 1; - setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)); - setsockopt(listen_fd, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one)); - - sockaddr_in addr{}; - addr.sin_family = AF_INET; - addr.sin_addr.s_addr = INADDR_ANY; - addr.sin_port = htons(port_); - if (bind(listen_fd, reinterpret_cast(&addr), sizeof(addr)) < 0) { - std::cerr << "Worker " << id_ << " bind failed\n"; - return; - } - listen(listen_fd, 128); - - struct io_uring ring; - io_uring_queue_init(1024, &ring, 0); - - add_accept_sqe(&ring, listen_fd); - io_uring_submit(&ring); - - std::unordered_map> clients; - - while (running) { - shard->engine().tick(); - - struct io_uring_cqe *cqe; - struct __kernel_timespec ts{0, 10000000}; // 10ms - int ret = io_uring_wait_cqe_timeout(&ring, &cqe, &ts); - - if (ret < 0) { - if (ret == -ETIME) { - io_uring_submit(&ring); - continue; - } - break; - } - - int head; - unsigned count = 0; - io_uring_for_each_cqe(&ring, head, cqe) { - count++; - auto* data = reinterpret_cast(cqe->user_data); - uint64_t type = reinterpret_cast(data) & 0x7; - int fd = static_cast(reinterpret_cast(data) >> 3); - - if (fd == listen_fd) { - int cfd = cqe->res; - if (cfd >= 0) { - auto client = std::make_unique(); - client->fd = cfd; - clients[cfd] = std::move(client); - add_recv_sqe(&ring, cfd, clients[cfd]->buf); - } - add_accept_sqe(&ring, listen_fd); - } else { - auto it = clients.find(fd); - if (it == clients.end()) continue; - auto& st = *it->second; - - if (type == 1) { // RECV - int r = cqe->res; - if (r <= 0) { - close(fd); - clients.erase(it); - } else { - st.parser.feed(std::string_view(st.buf, r)); - while (auto req = st.parser.next_request()) { - handle_http_request(st, *req); - } - if (!st.is_sending && !st.out.empty()) { - st.is_sending = true; - st.sending.swap(st.out); - add_send_sqe(&ring, fd, st.sending); - } - add_recv_sqe(&ring, fd, st.buf); - } - } else if (type == 2) { // SEND - st.is_sending = false; - if (cqe->res > 0) { - st.sending.erase(0, cqe->res); - } - if (!st.sending.empty()) { - st.is_sending = true; - add_send_sqe(&ring, fd, st.sending); - } else if (!st.out.empty()) { - st.is_sending = true; - st.sending.swap(st.out); - add_send_sqe(&ring, fd, st.sending); - } - } - } - } - io_uring_cq_advance(&ring, count); - io_uring_submit(&ring); - } - io_uring_queue_exit(&ring); - close(listen_fd); - } - -private: - void add_accept_sqe(struct io_uring *ring, int fd) { - struct io_uring_sqe *sqe = io_uring_get_sqe(ring); - io_uring_prep_accept(sqe, fd, nullptr, nullptr, 0); - io_uring_sqe_set_data(sqe, reinterpret_cast(static_cast(fd) << 3)); - } - - void add_recv_sqe(struct io_uring *ring, int fd, char* buf) { - struct io_uring_sqe *sqe = io_uring_get_sqe(ring); - io_uring_prep_recv(sqe, fd, buf, 4096, 0); - io_uring_sqe_set_data(sqe, reinterpret_cast((static_cast(fd) << 3) | 1)); - } - - void add_send_sqe(struct io_uring *ring, int fd, const std::string& data) { - struct io_uring_sqe *sqe = io_uring_get_sqe(ring); - io_uring_prep_send(sqe, fd, data.data(), data.size(), 0); - io_uring_sqe_set_data(sqe, reinterpret_cast((static_cast(fd) << 3) | 2)); - } - -private: - std::vector split_path(const std::string& path) { - std::vector parts; - std::size_t start = 0; - while (start < path.size()) { - if (path[start] == '/') { - start++; - continue; - } - auto end = path.find('/', start); - if (end == std::string::npos) { - parts.push_back(path.substr(start)); - break; - } - parts.push_back(path.substr(start, end - start)); - start = end + 1; - } - return parts; - } - - void handle_http_request(ClientState& st, const pomai_cache::HttpRequest& req) { - auto parts = split_path(req.path); - if (parts.empty()) { - st.out += pomai_cache::http_response(404, "Not Found", "Path missing"); - return; - } - - std::string base = parts[0]; - - // INFO - if (base == "info" && req.method == "GET") { - auto shards = pomai_cache::ShardSet::instance().all_shards(); - std::string combined; - for (auto* s : shards) combined += s->engine().info(); - st.out += pomai_cache::http_response(200, "OK", combined); - return; - } - - // CONFIG - if (base == "config" && req.method == "GET") { - if (parts.size() >= 2 && parts[1] == "policy") { - auto shards = pomai_cache::ShardSet::instance().all_shards(); - std::string name = shards.empty() ? "unknown" : shards[0]->engine().policy().name(); - st.out += pomai_cache::http_response(200, "OK", name); - } else { - st.out += pomai_cache::http_response(400, "Bad Request", "CONFIG param not supported"); - } - return; - } - - // KEY Ops: /key/ - if (base == "key" && parts.size() >= 2) { - std::string key = parts[1]; - auto* shard = pomai_cache::ShardSet::instance().get_shard(key); - if (!shard) { - st.out += pomai_cache::http_response(503, "Service Unavailable", "No shards"); - return; - } - auto& engine = shard->engine(); - - if (req.method == "GET") { - auto v = engine.get(key); - if (v) { - st.out += pomai_cache::http_response(200, "OK", std::string(v->begin(), v->end())); - } else { - st.out += pomai_cache::http_response(404, "Not Found", "Key not found"); - } - } else if (req.method == "POST") { - std::vector val(req.body.begin(), req.body.end()); - std::optional ttl_ms; - - auto it = req.query_params.find("ex"); - if (it != req.query_params.end()) { - std::uint64_t v = 0; - if (parse_u64(it->second, v)) ttl_ms = v * 1000; - } - it = req.query_params.find("px"); - if (it != req.query_params.end()) { - std::uint64_t v = 0; - if (parse_u64(it->second, v)) ttl_ms = v; - } - - std::string set_err; - if (engine.set(key, val, ttl_ms, "default", &set_err)) { - std::vector jcmd = {"SET", key, req.body}; - if (ttl_ms) { jcmd.push_back("PX"); jcmd.push_back(std::to_string(*ttl_ms)); } - shard->journal().record(pomai_cache::OpCode::SET, jcmd); - st.out += pomai_cache::http_response(200, "OK", "OK"); - } else { - st.out += pomai_cache::http_response(400, "Bad Request", set_err); - } - } else if (req.method == "DELETE") { - int d = engine.del({key}); - st.out += pomai_cache::http_response(200, "OK", std::to_string(d)); - } else { - st.out += pomai_cache::http_response(405, "Method Not Allowed", "Use GET, POST or DELETE"); - } - return; - } - - // AI Operations - if (base == "ai" && parts.size() >= 2) { - std::string op = parts[1]; - - if (op == "stats" && req.method == "GET") { - auto shards = pomai_cache::ShardSet::instance().all_shards(); - std::string combined; - for (auto* s : shards) combined += s->ai_cache().stats(); - st.out += pomai_cache::http_response(200, "OK", combined); - return; - } - - if (op == "cost_report" && req.method == "GET") { - auto shards = pomai_cache::ShardSet::instance().all_shards(); - std::ostringstream os; - double total_saved = 0; - std::uint64_t total_tokens = 0, total_latency = 0, total_hits = 0; - for (auto* s : shards) { - auto r = s->ai_cache().cost_report(); - total_saved += r.total_dollar_saved; - total_tokens += r.total_tokens_saved; - total_latency += r.total_latency_saved_ms; - total_hits += r.total_hits; - } - os << "total_dollar_saved:" << total_saved << "\n"; - os << "total_tokens_saved:" << total_tokens << "\n"; - os << "total_latency_saved_ms:" << total_latency << "\n"; - os << "total_hits:" << total_hits << "\n"; - st.out += pomai_cache::http_response(200, "OK", os.str()); - return; - } - - if (op == "budget" && req.method == "POST") { - auto it = req.query_params.find("value"); - if (it != req.query_params.end()) { - double budget = std::stod(it->second); - auto shards = pomai_cache::ShardSet::instance().all_shards(); - for (auto* s : shards) s->ai_cache().set_budget(budget / static_cast(shards.size())); - st.out += pomai_cache::http_response(200, "OK", "OK"); - } else { - st.out += pomai_cache::http_response(400, "Bad Request", "Missing value"); - } - return; - } - - if (op == "invalidate" && req.method == "POST" && parts.size() >= 4) { - std::string subcmd = upper(parts[2]); - std::string arg = parts[3]; - std::size_t total = 0; - auto shards = pomai_cache::ShardSet::instance().all_shards(); - for (auto* s : shards) { - if (subcmd == "EPOCH") total += s->ai_cache().invalidate_epoch(arg); - else if (subcmd == "MODEL") total += s->ai_cache().invalidate_model(arg); - else if (subcmd == "PREFIX") total += s->ai_cache().invalidate_prefix(arg); - else if (subcmd == "CASCADE") total += s->ai_cache().invalidate_cascade(arg); - } - st.out += pomai_cache::http_response(200, "OK", std::to_string(total)); - return; - } - - if (op == "sim" && parts.size() >= 3) { - std::string subcmd = parts[2]; - - if (subcmd == "put" && req.method == "POST" && parts.size() >= 4) { - std::string key = parts[3]; - auto* shard = pomai_cache::ShardSet::instance().get_shard(key); - if (!shard) { st.out += pomai_cache::http_response(503, "Service Unavailable", "no shard"); return; } - - auto it = req.query_params.find("vec"); - if (it == req.query_params.end()) { st.out += pomai_cache::http_response(400, "Bad Request", "missing vec"); return; } - - std::vector vec; - std::istringstream vss(it->second); - float val; - while (vss >> val) { vec.push_back(val); if (vss.peek() == ',') vss.ignore(); } - - std::vector payload(req.body.begin(), req.body.end()); - - auto meta_it = req.query_params.find("meta"); - std::string meta_json = meta_it != req.query_params.end() ? meta_it->second : "{\"artifact_type\":\"embedding\",\"owner\":\"vector\",\"schema_version\":\"v1\"}"; - - std::string err; - if (shard->ai_cache().sim_put(key, vec, payload, meta_json, &err)) - st.out += pomai_cache::http_response(200, "OK", "OK"); - else - st.out += pomai_cache::http_response(400, "Bad Request", err); - return; - } - - if (subcmd == "get" && req.method == "GET") { - auto it = req.query_params.find("vec"); - if (it == req.query_params.end()) { st.out += pomai_cache::http_response(400, "Bad Request", "missing vec"); return; } - std::vector query; - std::istringstream vss(it->second); - float val; - while (vss >> val) { query.push_back(val); if (vss.peek() == ',') vss.ignore(); } - - std::size_t top_k = 1; - float threshold = 0.9f; - auto kt = req.query_params.find("topk"); - if (kt != req.query_params.end()) top_k = std::stoull(kt->second); - auto tt = req.query_params.find("threshold"); - if (tt != req.query_params.end()) threshold = std::stof(tt->second); - - auto shards = pomai_cache::ShardSet::instance().all_shards(); - std::ostringstream arr; - for (auto* s : shards) { - auto results = s->ai_cache().sim_get(query, top_k, threshold); - for (const auto& r : results) { - arr << "key:" << r.key << " score:" << r.score << "\n"; - arr << "meta:" << pomai_cache::AiArtifactCache::meta_to_json(r.value.meta) << "\n"; - arr << "body:" << std::string(r.value.payload.begin(), r.value.payload.end()) << "\n"; - } - } - st.out += pomai_cache::http_response(200, "OK", arr.str()); - return; - } - } - - if (op == "put" && req.method == "POST" && parts.size() >= 4) { - std::string type = parts[2]; - std::string key = parts[3]; - - auto* shard = pomai_cache::ShardSet::instance().get_shard(key); - if (!shard) { st.out += pomai_cache::http_response(503, "Service Unavailable", "no shard"); return; } - - std::vector payload(req.body.begin(), req.body.end()); - auto it = req.query_params.find("meta"); - std::string meta = it != req.query_params.end() ? it->second : "{}"; - - std::string err; - if (shard->ai_cache().put(type, key, meta, payload, &err)) { - std::vector jcmd = {"AI.PUT", type, key, meta, req.body}; - shard->journal().record(pomai_cache::OpCode::AI_PUT, jcmd); - st.out += pomai_cache::http_response(200, "OK", "OK"); - } else { - st.out += pomai_cache::http_response(400, "Bad Request", err); - } - return; - } - - if (op == "get" && req.method == "GET" && parts.size() >= 3) { - std::string key = parts[2]; - auto* shard = pomai_cache::ShardSet::instance().get_shard(key); - if (!shard) { st.out += pomai_cache::http_response(503, "Service Unavailable", "no shard"); return; } - - auto v = shard->ai_cache().get(key); - if (!v) { - st.out += pomai_cache::http_response(404, "Not Found", ""); - } else { - std::string resp_body = pomai_cache::AiArtifactCache::meta_to_json(v->meta) + "\n" + std::string(v->payload.begin(), v->payload.end()); - st.out += pomai_cache::http_response(200, "OK", resp_body); - } - return; - } - } - - st.out += pomai_cache::http_response(400, "Bad Request", "Unknown command"); - } - - int port_; - pomai_cache::EngineConfig cfg_; - int id_; -}; - -} // namespace - -int main(int argc, char **argv) { - int port = 6379; - std::size_t memory_limit = 128 * 1024 * 1024; - std::string data_dir = "./data"; - - for (int i = 1; i < argc; ++i) { - std::string a = argv[i]; - if (a == "--port" && i + 1 < argc) port = std::stoi(argv[++i]); - else if (a == "--memory" && i + 1 < argc) memory_limit = std::stoull(argv[++i]); - } - - pomai_cache::EngineConfig cfg; - cfg.memory_limit_bytes = memory_limit; // No division because it is single-threaded - cfg.data_dir = data_dir; - - std::cout << "Starting PomaiCache on single core...\n"; - - std::signal(SIGINT, on_sigint); - - // Single-threaded so just call run() on main thread - UringWorker(port, cfg, 0).run(); - - return 0; -} diff --git a/tests/test_http.cpp b/tests/test_http.cpp deleted file mode 100644 index 14e65c9..0000000 --- a/tests/test_http.cpp +++ /dev/null @@ -1,36 +0,0 @@ -// tests/test_http.cpp -#include -#include "pomai_cache/http.hpp" - -using namespace pomai_cache; - -TEST_CASE("HttpParser: simple GET", "[http]") { - HttpParser parser; - parser.feed("GET /key/a HTTP/1.1\r\nHost: localhost\r\n\r\n"); - auto req = parser.next_request(); - REQUIRE(req.has_value()); - CHECK(req->method == "GET"); - CHECK(req->path == "/key/a"); - CHECK(req->headers["Host"] == "localhost"); - CHECK(req->body.empty()); -} - -TEST_CASE("HttpParser: simple POST with body", "[http]") { - HttpParser parser; - parser.feed("POST /key/a HTTP/1.1\r\nContent-Length: 5\r\n\r\nhello"); - auto req = parser.next_request(); - REQUIRE(req.has_value()); - CHECK(req->method == "POST"); - CHECK(req->path == "/key/a"); - CHECK(req->body == "hello"); -} - -TEST_CASE("HttpParser: query params", "[http]") { - HttpParser parser; - parser.feed("GET /ai/sim/get?vec=1,2,3&topk=5 HTTP/1.1\r\n\r\n"); - auto req = parser.next_request(); - REQUIRE(req.has_value()); - CHECK(req->path == "/ai/sim/get"); - CHECK(req->query_params["vec"] == "1,2,3"); - CHECK(req->query_params["topk"] == "5"); -} diff --git a/tests/test_integration.cpp b/tests/test_integration.cpp index 824dc47..b311df7 100644 --- a/tests/test_integration.cpp +++ b/tests/test_integration.cpp @@ -1,207 +1,100 @@ #include -#include -#include -#include -#include +#include "pomaicache.h" +#include "pomai_cache/ai_cache.hpp" +#include "pomai_cache/engine.hpp" + #include -#include #include -#include -#include -#include -#include -#include #include -#include -#include - -namespace { - -std::optional read_reply(int fd) { - std::string out; - char buf[4096]; - while (true) { - int r = recv(fd, buf, 4096, 0); - if (r <= 0) break; - out.append(buf, r); - if (out.find("\r\n\r\n") != std::string::npos) { - auto pos = out.find("Content-Length: "); - if (pos != std::string::npos) { - auto end = out.find("\r\n", pos); - int len = std::stoi(out.substr(pos + 16, end - pos - 16)); - auto header_end = out.find("\r\n\r\n") + 4; - if (out.size() >= header_end + len) { - return out; - } - } else { - return out; - } - } - } - return out.empty() ? std::nullopt : std::make_optional(out); -} -int connect_port(int port) { - int fd = socket(AF_INET, SOCK_STREAM, 0); - sockaddr_in addr{}; - addr.sin_family = AF_INET; - addr.sin_port = htons(port); - inet_pton(AF_INET, "127.0.0.1", &addr.sin_addr); - if (connect(fd, reinterpret_cast(&addr), sizeof(addr)) < 0) - return -1; - timeval tv{2, 0}; - setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); - return fd; +TEST_CASE("integration: embedded core operations", "[integration]") { + pomaicache::Config cfg; + cfg.memory_limit_bytes = 16 * 1024 * 1024; + cfg.data_dir = "./data_integration_core"; + + pomaicache::PomaiCache cache(cfg); + + const std::string key = "a"; + const std::string value = "1"; + std::span v{ + reinterpret_cast(value.data()), value.size()}; + + REQUIRE(cache.Set(key, v, pomaicache::Ttl{0})); + + auto got = cache.Get(key); + REQUIRE(got.has_value()); + std::string roundtrip(reinterpret_cast(got->data()), + got->size()); + CHECK(roundtrip == value); + + // Overwrite and read again to ensure basic churn works. + const std::string value2 = "2"; + std::span v2{ + reinterpret_cast(value2.data()), value2.size()}; + REQUIRE(cache.Set(key, v2, pomaicache::Ttl{0})); + auto got2 = cache.Get(key); + REQUIRE(got2.has_value()); + std::string roundtrip2(reinterpret_cast(got2->data()), + got2->size()); + CHECK(roundtrip2 == value2); } -struct ServerProc { - int port; - pid_t pid; -}; - -ServerProc spawn_server() { - static int attempt = 0; - int port = 22000 + ((::getpid() + attempt * 137) % 20000); - ++attempt; - pid_t pid = fork(); - if (pid == 0) { - execl("./pomai_cache_server", "./pomai_cache_server", "--port", - std::to_string(port).c_str(), "--params", - "../config/policy_params.json", nullptr); - _exit(1); +TEST_CASE("integration: embedded churn under load", "[integration][adversarial]") { + pomaicache::Config cfg; + cfg.memory_limit_bytes = 16 * 1024 * 1024; + cfg.data_dir = "./data_integration_churn"; + + pomaicache::PomaiCache cache(cfg); + + // Insert a bunch of small keys to exercise caps/churn behavior. + for (int i = 0; i < 1000; ++i) { + std::string key = "churn" + std::to_string(i); + std::string value = "val" + std::to_string(i); + std::span v{ + reinterpret_cast(value.data()), value.size()}; + REQUIRE(cache.Set(key, v, pomaicache::Ttl{0})); } - for (int i = 0; i < 50; ++i) { - int fd = connect_port(port); - if (fd >= 0) { - close(fd); - break; - } - std::this_thread::sleep_for(std::chrono::milliseconds(50)); + + // Spot-check a few keys. + for (int i = 0; i < 10; ++i) { + std::string key = "churn" + std::to_string(i * 10); + auto got = cache.Get(key); + REQUIRE(got.has_value()); } - return {port, pid}; } -void stop_server(const ServerProc &s) { - kill(s.pid, SIGINT); - waitpid(s.pid, nullptr, 0); -} -} // namespace - -TEST_CASE("integration: HTTP core commands and clean shutdown", - "[integration]") { - auto s = spawn_server(); - int fd = connect_port(s.port); - REQUIRE(fd >= 0); - - auto req1 = "POST /key/a HTTP/1.1\r\nContent-Length: 1\r\n\r\n1"; - send(fd, req1, strlen(req1), 0); - REQUIRE(read_reply(fd).value().find("200 OK") != std::string::npos); - - auto req2 = "GET /key/a HTTP/1.1\r\n\r\n"; - send(fd, req2, strlen(req2), 0); - REQUIRE(read_reply(fd).value().find("1") != std::string::npos); - - auto req3 = "POST /key/a?ex=1 HTTP/1.1\r\nContent-Length: 1\r\n\r\n1"; - send(fd, req3, strlen(req3), 0); - REQUIRE(read_reply(fd).value().find("200 OK") != std::string::npos); - - auto req4 = "GET /info HTTP/1.1\r\n\r\n"; - send(fd, req4, strlen(req4), 0); - REQUIRE(read_reply(fd).value().find("200 OK") != std::string::npos); - - auto req5 = "GET /config/policy HTTP/1.1\r\n\r\n"; - send(fd, req5, strlen(req5), 0); - REQUIRE(read_reply(fd).value().find("200 OK") != std::string::npos); - - auto req6 = "DELETE /key/a HTTP/1.1\r\n\r\n"; - send(fd, req6, strlen(req6), 0); - REQUIRE(read_reply(fd).value().find("200 OK") != std::string::npos); - - const std::string bad_req = "NOPE /key/a HTTP/1.1\r\n\r\n"; - send(fd, bad_req.data(), bad_req.size(), 0); - auto bad = read_reply(fd); - REQUIRE(bad.has_value()); - CHECK(bad->find("405") != std::string::npos); - - close(fd); - stop_server(s); -} +TEST_CASE("integration: embedded AI artifact commands", "[integration][ai]") { + pomai_cache::EngineConfig cfg; + cfg.memory_limit_bytes = 16 * 1024 * 1024; + cfg.data_dir = "./data_integration_ai"; -TEST_CASE("integration: adversarial caps and churn", - "[integration][adversarial]") { - auto s = spawn_server(); - int fd = connect_port(s.port); - REQUIRE(fd >= 0); - - std::string big(1024 * 1024 + 8, 'x'); - std::string req = "POST /key/big HTTP/1.1\r\nContent-Length: " + std::to_string(big.size()) + "\r\n\r\n" + big; - send(fd, req.data(), req.size(), 0); - auto rep = read_reply(fd); - REQUIRE(rep.has_value()); - CHECK(rep->find("400") != std::string::npos); - - for (int i = 0; i < 500; ++i) { - std::string sreq = "POST /key/churn" + std::to_string(i) + " HTTP/1.1\r\nContent-Length: 3\r\n\r\nval"; - send(fd, sreq.data(), sreq.size(), 0); - REQUIRE(read_reply(fd).has_value()); - } + auto policy = pomai_cache::make_policy_by_name("pomai_cost"); + pomai_cache::Engine engine(cfg, std::move(policy)); + pomai_cache::AiArtifactCache ai(engine); - std::string ireq = "GET /info HTTP/1.1\r\n\r\n"; - send(fd, ireq.data(), ireq.size(), 0); - auto info = read_reply(fd); - REQUIRE(info.has_value()); - CHECK(info->find("evictions") != std::string::npos); + const std::string key = "emb:m:h:3:float"; + const std::string type = "embedding"; + const std::string payload_str = "abc"; + std::vector payload(payload_str.begin(), payload_str.end()); - for (int i = 0; i < 128; ++i) { - std::string t = "POST /key/ttl" + std::to_string(i) + "?px=1 HTTP/1.1\r\nContent-Length: 1\r\n\r\nv"; - send(fd, t.data(), t.size(), 0); - REQUIRE(read_reply(fd).has_value()); - } - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - send(fd, ireq.data(), ireq.size(), 0); - auto info2 = read_reply(fd); - REQUIRE(info2.has_value()); - CHECK(info2->find("expiration_backlog") != std::string::npos); - - close(fd); - stop_server(s); -} + std::string meta = + R"({"artifact_type":"embedding","owner":"vector","schema_version":"v1","model_id":"m","snapshot_epoch":"ep9"})"; + + std::string err; + REQUIRE(ai.put(type, key, meta, payload, &err)); + + auto got = ai.get(key); + REQUIRE(got.has_value()); + std::string body(got->payload.begin(), got->payload.end()); + CHECK(body == payload_str); + + auto stats = ai.stats(); + CHECK(stats.find("dedup_hits") != std::string::npos); + + auto removed = ai.invalidate_epoch("ep9"); + CHECK(removed >= 1); -TEST_CASE("integration: AI artifact commands", "[integration][ai]") { - auto s = spawn_server(); - int fd = connect_port(s.port); - REQUIRE(fd >= 0); - - std::string p1 = "POST /ai/put/embedding/emb:m:h:3:float?meta={\"artifact_type\":\"embedding\",\"owner\":\"vector\",\"schema_version\":\"v1\",\"model_id\":\"m\",\"snapshot_epoch\":\"ep9\"} HTTP/1.1\r\nContent-Length: 3\r\n\r\nabc"; - send(fd, p1.data(), p1.size(), 0); - auto put = read_reply(fd); - REQUIRE(put.has_value()); - REQUIRE(put->find("200 OK") != std::string::npos); - - std::string g1 = "GET /ai/get/emb:m:h:3:float HTTP/1.1\r\n\r\n"; - send(fd, g1.data(), g1.size(), 0); - auto get = read_reply(fd); - REQUIRE(get.has_value()); - CHECK(get->find("200 OK") != std::string::npos); - - std::string s1 = "GET /ai/stats HTTP/1.1\r\n\r\n"; - send(fd, s1.data(), s1.size(), 0); - auto stats = read_reply(fd); - REQUIRE(stats.has_value()); - CHECK(stats->find("dedup_hits") != std::string::npos); - - std::string i1 = "POST /ai/invalidate/EPOCH/ep9 HTTP/1.1\r\nContent-Length: 0\r\n\r\n"; - send(fd, i1.data(), i1.size(), 0); - auto inv = read_reply(fd); - REQUIRE(inv.has_value()); - CHECK(inv->find("1") != std::string::npos); - - send(fd, g1.data(), g1.size(), 0); - auto miss = read_reply(fd); - REQUIRE(miss.has_value()); - CHECK(miss->find("404") != std::string::npos); - - close(fd); - stop_server(s); + auto miss = ai.get(key); + CHECK_FALSE(miss.has_value()); } diff --git a/tests/test_prompt_cache.cpp b/tests/test_prompt_cache.cpp new file mode 100644 index 0000000..d950bfa --- /dev/null +++ b/tests/test_prompt_cache.cpp @@ -0,0 +1,53 @@ +#include "pomai_cache/ai_cache.hpp" +#include "pomai_cache/engine.hpp" +#include "pomai_cache/prompt_cache.hpp" + +#include + +#include +#include + +using namespace pomai_cache; + +TEST_CASE("PromptCacheManager prefix matching and reuse", "[prompt_cache]") { + Engine e({4 * 1024 * 1024, 256, 1024 * 1024}, + make_policy_by_name("pomai_cost")); + AiArtifactCache ai(e); + PromptCacheConfig cfg; + cfg.enabled = true; + cfg.default_ttl_ms = 60'000; + cfg.prefix_min_tokens = 2; + PromptCacheManager pcm(e, ai, cfg); + + std::vector prefix{'h', 'e', 'l', 'l', 'o'}; + REQUIRE(pcm.put_prefix("tok", "pfx1", prefix, 5)); + + std::vector full{'h', 'e', 'l', 'l', 'o', ' ', 'x'}; + auto reuse = pcm.reuse_for_query("tok", "full1", full); + CHECK(reuse.hit); + CHECK(reuse.cached_tokens == 5); + CHECK(reuse.suffix_tokens >= 1); + CHECK(reuse.savings_ratio > 0.0); +} + +TEST_CASE("PromptCacheManager TTL expiration", "[prompt_cache][ttl]") { + Engine e({4 * 1024 * 1024, 256, 1024 * 1024}, + make_policy_by_name("pomai_cost")); + AiArtifactCache ai(e); + PromptCacheConfig cfg; + cfg.enabled = true; + cfg.default_ttl_ms = 10; + cfg.prefix_min_tokens = 1; + PromptCacheManager pcm(e, ai, cfg); + + std::vector prefix{'a', 'b', 'c'}; + REQUIRE(pcm.put_prefix("tok", "short", prefix, 3)); + + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + pcm.tick(); + + std::vector full{'a', 'b', 'c', 'x'}; + auto reuse = pcm.reuse_for_query("tok", "full", full); + CHECK_FALSE(reuse.hit); +} +