diff --git a/cmd/clouddns/main.go b/cmd/clouddns/main.go index 7fefcb7..8809647 100644 --- a/cmd/clouddns/main.go +++ b/cmd/clouddns/main.go @@ -178,6 +178,21 @@ func run(ctx context.Context) error { } } }() + + // Periodic zone/record count metrics + var metricsCollector *metrics.DerivedMetricCollector + go func() { + interval := 30 * time.Second + if os.Getenv("TEST_MODE") == "true" { + interval = 10 * time.Millisecond + } + counter := metrics.NewZoneRecordCounter(repo, interval) + counter.Start(runCtx) + metricsCollector = metrics.NewDerivedMetricCollector(interval) + <-runCtx.Done() + counter.Stop() + metricsCollector.Stop() + }() } var cacheInvalidator ports.CacheInvalidator diff --git a/internal/core/services/dnssec_service.go b/internal/core/services/dnssec_service.go index f0f275a..d2e24dd 100644 --- a/internal/core/services/dnssec_service.go +++ b/internal/core/services/dnssec_service.go @@ -8,7 +8,9 @@ import ( "crypto/rand" "crypto/x509" "fmt" + "log/slog" "math" + "strings" "time" "github.com/google/uuid" @@ -19,12 +21,13 @@ import ( // DNSSECService provides functionality for managing DNSSEC keys and signing RRsets. type DNSSECService struct { - repo ports.DNSRepository + repo ports.DNSRepository + logger *slog.Logger } // NewDNSSECService creates and returns a new DNSSECService instance. func NewDNSSECService(repo ports.DNSRepository) *DNSSECService { - return &DNSSECService{repo: repo} + return &DNSSECService{repo: repo, logger: slog.Default()} } // GenerateKey creates a new ECDSA P-256 key pair for a zone @@ -199,3 +202,44 @@ func (s *DNSSECService) SignRRSet(ctx context.Context, zoneName string, zoneID s return sigs, nil } + +// KeyStats holds DNSSEC key statistics for metrics. +type KeyStats struct { + ZoneID string + ZoneName string + KeyType string + Algorithm int + AgeSeconds float64 +} + +// CollectKeyStats returns statistics for all active DNSSEC keys. +// Used by the metrics collector to update DNSSEC key age metrics. +func (s *DNSSECService) CollectKeyStats(ctx context.Context) ([]KeyStats, error) { + zones, err := s.repo.ListZones(ctx, "") + if err != nil { + return nil, err + } + + var stats []KeyStats + now := time.Now() + for _, zone := range zones { + keys, err := s.repo.ListKeysForZone(ctx, zone.ID) + if err != nil { + s.logger.Debug("failed to list keys for zone", "zone", zone.Name, "error", err) + continue + } + for _, k := range keys { + if !k.Active { + continue + } + stats = append(stats, KeyStats{ + ZoneID: zone.ID, + ZoneName: zone.Name, + KeyType: strings.ToLower(k.KeyType), + Algorithm: k.Algorithm, + AgeSeconds: now.Sub(k.CreatedAt).Seconds(), + }) + } + } + return stats, nil +} diff --git a/internal/core/services/dnssec_service_test.go b/internal/core/services/dnssec_service_test.go index 0ee4826..4afdaa8 100644 --- a/internal/core/services/dnssec_service_test.go +++ b/internal/core/services/dnssec_service_test.go @@ -13,8 +13,11 @@ import ( ) type mockDNSSECRepo struct { - keys []domain.DNSSECKey - err error + keys []domain.DNSSECKey + zones []domain.Zone + err error + keysErr error + listErr error } func (m *mockDNSSECRepo) GetRecords(_ context.Context, _ string, _ domain.RecordType, _ string) ([]domain.Record, error) { @@ -44,7 +47,10 @@ func (m *mockDNSSECRepo) CreateZoneWithRecords(_ context.Context, _ *domain.Zone func (m *mockDNSSECRepo) CreateRecord(_ context.Context, _ *domain.Record) error { return nil } func (m *mockDNSSECRepo) BatchCreateRecords(_ context.Context, _ []domain.Record) error { return nil } func (m *mockDNSSECRepo) ListZones(_ context.Context, _ string) ([]domain.Zone, error) { - return nil, nil + if m.listErr != nil { + return nil, m.listErr + } + return m.zones, nil } func (m *mockDNSSECRepo) DeleteZone(_ context.Context, _, _ string) error { return nil } func (m *mockDNSSECRepo) DeleteRecord(_ context.Context, _, _, _ string) error { return nil } @@ -90,8 +96,8 @@ func (m *mockDNSSECRepo) CreateKey(_ context.Context, key *domain.DNSSECKey) err } func (m *mockDNSSECRepo) ListKeysForZone(_ context.Context, zoneID string) ([]domain.DNSSECKey, error) { - if m.err != nil { - return nil, m.err + if m.keysErr != nil { + return nil, m.keysErr } var result []domain.DNSSECKey for _, k := range m.keys { @@ -292,6 +298,55 @@ func TestSignRRSet(t *testing.T) { } } +// TestCollectKeyStats_AllZonesFail verifies that CollectKeyStats returns an +// empty slice (not an error) when ListZones succeeds but all ListKeysForZone +// calls fail. This is the "all-zones-fail" edge case. +func TestCollectKeyStats_AllZonesFail(t *testing.T) { + repo := &mockDNSSECRepo{ + zones: []domain.Zone{ + {ID: "z1", Name: "example.com."}, + {ID: "z2", Name: "test.com."}, + }, + keysErr: errors.New("db error on ListKeysForZone"), + } + svc := NewDNSSECService(repo) + ctx := context.Background() + + stats, err := svc.CollectKeyStats(ctx) + if err != nil { + t.Fatalf("CollectKeyStats should not return error on ListKeysForZone failure, got: %v", err) + } + if len(stats) != 0 { + t.Errorf("Expected empty stats slice when all zones fail, got %d", len(stats)) + } +} + +// TestCollectKeyStats_Normal verifies CollectKeyStats returns correct stats +// when keys exist for zones. +func TestCollectKeyStats_Normal(t *testing.T) { + repo := &mockDNSSECRepo{ + zones: []domain.Zone{ + {ID: "z1", Name: "example.com."}, + }, + keys: []domain.DNSSECKey{ + {ID: "k1", ZoneID: "z1", KeyType: "ZSK", Active: true, Algorithm: 13, CreatedAt: time.Now()}, + }, + } + svc := NewDNSSECService(repo) + ctx := context.Background() + + stats, err := svc.CollectKeyStats(ctx) + if err != nil { + t.Fatalf("CollectKeyStats failed: %v", err) + } + if len(stats) != 1 { + t.Errorf("Expected 1 stat, got %d", len(stats)) + } + if stats[0].KeyType != "zsk" { + t.Errorf("Expected key type 'zsk', got %s", stats[0].KeyType) + } +} + func TestAutomateLifecycle_Rollover(t *testing.T) { repo := &mockDNSSECRepo{} svc := NewDNSSECService(repo) diff --git a/internal/dns/server/ratelimit.go b/internal/dns/server/ratelimit.go index 63f6792..54beca2 100644 --- a/internal/dns/server/ratelimit.go +++ b/internal/dns/server/ratelimit.go @@ -3,7 +3,10 @@ package server import ( "container/heap" "sync" + "sync/atomic" "time" + + "github.com/poyrazK/cloudDNS/internal/infrastructure/metrics" ) // rateLimiter implements a simple per-IP token bucket with O(1) eviction. @@ -14,6 +17,7 @@ type rateLimiter struct { burst int // max tokens maxBuckets int // maximum buckets to store (bounds memory) idleHeap bucketIdleHeap + rateLimited atomic.Uint64 } type bucket struct { @@ -99,6 +103,8 @@ func (rl *rateLimiter) Allow(ip string) bool { return true } + rl.rateLimited.Add(1) + metrics.RateLimitedTotal.Inc() return false } @@ -145,4 +151,9 @@ func (rl *rateLimiter) CleanupLoop(done <-chan struct{}) { rl.Cleanup() } } -} \ No newline at end of file +} + +// RateLimited returns the total number of queries rejected by rate limiting. +func (rl *rateLimiter) RateLimited() uint64 { + return rl.rateLimited.Load() +} diff --git a/internal/dns/server/ratelimit_test.go b/internal/dns/server/ratelimit_test.go index ac6c0a2..0d8daaa 100644 --- a/internal/dns/server/ratelimit_test.go +++ b/internal/dns/server/ratelimit_test.go @@ -111,3 +111,22 @@ func TestRateLimiter_MaxBuckets(t *testing.T) { t.Errorf("Should still have 5 buckets after eviction, got %d", bucketCount) } } + +func TestRateLimiter_RateLimited(t *testing.T) { + rl := newRateLimiter(1.0, 1, 100) // 1 token/sec, burst 1 + + // First request should succeed (bucket just created with full burst) + if !rl.Allow("192.168.1.1") { + t.Fatal("first request should be allowed") + } + + // Exhaust the bucket + if rl.Allow("192.168.1.1") { + t.Fatal("second request should be rate limited") + } + + // Now RateLimited() should be > 0 + if rl.RateLimited() == 0 { + t.Errorf("expected rate limited count > 0 after exhaustion") + } +} diff --git a/internal/dns/server/recursive.go b/internal/dns/server/recursive.go index 76f188e..ba51875 100644 --- a/internal/dns/server/recursive.go +++ b/internal/dns/server/recursive.go @@ -11,6 +11,7 @@ import ( "time" "github.com/poyrazK/cloudDNS/internal/dns/packet" + "github.com/poyrazK/cloudDNS/internal/infrastructure/metrics" ) type recursiveResolver struct { @@ -83,6 +84,7 @@ func (s *Server) resolveRecursive(name string, qType packet.QueryType) (*packet. // Check total resolution timeout if time.Since(resolveStart) >= recursiveTimeout { s.Logger.Warn("recursive resolution timed out during root iteration", "name", name) + metrics.RecursiveResolutionsTotal.WithLabelValues("timeout").Inc() return nil, errors.New(errRecursiveTimeout) } rootNS := roots[i] @@ -142,11 +144,13 @@ func (s *Server) resolveRecursive(name string, qType packet.QueryType) (*packet. continue } + metrics.RecursiveResolutionsTotal.WithLabelValues("success").Inc() return resp, nil } // NXDOMAIN is a definitive answer, so we stop here if resp.Header.ResCode == 3 { + metrics.RecursiveResolutionsTotal.WithLabelValues("nxdomain").Inc() return resp, nil } @@ -171,6 +175,7 @@ func (s *Server) resolveRecursive(name string, qType packet.QueryType) (*packet. // Check total resolution timeout before attempting fallbacks if time.Since(resolveStart) >= recursiveTimeout { s.Logger.Warn("recursive resolution timed out before fallback", "name", name) + metrics.RecursiveResolutionsTotal.WithLabelValues("timeout").Inc() return nil, errors.New(errRecursiveTimeout) } s.Logger.Info("iterative resolution failed or inconclusive, trying fallbacks", "name", name) @@ -178,15 +183,21 @@ func (s *Server) resolveRecursive(name string, qType packet.QueryType) (*packet. // Check total resolution timeout before each fallback query if time.Since(resolveStart) >= recursiveTimeout { s.Logger.Warn("recursive resolution timed out during fallback", "name", name) + metrics.RecursiveResolutionsTotal.WithLabelValues("timeout").Inc() return nil, errors.New(errRecursiveTimeout) } serverAddr := net.JoinHostPort(fallback, "53") // Use sendQueryInternal with RecursionDesired=true for fallbacks resp, err := s.sendQueryInternal(serverAddr, name, qType, true) - if err == nil && (resp.Header.ResCode == 0 || resp.Header.ResCode == 3) { + if err == nil && resp.Header.ResCode == 0 { + metrics.RecursiveResolutionsTotal.WithLabelValues("success").Inc() s.Logger.Info("fallback resolution successful", "name", name, "fallback", fallback) return resp, nil } + if err == nil && resp.Header.ResCode == 3 { + metrics.RecursiveResolutionsTotal.WithLabelValues("nxdomain").Inc() + return resp, nil + } if err != nil { s.Logger.Warn("fallback query failed", "fallback", fallback, "error", err) } diff --git a/internal/dns/server/server.go b/internal/dns/server/server.go index 327c790..68138c0 100644 --- a/internal/dns/server/server.go +++ b/internal/dns/server/server.go @@ -251,6 +251,29 @@ func (s *Server) automateDNSSEC() { s.Logger.Error("DNSSEC automation failed for zone", "zone", z.Name, "error", errAutomate) } } + + // Update DNSSEC key metrics after automation + s.updateDNSSECMetrics(ctx) +} + +func (s *Server) updateDNSSECMetrics(ctx context.Context) { + if s.DNSSEC == nil { + return + } + stats, err := s.DNSSEC.CollectKeyStats(ctx) + if err != nil { + s.Logger.Debug("failed to collect DNSSEC key stats", "error", err) + return + } + metrics.DNSSECKeysTotal.Reset() + metrics.DNSSECKeysAgeSeconds.Reset() + signedZones := 0 + for _, st := range stats { + metrics.DNSSECKeysTotal.WithLabelValues(st.ZoneName, st.KeyType, fmt.Sprintf("%d", st.Algorithm)).Set(1) + metrics.DNSSECKeysAgeSeconds.WithLabelValues(st.ZoneName, st.KeyType).Set(st.AgeSeconds) + signedZones++ + } + metrics.DNSSECZonesSigned.Set(float64(signedZones)) } // startInvalidationListener listens for cache invalidation events from Redis pub/sub. @@ -821,6 +844,7 @@ func (s *Server) sendAXFRRecord(conn net.Conn, id uint16, q packet.DNSQuestion, } s.Logger.Debug("AXFR sent packet", "index", index, "type", pRec.Type) packet.PutBuffer(resBuffer) + metrics.AXFRBytesTotal.Add(float64(len(fullResp))) } // sendTCPError sends a TCP DNS error response with the given RCODE. @@ -944,6 +968,7 @@ func (s *Server) handlePacket(ctx context.Context, data []byte, srcAddr interfac if data, found := s.Cache.GetInto(cacheKey, request.Header.ID); found { metrics.CacheOperations.WithLabelValues("l1", "hit").Inc() + metrics.RecordCacheHit() metrics.QueriesTotal.WithLabelValues(qTypeLabel, "0", protocol).Inc() metrics.QueryDuration.WithLabelValues("cache_l1").Observe(time.Since(start).Seconds()) err := sendFn(data) @@ -951,10 +976,12 @@ func (s *Server) handlePacket(ctx context.Context, data []byte, srcAddr interfac return err } metrics.CacheOperations.WithLabelValues("l1", "miss").Inc() + metrics.RecordCacheMiss() if s.Redis != nil { if data, remainingTTL, found := s.Redis.GetWithTTL(ctx, cacheKey); found { metrics.CacheOperations.WithLabelValues("l2", "hit").Inc() + metrics.RecordCacheHit() metrics.QueriesTotal.WithLabelValues(qTypeLabel, "0", protocol).Inc() metrics.QueryDuration.WithLabelValues("cache_l2").Observe(time.Since(start).Seconds()) // Rewrite Transaction ID (data is a copy from Redis, safe to mutate) @@ -969,6 +996,10 @@ func (s *Server) handlePacket(ctx context.Context, data []byte, srcAddr interfac } s.Cache.Set(cacheKey, data, remainingTTL) cachedData = data + } else if s.Redis != nil { + // Redis was checked but key not found = L2 miss + metrics.CacheOperations.WithLabelValues("l2", "miss").Inc() + metrics.RecordCacheMiss() } } @@ -1069,6 +1100,7 @@ func (s *Server) handlePacket(ctx context.Context, data []byte, srcAddr interfac qTypeStr := queryTypeToRecordType(q.QType) records, errRepo := s.Repo.GetRecords(ctx, q.Name, qTypeStr, clientIP) metrics.QueryDuration.WithLabelValues("database").Observe(time.Since(dbStart).Seconds()) + metrics.RecordCacheMiss() // DB lookup is a cache miss for ratio purposes if errRepo == nil && len(records) > 0 { for _, rec := range records { @@ -1303,6 +1335,7 @@ func (s *Server) handleNotify(ctx context.Context, request *packet.DNSPacket, cl return nil } s.Logger.Info("received NOTIFY", "zone", request.Questions[0].Name, "from", clientIP) + metrics.NotifiesTotal.WithLabelValues(request.Questions[0].Name, "accepted").Inc() response := packet.NewDNSPacket() response.Header.ID = request.Header.ID diff --git a/internal/infrastructure/metrics/metrics.go b/internal/infrastructure/metrics/metrics.go index d3b0339..24ba8c4 100644 --- a/internal/infrastructure/metrics/metrics.go +++ b/internal/infrastructure/metrics/metrics.go @@ -3,10 +3,38 @@ package metrics import ( + "context" + "sync" + "sync/atomic" + "time" + + "github.com/poyrazK/cloudDNS/internal/core/domain" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" ) +// DNSBuckets are appropriate for DNS query latencies ranging from +// sub-millisecond (local cache) to seconds (AXFR/recursive lookups). +var DNSBuckets = []float64{ + 0.00005, // 50µs — L1 cache hit + 0.0001, // 100µs — L2 cache hit + 0.00025, // 250µs — fast DB + 0.0005, // 500µs — typical DB + 0.001, // 1ms + 0.0025, // 2.5ms + 0.005, // 5ms + 0.01, // 10ms — network latency + 0.025, // 25ms + 0.05, // 50ms — slow network + 0.1, // 100ms + 0.25, // 250ms + 0.5, // 500ms + 1.0, // 1s + 2.5, // 2.5s — AXFR timeout + 5.0, // 5s + 10.0, // 10s — slow AXFR +} + var ( // QueriesTotal tracks total DNS queries processed QueriesTotal *prometheus.CounterVec @@ -20,8 +48,96 @@ var ( DBConnectionsActive prometheus.Gauge // BGPAnnounced indicates if the node is currently announcing routes via BGP BGPAnnounced prometheus.Gauge + + // cacheMissCount and cacheHitCount track raw counts for ratio computation. + // Thread-safe via atomic operations. + cacheHitCount atomic.Uint64 + cacheMissCount atomic.Uint64 + + // DNSSECKeysTotal tracks DNSSEC keys by zone, key type, and algorithm + DNSSECKeysTotal *prometheus.GaugeVec + // DNSSECKeysAgeSeconds tracks age of active signing keys + DNSSECKeysAgeSeconds *prometheus.GaugeVec + // DNSSECZonesSigned tracks number of zones with DNSSEC active + DNSSECZonesSigned prometheus.Gauge + + // NotifiesTotal tracks incoming NOTIFY requests + NotifiesTotal *prometheus.CounterVec + // RateLimitedTotal tracks queries rejected by rate limiter + RateLimitedTotal prometheus.Counter + // AXFRBytesTotal tracks bytes transferred via zone transfers + AXFRBytesTotal prometheus.Counter + // RecursiveResolutionsTotal tracks recursive resolution outcomes + RecursiveResolutionsTotal *prometheus.CounterVec + + // ZonesTotal tracks total hosted zones + ZonesTotal prometheus.Gauge + // RecordsTotal tracks total records across all zones + RecordsTotal prometheus.Gauge + // CacheHitRatio tracks the L1 cache hit ratio (computed periodically) + CacheHitRatio prometheus.Gauge ) +// DerivedMetricCollector periodically computes derived metrics (e.g., cache hit ratio) +// to avoid per-query overhead. +type DerivedMetricCollector struct { + interval time.Duration + stopCh chan struct{} + doneCh chan struct{} +} + +// NewDerivedMetricCollector creates a collector that updates derived metrics at the given interval. +func NewDerivedMetricCollector(interval time.Duration) *DerivedMetricCollector { + c := &DerivedMetricCollector{ + interval: interval, + stopCh: make(chan struct{}), + doneCh: make(chan struct{}), + } + go c.run() + return c +} + +func (c *DerivedMetricCollector) run() { + defer close(c.doneCh) + ticker := time.NewTicker(c.interval) + defer ticker.Stop() + for { + select { + case <-c.stopCh: + return + case <-ticker.C: + c.compute() + } + } +} + +func (c *DerivedMetricCollector) compute() { + hits := cacheHitCount.Load() + misses := cacheMissCount.Load() + total := hits + misses + if total > 0 { + CacheHitRatio.Set(float64(hits) / float64(total)) + } +} + +// Stop gracefully stops the collector goroutine. +func (c *DerivedMetricCollector) Stop() { + close(c.stopCh) + <-c.doneCh +} + +// RecordCacheHit records a cache hit for derived metric computation. +// Thread-safe via atomic operations. +func RecordCacheHit() { + cacheHitCount.Add(1) +} + +// RecordCacheMiss records a cache miss for derived metric computation. +// Thread-safe via atomic operations. +func RecordCacheMiss() { + cacheMissCount.Add(1) +} + func init() { // QueriesTotal tracks total DNS queries processed QueriesTotal = promauto.NewCounterVec(prometheus.CounterOpts{ @@ -29,11 +145,11 @@ func init() { Help: "Total number of DNS queries processed", }, []string{"qtype", "rcode", "protocol"}) - // QueryDuration tracks query processing time + // QueryDuration tracks query processing time (now with DNS-appropriate buckets) QueryDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{ Name: "clouddns_query_duration_seconds", Help: "Histogram of query processing duration", - Buckets: prometheus.DefBuckets, + Buckets: DNSBuckets, }, []string{"source"}) // CacheOperations tracks L1/L2 cache hits and misses @@ -59,4 +175,135 @@ func init() { Name: "clouddns_bgp_announced", Help: "Binary indicator of BGP announcement status (1 = announcing, 0 = withdrawn)", }) + + // CacheHitRatio tracks the computed cache hit ratio + CacheHitRatio = promauto.NewGauge(prometheus.GaugeOpts{ + Name: "clouddns_cache_hit_ratio", + Help: "L1 cache hit ratio (hits / total cache operations), computed every 30s", + }) + + // DNSSECKeysTotal tracks DNSSEC keys by zone, key type, and algorithm + DNSSECKeysTotal = promauto.NewGaugeVec(prometheus.GaugeOpts{ + Name: "clouddns_dnssec_keys_total", + Help: "Total number of active DNSSEC keys", + }, []string{"zone", "key_type", "algorithm"}) + + // DNSSECKeysAgeSeconds tracks age of active signing keys + DNSSECKeysAgeSeconds = promauto.NewGaugeVec(prometheus.GaugeOpts{ + Name: "clouddns_dnssec_keys_age_seconds", + Help: "Age of active DNSSEC signing keys in seconds", + }, []string{"zone", "key_type"}) + + // DNSSECZonesSigned tracks number of zones with DNSSEC active + DNSSECZonesSigned = promauto.NewGauge(prometheus.GaugeOpts{ + Name: "clouddns_zones_signed", + Help: "Number of zones with DNSSEC enabled and active", + }) + + // NotifiesTotal tracks incoming NOTIFY requests + NotifiesTotal = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "clouddns_notifies_total", + Help: "Total number of NOTIFY messages received", + }, []string{"zone", "result"}) + + // RateLimitedTotal tracks queries rejected by rate limiter + RateLimitedTotal = promauto.NewCounter(prometheus.CounterOpts{ + Name: "clouddns_rate_limited_total", + Help: "Total number of queries rejected by rate limiter", + }) + + // AXFRBytesTotal tracks bytes transferred via zone transfers + AXFRBytesTotal = promauto.NewCounter(prometheus.CounterOpts{ + Name: "clouddns_axfr_bytes_total", + Help: "Total bytes transferred via AXFR/IXFR", + }) + + // RecursiveResolutionsTotal tracks recursive resolution outcomes + RecursiveResolutionsTotal = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "clouddns_recursive_resolutions_total", + Help: "Total number of recursive resolution outcomes", + }, []string{"result"}) + + // ZonesTotal tracks total hosted zones + ZonesTotal = promauto.NewGauge(prometheus.GaugeOpts{ + Name: "clouddns_zones_total", + Help: "Total number of hosted zones", + }) + + // RecordsTotal tracks total records across all zones + RecordsTotal = promauto.NewGauge(prometheus.GaugeOpts{ + Name: "clouddns_records_total", + Help: "Total number of records across all zones", + }) +} + +// ZoneRecordCounter provides a way to update zone/record count metrics periodically. +type ZoneRecordCounter struct { + repo ZoneRecordRepo + interval time.Duration + stopCh chan struct{} + doneCh chan struct{} + wg sync.WaitGroup +} + +// ZoneRecordRepo is the interface for fetching zone and record counts. +type ZoneRecordRepo interface { + ListZones(ctx context.Context, tenantID string) ([]domain.Zone, error) + ListRecordsForZone(ctx context.Context, zoneID string, tenantID string) ([]domain.Record, error) +} + +// NewZoneRecordCounter creates a counter that updates zone/record metrics periodically. +func NewZoneRecordCounter(repo ZoneRecordRepo, interval time.Duration) *ZoneRecordCounter { + return &ZoneRecordCounter{ + repo: repo, + interval: interval, + stopCh: make(chan struct{}), + doneCh: make(chan struct{}), + } +} + +// Start begins the periodic collection goroutine. +// The provided ctx is used as the parent context for cancellation. +func (c *ZoneRecordCounter) Start(ctx context.Context) { + c.wg.Add(1) + go func() { + defer c.wg.Done() + ticker := time.NewTicker(c.interval) + defer ticker.Stop() + // Run once immediately + c.collect(ctx) + for { + select { + case <-c.stopCh: + return + case <-ticker.C: + c.collect(ctx) + } + } + }() +} + +// Stop gracefully stops the collector. +func (c *ZoneRecordCounter) Stop() { + close(c.stopCh) + c.wg.Wait() + close(c.doneCh) +} + +func (c *ZoneRecordCounter) collect(ctx context.Context) { + zones, err := c.repo.ListZones(ctx, "") + if err != nil { + return + } + ZonesTotal.Set(float64(len(zones))) + + // Count records across all zones + var totalRecords int + for _, z := range zones { + records, err := c.repo.ListRecordsForZone(ctx, z.ID, "") + if err == nil { + totalRecords += len(records) + } + } + RecordsTotal.Set(float64(totalRecords)) } diff --git a/internal/infrastructure/metrics/metrics_test.go b/internal/infrastructure/metrics/metrics_test.go index 7354ab7..6b645d9 100644 --- a/internal/infrastructure/metrics/metrics_test.go +++ b/internal/infrastructure/metrics/metrics_test.go @@ -1,7 +1,11 @@ package metrics import ( + "context" "testing" + "time" + + "github.com/poyrazK/cloudDNS/internal/core/domain" ) func TestMetricsDeclarations(t *testing.T) { @@ -9,12 +13,24 @@ func TestMetricsDeclarations(t *testing.T) { name string metric interface{} }{ + // Original {"QueriesTotal", QueriesTotal}, {"QueryDuration", QueryDuration}, {"CacheOperations", CacheOperations}, {"ActiveWorkers", ActiveWorkers}, {"DBConnectionsActive", DBConnectionsActive}, {"BGPAnnounced", BGPAnnounced}, + // New + {"CacheHitRatio", CacheHitRatio}, + {"DNSSECKeysTotal", DNSSECKeysTotal}, + {"DNSSECKeysAgeSeconds", DNSSECKeysAgeSeconds}, + {"DNSSECZonesSigned", DNSSECZonesSigned}, + {"NotifiesTotal", NotifiesTotal}, + {"RateLimitedTotal", RateLimitedTotal}, + {"AXFRBytesTotal", AXFRBytesTotal}, + {"RecursiveResolutionsTotal", RecursiveResolutionsTotal}, + {"ZonesTotal", ZonesTotal}, + {"RecordsTotal", RecordsTotal}, } for _, tt := range tests { @@ -25,3 +41,150 @@ func TestMetricsDeclarations(t *testing.T) { }) } } + +func TestDNSBuckets(t *testing.T) { + if len(DNSBuckets) == 0 { + t.Fatal("DNSBuckets is empty") + } + + // Verify minimum is sub-millisecond (50µs) + if DNSBuckets[0] != 0.00005 { + t.Errorf("DNSBuckets[0] = %v, want 0.00005 (50µs)", DNSBuckets[0]) + } + + // Verify maximum covers slow AXFR (10s) + if DNSBuckets[len(DNSBuckets)-1] != 10.0 { + t.Errorf("DNSBuckets last = %v, want 10.0", DNSBuckets[len(DNSBuckets)-1]) + } + + // Verify buckets are monotonically increasing + for i := 1; i < len(DNSBuckets); i++ { + if DNSBuckets[i] <= DNSBuckets[i-1] { + t.Errorf("DNSBuckets not monotonically increasing at index %d: %v <= %v", i, DNSBuckets[i], DNSBuckets[i-1]) + } + } +} + +func TestRecordCacheHitMiss(t *testing.T) { + // Save original values + origHit := cacheHitCount.Load() + origMiss := cacheMissCount.Load() + + // Reset for test + cacheHitCount.Store(0) + cacheMissCount.Store(0) + + RecordCacheHit() + RecordCacheHit() + RecordCacheMiss() + + if cacheHitCount.Load() != 2 { + t.Errorf("cacheHitCount = %d, want 2", cacheHitCount.Load()) + } + if cacheMissCount.Load() != 1 { + t.Errorf("cacheMissCount = %d, want 1", cacheMissCount.Load()) + } + + // Restore + cacheHitCount.Store(origHit) + cacheMissCount.Store(origMiss) +} + +func TestDerivedMetricCollector(t *testing.T) { + // Save and reset + origHit := cacheHitCount.Load() + origMiss := cacheMissCount.Load() + cacheHitCount.Store(0) + cacheMissCount.Store(0) + + // Simulate 80% hit rate + cacheHitCount.Store(80) + cacheMissCount.Store(20) + + collector := NewDerivedMetricCollector(50 * time.Millisecond) + collector.compute() + + // Restore + cacheHitCount.Store(origHit) + cacheMissCount.Store(origMiss) + collector.Stop() +} + +func TestDerivedMetricCollector_Stop(t *testing.T) { + collector := NewDerivedMetricCollector(time.Hour) + collector.Stop() + // Should not hang or panic +} + +// mockZoneRecordRepo is a mock implementation of ZoneRecordRepo for testing. +type mockZoneRecordRepo struct { + zones []domain.Zone + records map[string][]domain.Record // keyed by zoneID +} + +func (m *mockZoneRecordRepo) ListZones(_ context.Context, _ string) ([]domain.Zone, error) { + return m.zones, nil +} + +func (m *mockZoneRecordRepo) ListRecordsForZone(_ context.Context, zoneID string, _ string) ([]domain.Record, error) { + if recs, ok := m.records[zoneID]; ok { + return recs, nil + } + return nil, nil +} + +func TestZoneRecordCounter(t *testing.T) { + repo := &mockZoneRecordRepo{ + zones: []domain.Zone{ + {ID: "z1", Name: "example.com."}, + {ID: "z2", Name: "test.com."}, + }, + records: map[string][]domain.Record{ + "z1": { + {ID: "r1", ZoneID: "z1", Name: "www.example.com.", Type: "A"}, + {ID: "r2", ZoneID: "z1", Name: "www.example.com.", Type: "AAAA"}, + }, + "z2": { + {ID: "r3", ZoneID: "z2", Name: "test.com.", Type: "MX"}, + }, + }, + } + + counter := NewZoneRecordCounter(repo, 50*time.Millisecond) + ctx := context.Background() + + counter.Start(ctx) + + // Let it collect at least once + time.Sleep(100 * time.Millisecond) + counter.Stop() + // Should not hang or panic +} + +func TestZoneRecordCounter_EmptyZones(t *testing.T) { + repo := &mockZoneRecordRepo{ + zones: []domain.Zone{}, + records: map[string][]domain.Record{}, + } + + counter := NewZoneRecordCounter(repo, 50*time.Millisecond) + ctx := context.Background() + + counter.Start(ctx) + time.Sleep(100 * time.Millisecond) + counter.Stop() +} + +func TestZoneRecordCounter_ZonesWithNoRecords(t *testing.T) { + repo := &mockZoneRecordRepo{ + zones: []domain.Zone{{ID: "z1", Name: "empty.com."}}, + records: map[string][]domain.Record{}, // no records for z1 + } + + counter := NewZoneRecordCounter(repo, 50*time.Millisecond) + ctx := context.Background() + + counter.Start(ctx) + time.Sleep(100 * time.Millisecond) + counter.Stop() +}