diff --git a/e2e_test.go b/e2e_test.go new file mode 100644 index 0000000..2e716ca --- /dev/null +++ b/e2e_test.go @@ -0,0 +1,254 @@ +package main + +import ( + "context" + "net" + "strconv" + "testing" + "time" + + "github.com/miekg/dns" +) + +type mockDNSServer struct { + server *dns.Server + addr string + ip string + port int + response net.IP +} + +func newMockDNSServer(responseIP string) (*mockDNSServer, error) { + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + return nil, err + } + + addr := conn.LocalAddr().String() + host, portStr, _ := net.SplitHostPort(addr) + port, _ := strconv.Atoi(portStr) + + mock := &mockDNSServer{ + addr: addr, + ip: host, + port: port, + response: net.ParseIP(responseIP), + } + + mux := dns.NewServeMux() + mux.HandleFunc(".", mock.handleQuery) + + mock.server = &dns.Server{ + PacketConn: conn, + Handler: mux, + } + + go mock.server.ActivateAndServe() + time.Sleep(50 * time.Millisecond) + + return mock, nil +} + +func (m *mockDNSServer) handleQuery(w dns.ResponseWriter, r *dns.Msg) { + msg := new(dns.Msg) + msg.SetReply(r) + msg.Authoritative = true + + for _, q := range r.Question { + switch q.Qtype { + case dns.TypeA: + msg.Answer = append(msg.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: m.response, + }) + case dns.TypeTXT: + msg.Answer = append(msg.Answer, &dns.TXT{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 60}, + Txt: []string{"mock"}, + }) + } + } + + w.WriteMsg(msg) +} + +func (m *mockDNSServer) Close() { + if m.server != nil { + m.server.Shutdown() + } +} + +func TestE2EBasicScan(t *testing.T) { + mock, err := newMockDNSServer("93.184.216.34") + if err != nil { + t.Fatalf("Failed to start mock DNS: %v", err) + } + defer mock.Close() + + scanner := NewScanner(1, 2*time.Second, mock.port, nil, "") + result := scanner.Probe(mock.ip) + + if !result.Working { + t.Errorf("Expected working, got error: %v", result.Error) + } + if result.Suspicious { + t.Error("Public IP should not be suspicious") + } +} + +func TestE2EHijackingDetection(t *testing.T) { + mock, err := newMockDNSServer("10.10.34.34") + if err != nil { + t.Fatalf("Failed to start mock DNS: %v", err) + } + defer mock.Close() + + scanner := NewScanner(1, 2*time.Second, mock.port, nil, "") + result := scanner.Probe(mock.ip) + + if result.Working { + t.Error("Private IP response should not be working") + } + if !result.Suspicious { + t.Error("Private IP response should be suspicious") + } +} + +func TestE2EDomainVerification(t *testing.T) { + mock, err := newMockDNSServer("93.184.216.34") + if err != nil { + t.Fatalf("Failed to start mock DNS: %v", err) + } + defer mock.Close() + + scanner := NewScanner(1, 2*time.Second, mock.port, nil, "test.example.com") + result := scanner.Probe(mock.ip) + + if !result.Working { + t.Errorf("Domain verification should pass, got: %v", result.Error) + } +} + +func TestE2EBurstTest(t *testing.T) { + mock, err := newMockDNSServer("93.184.216.34") + if err != nil { + t.Fatalf("Failed to start mock DNS: %v", err) + } + defer mock.Close() + + result := BurstTest(mock.ip, "test.example.com", mock.port, 2*time.Second) + + if result.Queries != BurstQueries { + t.Errorf("Expected %d queries, got %d", BurstQueries, result.Queries) + } + if result.SuccessRate() < 90 { + t.Errorf("Expected high success rate, got %.1f%%", result.SuccessRate()) + } + if !result.Passed() { + t.Errorf("Should pass burst test, got %.1f%%", result.SuccessRate()) + } +} + +func TestE2EWorkerPool(t *testing.T) { + mock, err := newMockDNSServer("93.184.216.34") + if err != nil { + t.Fatalf("Failed to start mock DNS: %v", err) + } + defer mock.Close() + + ips := []string{mock.ip, "192.0.2.1"} // second IP won't respond + progress := NewProgress(len(ips), false) + scanner := NewScanner(2, 500*time.Millisecond, mock.port, progress, "") + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + results := scanner.Run(ctx, IPsFromList(ips)) + + var working, total int + for r := range results { + total++ + if r.Working { + working++ + } + } + + if total != 2 { + t.Errorf("Expected 2 results, got %d", total) + } + if working != 1 { + t.Errorf("Expected 1 working, got %d", working) + } +} + +func TestE2EAllPrivateRanges(t *testing.T) { + privateIPs := []string{"10.0.0.1", "172.16.0.1", "192.168.1.1", "127.0.0.1", "169.254.1.1", "100.64.0.1"} + + for _, ip := range privateIPs { + mock, err := newMockDNSServer(ip) + if err != nil { + t.Fatalf("Failed for %s: %v", ip, err) + } + + scanner := NewScanner(1, 2*time.Second, mock.port, nil, "") + result := scanner.Probe(mock.ip) + mock.Close() + + if result.Working { + t.Errorf("%s should not be working", ip) + } + if !result.Suspicious { + t.Errorf("%s should be suspicious", ip) + } + } +} + +func TestE2ETimeout(t *testing.T) { + scanner := NewScanner(1, 100*time.Millisecond, 65534, nil, "") + result := scanner.Probe("127.0.0.1") + + if result.Working { + t.Error("Non-responsive should not be working") + } + if result.Error == nil { + t.Error("Expected timeout error") + } +} + +func TestE2EProgressTracking(t *testing.T) { + mock, err := newMockDNSServer("93.184.216.34") + if err != nil { + t.Fatalf("Failed to start mock DNS: %v", err) + } + defer mock.Close() + + ips := []string{mock.ip, mock.ip, mock.ip} + progress := NewProgress(len(ips), false) + scanner := NewScanner(1, 2*time.Second, mock.port, progress, "") + + results := scanner.Run(context.Background(), IPsFromList(ips)) + for range results { + } + + scanned, found, total, _ := progress.Stats() + if scanned != 3 || found != 3 || total != 3 { + t.Errorf("Progress mismatch: scanned=%d found=%d total=%d", scanned, found, total) + } +} + +func TestE2EBurstQPS(t *testing.T) { + mock, err := newMockDNSServer("93.184.216.34") + if err != nil { + t.Fatalf("Failed to start mock DNS: %v", err) + } + defer mock.Close() + + result := BurstTest(mock.ip, "test.example.com", mock.port, 2*time.Second) + + if result.QPS() <= 0 { + t.Errorf("QPS should be positive, got %.2f", result.QPS()) + } + if result.P50() <= 0 { + t.Errorf("P50 should be positive, got %v", result.P50()) + } +} diff --git a/main.go b/main.go index 6b9b42f..43ad5ea 100644 --- a/main.go +++ b/main.go @@ -214,7 +214,7 @@ func main() { // Create scanner prog := NewProgress(totalIPs, *progress) - scanner := NewScanner(*workers, *timeout, prog, *domain) + scanner := NewScanner(*workers, *timeout, 53, prog, *domain) // Start progress ticker var progressDone chan struct{} @@ -351,7 +351,7 @@ resultLoop: fmt.Fprintf(os.Stderr, "[%*d/%d] %-15s ", width, i+1, total, ip) } - result := BurstTest(ip, *domain, *timeout) + result := BurstTest(ip, *domain, 53, *timeout) if result.Passed() { burstResults = append(burstResults, result) diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..2657886 --- /dev/null +++ b/main_test.go @@ -0,0 +1,168 @@ +package main + +import ( + "os" + "os/exec" + "path/filepath" + "strings" + "testing" +) + +var binaryPath string + +func TestMain(m *testing.M) { + // Build binary once for all integration tests + dir, _ := os.MkdirTemp("", "dnscan-test") + binaryPath = filepath.Join(dir, "dnscan") + + cmd := exec.Command("go", "build", "-o", binaryPath, ".") + if err := cmd.Run(); err != nil { + os.Exit(1) + } + + code := m.Run() + + os.RemoveAll(dir) + os.Exit(code) +} + +func TestVersionFlag(t *testing.T) { + cmd := exec.Command(binaryPath, "--version") + out, err := cmd.Output() + if err != nil { + t.Fatalf("--version failed: %v", err) + } + + if !strings.Contains(string(out), "dnscan") { + t.Errorf("--version output missing 'dnscan': %s", out) + } +} + +func TestInvalidMode(t *testing.T) { + cmd := exec.Command(binaryPath, "--mode", "invalid", "--progress=false") + out, err := cmd.CombinedOutput() + + if err == nil { + t.Error("expected error for invalid mode") + } + + if !strings.Contains(string(out), "Invalid mode") { + t.Errorf("expected 'Invalid mode' error, got: %s", out) + } +} + +func TestFileInputNotFound(t *testing.T) { + cmd := exec.Command(binaryPath, "--file", "/nonexistent/file.txt", "--progress=false") + out, err := cmd.CombinedOutput() + + if err == nil { + t.Error("expected error for missing file") + } + + if !strings.Contains(string(out), "Failed to read file") { + t.Errorf("expected 'Failed to read file' error, got: %s", out) + } +} + +func TestVerifyBinaryNotFound(t *testing.T) { + cmd := exec.Command(binaryPath, "--verify", "/nonexistent/binary", "--progress=false") + out, err := cmd.CombinedOutput() + + if err == nil { + t.Error("expected error for missing verify binary") + } + + if !strings.Contains(string(out), "not found") { + t.Errorf("expected 'not found' error, got: %s", out) + } +} + +func TestOutputToFile(t *testing.T) { + // Create temp file for custom IP list + inputFile, _ := os.CreateTemp("", "input-*.txt") + inputFile.WriteString("# comment\n8.8.8.8\n") + inputFile.Close() + defer os.Remove(inputFile.Name()) + + // Create temp output file + outputFile, _ := os.CreateTemp("", "output-*.txt") + outputFile.Close() + defer os.Remove(outputFile.Name()) + + // Run with very short timeout (will likely fail DNS but tests output mechanism) + cmd := exec.Command(binaryPath, + "--file", inputFile.Name(), + "--output", outputFile.Name(), + "--timeout", "100ms", + "--progress=false", + ) + cmd.Run() // Ignore error - DNS may fail + + // Check output file was created + if _, err := os.Stat(outputFile.Name()); os.IsNotExist(err) { + t.Error("output file was not created") + } +} + +func TestModeListWithCountry(t *testing.T) { + cmd := exec.Command(binaryPath, + "--country", "ir", + "--mode", "list", + "--timeout", "100ms", + "--progress=false", + ) + _, err := cmd.CombinedOutput() + + // Should start without error (actual DNS queries may fail) + if err != nil { + exitErr, ok := err.(*exec.ExitError) + if ok && exitErr.ExitCode() == 1 { + // Exit code 1 from no results is OK + return + } + t.Errorf("unexpected error: %v", err) + } +} + +func TestProgressFlagDisabled(t *testing.T) { + inputFile, _ := os.CreateTemp("", "input-*.txt") + inputFile.WriteString("8.8.8.8\n") + inputFile.Close() + defer os.Remove(inputFile.Name()) + + cmd := exec.Command(binaryPath, + "--file", inputFile.Name(), + "--timeout", "100ms", + "--progress=false", + ) + out, _ := cmd.CombinedOutput() + + // With progress disabled, should not see "dnscan" header + if strings.Contains(string(out), "Country:") { + t.Error("progress=false should not show header") + } +} + +func TestReadIPsFromFile(t *testing.T) { + // Create temp file + f, _ := os.CreateTemp("", "ips-*.txt") + f.WriteString("# comment line\n") + f.WriteString("192.168.1.1\n") + f.WriteString("\n") // empty line + f.WriteString("10.0.0.1\n") + f.Close() + defer os.Remove(f.Name()) + + ips, err := readIPsFromFile(f.Name()) + if err != nil { + t.Fatalf("readIPsFromFile failed: %v", err) + } + + if len(ips) != 2 { + t.Errorf("expected 2 IPs, got %d", len(ips)) + } + + if ips[0] != "192.168.1.1" || ips[1] != "10.0.0.1" { + t.Errorf("unexpected IPs: %v", ips) + } +} diff --git a/scanner.go b/scanner.go index 6cc810e..7df53d4 100644 --- a/scanner.go +++ b/scanner.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "encoding/base32" "encoding/hex" + "fmt" "net" "sort" "sync" @@ -72,8 +73,9 @@ func isPrivateIP(ip net.IP) bool { type Scanner struct { workers int timeout time.Duration + port int progress *Progress - verifyDomain string // If set, verify this domain can be resolved (for slipstream) + verifyDomain string } // Progress tracks scanning progress @@ -114,10 +116,14 @@ func (p *Progress) Stats() (scanned, found, total int64, elapsed time.Duration) } // NewScanner creates a new scanner with given workers and timeout -func NewScanner(workers int, timeout time.Duration, progress *Progress, verifyDomain string) *Scanner { +func NewScanner(workers int, timeout time.Duration, port int, progress *Progress, verifyDomain string) *Scanner { + if port == 0 { + port = 53 + } return &Scanner{ workers: workers, timeout: timeout, + port: port, progress: progress, verifyDomain: verifyDomain, } @@ -136,7 +142,8 @@ func (s *Scanner) Probe(ip string) ScanResult { m.SetQuestion(dns.Fqdn("google.com"), dns.TypeA) m.RecursionDesired = true - reply, rtt, err := client.Exchange(m, ip+":53") + addr := fmt.Sprintf("%s:%d", ip, s.port) + reply, rtt, err := client.Exchange(m, addr) if err != nil { return ScanResult{IP: ip, Working: false, Error: err} } @@ -168,7 +175,7 @@ func (s *Scanner) Probe(ip string) ScanResult { // Set EDNS0 with 1232 byte UDP payload (matches slipstream) m2.SetEdns0(1232, false) - reply2, rtt2, err := client.Exchange(m2, ip+":53") + reply2, rtt2, err := client.Exchange(m2, addr) if err != nil { return ScanResult{IP: ip, Working: false, Error: err} } @@ -278,7 +285,12 @@ func (r *BurstResult) Passed() bool { } // BurstTest runs concurrent DNS queries to test server reliability under load -func BurstTest(ip, domain string, timeout time.Duration) *BurstResult { +func BurstTest(ip, domain string, port int, timeout time.Duration) *BurstResult { + if port == 0 { + port = 53 + } + addr := fmt.Sprintf("%s:%d", ip, port) + result := &BurstResult{ IP: ip, Queries: BurstQueries, @@ -309,7 +321,7 @@ func BurstTest(ip, domain string, timeout time.Duration) *BurstResult { m.RecursionDesired = true m.SetEdns0(1232, false) - _, rtt, err := client.Exchange(m, ip+":53") + _, rtt, err := client.Exchange(m, addr) mu.Lock() if err != nil { diff --git a/scanner_test.go b/scanner_test.go new file mode 100644 index 0000000..d36b000 --- /dev/null +++ b/scanner_test.go @@ -0,0 +1,191 @@ +package main + +import ( + "net" + "testing" + "time" +) + +func TestIsPrivateIP(t *testing.T) { + tests := []struct { + ip string + expected bool + desc string + }{ + // Private ranges - should be detected + {"10.0.0.1", true, "10.x.x.x (RFC 1918)"}, + {"10.255.255.255", true, "10.x.x.x upper bound"}, + {"172.16.0.1", true, "172.16.x.x (RFC 1918)"}, + {"172.31.255.255", true, "172.31.x.x upper bound"}, + {"192.168.0.1", true, "192.168.x.x (RFC 1918)"}, + {"192.168.255.255", true, "192.168.x.x upper bound"}, + {"127.0.0.1", true, "loopback"}, + {"169.254.1.1", true, "link-local"}, + {"100.64.0.1", true, "CGNAT"}, + {"100.127.255.255", true, "CGNAT upper bound"}, + {"0.0.0.0", true, "zero address"}, + + // Public IPs - should not be detected + {"8.8.8.8", false, "Google DNS"}, + {"1.1.1.1", false, "Cloudflare DNS"}, + {"185.8.174.140", false, "Iranian DNS"}, + {"172.15.255.255", false, "just below 172.16.0.0"}, + {"172.32.0.0", false, "just above 172.31.255.255"}, + {"100.63.255.255", false, "just below CGNAT"}, + {"100.128.0.0", false, "just above CGNAT"}, + } + + for _, tt := range tests { + ip := net.ParseIP(tt.ip) + result := isPrivateIP(ip) + if result != tt.expected { + t.Errorf("isPrivateIP(%s) = %v, expected %v (%s)", + tt.ip, result, tt.expected, tt.desc) + } + } +} + +func TestIsPrivateIPNil(t *testing.T) { + if isPrivateIP(nil) { + t.Error("isPrivateIP(nil) should return false") + } +} + +func TestRandomSubdomain(t *testing.T) { + s1 := randomSubdomain() + s2 := randomSubdomain() + + // Should be hex encoded (16 chars for 8 bytes) + if len(s1) != 16 { + t.Errorf("randomSubdomain length = %d, expected 16", len(s1)) + } + + // Should be different each time + if s1 == s2 { + t.Error("randomSubdomain should generate unique values") + } +} + +func TestRandomSlipstreamSubdomain(t *testing.T) { + s1 := randomSlipstreamSubdomain() + s2 := randomSlipstreamSubdomain() + + // Base32 encoded 32 bytes = 52 chars + if len(s1) != 52 { + t.Errorf("randomSlipstreamSubdomain length = %d, expected 52", len(s1)) + } + + // Should be different each time + if s1 == s2 { + t.Error("randomSlipstreamSubdomain should generate unique values") + } +} + +func TestBurstResultSuccessRate(t *testing.T) { + tests := []struct { + queries int + successful int + expected float64 + }{ + {20, 20, 100.0}, + {20, 10, 50.0}, + {20, 0, 0.0}, + {0, 0, 0.0}, + } + + for _, tt := range tests { + r := &BurstResult{ + Queries: tt.queries, + Successful: tt.successful, + } + if r.SuccessRate() != tt.expected { + t.Errorf("SuccessRate(%d/%d) = %.1f, expected %.1f", + tt.successful, tt.queries, r.SuccessRate(), tt.expected) + } + } +} + +func TestBurstResultQPS(t *testing.T) { + r := &BurstResult{ + Successful: 10, + Duration: time.Second, + } + if r.QPS() != 10.0 { + t.Errorf("QPS = %.1f, expected 10.0", r.QPS()) + } + + // Zero duration edge case + r2 := &BurstResult{Duration: 0} + if r2.QPS() != 0.0 { + t.Error("QPS with zero duration should be 0") + } +} + +func TestBurstResultP50(t *testing.T) { + r := &BurstResult{ + Latencies: []time.Duration{ + 10 * time.Millisecond, + 20 * time.Millisecond, + 30 * time.Millisecond, + 40 * time.Millisecond, + 50 * time.Millisecond, + }, + } + + p50 := r.P50() + if p50 != 30*time.Millisecond { + t.Errorf("P50 = %v, expected 30ms", p50) + } + + // Empty latencies + r2 := &BurstResult{} + if r2.P50() != 0 { + t.Error("P50 with no latencies should be 0") + } +} + +func TestBurstResultPassed(t *testing.T) { + tests := []struct { + queries int + successful int + expected bool + }{ + {20, 14, true}, // 70% exactly + {20, 15, true}, // 75% + {20, 13, false}, // 65% + {20, 0, false}, // 0% + } + + for _, tt := range tests { + r := &BurstResult{ + Queries: tt.queries, + Successful: tt.successful, + } + if r.Passed() != tt.expected { + t.Errorf("Passed(%d/%d = %.0f%%) = %v, expected %v", + tt.successful, tt.queries, r.SuccessRate(), r.Passed(), tt.expected) + } + } +} + +func TestProgressStats(t *testing.T) { + p := NewProgress(100, true) + + p.Increment() + p.Increment() + p.Found() + + scanned, found, total, elapsed := p.Stats() + if scanned != 2 { + t.Errorf("scanned = %d, expected 2", scanned) + } + if found != 1 { + t.Errorf("found = %d, expected 1", found) + } + if total != 100 { + t.Errorf("total = %d, expected 100", total) + } + if elapsed <= 0 { + t.Error("elapsed should be positive") + } +}