Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 4 additions & 31 deletions src/1-ds/ds_fenwick.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#include "../0-common/common.hpp"

// what: maintain prefix sums with point updates and range sum queries.
// time: build O(n), update/query O(log n); memory: O(n)
// time: init O(n), update/query O(log n); memory: O(n)
// constraint: 1-indexed [1, n]; a[0] unused; kth needs all values >= 0.
// usage: fenwick fw; fw.build(a); fw.add(p, x); fw.sum(l, r); fw.kth(k);
// usage: fenwick fw; fw.init(n); fw.add(p, x); fw.set(p, v); fw.sum(l, r); fw.kth(k);
struct fenwick {
int n;
vector<ll> a, t;
Expand All @@ -13,16 +13,6 @@ struct fenwick {
a.assign(n + 1, 0);
t.assign(n + 1, 0);
}
void build(const vector<ll> &v) {
// goal: build fenwick in O(n) from initial array.
n = sz(v) - 1;
a = v;
t = a;
for (int i = 1; i <= n; i++) {
int j = i + (i & -i);
if (j <= n) t[j] += t[i];
}
}
void add(int p, ll val) {
// goal: a[p] += val.
a[p] += val;
Expand Down Expand Up @@ -79,9 +69,9 @@ struct fenw_range { // 1-indexed
};

// what: 2D point updates with axis-aligned rectangle sum queries.
// time: build O(n m), update/query O(log n log m); memory: O(n m)
// time: init O(n m), update/query O(log n log m); memory: O(n m)
// constraint: 1-indexed [1..n] x [1..m]; a[0][*], a[*][0] unused; no bounds check.
// usage: fenw_2d fw; fw.build(a); fw.add(x, y, v); fw.sum(x1, y1, x2, y2);
// usage: fenw_2d fw; fw.init(n, m); fw.add(x, y, v); fw.set(x, y, val); fw.sum(x1, y1, x2, y2);
struct fenw_2d { // 1-indexed
int n, m;
vector<vector<ll>> a, t;
Expand All @@ -91,23 +81,6 @@ struct fenw_2d { // 1-indexed
a.assign(n + 1, vector<ll>(m + 1, 0));
t.assign(n + 1, vector<ll>(m + 1, 0));
}
void build(const vector<vector<ll>> &v) {
// goal: build 2D fenwick in O(n*m).
n = sz(v) - 1;
m = n ? sz(v[1]) - 1 : 0;
a = v;
t = a;
for (int i = 1; i <= n; i++)
for (int j = 1; j <= m; j++) {
int nj = j + (j & -j);
if (nj <= m) t[i][nj] += t[i][j];
}
for (int j = 1; j <= m; j++)
for (int i = 1; i <= n; i++) {
int ni = i + (i & -i);
if (ni <= n) t[ni][j] += t[i][j];
}
}
void add(int x, int y, ll val) {
// goal: a[x][y] += val.
a[x][y] += val;
Expand Down
93 changes: 30 additions & 63 deletions src/1-ds/ds_segtree.cpp
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
#include "../0-common/common.hpp"

// what: point update + range sum on a fixed-size array using a tree.
// time: build O(n), update/query O(log n); memory: O(n)
// constraint: 1-indexed [1, n]; a[0] unused.
// usage: seg_tree st; st.build(a); st.set(p, v); st.query(l, r);
// what: point update + range sum on a fixed-size array, plus kth by prefix sum.
// time: init O(n), update/query/kth O(log n); memory: O(n)
// constraint: 1-indexed [1, n]; a[0] unused; kth needs all values >= 0.
// usage: seg_tree st; st.init(n); st.set(p, v); st.query(l, r); st.kth(k);
struct seg_tree {
int flag;
vector<ll> t;
void build(const vector<ll> &a) {
// goal: build tree from 1-indexed array.
int n = sz(a) - 1;
void init(int n) {
// goal: allocate tree for size n (all zeros).
flag = 1;
while (flag < n) flag <<= 1;
t.assign(2 * flag, 0);
for (int i = 1; i <= n; i++) t[flag + i - 1] = a[i];
for (int i = flag - 1; i >= 1; i--) t[i] = t[i << 1] + t[i << 1 | 1];
}
void set(int p, ll val) {
// goal: set a[p] = val.
Expand All @@ -28,21 +25,29 @@ struct seg_tree {
int mid = (nl + nr) >> 1;
return query(l, r, v << 1, nl, mid) + query(l, r, v << 1 | 1, mid + 1, nr);
}
int kth(ll k) const {
// result: smallest idx with prefix sum >= k.
assert(k > 0 && t[1] >= k);
int v = 1;
while (v < flag) {
if (k <= t[v << 1]) v <<= 1;
else k -= t[v << 1], v = v << 1 | 1;
}
return v - flag + 1;
}
};

// what: iterative segment tree for point update and range sum.
// time: build O(n), update/query O(log n); memory: O(n)
// time: init O(n), update/query O(log n); memory: O(n)
// constraint: 1-indexed [1, n]; a[0] unused.
// usage: seg_tree_it st; st.build(a); st.set(p, v); st.query(l, r);
// usage: seg_tree_it st; st.init(n); st.set(p, v); st.query(l, r);
struct seg_tree_it { // 1-indexed
int n;
vector<ll> t;
void build(const vector<ll> &a) {
// goal: build tree from 1-indexed array.
n = sz(a) - 1;
void init(int n_) {
// goal: allocate tree for size n (all zeros).
n = n_;
t.assign(2 * n, 0);
for (int i = 1; i <= n; i++) t[n + i - 1] = a[i];
for (int i = n - 1; i >= 1; i--) t[i] = t[i << 1] + t[i << 1 | 1];
}
void set(int p, ll val) {
// goal: set a[p] = val.
Expand All @@ -59,48 +64,19 @@ struct seg_tree_it { // 1-indexed
}
};

// what: find k-th element by prefix sum on a frequency array.
// time: update/query O(log n); memory: O(n)
// constraint: 1-indexed [1, n], values >= 0.
// usage: seg_tree_kth st; st.init(n); st.add(p, v); st.kth(k);
struct seg_tree_kth {
int flag;
vector<ll> t;
void init(int n) {
// goal: allocate tree for size n.
flag = 1;
while (flag < n) flag <<= 1;
t.assign(flag << 1, 0);
}
void add(int p, ll val) {
// goal: add val to frequency at p.
for (t[p += flag - 1] += val; p > 1; p >>= 1) t[p >> 1] = t[p] + t[p ^ 1];
}
ll kth(ll k, int v = 1) const {
// result: smallest index with prefix sum >= k.
assert(t[v] >= k);
if (v >= flag) return v - flag + 1;
if (k <= t[v << 1]) return kth(k, v << 1);
return kth(k - t[v << 1], v << 1 | 1);
}
};

// what: range add and range sum with lazy propagation.
// time: update/query O(log n); memory: O(n)
// time: init O(n), update/query O(log n); memory: O(n)
// constraint: 1-indexed [1, n]; a[0] unused.
// usage: seg_tree_lz st; st.build(a); st.add(l, r, v); st.query(l, r);
// usage: seg_tree_lz st; st.init(n); st.add(l, r, v); st.query(l, r);
struct seg_tree_lz {
int flag;
vector<ll> t, lz;
void build(const vector<ll> &a) {
// goal: build tree and clear lazy tags.
int n = sz(a) - 1;
void init(int n) {
// goal: allocate tree and clear lazy tags (all zeros).
flag = 1;
while (flag < n) flag <<= 1;
t.assign(2 * flag, 0);
lz.assign(2 * flag, 0);
for (int i = 1; i <= n; i++) t[flag + i - 1] = a[i];
for (int i = flag - 1; i >= 1; i--) t[i] = t[i << 1] + t[i << 1 | 1];
}
void add(int l, int r, ll val) { add(l, r, val, 1, 1, flag); }
ll query(int l, int r) { return query(l, r, 1, 1, flag); }
Expand Down Expand Up @@ -182,25 +158,16 @@ struct seg_sparse {
};

// what: 2D point updates with rectangle sum queries on a square grid.
// time: build O(n^2), update/query O(log^2 n); memory: O(n^2)
// time: init O(n^2), update/query O(log^2 n); memory: O(n^2)
// constraint: 1-indexed square [1..n] x [1..n]; a[0][*], a[*][0] unused.
// usage: seg_2d st; st.build(a); st.set(x, y, v); st.query(x1, x2, y1, y2);
// usage: seg_2d st; st.init(n); st.set(x, y, v); st.query(x1, x2, y1, y2);
struct seg_2d { // 1-indexed
int n;
vector<vector<ll>> t;
void build(const vector<vector<ll>> &a) {
// goal: build 2D tree from initial grid.
n = sz(a) - 1;
void init(int n_) {
// goal: allocate 2D tree (all zeros).
n = n_;
t.assign(2 * n, vector<ll>(2 * n, 0));
for (int i = 1; i <= n; i++)
for (int j = 1; j <= n; j++)
t[i + n - 1][j + n - 1] = a[i][j];
for (int i = n; i < 2 * n; i++)
for (int j = n - 1; j > 0; j--)
t[i][j] = t[i][j << 1] + t[i][j << 1 | 1];
for (int i = n - 1; i > 0; i--)
for (int j = 1; j < 2 * n; j++)
t[i][j] = t[i << 1][j] + t[i << 1 | 1][j];
}
void set(int x, int y, ll val) {
// goal: set a[x][y] = val.
Expand Down
78 changes: 25 additions & 53 deletions src/1-ds/ds_segtree_pst.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#include "../0-common/common.hpp"

// what: persistent segment tree for point set with range sum queries, plus kth by prefix sum.
// time: build O(n), update/query O(log n); memory: O(n log n)
// constraint: 1-indexed [1, n]; a[0] unused; kth needs all values >= 0.
// usage: seg_pst st; st.build(n, a); st.set(p, v); st.query(l, r, ver); st.kth(k, ver);
// time: init O(1), update/query/kth O(log n); memory: O(q log n)
// constraint: version 0 is all zeros; 1-indexed [1, n]; kth needs all values >= 0.
// usage: seg_pst st; st.init(n); st.set(p, v); st.query(l, r, ver); st.kth(k, ver);
struct seg_pst {
struct node {
int l, r;
Expand All @@ -13,64 +13,36 @@ struct seg_pst {
vector<node> t;
vector<int> root;

void newnd() { t.push_back({-1, -1, 0}); }
void build(int n_, const vector<ll> &a) {
// goal: build initial version.
int newnd(const node &nd) {
t.push_back(nd);
return sz(t) - 1;
}
void init(int n_) {
// goal: version 0 = all zeros (root = 0).
n = n_;
t.clear();
root.clear();
newnd();
root.push_back(0);
build(1, n, root[0], a);
t.assign(1, {0, 0, 0});
root.assign(1, 0);
}
void build(int l, int r, int v, const vector<ll> &a) {
// goal: build node v for range [l, r].
if (l == r) {
t[v].val = a[l];
return;
int set(int p, ll val, int nl, int nr, int v) {
// goal: update along path while sharing unchanged nodes.
int u = newnd(t[v]);
if (nl == nr) {
t[u].val = val;
return u;
}
newnd();
t[v].l = sz(t) - 1;
newnd();
t[v].r = sz(t) - 1;
int mid = (l + r) >> 1;
build(l, mid, t[v].l, a);
build(mid + 1, r, t[v].r, a);
t[v].val = t[t[v].l].val + t[t[v].r].val;
int mid = (nl + nr) >> 1;
if (p <= mid) t[u].l = set(p, val, nl, mid, t[v].l);
else t[u].r = set(p, val, mid + 1, nr, t[v].r);
t[u].val = t[t[u].l].val + t[t[u].r].val;
return u;
}
void set(int p, ll val) {
// goal: create new version with a[p] = val.
newnd();
root.push_back(sz(t) - 1);
set(p, val, 1, n, root[sz(root) - 2], root.back());
}
void set(int p, ll val, int l, int r, int v1, int v2) {
// goal: update along path while sharing unchanged nodes.
if (p < l || r < p) {
t[v2] = t[v1];
return;
}
if (l == r) {
t[v2].val = val;
return;
}
int mid = (l + r) >> 1;
if (p <= mid) {
t[v2].r = t[v1].r;
newnd();
t[v2].l = sz(t) - 1;
set(p, val, l, mid, t[v1].l, t[v2].l);
} else {
t[v2].l = t[v1].l;
newnd();
t[v2].r = sz(t) - 1;
set(p, val, mid + 1, r, t[v1].r, t[v2].r);
}
t[v2].val = t[t[v2].l].val + t[t[v2].r].val;
// goal: create new version from last with a[p] = val.
root.push_back(set(p, val, 1, n, root.back()));
}
ll query(int l, int r, int v, int nl, int nr) const {
// result: sum on [l, r] in a specific version.
if (r < nl || nr < l) return 0;
if (v == 0 || r < nl || nr < l) return 0;
if (l <= nl && nr <= r) return t[v].val;
int mid = (nl + nr) >> 1;
return query(l, r, t[v].l, nl, mid) + query(l, r, t[v].r, mid + 1, nr);
Expand Down
3 changes: 1 addition & 2 deletions src/3-tree/tree_hld.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ struct hld_tree {
void build(int root = 1) {
dfs_sz(root, 0);
dfs_hld(root, root);
vector<ll> a(n + 1, 0);
seg.build(a);
seg.init(n);
}
void set(int v, ll val) { seg.set(in[v], val); }
ll query(int a, int b) const {
Expand Down
15 changes: 9 additions & 6 deletions tests/1-ds/test_ds_fenwick.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ ll sum_2d(const vector<vector<ll>> &a, int x1, int y1, int x2, int y2) {

void test_fenwick_basic() {
fenwick fw;
vector<ll> a = {0, 5};
fw.build(a);
fw.init(1);
fw.set(1, 5);
assert(fw.sum(1, 1) == 5);
assert(fw.kth(1) == 1);
fw.set(1, 0);
Expand All @@ -51,7 +51,8 @@ void test_fenwick_random() {
for (int i = 1; i <= n; i++) a[i] = rnd(0, 5);

fenwick fw;
fw.build(a);
fw.init(n);
for (int i = 1; i <= n; i++) fw.add(i, a[i]);

for (int it = 0; it < 5000; it++) {
int op = (int)rnd(0, 3);
Expand Down Expand Up @@ -109,8 +110,8 @@ void test_fenwick_rp_random() {

void test_fenwick_2d_basic() {
fenw_2d fw;
vector<vector<ll>> a = {{0, 0}, {0, 3}};
fw.build(a);
fw.init(1, 1);
fw.set(1, 1, 3);
assert(fw.sum(1, 1, 1, 1) == 3);
fw.set(1, 1, -2);
assert(fw.sum(1, 1, 1, 1) == -2);
Expand All @@ -123,7 +124,9 @@ void test_fenwick_2d_random() {
for (int j = 1; j <= m; j++) a[i][j] = rnd(-3, 3);

fenw_2d fw;
fw.build(a);
fw.init(n, m);
for (int i = 1; i <= n; i++)
for (int j = 1; j <= m; j++) fw.add(i, j, a[i][j]);

for (int it = 0; it < 4000; it++) {
int op = (int)rnd(0, 2);
Expand Down
Loading