From 19d7843282f11375dbf81bbfa82f483cd33ce5a9 Mon Sep 17 00:00:00 2001 From: manoflearning <77jwk0724@gmail.com> Date: Sat, 3 Jan 2026 20:36:30 +0900 Subject: [PATCH] feat: add kth feature in pst --- src/1-ds/segment_tree.cpp | 16 ++++++++-- src/8-misc/facts.txt | 0 tests/1-ds/test_segment_tree.cpp | 50 ++++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 3 deletions(-) create mode 100644 src/8-misc/facts.txt diff --git a/src/1-ds/segment_tree.cpp b/src/1-ds/segment_tree.cpp index c122ca6..9ed43da 100644 --- a/src/1-ds/segment_tree.cpp +++ b/src/1-ds/segment_tree.cpp @@ -138,10 +138,10 @@ struct seg_tree_lz { } }; -// what: keep all versions of point updates with range sum queries. +// what: keep all versions of point updates with range sum queries, plus kth by prefix sum. // time: build O(n), update/query O(log n); memory: O(n log n) -// constraint: 1-indexed [1, n]; a[0] unused. -// usage: seg_pst st; st.build(n, a); st.set(p, v); st.query(l, r, ver); +// constraint: 1-indexed [1, n]; a[0] unused; kth needs all values >= 0. +// usage: seg_pst st; st.build(n, a); st.set(p, v); st.query(l, r, ver); st.kth(k, ver); struct seg_pst { struct node { int l, r; @@ -214,6 +214,16 @@ struct seg_pst { return query(l, r, t[v].l, nl, mid) + query(l, r, t[v].r, mid + 1, nr); } ll query(int l, int r, int ver) const { return query(l, r, root[ver], 1, n); } + int kth(ll k, int v, int nl, int nr) const { + // result: smallest idx with prefix sum >= k (in this subtree). + assert(k > 0 && t[v].val >= k); + if (nl == nr) return nl; + int mid = (nl + nr) >> 1; + ll lv = t[t[v].l].val; + if (k <= lv) return kth(k, t[v].l, nl, mid); + return kth(k - lv, t[v].r, mid + 1, nr); + } + int kth(ll k, int ver) const { return kth(k, root[ver], 1, n); } }; // what: sparse segment tree for large coordinate range (point add, range sum). diff --git a/src/8-misc/facts.txt b/src/8-misc/facts.txt new file mode 100644 index 0000000..e69de29 diff --git a/tests/1-ds/test_segment_tree.cpp b/tests/1-ds/test_segment_tree.cpp index 400b45c..dfbd828 100644 --- a/tests/1-ds/test_segment_tree.cpp +++ b/tests/1-ds/test_segment_tree.cpp @@ -196,6 +196,54 @@ void test_pst_random() { } } +void test_pst_kth_basic() { + int n = 5; + vector a = {0, 2, 0, 1, 3, 0}; + seg_pst st; + st.build(n, a); + assert(st.kth(1, 0) == 1); + assert(st.kth(2, 0) == 1); + assert(st.kth(3, 0) == 3); + assert(st.kth(6, 0) == 4); + st.set(2, 4); + assert(st.kth(3, 1) == 2); + assert(st.kth(10, 1) == 4); + st.set(4, 0); + assert(st.kth(7, 2) == 3); +} + +void test_pst_kth_random() { + int n = 40; + vector> ver; + vector a(n + 1, 0); + for (int i = 1; i <= n; i++) a[i] = rnd(0, 3); + ver.push_back(a); + + seg_pst st; + st.build(n, a); + + for (int it = 0; it < 2000; it++) { + int op = (int)rnd(0, 1); + if (op == 0) { + int p = (int)rnd(1, n); + ll v = rnd(0, 5); + vector nw = ver.back(); + nw[p] = v; + ver.push_back(nw); + st.set(p, v); + } else { + int id = (int)rnd(0, (int)ver.size() - 1); + ll tot = 0; + for (int i = 1; i <= n; i++) tot += ver[id][i]; + if (tot == 0) continue; + ll k = rnd(1, tot); + assert(st.kth(k, id) == kth_naive_freq(ver[id], k)); + assert(st.kth(1, id) == kth_naive_freq(ver[id], 1)); + assert(st.kth(tot, id) == kth_naive_freq(ver[id], tot)); + } + } +} + void test_dyseg_basic() { seg_sparse st; st.add(MAXL, 5); @@ -368,6 +416,8 @@ int main() { test_seglz_random(); test_pst_basic(); test_pst_random(); + test_pst_kth_basic(); + test_pst_kth_random(); test_dyseg_basic(); test_dyseg_random(); test_seg2d_basic();