Skip to content

Commit c5fa764

Browse files
author
lixizan
committed
simplify segment_tree
1 parent 1496ea2 commit c5fa764

File tree

5 files changed

+87
-130
lines changed

5 files changed

+87
-130
lines changed

README.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -356,14 +356,15 @@ ability to perform interval queries and interval updates in `O(logn)` time.
356356
package main
357357

358358
import (
359-
tree "github.com/lxzan/dao/segment_tree"
359+
"fmt"
360+
st "github.com/lxzan/dao/segment_tree"
360361
)
361362

362363
func main() {
363-
var data = []tree.Int64{1, 3, 5, 7, 9, 2, 4, 6, 8, 10}
364-
var lines = tree.New[tree.Int64Schema, tree.Int64](data)
365-
var result = lines.Query(0, 10)
366-
println(result.MinValue, result.MaxValue, result.Sum)
364+
var a = []int{1, 3, 5, 7, 9, 2, 4, 6, 8, 10}
365+
var t = st.New(a, st.NewIntSummary[int], st.MergeIntSummary[int])
366+
var r = t.Query(3, 6)
367+
fmt.Printf("%v\n", r)
367368
}
368369

369370
```

README_CN.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -346,16 +346,16 @@ func main() {
346346
package main
347347

348348
import (
349-
tree "github.com/lxzan/dao/segment_tree"
349+
"fmt"
350+
st "github.com/lxzan/dao/segment_tree"
350351
)
351352

352353
func main() {
353-
var data = []tree.Int64{1, 3, 5, 7, 9, 2, 4, 6, 8, 10}
354-
var lines = tree.New[tree.Int64Schema, tree.Int64](data)
355-
var result = lines.Query(0, 10)
356-
println(result.MinValue, result.MaxValue, result.Sum)
354+
var a = []int{1, 3, 5, 7, 9, 2, 4, 6, 8, 10}
355+
var t = st.New(a, st.NewIntSummary[int], st.MergeIntSummary[int])
356+
var r = t.Query(3, 6)
357+
fmt.Printf("%v\n", r)
357358
}
358-
359359
```
360360

361361
### 基准测试

segment_tree/impl.go

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,39 +2,33 @@ package segment_tree
22

33
import (
44
"github.com/lxzan/dao/algo"
5+
"github.com/lxzan/dao/types/cmp"
56
)
67

7-
type Int64 int64
8+
type (
9+
NewSummary[T any, S any] func(T, Operate) S
810

9-
// Init 初始化摘要结构
10-
func (c Int64) Init(op Operate) Int64Schema {
11-
var val = int64(c)
12-
var result = Int64Schema{
13-
MaxValue: val,
14-
MinValue: val,
15-
Sum: val,
16-
}
17-
if op == OperateQuery {
18-
result.Sum = 0
19-
}
20-
return result
21-
}
11+
MergeSummary[S any] func(a, b S) S
12+
)
2213

23-
func (c Int64) Value() int64 {
24-
return int64(c)
14+
type IntSummary[T cmp.Integer] struct {
15+
MaxValue T
16+
MinValue T
17+
Sum T
2518
}
2619

27-
type Int64Schema struct {
28-
MaxValue int64
29-
MinValue int64
30-
Sum int64
20+
func NewIntSummary[T cmp.Integer](num T, op Operate) IntSummary[T] {
21+
var r = IntSummary[T]{MaxValue: num, MinValue: num, Sum: num}
22+
if op == OperateQuery {
23+
r.Sum = 0
24+
}
25+
return r
3126
}
3227

33-
// Merge 合并摘要信息
34-
func (c Int64Schema) Merge(d Int64Schema) Int64Schema {
35-
return Int64Schema{
36-
MaxValue: algo.Max(c.MaxValue, d.MaxValue),
37-
MinValue: algo.Min(c.MinValue, d.MinValue),
38-
Sum: c.Sum + d.Sum,
28+
func MergeIntSummary[T cmp.Integer](a, b IntSummary[T]) IntSummary[T] {
29+
return IntSummary[T]{
30+
MaxValue: algo.Max(a.MaxValue, b.MaxValue),
31+
MinValue: algo.Min(a.MinValue, b.MinValue),
32+
Sum: a.Sum + b.Sum,
3933
}
4034
}

segment_tree/segement_tree_test.go

Lines changed: 28 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -3,69 +3,41 @@ package segment_tree
33
import (
44
"github.com/lxzan/dao/algo"
55
"github.com/lxzan/dao/internal/utils"
6+
"github.com/stretchr/testify/assert"
67
"testing"
78
)
89

910
func TestSegmentTree_Query(t *testing.T) {
1011
var n = 10000
11-
var arr = make([]Int64, 0)
12+
var arr = make([]int, 0)
1213
for i := 0; i < n; i++ {
13-
arr = append(arr, Int64(utils.Rand.Intn(n)))
14+
arr = append(arr, utils.Rand.Intn(n))
1415
}
15-
16-
var tree = New[Int64Schema, Int64](arr)
17-
18-
for i := 0; i < 1000; i++ {
19-
var left = utils.Rand.Intn(n)
20-
var right = utils.Rand.Intn(n)
21-
if left > right {
22-
left, right = right, left
23-
}
24-
var result1 = tree.Query(left, right+1)
25-
26-
var result2 = Int64Schema{
27-
MaxValue: arr[left].Value(),
28-
MinValue: arr[left].Value(),
29-
Sum: 0,
30-
}
31-
for j := left; j <= right; j++ {
32-
result2.Sum += arr[j].Value()
33-
result2.MaxValue = algo.Max(result2.MaxValue, arr[j].Value())
34-
result2.MinValue = algo.Min(result2.MinValue, arr[j].Value())
35-
}
36-
37-
if result1.Sum != result2.Sum || result1.MinValue != result2.MinValue || result1.MaxValue != result2.MaxValue {
38-
t.Fatal("error!")
39-
}
40-
}
41-
42-
for i := 0; i < 1000; i++ {
43-
var index = utils.Rand.Intn(n)
44-
var value = Int64(utils.Rand.Intn(n))
45-
tree.Update(index, value)
46-
}
47-
48-
for i := 0; i < 1000; i++ {
49-
var left = utils.Rand.Intn(n)
50-
var right = utils.Rand.Intn(n)
51-
if left > right {
52-
left, right = right, left
53-
}
54-
var result1 = tree.Query(left, right+1)
55-
56-
var result2 = Int64Schema{
57-
MaxValue: arr[left].Value(),
58-
MinValue: arr[left].Value(),
59-
Sum: 0,
60-
}
61-
for j := left; j <= right; j++ {
62-
result2.Sum += arr[j].Value()
63-
result2.MaxValue = algo.Max(result2.MaxValue, arr[j].Value())
64-
result2.MinValue = algo.Min(result2.MinValue, arr[j].Value())
65-
}
66-
67-
if result1.Sum != result2.Sum || result1.MinValue != result2.MinValue || result1.MaxValue != result2.MaxValue {
68-
t.Fatal("error!")
16+
var stree = New(arr, NewIntSummary[int], MergeIntSummary[int])
17+
for i := 0; i < 100; i++ {
18+
var x, y = utils.Alphabet.Intn(n), utils.Alphabet.Intn(n)
19+
if x == y {
20+
continue
21+
}
22+
if x > y {
23+
x, y = y, x
24+
}
25+
26+
var flag = utils.Alphabet.Intn(4)
27+
switch flag {
28+
case 0:
29+
stree.Update(x, y)
30+
default:
31+
r0 := stree.Query(x, y)
32+
r1 := NewIntSummary(arr[x], OperateQuery)
33+
for j := x; j < y; j++ {
34+
r1.MaxValue = algo.Max(r1.MaxValue, arr[j])
35+
r1.MinValue = algo.Min(r1.MinValue, arr[j])
36+
r1.Sum += arr[j]
37+
}
38+
assert.Equal(t, r0.MaxValue, r1.MaxValue)
39+
assert.Equal(t, r0.MinValue, r1.MinValue)
40+
assert.Equal(t, r0.Sum, r1.Sum)
6941
}
7042
}
7143
}

segment_tree/segment_tree.go

Lines changed: 27 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,95 +8,85 @@ const (
88
OperateUpdate Operate = 2
99
)
1010

11-
type (
12-
Initer[T any] interface {
13-
Init(op Operate) T
14-
}
15-
16-
Merger[T any] interface {
17-
Merge(T) T
18-
}
19-
)
20-
21-
type Element[S Merger[S], T Initer[S]] struct {
11+
type element[T any, S any] struct {
2212
left int
2313
right int
24-
son *Element[S, T]
25-
daughter *Element[S, T]
14+
son *element[T, S]
15+
daughter *element[T, S]
2616
data S
2717
}
2818

29-
type SegmentTree[S Merger[S], T Initer[S]] struct {
30-
root *Element[S, T]
31-
arr []T
19+
type SegmentTree[T any, S any] struct {
20+
root *element[T, S]
21+
arr []T
22+
newSummary NewSummary[T, S]
23+
mergeSummary MergeSummary[S]
3224
}
3325

34-
func New[S Merger[S], T Initer[S]](arr []T) *SegmentTree[S, T] {
35-
var obj = &SegmentTree[S, T]{
36-
root: &Element[S, T]{
26+
func New[T any, S any](arr []T, newSummary NewSummary[T, S], mergeSummary MergeSummary[S]) *SegmentTree[T, S] {
27+
var obj = &SegmentTree[T, S]{
28+
root: &element[T, S]{
3729
left: 0,
3830
right: len(arr) - 1,
3931
},
40-
arr: arr,
32+
arr: arr,
33+
newSummary: newSummary,
34+
mergeSummary: mergeSummary,
4135
}
4236
obj.build(obj.root)
4337
return obj
4438
}
4539

46-
func (c *SegmentTree[S, T]) build(cur *Element[S, T]) {
40+
func (c *SegmentTree[T, S]) build(cur *element[T, S]) {
4741
if cur.left == cur.right {
48-
cur.data = c.arr[cur.left].Init(OperateCreate)
42+
cur.data = c.newSummary(c.arr[cur.left], OperateCreate)
4943
return
5044
}
51-
5245
var mid = (cur.left + cur.right) / 2
53-
cur.son = &Element[S, T]{
46+
cur.son = &element[T, S]{
5447
left: cur.left,
5548
right: mid,
5649
}
57-
cur.daughter = &Element[S, T]{
50+
cur.daughter = &element[T, S]{
5851
left: mid + 1,
5952
right: cur.right,
6053
}
6154
c.build(cur.son)
6255
c.build(cur.daughter)
63-
cur.data = cur.son.data.Merge(cur.daughter.data)
56+
cur.data = c.mergeSummary(cur.son.data, cur.daughter.data)
6457
}
6558

6659
// Query 查询 begin <= index < end 区间
67-
func (c *SegmentTree[S, T]) Query(begin int, end int) S {
68-
var result S
69-
result = c.arr[begin].Init(OperateQuery)
60+
func (c *SegmentTree[T, S]) Query(begin int, end int) S {
61+
result := c.newSummary(c.arr[begin], OperateQuery)
7062
c.doQuery(c.root, begin, end-1, &result)
7163
return result
7264
}
7365

74-
func (c *SegmentTree[S, T]) doQuery(cur *Element[S, T], left int, right int, result *S) {
66+
func (c *SegmentTree[T, S]) doQuery(cur *element[T, S], left int, right int, result *S) {
7567
if cur.left >= left && cur.right <= right {
76-
*result = cur.data.Merge(*result)
68+
*result = c.mergeSummary(*result, cur.data)
7769
} else if !(cur.left > right || cur.right < left) {
7870
c.doQuery(cur.son, left, right, result)
7971
c.doQuery(cur.daughter, left, right, result)
8072
}
8173
}
8274

8375
// Update 更新
84-
func (c *SegmentTree[S, T]) Update(i int, v T) {
76+
func (c *SegmentTree[T, S]) Update(i int, v T) {
8577
c.arr[i] = v
8678
c.rebuild(c.root, i)
8779
}
8880

89-
func (c *SegmentTree[S, T]) rebuild(cur *Element[S, T], i int) {
81+
func (c *SegmentTree[T, S]) rebuild(cur *element[T, S], i int) {
9082
if !(i >= cur.left && i <= cur.right) {
9183
return
9284
}
93-
9485
if cur.left == cur.right && cur.left == i {
95-
cur.data = c.arr[cur.left].Init(OperateUpdate)
86+
cur.data = c.newSummary(c.arr[cur.left], OperateUpdate)
9687
return
9788
}
98-
9989
c.rebuild(cur.son, i)
10090
c.rebuild(cur.daughter, i)
101-
cur.data = cur.son.data.Merge(cur.daughter.data)
91+
cur.data = c.mergeSummary(cur.son.data, cur.daughter.data)
10292
}

0 commit comments

Comments
 (0)