Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 248 additions & 0 deletions e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"net"
"strconv"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -326,3 +327,250 @@ func TestE2EBenchmarkResultMetrics(t *testing.T) {
t.Error("P50 latency should not be negative")
}
}

// flakyDNSServer drops the first N requests, then responds normally.
type flakyDNSServer struct {
server *dns.Server
addr string
ip string
port int
response net.IP
queries atomic.Int64
dropN int64
}

func newFlakyDNSServer(responseIP string, dropFirst int64) (*flakyDNSServer, error) {
conn, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
return nil, err
}

addr := conn.LocalAddr().String()
host, portStr, _ := net.SplitHostPort(addr)
port, _ := strconv.Atoi(portStr)

mock := &flakyDNSServer{
addr: addr,
ip: host,
port: port,
response: net.ParseIP(responseIP),
dropN: dropFirst,
}

mux := dns.NewServeMux()
mux.HandleFunc(".", mock.handleQuery)

mock.server = &dns.Server{
PacketConn: conn,
Handler: mux,
}

go mock.server.ActivateAndServe()
time.Sleep(50 * time.Millisecond)

return mock, nil
}

func (m *flakyDNSServer) handleQuery(w dns.ResponseWriter, r *dns.Msg) {
n := m.queries.Add(1)
if n <= m.dropN {
return // silently drop → causes timeout on client
}

msg := new(dns.Msg)
msg.SetReply(r)
msg.Authoritative = true

for _, q := range r.Question {
switch q.Qtype {
case dns.TypeA:
msg.Answer = append(msg.Answer, &dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: m.response,
})
case dns.TypeTXT:
msg.Answer = append(msg.Answer, &dns.TXT{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 60},
Txt: []string{"mock"},
})
}
}

w.WriteMsg(msg)
}

func (m *flakyDNSServer) Close() {
if m.server != nil {
m.server.Shutdown()
}
}

func TestE2ERetryOnTimeout(t *testing.T) {
// Drop first request, respond to second — retry should recover
flaky, err := newFlakyDNSServer("93.184.216.34", 1)
if err != nil {
t.Fatalf("Failed to start flaky DNS: %v", err)
}
defer flaky.Close()

scanner := NewScanner(1, 500*time.Millisecond, flaky.port, "", 1, nil, false)
working := scanner.Scan(context.Background(), sliceToChannel([]string{flaky.ip}))

if len(working) != 1 {
t.Errorf("Expected 1 working (retry should recover), got %d", len(working))
}
if flaky.queries.Load() < 2 {
t.Errorf("Expected at least 2 queries (1 drop + 1 success), got %d", flaky.queries.Load())
}
}

func TestE2ERetryExhausted(t *testing.T) {
// Drop more than ScanRetries+1 requests — should fail
flaky, err := newFlakyDNSServer("93.184.216.34", int64(ScanRetries+1))
if err != nil {
t.Fatalf("Failed to start flaky DNS: %v", err)
}
defer flaky.Close()

scanner := NewScanner(1, 300*time.Millisecond, flaky.port, "", 1, nil, false)
working := scanner.Scan(context.Background(), sliceToChannel([]string{flaky.ip}))

if len(working) != 0 {
t.Errorf("Expected 0 working (retries exhausted), got %d", len(working))
}
}

func TestE2ERetryDomainVerification(t *testing.T) {
// Drop first request (A-record timeout), retry succeeds, TXT passes normally
flaky, err := newFlakyDNSServer("93.184.216.34", 1)
if err != nil {
t.Fatalf("Failed to start flaky DNS: %v", err)
}
defer flaky.Close()

scanner := NewScanner(1, 500*time.Millisecond, flaky.port, "test.example.com", 1, nil, false)
working := scanner.Scan(context.Background(), sliceToChannel([]string{flaky.ip}))

if len(working) != 1 {
t.Errorf("Expected 1 working (domain verify retry should recover), got %d", len(working))
}
if flaky.queries.Load() < 3 {
t.Errorf("Expected at least 3 queries (1 drop + 1 A success + 1 TXT success), got %d", flaky.queries.Load())
}
}

func TestE2ERetryDomainVerificationTXTTimeout(t *testing.T) {
// Use alternating-drop server: drops odd-numbered queries (1st, 3rd, ...)
// Query 1 (A): dropped → Query 2 (A retry): OK → Query 3 (TXT): dropped → Query 4 (TXT retry): OK
flaky, err := newAlternatingDropServer("93.184.216.34")
if err != nil {
t.Fatalf("Failed to start flaky DNS: %v", err)
}
defer flaky.Close()

scanner := NewScanner(1, 500*time.Millisecond, flaky.port, "test.example.com", 1, nil, false)
working := scanner.Scan(context.Background(), sliceToChannel([]string{flaky.ip}))

if len(working) != 1 {
t.Errorf("Expected 1 working (both A and TXT retries should recover), got %d", len(working))
}
if flaky.queries.Load() < 4 {
t.Errorf("Expected at least 4 queries (2 drops + 2 successes), got %d", flaky.queries.Load())
}
}

func TestIsTimeout(t *testing.T) {
if isTimeout(nil) {
t.Error("nil should not be timeout")
}

timeoutErr := &net.OpError{
Op: "read",
Err: &timeoutError{},
}
if !isTimeout(timeoutErr) {
t.Error("net timeout error should be detected")
}
}

// alternatingDropServer drops odd-numbered queries (1st, 3rd, 5th, ...).
type alternatingDropServer struct {
server *dns.Server
addr string
ip string
port int
response net.IP
queries atomic.Int64
}

func newAlternatingDropServer(responseIP string) (*alternatingDropServer, error) {
conn, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
return nil, err
}

addr := conn.LocalAddr().String()
host, portStr, _ := net.SplitHostPort(addr)
port, _ := strconv.Atoi(portStr)

mock := &alternatingDropServer{
addr: addr,
ip: host,
port: port,
response: net.ParseIP(responseIP),
}

mux := dns.NewServeMux()
mux.HandleFunc(".", mock.handleQuery)

mock.server = &dns.Server{
PacketConn: conn,
Handler: mux,
}

go mock.server.ActivateAndServe()
time.Sleep(50 * time.Millisecond)

return mock, nil
}

func (m *alternatingDropServer) handleQuery(w dns.ResponseWriter, r *dns.Msg) {
n := m.queries.Add(1)
if n%2 == 1 {
return // drop odd-numbered queries
}

msg := new(dns.Msg)
msg.SetReply(r)
msg.Authoritative = true

for _, q := range r.Question {
switch q.Qtype {
case dns.TypeA:
msg.Answer = append(msg.Answer, &dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: m.response,
})
case dns.TypeTXT:
msg.Answer = append(msg.Answer, &dns.TXT{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 60},
Txt: []string{"mock"},
})
}
}

w.WriteMsg(msg)
}

func (m *alternatingDropServer) Close() {
if m.server != nil {
m.server.Shutdown()
}
}

// timeoutError implements net.Error with Timeout() = true
type timeoutError struct{}

func (e *timeoutError) Error() string { return "i/o timeout" }
func (e *timeoutError) Timeout() bool { return true }
func (e *timeoutError) Temporary() bool { return true }
43 changes: 35 additions & 8 deletions scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@ import (
"fmt"
"io"
"net"
"strings"
"sync"
"time"

"github.com/miekg/dns"
)

const ScanRetries = 1

const probeDomain = "google.com"

type ScanResult struct {
Expand Down Expand Up @@ -154,9 +157,18 @@ func (s *Scanner) probe(ip string) ScanResult {
m.RecursionDesired = true

addr := fmt.Sprintf("%s:%d", ip, s.port)
reply, rtt, err := client.Exchange(m, addr)
if err != nil {
return ScanResult{IP: ip, Working: false, Error: err}

var reply *dns.Msg
var rtt time.Duration
var err error
for attempt := 0; attempt <= ScanRetries; attempt++ {
reply, rtt, err = client.Exchange(m, addr)
if err == nil {
break
}
if !isTimeout(err) || attempt == ScanRetries {
return ScanResult{IP: ip, Working: false, Error: err}
}
}

if reply == nil || reply.Rcode != dns.RcodeSuccess || len(reply.Answer) == 0 {
Expand All @@ -179,20 +191,35 @@ func (s *Scanner) probe(ip string) ScanResult {
m2.RecursionDesired = true
m2.SetEdns0(EDNSBufferSize, false)

reply2, rtt2, err := client.Exchange(m2, addr)
if err != nil {
return ScanResult{IP: ip, Working: false, Error: err}
for attempt := 0; attempt <= ScanRetries; attempt++ {
reply, rtt, err = client.Exchange(m2, addr)
if err == nil {
break
}
if !isTimeout(err) || attempt == ScanRetries {
return ScanResult{IP: ip, Working: false, Error: err}
}
}

if reply2 != nil {
return ScanResult{IP: ip, Working: true, RTT: rtt2}
if reply != nil {
return ScanResult{IP: ip, Working: true, RTT: rtt}
}
return ScanResult{IP: ip, Working: false}
}

return ScanResult{IP: ip, Working: true, RTT: rtt}
}

func isTimeout(err error) bool {
if err == nil {
return false
}
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return true
}
return strings.Contains(err.Error(), "i/o timeout")
}

var privateRanges = []string{
"10.0.0.0/8",
"172.16.0.0/12",
Expand Down