Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 26 additions & 20 deletions internal/adapters/repository/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -458,33 +458,39 @@ func (r *PostgresRepository) CreateZoneWithRecords(ctx context.Context, zone *do
}()

// 1. Insert Zone
zoneQuery := `INSERT INTO dns_zones (id, tenant_id, name, vpc_id, description, role, master_server, created_at, updated_at)
zoneQuery := `INSERT INTO dns_zones (id, tenant_id, name, vpc_id, description, role, master_server, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`
_, errExec := tx.ExecContext(ctx, zoneQuery, zone.ID, zone.TenantID, zone.Name, zone.VPCID, zone.Description, zone.Role, zone.MasterServer, zone.CreatedAt, zone.UpdatedAt)
if errExec != nil {
return errExec
}

// 2. Insert Records row-by-row (UNNEST batch not used here — sqlmock
// doesn't support slice args for UNNEST in transaction context)
recordQuery := `INSERT INTO dns_records (id, zone_id, name, type, content, ttl, priority, weight, port, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)`
for _, rec := range records {
priority := 0
if rec.Priority != nil {
priority = *rec.Priority
}
weight := 0
if rec.Weight != nil {
weight = *rec.Weight
}
port := 0
if rec.Port != nil {
port = *rec.Port
// 2. Batch insert records using multi-row VALUES
if len(records) > 0 {
valueStrings := make([]string, 0, len(records))
valueArgs := make([]interface{}, 0, len(records)*11)
for i, rec := range records {
offset := i * 11
valueStrings = append(valueStrings, fmt.Sprintf("($%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d)",
offset+1, offset+2, offset+3, offset+4, offset+5, offset+6, offset+7, offset+8, offset+9, offset+10, offset+11))
priority := 0
if rec.Priority != nil {
priority = *rec.Priority
}
weight := 0
if rec.Weight != nil {
weight = *rec.Weight
}
port := 0
if rec.Port != nil {
port = *rec.Port
}
valueArgs = append(valueArgs, rec.ID, rec.ZoneID, rec.Name, rec.Type, rec.Content, rec.TTL, priority, weight, port, rec.CreatedAt, rec.UpdatedAt)
}
_, errExecRecord := tx.ExecContext(ctx, recordQuery, rec.ID, rec.ZoneID, rec.Name, rec.Type, rec.Content, rec.TTL, priority, weight, port, rec.CreatedAt, rec.UpdatedAt)
if errExecRecord != nil {
return errExecRecord
batchQuery := fmt.Sprintf("INSERT INTO dns_records (id, zone_id, name, type, content, ttl, priority, weight, port, created_at, updated_at) VALUES %s", strings.Join(valueStrings, ","))
_, errExec = tx.ExecContext(ctx, batchQuery, valueArgs...)
if errExec != nil {
return errExec
}
}

Expand Down
84 changes: 55 additions & 29 deletions internal/dns/server/ratelimit.go
Original file line number Diff line number Diff line change
@@ -1,54 +1,90 @@
package server

import (
"container/heap"
"sync"
"time"
)

// rateLimiter implements a simple per-IP token bucket
// rateLimiter implements a simple per-IP token bucket with O(1) eviction.
type rateLimiter struct {
mu sync.Mutex
buckets map[string]*bucket
rate float64 // tokens per second
burst int // max tokens
maxBuckets int // maximum buckets to store (bounds memory)
idleHeap bucketIdleHeap
}

type bucket struct {
tokens float64
tokens float64
last time.Time
heapIdx int // index in idleHeap, -1 if not in heap
}

type bucketIdleEntry struct {
ip string
b *bucket
}

type bucketIdleHeap []*bucketIdleEntry

func (h bucketIdleHeap) Len() int { return len(h) }
func (h bucketIdleHeap) Less(i, j int) bool {
return h[i].b.last.Before(h[j].b.last)
}
func (h bucketIdleHeap) Swap(i, j int) {
h[i], h[j] = h[j], h[i]
}
func (h *bucketIdleHeap) Push(x any) {
*h = append(*h, x.(*bucketIdleEntry))
}
func (h *bucketIdleHeap) Pop() any {
old := *h
n := len(old)
item := old[n-1]
*h = old[:n-1]
return item
}

func newRateLimiter(rate float64, burst int, maxBuckets int) *rateLimiter {
h := bucketIdleHeap{}
heap.Init(&h)
return &rateLimiter{
buckets: make(map[string]*bucket),
rate: rate,
burst: burst,
maxBuckets: maxBuckets,
idleHeap: h,
}
}

func (rl *rateLimiter) Allow(ip string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()

now := time.Now()
b, exists := rl.buckets[ip]
if !exists {
// Evict an idle bucket if at capacity
if len(rl.buckets) >= rl.maxBuckets {
rl.evictIdleBucket()
rl.evictOldestBucket()
}
b = &bucket{
tokens: float64(rl.burst),
last: time.Now(),
last: now,
}
entry := &bucketIdleEntry{ip: ip, b: b}
heap.Push(&rl.idleHeap, entry)
b.heapIdx = len(rl.idleHeap) - 1
rl.buckets[ip] = b
}

now := time.Now()
elapsed := now.Sub(b.last).Seconds()
b.last = now

// Update heap position after last change
heap.Fix(&rl.idleHeap, b.heapIdx)

// Refill
b.tokens += elapsed * rl.rate
if b.tokens > float64(rl.burst) {
Expand All @@ -64,31 +100,20 @@ func (rl *rateLimiter) Allow(ip string) bool {
return false
}

// evictIdleBucket removes a bucket that hasn't been used recently.
// Performs a bounded scan of up to 8 entries to find an idle bucket.
func (rl *rateLimiter) evictIdleBucket() {
now := time.Now()
found := -1
foundIP := ""
count := 0
for ip, b := range rl.buckets {
if now.Sub(b.last) > 1*time.Minute {
delete(rl.buckets, ip)
return
// evictOldestBucket removes the bucket with the oldest last timestamp in O(log n).
func (rl *rateLimiter) evictOldestBucket() {
for len(rl.idleHeap) > 0 {
entry := heap.Pop(&rl.idleHeap).(*bucketIdleEntry)
if entry == nil {
continue
}
if found == -1 {
found = count
foundIP = ip
}
count++
if count >= 8 {
break
// If bucket still exists in map, delete it; otherwise it was already
// evicted by Cleanup() and this is a stale heap entry — discard it.
if _, ok := rl.buckets[entry.ip]; ok {
delete(rl.buckets, entry.ip)
return
}
}
// If no idle bucket found, evict the first candidate
if foundIP != "" {
delete(rl.buckets, foundIP)
}
}

// Cleanup removes old buckets to prevent memory leaks.
Expand All @@ -99,6 +124,7 @@ func (rl *rateLimiter) Cleanup() {
now := time.Now()
for ip, b := range rl.buckets {
if now.Sub(b.last) > 10*time.Minute {
heap.Remove(&rl.idleHeap, b.heapIdx)
delete(rl.buckets, ip)
}
}
Expand All @@ -117,4 +143,4 @@ func (rl *rateLimiter) CleanupLoop(done <-chan struct{}) {
rl.Cleanup()
}
}
}
}
Loading