Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 70 additions & 24 deletions botrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ package botrate

import (
"context"
"os"
"strings"
"testing"
"time"

"github.com/cnlangzi/knownbots"
"golang.org/x/time/rate"
)

Expand Down Expand Up @@ -66,14 +68,50 @@ func TestLimiter_New_WithOptions(t *testing.T) {
}

func TestLimiter_Allow_VerifiedBot(t *testing.T) {
l, err := New()
botDir := t.TempDir()
botConfDir := botDir + "/conf.d"
if err := os.MkdirAll(botConfDir, 0755); err != nil {
t.Fatalf("Failed to create config dir: %v", err)
}

customBotYAML := `kind: SearchEngine
name: testbot
parser: txt
ua: "TestBot"
custom:
- "192.168.100.0/24"
`
if err := os.WriteFile(botConfDir+"/testbot.yaml", []byte(customBotYAML), 0644); err != nil {
t.Fatalf("Failed to write bot config: %v", err)
}

kb, err := knownbots.New(knownbots.WithRoot(botDir))
if err != nil {
t.Fatalf("Failed to create knownbots validator: %v", err)
}
defer kb.Close()

l, err := New(WithKnownbots(kb))
if err != nil {
t.Fatalf("New() returned error: %v", err)
}
defer l.Close()

result := l.Allow("Googlebot/2.1", "66.249.66.1")
_ = result
allowed, reason := l.Allow("TestBot/1.0", "192.168.100.42")
if !allowed {
t.Error("verified bot should be allowed")
}
if reason != "" {
t.Errorf("reason should be empty for allowed request, got %s", reason)
}

allowed, reason = l.Allow("TestBot/1.0", "10.0.0.1")
if allowed {
t.Error("fake bot should be blocked")
}
if reason != ReasonFakeBot {
t.Errorf("expected reason %s, got %s", ReasonFakeBot, reason)
}
}

func TestLimiter_Wait_VerifiedBot(t *testing.T) {
Expand All @@ -83,7 +121,7 @@ func TestLimiter_Wait_VerifiedBot(t *testing.T) {
}
defer l.Close()

err = l.Wait(context.Background(), "Googlebot/2.1", "66.249.66.1")
err, _ = l.Wait(context.Background(), "Googlebot/2.1", "66.249.66.1")
_ = err
}

Expand All @@ -99,7 +137,7 @@ func TestLimiter_Wait_ContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()

err = l.Wait(ctx, "Mozilla/5.0", "192.168.1.1")
err, _ = l.Wait(ctx, "Mozilla/5.0", "192.168.1.1")

if err != nil && err != context.Canceled && err != ErrLimit {
t.Errorf("expected nil, context.Canceled, or ErrLimit, got %v", err)
Expand All @@ -116,7 +154,7 @@ func TestLimiter_Allow_NormalUser(t *testing.T) {
}
defer l.Close()

allowed := l.Allow("Mozilla/5.0", "192.168.1.1")
allowed, _ := l.Allow("Mozilla/5.0", "192.168.1.1")

if !allowed {
t.Error("normal user should be allowed")
Expand All @@ -130,7 +168,7 @@ func TestLimiter_Allow_BotLike(t *testing.T) {
}
defer l.Close()

allowed := l.Allow("Python-urllib/3.11", "192.168.1.1")
allowed, _ := l.Allow("Python-urllib/3.11", "192.168.1.1")
_ = allowed
}

Expand All @@ -144,14 +182,14 @@ func TestLimiter_Allow_BlacklistedIP(t *testing.T) {
}
defer l.Close()

allowed := l.Allow("Mozilla/5.0", "192.168.1.1")
allowed, _ := l.Allow("Mozilla/5.0", "192.168.1.1")
if !allowed {
t.Error("first request should be allowed")
}

time.Sleep(time.Millisecond * 200)

allowed = l.Allow("Mozilla/5.0", "192.168.1.1")
allowed, _ = l.Allow("Mozilla/5.0", "192.168.1.1")
_ = allowed
}

Expand All @@ -165,7 +203,7 @@ func TestLimiter_Wait_NormalUser(t *testing.T) {
}
defer l.Close()

err = l.Wait(context.Background(), "Mozilla/5.0", "192.168.1.1")
err, _ = l.Wait(context.Background(), "Mozilla/5.0", "192.168.1.1")

if err != nil {
t.Errorf("normal user should not return error, got %v", err)
Expand All @@ -184,7 +222,7 @@ func TestLimiter_Wait_BotLike(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50)
defer cancel()

_ = l.Wait(ctx, "Python-urllib/3.11", "192.168.1.1")
_, _ = l.Wait(ctx, "Python-urllib/3.11", "192.168.1.1")
}

func TestLimiter_Close(t *testing.T) {
Expand All @@ -211,7 +249,9 @@ func TestLimiter_Allow_ManyRequests(t *testing.T) {
ip := "192.168.1." + string(rune('0'+i%256))
ua := "UserAgent/" + string(rune('A'+i%26))

if !l.Allow(ua, ip) {
allowed, _ := l.Allow(ua, ip)

if !allowed {
t.Errorf("request %d should be allowed", i)
}
}
Expand All @@ -224,7 +264,8 @@ func TestLimiter_Allow_IPv6(t *testing.T) {
}
defer l.Close()

if !l.Allow("Mozilla/5.0", "2001:0db8:85a3:0000:0000:8a2e:0370:7334") {
allowed, _ := l.Allow("Mozilla/5.0", "2001:0db8:85a3:0000:0000:8a2e:0370:7334")
if !allowed {
t.Error("IPv6 request should be allowed")
}
}
Expand All @@ -236,7 +277,8 @@ func TestLimiter_Allow_EmptyUserAgent(t *testing.T) {
}
defer l.Close()

if !l.Allow("", "192.168.1.1") {
allowed, _ := l.Allow("", "192.168.1.1")
if !allowed {
t.Error("empty UA should be allowed")
}
}
Expand All @@ -248,7 +290,8 @@ func TestLimiter_Allow_EmptyIP(t *testing.T) {
}
defer l.Close()

if !l.Allow("Mozilla/5.0", "") {
allowed, _ := l.Allow("Mozilla/5.0", "")
if !allowed {
t.Error("empty IP should be allowed")
}
}
Expand All @@ -266,8 +309,8 @@ func TestLimiter_WithKnownbots(t *testing.T) {
}
defer l2.Close()

_ = l1.Allow("Googlebot/2.1", "66.249.66.1")
_ = l2.Allow("Googlebot/2.1", "66.249.66.1")
_, _ = l1.Allow("Googlebot/2.1", "66.249.66.1")
_, _ = l2.Allow("Googlebot/2.1", "66.249.66.1")
}

func TestLimiter_RateLimitPersistence(t *testing.T) {
Expand All @@ -281,8 +324,8 @@ func TestLimiter_RateLimitPersistence(t *testing.T) {
}
defer l.Close()

_ = l.Allow("Python-urllib/3.11", "192.168.1.1")
_ = l.Allow("Python-urllib/3.11", "192.168.1.1")
_, _ = l.Allow("Python-urllib/3.11", "192.168.1.1")
_, _ = l.Allow("Python-urllib/3.11", "192.168.1.1")
}

func TestLimiter_DifferentBots(t *testing.T) {
Expand All @@ -298,7 +341,7 @@ func TestLimiter_DifferentBots(t *testing.T) {
}

for _, bot := range bots {
_ = l.Allow(bot, "66.249.66.1")
_, _ = l.Allow(bot, "66.249.66.1")
}
}

Expand Down Expand Up @@ -334,7 +377,7 @@ func TestLimiter_BotScenarios(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
allowed := l.Allow(tc.ua, tc.ip)
allowed, _ := l.Allow(tc.ua, tc.ip)
_ = allowed
})
}
Expand All @@ -355,7 +398,7 @@ func TestLimiter_InvalidIPFormat(t *testing.T) {
}

for _, ip := range invalidIPs {
_ = l.Allow("Mozilla/5.0", ip)
_, _ = l.Allow("Mozilla/5.0", ip)
}
}

Expand All @@ -368,7 +411,8 @@ func TestLimiter_LongUserAgent(t *testing.T) {

longUA := strings.Repeat("Mozilla/5.0 ", 1000)

if !l.Allow(longUA, "192.168.1.1") {
allowed, _ := l.Allow(longUA, "192.168.1.1")
if !allowed {
t.Error("long UA should be allowed")
}
}
Expand All @@ -381,9 +425,11 @@ func TestLimiter_LongPath(t *testing.T) {
defer l.Close()

longPath := "/" + strings.Repeat("a", 10000)
_, _ = l.Allow("Mozilla/5.0", "192.168.1.1")
_ = longPath

if !l.Allow("Mozilla/5.0", "192.168.1.1") {
allowed, _ := l.Allow("Mozilla/5.0", "192.168.1.1")
if !allowed {
t.Error("long path should be allowed")
}
}
Expand Down
3 changes: 2 additions & 1 deletion example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ func main() {
ua := r.UserAgent()
ip := extractIP(r)

if !limiter.Allow(ua, ip) {
allowed, _ := limiter.Allow(ua, ip)
if !allowed {
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
return
}
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ go 1.22

require (
github.com/bits-and-blooms/bloom/v3 v3.7.1
github.com/cnlangzi/knownbots v1.0.4
github.com/cnlangzi/knownbots v1.0.5
golang.org/x/time v0.7.0
)

Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ github.com/bits-and-blooms/bitset v1.24.2 h1:M7/NzVbsytmtfHbumG+K2bremQPMJuqv1JD
github.com/bits-and-blooms/bitset v1.24.2/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8=
github.com/bits-and-blooms/bloom/v3 v3.7.1 h1:WXovk4TRKZttAMJfoQx6K2DM0zNIt8w+c67UqO+etV0=
github.com/bits-and-blooms/bloom/v3 v3.7.1/go.mod h1:rZzYLLje2dfzXfAkJNxQQHsKurAyK55KUnL43Euk0hU=
github.com/cnlangzi/knownbots v1.0.4 h1:vyEqWKGf+2j4wlfYE1uNetXzwmPnajBwGia6IN12LwM=
github.com/cnlangzi/knownbots v1.0.4/go.mod h1:dDHujBVMOX5YDalVjmBfVzC3AwMTpCDMnB+mo+0DLUU=
github.com/cnlangzi/knownbots v1.0.5 h1:rKRhnXDjG0k0gfwo5ikHleeqnbjgSYwDyM+lfXy1QRE=
github.com/cnlangzi/knownbots v1.0.5/go.mod h1:dDHujBVMOX5YDalVjmBfVzC3AwMTpCDMnB+mo+0DLUU=
github.com/twmb/murmur3 v1.1.8 h1:8Yt9taO/WN3l08xErzjeschgZU2QSrwm1kclYq+0aRg=
github.com/twmb/murmur3 v1.1.8/go.mod h1:Qq/R7NUyOfr65zD+6Q5IHKsJLwP7exErjN6lyyq3OSQ=
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
Expand Down
54 changes: 40 additions & 14 deletions limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,19 @@ var (
DefaultQueueCap = 10000
)

// Reason represents the reason for rate limiting.
type Reason string

const (
// ReasonFakeBot indicates the request was blocked because
// the bot verification failed (fake bot or unknown status).
ReasonFakeBot Reason = "fake_bot"

// ReasonRateLimited indicates the request was blocked because
// the IP was flagged by behavior analysis.
ReasonRateLimited Reason = "rate_limited"
)

// Limiter provides bot-aware rate limiting.
type Limiter struct {
cfg Config
Expand Down Expand Up @@ -65,65 +78,78 @@ func New(opts ...Option) (*Limiter, error) {
}

// Allow reports whether the request should proceed.
// Returns true if allowed, false if blocked.
func (l *Limiter) Allow(ua, ip string) bool {
// Returns:
// - allowed: true if allowed, false if blocked
// - reason: the reason for blocking when allowed is false
func (l *Limiter) Allow(ua, ip string) (allowed bool, reason Reason) {
// Layer 1: Bot verification
botResult := l.kb.Validate(ua, ip)

if botResult.IsBot {
switch botResult.Status {
case knownbots.StatusVerified:
// Verified bot: allow without rate limit
return true
return true, ""
case knownbots.StatusPending:
// RDNS lookup failed, allow and retry verification next time
return true
return true, ""
case knownbots.StatusFailed, knownbots.StatusUnknown:
// Fake bot (failed verification) or unknown: block immediately
return false
return false, ReasonFakeBot
}
}

// Layer 2: Blocklist check (only for normal users)
if l.analyzer.Blocked(ip) {
// Behavior anomaly: apply rate limit
return l.allowBlocked(ip)
if l.allowBlocked(ip) {
return true, ""
}
return false, ReasonRateLimited
}

// Layer 3: Normal user + not blocked
l.analyzer.Record(ip, ua)
return true
return true, ""
}

// Wait blocks until the request is allowed or the context is canceled.
// Returns nil if allowed, error if blocked or context canceled.
func (l *Limiter) Wait(ctx context.Context, ua, ip string) error {
// Returns:
// - err: nil if allowed, otherwise the blocking error (context canceled/timeout or ErrLimit)
// - reason: the reason for blocking (ReasonFakeBot or ReasonRateLimited)
func (l *Limiter) Wait(ctx context.Context, ua, ip string) (err error, reason Reason) {
// Layer 1: Bot verification
botResult := l.kb.Validate(ua, ip)

if botResult.IsBot {
switch botResult.Status {
case knownbots.StatusVerified:
// Verified bot: no rate limit needed
return nil
return nil, ""
case knownbots.StatusPending:
// RDNS lookup failed, allow and retry verification next time
return nil
return nil, ""
case knownbots.StatusFailed, knownbots.StatusUnknown:
// Fake bot: block immediately
return ErrLimit
return ErrLimit, ReasonFakeBot
}
}

// Layer 2: Blocklist check (only for normal users)
if l.analyzer.Blocked(ip) {
// Behavior anomaly: apply rate limit
return l.waitBlocked(ctx, ip)
err = l.waitBlocked(ctx, ip)
if err != nil {
// Context canceled/timeout while waiting
return err, ReasonRateLimited
}
// Rate limit hit (wait returned without error but context still active)
return ErrLimit, ReasonRateLimited
}

// Layer 3: Normal user + not blocked
l.analyzer.Record(ip, ua)
return nil
return nil, ""
}

func (l *Limiter) allowBlocked(ip string) bool {
Expand Down