forked from go-gorm/dbresolver
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathpolicy_test.go
More file actions
519 lines (427 loc) · 14.5 KB
/
policy_test.go
File metadata and controls
519 lines (427 loc) · 14.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
package dbresolver
import (
"context"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"gorm.io/gorm"
)
func TestPolicy_RoundRobinPolicy(t *testing.T) {
var p1, p2, p3 gorm.ConnPool
var pools = []gorm.ConnPool{
p1, p2, p3,
}
for i := 0; i < 10; i++ {
if pools[i%3] != RoundRobinPolicy().Resolve(pools) {
t.Errorf("RoundRobinPolicy failed")
}
if pools[i%3] != StrictRoundRobinPolicy().Resolve(pools) {
t.Errorf("StrictRoundRobinPolicy failed")
}
}
}
func BenchmarkPolicy_StrictRoundRobinPolicy(b *testing.B) {
var p1, p2, p3 gorm.ConnPool
var pools = []gorm.ConnPool{
p1, p2, p3,
}
var i int64
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
if pools[int(atomic.AddInt64(&i, 1))%3] != StrictRoundRobinPolicy().Resolve(pools) {
b.Errorf("RoundRobinPolicy failed")
}
}
})
}
func TestHealthTracker_SingleSuccess(t *testing.T) {
// Test backward compatible behavior - 1 success should immediately recover
tracker := NewHealthTracker(100 * time.Millisecond)
pool := &struct{ gorm.ConnPool }{} // Create distinct instance // Use nil ConnPool as identifier (pointer-based key)
// Mark as bad
tracker.MarkBad(pool)
if !tracker.IsBad(pool) {
t.Error("Pool should be bad immediately after MarkBad")
}
// Wait for cooldown
time.Sleep(150 * time.Millisecond)
// Should now be in half-open state (not bad, can be probed)
if tracker.IsBad(pool) {
t.Error("Pool should not be bad after cooldown (half-open state)")
}
// Check state
state := tracker.GetState(pool)
if state != "probing (0/1)" {
t.Errorf("Expected state 'probing (0/1)', got '%s'", state)
}
// One success should fully recover (successesNeeded = 1)
tracker.MarkHealthy(pool)
if tracker.IsBad(pool) {
t.Error("Pool should be healthy after 1 success")
}
state = tracker.GetState(pool)
if state != "healthy" {
t.Errorf("Expected state 'healthy', got '%s'", state)
}
}
func TestHealthTracker_MultipleSuccessesRequired(t *testing.T) {
// Test requiring 3 consecutive successes
tracker := NewHealthTrackerWithSuccesses(100*time.Millisecond, 3)
pool := &struct{ gorm.ConnPool }{} // Create distinct instance
// Mark as bad
tracker.MarkBad(pool)
if !tracker.IsBad(pool) {
t.Error("Pool should be bad immediately after MarkBad")
}
// Wait for cooldown
time.Sleep(150 * time.Millisecond)
// Should now be in half-open state
if tracker.IsBad(pool) {
t.Error("Pool should not be bad after cooldown (half-open state)")
}
// First success - should still be probing
tracker.MarkHealthy(pool)
state := tracker.GetState(pool)
if state != "probing (1/3)" {
t.Errorf("Expected state 'probing (1/3)', got '%s'", state)
}
// Second success - should still be probing
tracker.MarkHealthy(pool)
state = tracker.GetState(pool)
if state != "probing (2/3)" {
t.Errorf("Expected state 'probing (2/3)', got '%s'", state)
}
// Third success - should now be fully healthy
tracker.MarkHealthy(pool)
state = tracker.GetState(pool)
if state != "healthy" {
t.Errorf("Expected state 'healthy', got '%s'", state)
}
if tracker.IsBad(pool) {
t.Error("Pool should be healthy after 3 successes")
}
}
func TestHealthTracker_FailureResetsCounter(t *testing.T) {
// Test that failure during probing resets the counter
tracker := NewHealthTrackerWithSuccesses(100*time.Millisecond, 3)
pool := &struct{ gorm.ConnPool }{} // Create distinct instance
// Mark as bad and wait for cooldown
tracker.MarkBad(pool)
time.Sleep(150 * time.Millisecond)
// Two successes
tracker.MarkHealthy(pool)
tracker.MarkHealthy(pool)
state := tracker.GetState(pool)
if state != "probing (2/3)" {
t.Errorf("Expected state 'probing (2/3)', got '%s'", state)
}
// Failure should reset counter and restart cooldown
tracker.MarkBad(pool)
if !tracker.IsBad(pool) {
t.Error("Pool should be bad immediately after MarkBad")
}
// Wait for cooldown again
time.Sleep(150 * time.Millisecond)
// Counter should be reset to 0
state = tracker.GetState(pool)
if state != "probing (0/3)" {
t.Errorf("Expected state 'probing (0/3)' after reset, got '%s'", state)
}
// Should need 3 successes again from scratch
tracker.MarkHealthy(pool)
state = tracker.GetState(pool)
if state != "probing (1/3)" {
t.Errorf("Expected state 'probing (1/3)', got '%s'", state)
}
}
func TestHealthTracker_NoProbingDuringCooldown(t *testing.T) {
// Test that pool is excluded during cooldown period
tracker := NewHealthTrackerWithSuccesses(200*time.Millisecond, 2)
pool := &struct{ gorm.ConnPool }{} // Create distinct instance
// Mark as bad
tracker.MarkBad(pool)
// Should be bad immediately
if !tracker.IsBad(pool) {
t.Error("Pool should be bad during cooldown")
}
state := tracker.GetState(pool)
if state != "bad" {
t.Errorf("Expected state 'bad', got '%s'", state)
}
// Wait 100ms (half of cooldown)
time.Sleep(100 * time.Millisecond)
// Should still be bad
if !tracker.IsBad(pool) {
t.Error("Pool should still be bad before cooldown expires")
}
// Wait for cooldown to expire
time.Sleep(150 * time.Millisecond)
// Should now be in half-open state (not bad)
if tracker.IsBad(pool) {
t.Error("Pool should not be bad after cooldown expires")
}
state = tracker.GetState(pool)
if state != "probing (0/2)" {
t.Errorf("Expected state 'probing (0/2)', got '%s'", state)
}
}
func TestHealthTracker_FlappingPrevention(t *testing.T) {
// Test that requiring multiple successes prevents flapping
tracker := NewHealthTrackerWithSuccesses(500*time.Millisecond, 5)
pool := &struct{ gorm.ConnPool }{} // Create distinct instance
// Simulate flapping scenario: bad → cooldown → 4 successes → failure
tracker.MarkBad(pool)
time.Sleep(550 * time.Millisecond)
// 4 successes (not enough for recovery)
for i := 0; i < 4; i++ {
tracker.MarkHealthy(pool)
}
state := tracker.GetState(pool)
if state != "probing (4/5)" {
t.Errorf("Expected state 'probing (4/5)', got '%s'", state)
}
// Failure before reaching threshold
tracker.MarkBad(pool)
// Should be back to bad state
if !tracker.IsBad(pool) {
t.Error("Pool should be bad after failure during probing")
}
state = tracker.GetState(pool)
if state != "bad" {
t.Errorf("Expected state 'bad', got '%s'", state)
}
// Wait for cooldown
time.Sleep(550 * time.Millisecond)
// Counter should be reset
state = tracker.GetState(pool)
if state != "probing (0/5)" {
t.Errorf("Expected state 'probing (0/5)' after reset, got '%s'", state)
}
}
func TestHealthTracker_MultiplePools(t *testing.T) {
// Test that multiple pools are tracked independently
tracker := NewHealthTrackerWithSuccesses(100*time.Millisecond, 2)
var pool1, pool2 gorm.ConnPool
pool1 = &struct{ gorm.ConnPool }{} // Different instances for tracking
pool2 = &struct{ gorm.ConnPool }{}
// Mark both as bad
tracker.MarkBad(pool1)
tracker.MarkBad(pool2)
time.Sleep(150 * time.Millisecond)
// One success on pool1
tracker.MarkHealthy(pool1)
// Check states are independent
state1 := tracker.GetState(pool1)
state2 := tracker.GetState(pool2)
if state1 != "probing (1/2)" {
t.Errorf("Pool1: expected 'probing (1/2)', got '%s'", state1)
}
if state2 != "probing (0/2)" {
t.Errorf("Pool2: expected 'probing (0/2)', got '%s'", state2)
}
// Recover pool1 fully
tracker.MarkHealthy(pool1)
if tracker.GetState(pool1) != "healthy" {
t.Error("Pool1 should be healthy")
}
// Pool2 should still be probing
if tracker.GetState(pool2) != "probing (0/2)" {
t.Error("Pool2 should still be probing")
}
}
func TestHealthTracker_ConcurrentAccess(t *testing.T) {
// Test thread safety with concurrent operations
tracker := NewHealthTrackerWithSuccesses(50*time.Millisecond, 3)
pool := &struct{ gorm.ConnPool }{} // Create distinct instance
var wg sync.WaitGroup
iterations := 100
// Concurrent MarkBad calls
for i := 0; i < iterations; i++ {
wg.Add(1)
go func() {
defer wg.Done()
tracker.MarkBad(pool)
}()
}
wg.Wait()
// Should be bad
if !tracker.IsBad(pool) {
t.Error("Pool should be bad after concurrent MarkBad calls")
}
// Wait for cooldown
time.Sleep(100 * time.Millisecond)
// Concurrent MarkHealthy calls
for i := 0; i < iterations; i++ {
wg.Add(1)
go func() {
defer wg.Done()
tracker.MarkHealthy(pool)
}()
}
wg.Wait()
// Should be healthy (counter should exceed threshold)
state := tracker.GetState(pool)
if state != "healthy" {
t.Errorf("Pool should be healthy after concurrent MarkHealthy calls, got state: %s", state)
}
}
func TestHealthTracker_ZeroSuccessesNeeded(t *testing.T) {
// Test that zero or negative successesNeeded defaults to 1
tracker := NewHealthTrackerWithSuccesses(100*time.Millisecond, 0)
pool := &struct{ gorm.ConnPool }{} // Create distinct instance
tracker.MarkBad(pool)
time.Sleep(150 * time.Millisecond)
// Should require at least 1 success
tracker.MarkHealthy(pool)
if tracker.GetState(pool) != "healthy" {
t.Error("Pool should be healthy after 1 success (default minimum)")
}
}
func TestHealthTracker_isTracking(t *testing.T) {
// Test the isTracking helper method
tracker := NewHealthTrackerWithSuccesses(100*time.Millisecond, 3)
pool := &struct{ gorm.ConnPool }{} // Create distinct instance
// Initially not tracking
tracking, current, needed := tracker.isTracking(pool)
if tracking {
t.Error("Should not be tracking initially")
}
if current != 0 || needed != 3 {
t.Errorf("Expected 0/3, got %d/%d", current, needed)
}
// Mark as bad
tracker.MarkBad(pool)
time.Sleep(150 * time.Millisecond)
// Should be tracking
tracking, current, needed = tracker.isTracking(pool)
if !tracking {
t.Error("Should be tracking after MarkBad")
}
if current != 0 || needed != 3 {
t.Errorf("Expected 0/3, got %d/%d", current, needed)
}
// Add a success
tracker.MarkHealthy(pool)
tracking, current, needed = tracker.isTracking(pool)
if !tracking {
t.Error("Should still be tracking after 1 success")
}
if current != 1 || needed != 3 {
t.Errorf("Expected 1/3, got %d/%d", current, needed)
}
// Fully recover
tracker.MarkHealthy(pool)
tracker.MarkHealthy(pool)
tracking, current, needed = tracker.isTracking(pool)
if tracking {
t.Error("Should not be tracking after full recovery")
}
if current != 0 || needed != 3 {
t.Errorf("Expected 0/3, got %d/%d", current, needed)
}
}
// mockNetError implements net.Error for testing
type mockNetError struct{ msg string }
func (e *mockNetError) Error() string { return e.msg }
func (e *mockNetError) Timeout() bool { return false }
func (e *mockNetError) Temporary() bool { return false }
func TestDefaultErrorClassifier(t *testing.T) {
tests := []struct {
name string
err error
wantBad bool
}{
{"nil error", nil, false},
{"non-network error", errors.New("record not found"), false},
{"sql no rows", errors.New("sql: no rows in result set"), false},
// Unambiguous instance-level failures
{"connection refused", errors.New("connection refused"), true},
{"no such host", errors.New("dial tcp: no such host"), true},
{"net.Error (timeout)", &mockNetError{msg: "connection timed out"}, true},
{"pgx read-only", errors.New("ValidateConnect failed: not read only"), true},
// Ambiguous connection-level errors — NOT classified by default
{"broken pipe", errors.New("write: broken pipe"), false},
{"connection reset", errors.New("read: connection reset by peer"), false},
{"server closed", errors.New("server closed the connection"), false},
{"eof", errors.New("EOF"), false},
{"i/o timeout string", errors.New("i/o timeout"), false},
// Application-level errors — never a replica failure
{"context canceled", context.Canceled, false},
{"context deadline exceeded", context.DeadlineExceeded, false},
{"wrapped context deadline", fmt.Errorf("query failed: %w", context.DeadlineExceeded), false},
// Connection pool exhaustion — concurrency issue, not a bad replica
{"pool exhausted", errors.New("timeout acquiring conn from pool"), false},
{"conn busy", errors.New("conn busy"), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := DefaultErrorClassifier(tt.err)
if got != tt.wantBad {
t.Errorf("DefaultErrorClassifier(%v) = %v, want %v", tt.err, got, tt.wantBad)
}
})
}
}
func TestStrictConnectionErrorClassifier(t *testing.T) {
tests := []struct {
name string
err error
wantBad bool
}{
{"nil error", nil, false},
{"non-network error", errors.New("record not found"), false},
{"sql no rows", errors.New("sql: no rows in result set"), false},
// Inherits all of DefaultErrorClassifier
{"connection refused", errors.New("connection refused"), true},
{"no such host", errors.New("dial tcp: no such host"), true},
{"net.Error", &mockNetError{msg: "connection timed out"}, true},
{"pgx read-only", errors.New("ValidateConnect failed: not read only"), true},
// Additional ambiguous connection-level errors
{"broken pipe", errors.New("write: broken pipe"), true},
{"connection reset", errors.New("read: connection reset by peer"), true},
{"server closed", errors.New("server closed the connection"), true},
{"eof lowercase", errors.New("unexpected eof"), true},
{"EOF uppercase", errors.New("EOF"), true},
{"i/o timeout", errors.New("i/o timeout"), true},
// Still excluded even in strict mode — these are never replica failures
{"context canceled", context.Canceled, false},
{"context deadline exceeded", context.DeadlineExceeded, false},
{"wrapped context deadline", fmt.Errorf("query failed: %w", context.DeadlineExceeded), false},
{"pool exhausted", errors.New("timeout acquiring conn from pool"), false},
{"conn busy", errors.New("conn busy"), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := StrictConnectionErrorClassifier(tt.err)
if got != tt.wantBad {
t.Errorf("StrictConnectionErrorClassifier(%v) = %v, want %v", tt.err, got, tt.wantBad)
}
})
}
}
func TestErrorClassifier_DefaultIsStrictSubset(t *testing.T) {
// Anything classified as bad by Default should also be bad in Strict
errs := []error{
nil,
errors.New("connection refused"),
errors.New("no such host"),
errors.New("broken pipe"),
errors.New("EOF"),
errors.New("connection reset by peer"),
errors.New("server closed the connection"),
errors.New("i/o timeout"),
&mockNetError{msg: "timeout"},
errors.New("ValidateConnect failed: not read only"),
errors.New("record not found"),
fmt.Errorf("wrapped: %w", &net.OpError{Op: "dial", Err: errors.New("refused")}),
}
for _, err := range errs {
if DefaultErrorClassifier(err) && !StrictConnectionErrorClassifier(err) {
t.Errorf("Default classified %v as bad but Strict did not — Strict must be a superset", err)
}
}
}