Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions cmd/clouddns/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,21 @@ func run(ctx context.Context) error {
}
}
}()

// Periodic zone/record count metrics
var metricsCollector *metrics.DerivedMetricCollector
go func() {
interval := 30 * time.Second
if os.Getenv("TEST_MODE") == "true" {
interval = 10 * time.Millisecond
}
counter := metrics.NewZoneRecordCounter(repo, interval)
counter.Start(runCtx)
metricsCollector = metrics.NewDerivedMetricCollector(interval)
<-runCtx.Done()
counter.Stop()
metricsCollector.Stop()
}()
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}

var cacheInvalidator ports.CacheInvalidator
Expand Down
48 changes: 46 additions & 2 deletions internal/core/services/dnssec_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import (
"crypto/rand"
"crypto/x509"
"fmt"
"log/slog"
"math"
"strings"
"time"

"github.com/google/uuid"
Expand All @@ -19,12 +21,13 @@ import (

// DNSSECService provides functionality for managing DNSSEC keys and signing RRsets.
type DNSSECService struct {
repo ports.DNSRepository
repo ports.DNSRepository
logger *slog.Logger
}

// NewDNSSECService creates and returns a new DNSSECService instance.
func NewDNSSECService(repo ports.DNSRepository) *DNSSECService {
return &DNSSECService{repo: repo}
return &DNSSECService{repo: repo, logger: slog.Default()}
}

// GenerateKey creates a new ECDSA P-256 key pair for a zone
Expand Down Expand Up @@ -199,3 +202,44 @@ func (s *DNSSECService) SignRRSet(ctx context.Context, zoneName string, zoneID s

return sigs, nil
}

// KeyStats holds DNSSEC key statistics for metrics.
type KeyStats struct {
ZoneID string
ZoneName string
KeyType string
Algorithm int
AgeSeconds float64
}

// CollectKeyStats returns statistics for all active DNSSEC keys.
// Used by the metrics collector to update DNSSEC key age metrics.
func (s *DNSSECService) CollectKeyStats(ctx context.Context) ([]KeyStats, error) {
zones, err := s.repo.ListZones(ctx, "")
if err != nil {
return nil, err
}

var stats []KeyStats
now := time.Now()
for _, zone := range zones {
keys, err := s.repo.ListKeysForZone(ctx, zone.ID)
if err != nil {
s.logger.Debug("failed to list keys for zone", "zone", zone.Name, "error", err)
continue
}
for _, k := range keys {
if !k.Active {
continue
}
stats = append(stats, KeyStats{
ZoneID: zone.ID,
ZoneName: zone.Name,
KeyType: strings.ToLower(k.KeyType),
Algorithm: k.Algorithm,
AgeSeconds: now.Sub(k.CreatedAt).Seconds(),
})
}
}
return stats, nil
}
65 changes: 60 additions & 5 deletions internal/core/services/dnssec_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@ import (
)

type mockDNSSECRepo struct {
keys []domain.DNSSECKey
err error
keys []domain.DNSSECKey
zones []domain.Zone
err error
keysErr error
listErr error
}

func (m *mockDNSSECRepo) GetRecords(_ context.Context, _ string, _ domain.RecordType, _ string) ([]domain.Record, error) {
Expand Down Expand Up @@ -44,7 +47,10 @@ func (m *mockDNSSECRepo) CreateZoneWithRecords(_ context.Context, _ *domain.Zone
func (m *mockDNSSECRepo) CreateRecord(_ context.Context, _ *domain.Record) error { return nil }
func (m *mockDNSSECRepo) BatchCreateRecords(_ context.Context, _ []domain.Record) error { return nil }
func (m *mockDNSSECRepo) ListZones(_ context.Context, _ string) ([]domain.Zone, error) {
return nil, nil
if m.listErr != nil {
return nil, m.listErr
}
return m.zones, nil
}
func (m *mockDNSSECRepo) DeleteZone(_ context.Context, _, _ string) error { return nil }
func (m *mockDNSSECRepo) DeleteRecord(_ context.Context, _, _, _ string) error { return nil }
Expand Down Expand Up @@ -90,8 +96,8 @@ func (m *mockDNSSECRepo) CreateKey(_ context.Context, key *domain.DNSSECKey) err
}

func (m *mockDNSSECRepo) ListKeysForZone(_ context.Context, zoneID string) ([]domain.DNSSECKey, error) {
if m.err != nil {
return nil, m.err
if m.keysErr != nil {
return nil, m.keysErr
}
var result []domain.DNSSECKey
for _, k := range m.keys {
Expand Down Expand Up @@ -292,6 +298,55 @@ func TestSignRRSet(t *testing.T) {
}
}

// TestCollectKeyStats_AllZonesFail verifies that CollectKeyStats returns an
// empty slice (not an error) when ListZones succeeds but all ListKeysForZone
// calls fail. This is the "all-zones-fail" edge case.
func TestCollectKeyStats_AllZonesFail(t *testing.T) {
repo := &mockDNSSECRepo{
zones: []domain.Zone{
{ID: "z1", Name: "example.com."},
{ID: "z2", Name: "test.com."},
},
keysErr: errors.New("db error on ListKeysForZone"),
}
svc := NewDNSSECService(repo)
ctx := context.Background()

stats, err := svc.CollectKeyStats(ctx)
if err != nil {
t.Fatalf("CollectKeyStats should not return error on ListKeysForZone failure, got: %v", err)
}
if len(stats) != 0 {
t.Errorf("Expected empty stats slice when all zones fail, got %d", len(stats))
}
}

// TestCollectKeyStats_Normal verifies CollectKeyStats returns correct stats
// when keys exist for zones.
func TestCollectKeyStats_Normal(t *testing.T) {
repo := &mockDNSSECRepo{
zones: []domain.Zone{
{ID: "z1", Name: "example.com."},
},
keys: []domain.DNSSECKey{
{ID: "k1", ZoneID: "z1", KeyType: "ZSK", Active: true, Algorithm: 13, CreatedAt: time.Now()},
},
}
svc := NewDNSSECService(repo)
ctx := context.Background()

stats, err := svc.CollectKeyStats(ctx)
if err != nil {
t.Fatalf("CollectKeyStats failed: %v", err)
}
if len(stats) != 1 {
t.Errorf("Expected 1 stat, got %d", len(stats))
}
if stats[0].KeyType != "zsk" {
t.Errorf("Expected key type 'zsk', got %s", stats[0].KeyType)
}
}

func TestAutomateLifecycle_Rollover(t *testing.T) {
repo := &mockDNSSECRepo{}
svc := NewDNSSECService(repo)
Expand Down
13 changes: 12 additions & 1 deletion internal/dns/server/ratelimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ package server
import (
"container/heap"
"sync"
"sync/atomic"
"time"

"github.com/poyrazK/cloudDNS/internal/infrastructure/metrics"
)

// rateLimiter implements a simple per-IP token bucket with O(1) eviction.
Expand All @@ -14,6 +17,7 @@ type rateLimiter struct {
burst int // max tokens
maxBuckets int // maximum buckets to store (bounds memory)
idleHeap bucketIdleHeap
rateLimited atomic.Uint64
}

type bucket struct {
Expand Down Expand Up @@ -99,6 +103,8 @@ func (rl *rateLimiter) Allow(ip string) bool {
return true
}

rl.rateLimited.Add(1)
metrics.RateLimitedTotal.Inc()
return false
}

Expand Down Expand Up @@ -145,4 +151,9 @@ func (rl *rateLimiter) CleanupLoop(done <-chan struct{}) {
rl.Cleanup()
}
}
}
}

// RateLimited returns the total number of queries rejected by rate limiting.
func (rl *rateLimiter) RateLimited() uint64 {
return rl.rateLimited.Load()
}
19 changes: 19 additions & 0 deletions internal/dns/server/ratelimit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,22 @@ func TestRateLimiter_MaxBuckets(t *testing.T) {
t.Errorf("Should still have 5 buckets after eviction, got %d", bucketCount)
}
}

func TestRateLimiter_RateLimited(t *testing.T) {
rl := newRateLimiter(1.0, 1, 100) // 1 token/sec, burst 1

// First request should succeed (bucket just created with full burst)
if !rl.Allow("192.168.1.1") {
t.Fatal("first request should be allowed")
}

// Exhaust the bucket
if rl.Allow("192.168.1.1") {
t.Fatal("second request should be rate limited")
}

// Now RateLimited() should be > 0
if rl.RateLimited() == 0 {
t.Errorf("expected rate limited count > 0 after exhaustion")
}
}
13 changes: 12 additions & 1 deletion internal/dns/server/recursive.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"time"

"github.com/poyrazK/cloudDNS/internal/dns/packet"
"github.com/poyrazK/cloudDNS/internal/infrastructure/metrics"
)

type recursiveResolver struct {
Expand Down Expand Up @@ -83,6 +84,7 @@ func (s *Server) resolveRecursive(name string, qType packet.QueryType) (*packet.
// Check total resolution timeout
if time.Since(resolveStart) >= recursiveTimeout {
s.Logger.Warn("recursive resolution timed out during root iteration", "name", name)
metrics.RecursiveResolutionsTotal.WithLabelValues("timeout").Inc()
return nil, errors.New(errRecursiveTimeout)
}
rootNS := roots[i]
Expand Down Expand Up @@ -142,11 +144,13 @@ func (s *Server) resolveRecursive(name string, qType packet.QueryType) (*packet.
continue
}

metrics.RecursiveResolutionsTotal.WithLabelValues("success").Inc()
return resp, nil
}

// NXDOMAIN is a definitive answer, so we stop here
if resp.Header.ResCode == 3 {
metrics.RecursiveResolutionsTotal.WithLabelValues("nxdomain").Inc()
return resp, nil
}

Expand All @@ -171,22 +175,29 @@ func (s *Server) resolveRecursive(name string, qType packet.QueryType) (*packet.
// Check total resolution timeout before attempting fallbacks
if time.Since(resolveStart) >= recursiveTimeout {
s.Logger.Warn("recursive resolution timed out before fallback", "name", name)
metrics.RecursiveResolutionsTotal.WithLabelValues("timeout").Inc()
return nil, errors.New(errRecursiveTimeout)
}
s.Logger.Info("iterative resolution failed or inconclusive, trying fallbacks", "name", name)
for _, fallback := range resolver.fallbacks {
// Check total resolution timeout before each fallback query
if time.Since(resolveStart) >= recursiveTimeout {
s.Logger.Warn("recursive resolution timed out during fallback", "name", name)
metrics.RecursiveResolutionsTotal.WithLabelValues("timeout").Inc()
return nil, errors.New(errRecursiveTimeout)
}
serverAddr := net.JoinHostPort(fallback, "53")
// Use sendQueryInternal with RecursionDesired=true for fallbacks
resp, err := s.sendQueryInternal(serverAddr, name, qType, true)
if err == nil && (resp.Header.ResCode == 0 || resp.Header.ResCode == 3) {
if err == nil && resp.Header.ResCode == 0 {
metrics.RecursiveResolutionsTotal.WithLabelValues("success").Inc()
s.Logger.Info("fallback resolution successful", "name", name, "fallback", fallback)
return resp, nil
}
if err == nil && resp.Header.ResCode == 3 {
metrics.RecursiveResolutionsTotal.WithLabelValues("nxdomain").Inc()
return resp, nil
}
if err != nil {
s.Logger.Warn("fallback query failed", "fallback", fallback, "error", err)
}
Expand Down
33 changes: 33 additions & 0 deletions internal/dns/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,29 @@ func (s *Server) automateDNSSEC() {
s.Logger.Error("DNSSEC automation failed for zone", "zone", z.Name, "error", errAutomate)
}
}

// Update DNSSEC key metrics after automation
s.updateDNSSECMetrics(ctx)
}

func (s *Server) updateDNSSECMetrics(ctx context.Context) {
if s.DNSSEC == nil {
return
}
stats, err := s.DNSSEC.CollectKeyStats(ctx)
if err != nil {
s.Logger.Debug("failed to collect DNSSEC key stats", "error", err)
return
}
metrics.DNSSECKeysTotal.Reset()
metrics.DNSSECKeysAgeSeconds.Reset()
signedZones := 0
for _, st := range stats {
metrics.DNSSECKeysTotal.WithLabelValues(st.ZoneName, st.KeyType, fmt.Sprintf("%d", st.Algorithm)).Set(1)
metrics.DNSSECKeysAgeSeconds.WithLabelValues(st.ZoneName, st.KeyType).Set(st.AgeSeconds)
signedZones++
}
metrics.DNSSECZonesSigned.Set(float64(signedZones))
}
Comment on lines +259 to 277
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

updateDNSSECMetrics has three correctness bugs.

  1. DNSSECZonesSigned counts keys, not zones. signedZones++ runs once per key, so a zone with both KSK and ZSK (or multiple algorithms / mid-rollover) inflates the gauge. Track unique zone names (or zone IDs) in a map[string]struct{} and report len(set).
  2. DNSSECKeysTotal.Set(1) cannot represent rollover. During rollover a zone legitimately has two active keys of the same (zone, key_type, algorithm) tuple, but Set(1) keeps the gauge at 1 and the second key disappears from the metric. Use Inc() after Reset() so each label combination accumulates the actual count.
  3. DNSSECKeysAgeSeconds.Set overwrites siblings. With overlap rollover you'll have two keys sharing (zone, key_type); the second Set discards the first key's age. Either expand the label set with the key tag/ID, or aggregate (e.g. report only the oldest active key's age).
♻️ Proposed fix
 	metrics.DNSSECKeysTotal.Reset()
 	metrics.DNSSECKeysAgeSeconds.Reset()
-	signedZones := 0
+	zoneSet := make(map[string]struct{}, len(stats))
+	oldestAge := make(map[string]float64) // key: zone|key_type
 	for _, st := range stats {
-		metrics.DNSSECKeysTotal.WithLabelValues(st.ZoneName, st.KeyType, fmt.Sprintf("%d", st.Algorithm)).Set(1)
-		metrics.DNSSECKeysAgeSeconds.WithLabelValues(st.ZoneName, st.KeyType).Set(st.AgeSeconds)
-		signedZones++
+		metrics.DNSSECKeysTotal.WithLabelValues(st.ZoneName, st.KeyType, fmt.Sprintf("%d", st.Algorithm)).Inc()
+		k := st.ZoneName + "|" + st.KeyType
+		if cur, ok := oldestAge[k]; !ok || st.AgeSeconds > cur {
+			oldestAge[k] = st.AgeSeconds
+		}
+		zoneSet[st.ZoneID] = struct{}{}
 	}
-	metrics.DNSSECZonesSigned.Set(float64(signedZones))
+	for k, age := range oldestAge {
+		parts := strings.SplitN(k, "|", 2)
+		metrics.DNSSECKeysAgeSeconds.WithLabelValues(parts[0], parts[1]).Set(age)
+	}
+	metrics.DNSSECZonesSigned.Set(float64(len(zoneSet)))
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@internal/dns/server/server.go` around lines 254 - 271, updateDNSSECMetrics is
wrong: it increments signedZones per key, hides multi-key rollover counts, and
overwrites ages. Fix by (1) tracking unique zones with a map[string]struct{}
(e.g., seenZones) and after the loop set metrics.DNSSECZonesSigned to
len(seenZones); (2) after metrics.DNSSECKeysTotal.Reset() call
metrics.DNSSECKeysTotal.WithLabelValues(st.ZoneName, st.KeyType,
fmt.Sprintf("%d", st.Algorithm)).Inc() for each stat instead of Set(1) so
duplicate label tuples accumulate counts; and (3) avoid overwriting ages by
either adding a unique key identifier label (e.g., st.KeyTag or st.KeyID) to
metrics.DNSSECKeysAgeSeconds.WithLabelValues(...) before Set(st.AgeSeconds) or
by pre-aggregating (choose the oldest/desired age per (ZoneName,KeyType) in a
map and call Set once per group). Apply these changes inside updateDNSSECMetrics
using the existing stats elements returned by s.DNSSEC.CollectKeyStats.


// startInvalidationListener listens for cache invalidation events from Redis pub/sub.
Expand Down Expand Up @@ -821,6 +844,7 @@ func (s *Server) sendAXFRRecord(conn net.Conn, id uint16, q packet.DNSQuestion,
}
s.Logger.Debug("AXFR sent packet", "index", index, "type", pRec.Type)
packet.PutBuffer(resBuffer)
metrics.AXFRBytesTotal.Add(float64(len(fullResp)))
}

// sendTCPError sends a TCP DNS error response with the given RCODE.
Expand Down Expand Up @@ -944,17 +968,20 @@ func (s *Server) handlePacket(ctx context.Context, data []byte, srcAddr interfac

if data, found := s.Cache.GetInto(cacheKey, request.Header.ID); found {
metrics.CacheOperations.WithLabelValues("l1", "hit").Inc()
metrics.RecordCacheHit()
metrics.QueriesTotal.WithLabelValues(qTypeLabel, "0", protocol).Inc()
metrics.QueryDuration.WithLabelValues("cache_l1").Observe(time.Since(start).Seconds())
err := sendFn(data)
lock.Unlock()
return err
}
metrics.CacheOperations.WithLabelValues("l1", "miss").Inc()
metrics.RecordCacheMiss()

if s.Redis != nil {
if data, remainingTTL, found := s.Redis.GetWithTTL(ctx, cacheKey); found {
metrics.CacheOperations.WithLabelValues("l2", "hit").Inc()
metrics.RecordCacheHit()
metrics.QueriesTotal.WithLabelValues(qTypeLabel, "0", protocol).Inc()
metrics.QueryDuration.WithLabelValues("cache_l2").Observe(time.Since(start).Seconds())
// Rewrite Transaction ID (data is a copy from Redis, safe to mutate)
Expand All @@ -969,6 +996,10 @@ func (s *Server) handlePacket(ctx context.Context, data []byte, srcAddr interfac
}
s.Cache.Set(cacheKey, data, remainingTTL)
cachedData = data
} else if s.Redis != nil {
// Redis was checked but key not found = L2 miss
metrics.CacheOperations.WithLabelValues("l2", "miss").Inc()
metrics.RecordCacheMiss()
}
}

Expand Down Expand Up @@ -1069,6 +1100,7 @@ func (s *Server) handlePacket(ctx context.Context, data []byte, srcAddr interfac
qTypeStr := queryTypeToRecordType(q.QType)
records, errRepo := s.Repo.GetRecords(ctx, q.Name, qTypeStr, clientIP)
metrics.QueryDuration.WithLabelValues("database").Observe(time.Since(dbStart).Seconds())
metrics.RecordCacheMiss() // DB lookup is a cache miss for ratio purposes

if errRepo == nil && len(records) > 0 {
for _, rec := range records {
Expand Down Expand Up @@ -1303,6 +1335,7 @@ func (s *Server) handleNotify(ctx context.Context, request *packet.DNSPacket, cl
return nil
}
s.Logger.Info("received NOTIFY", "zone", request.Questions[0].Name, "from", clientIP)
metrics.NotifiesTotal.WithLabelValues(request.Questions[0].Name, "accepted").Inc()

response := packet.NewDNSPacket()
response.Header.ID = request.Header.ID
Expand Down
Loading
Loading