From 4cfbf4a17bd0b0a4445dfaf5f3cc69057ba3fa52 Mon Sep 17 00:00:00 2001 From: Night Owl Nerd <256460992+nightowlnerd@users.noreply.github.com> Date: Mon, 2 Feb 2026 08:06:13 +0100 Subject: [PATCH 01/11] refactor(config): extract Config struct and flag parsing --- config.go | 74 +++++++++++++++++++++++ main.go | 164 ++++++++++++++++++--------------------------------- main_test.go | 4 +- 3 files changed, 132 insertions(+), 110 deletions(-) create mode 100644 config.go diff --git a/config.go b/config.go new file mode 100644 index 0000000..3fb3249 --- /dev/null +++ b/config.go @@ -0,0 +1,74 @@ +package main + +import ( + "flag" + "fmt" + "os" + "time" +) + +type Config struct { + Country string + Mode string + InputFile string + DataDir string + Workers int + Timeout time.Duration + Domain string + VerifyBinary string + OutputFile string + JSONOutput bool + Progress bool + ShowVersion bool +} + +func ParseFlags() *Config { + c := &Config{} + + flag.StringVar(&c.Country, "country", "ir", "Country code for IP ranges (e.g., ir, cn)") + flag.StringVar(&c.Mode, "mode", "fast", "Scan mode: fast, medium, all, list") + flag.IntVar(&c.Workers, "workers", 500, "Number of concurrent workers") + flag.DurationVar(&c.Timeout, "timeout", 2*time.Second, "DNS query timeout") + flag.StringVar(&c.OutputFile, "output", "", "Output file (default: stdout)") + flag.StringVar(&c.InputFile, "file", "", "Input file with DNS IPs (one per line)") + flag.StringVar(&c.DataDir, "data-dir", "data", "Directory containing ranges/ and dns/ subdirs") + flag.BoolVar(&c.Progress, "progress", true, "Show progress indicator") + flag.StringVar(&c.Domain, "domain", "", "Tunnel domain to verify (e.g., t.example.com)") + flag.StringVar(&c.VerifyBinary, "verify", "", "Path to slipstream-client binary") + flag.BoolVar(&c.ShowVersion, "version", false, "Show version") + flag.BoolVar(&c.JSONOutput, "json", false, "Output results as JSON") + + flag.Parse() + + // JSON is machine output - progress would corrupt it + if c.JSONOutput { + c.Progress = false + } + + return c +} + +func (c *Config) Validate() error { + validModes := map[string]bool{"fast": true, "medium": true, "all": true, "list": true} + if !validModes[c.Mode] { + return fmt.Errorf("invalid mode: %s (use: fast, medium, all, list)", c.Mode) + } + + if c.VerifyBinary != "" { + info, err := os.Stat(c.VerifyBinary) + if os.IsNotExist(err) { + return fmt.Errorf("verify binary not found: %s", c.VerifyBinary) + } + if err != nil { + return fmt.Errorf("cannot access verify binary: %w", err) + } + if info.IsDir() { + return fmt.Errorf("verify path is a directory: %s", c.VerifyBinary) + } + if info.Mode()&0111 == 0 { + return fmt.Errorf("verify binary not executable: %s (run: chmod +x %s)", c.VerifyBinary, c.VerifyBinary) + } + } + + return nil +} diff --git a/main.go b/main.go index 3ad0350..8d85ed4 100644 --- a/main.go +++ b/main.go @@ -5,7 +5,6 @@ import ( "bytes" "context" "encoding/json" - "flag" "fmt" "os" "os/exec" @@ -16,7 +15,8 @@ import ( "time" ) -// readIPsFromFile reads IPs from a file (one per line) +var version = "dev" + func readIPsFromFile(path string) ([]string, error) { f, err := os.Open(path) if err != nil { @@ -35,10 +35,6 @@ func readIPsFromFile(path string) ([]string, error) { return ips, scanner.Err() } -var ( - version = "dev" -) - // JSONOutput is the structured output format for --json flag type JSONOutput struct { Servers []JSONServer `json:"servers"` @@ -109,66 +105,23 @@ func verifyWithSlipstream(clientPath, domain, ip string, timeout time.Duration) } func main() { - // CLI flags - country := flag.String("country", "ir", "Country code for IP ranges (e.g., ir, cn)") - mode := flag.String("mode", "fast", "Scan mode: fast, medium, all, list") - workers := flag.Int("workers", 500, "Number of concurrent workers") - timeout := flag.Duration("timeout", 2*time.Second, "DNS query timeout") - output := flag.String("output", "", "Output file (default: stdout)") - inputFile := flag.String("file", "", "Input file with DNS IPs (one per line)") - dataDir := flag.String("data-dir", "data", "Directory containing ranges/ and dns/ subdirs") - progress := flag.Bool("progress", true, "Show progress indicator") - domain := flag.String("domain", "", "Tunnel domain to verify (e.g., t.example.com). Required for slipstream compatibility.") - verify := flag.String("verify", "", "Path to slipstream-client binary to verify candidates actually work") - showVersion := flag.Bool("version", false, "Show version") - jsonOutput := flag.Bool("json", false, "Output results as JSON") - flag.Parse() - - // Set data directory - DataDir = *dataDir - - // JSON mode disables progress - machine output only - if *jsonOutput { - *progress = false - } + cfg := ParseFlags() - if *showVersion { + if cfg.ShowVersion { fmt.Printf("dnscan %s\n", version) os.Exit(0) } - // Validate mode - validModes := map[string]bool{"fast": true, "medium": true, "all": true, "list": true} - if !validModes[*mode] { - fmt.Fprintf(os.Stderr, "Invalid mode: %s (use: fast, medium, all, list)\n", *mode) + if err := cfg.Validate(); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } - // Validate --verify binary if provided - if *verify != "" { - info, err := os.Stat(*verify) - if os.IsNotExist(err) { - fmt.Fprintf(os.Stderr, "Verify binary not found: %s\n", *verify) - os.Exit(1) - } - if err != nil { - fmt.Fprintf(os.Stderr, "Cannot access verify binary: %v\n", err) - os.Exit(1) - } - if info.IsDir() { - fmt.Fprintf(os.Stderr, "Verify path is a directory, not a binary: %s\n", *verify) - os.Exit(1) - } - if info.Mode()&0111 == 0 { - fmt.Fprintf(os.Stderr, "Verify binary is not executable: %s\nRun: chmod +x %s\n", *verify, *verify) - os.Exit(1) - } - } + DataDir = cfg.DataDir - // Setup output var outFile *os.File - if *output != "" { - f, err := os.Create(*output) + if cfg.OutputFile != "" { + f, err := os.Create(cfg.OutputFile) if err != nil { fmt.Fprintf(os.Stderr, "Failed to create output file: %v\n", err) os.Exit(1) @@ -177,49 +130,44 @@ func main() { outFile = f } - // Get IPs to scan var ips <-chan string var totalIPs int - if *inputFile != "" { - // Read from file - fileIPs, err := readIPsFromFile(*inputFile) + if cfg.InputFile != "" { + fileIPs, err := readIPsFromFile(cfg.InputFile) if err != nil { fmt.Fprintf(os.Stderr, "Failed to read file: %v\n", err) os.Exit(1) } ips = IPsFromList(fileIPs) totalIPs = len(fileIPs) - } else if *mode == "list" { - // Load known DNS servers for country - dnsList, err := LoadDNSList(*country) + } else if cfg.Mode == "list" { + dnsList, err := LoadDNSList(cfg.Country) if err != nil { - fmt.Fprintf(os.Stderr, "Failed to load DNS list for %s: %v\n", *country, err) + fmt.Fprintf(os.Stderr, "Failed to load DNS list for %s: %v\n", cfg.Country, err) os.Exit(1) } ips = IPsFromList(dnsList) totalIPs = len(dnsList) } else { - // Load IP ranges for country - ranges, err := LoadRanges(*country) + ranges, err := LoadRanges(cfg.Country) if err != nil { - fmt.Fprintf(os.Stderr, "Failed to load ranges for %s: %v\n", *country, err) + fmt.Fprintf(os.Stderr, "Failed to load ranges for %s: %v\n", cfg.Country, err) os.Exit(1) } - totalIPs = CountIPsWithMode(ranges, *mode) - ips = ExpandRangesWithMode(ranges, *mode) + totalIPs = CountIPsWithMode(ranges, cfg.Mode) + ips = ExpandRangesWithMode(ranges, cfg.Mode) } - // Print header - if *progress { + if cfg.Progress { fmt.Fprintf(os.Stderr, "dnscan %s\n", version) - if *inputFile != "" { - fmt.Fprintf(os.Stderr, "Source: %s | Workers: %d | Timeout: %v\n", *inputFile, *workers, *timeout) + if cfg.InputFile != "" { + fmt.Fprintf(os.Stderr, "Source: %s | Workers: %d | Timeout: %v\n", cfg.InputFile, cfg.Workers, cfg.Timeout) } else { - fmt.Fprintf(os.Stderr, "Country: %s | Mode: %s | Workers: %d | Timeout: %v\n", *country, *mode, *workers, *timeout) + fmt.Fprintf(os.Stderr, "Country: %s | Mode: %s | Workers: %d | Timeout: %v\n", cfg.Country, cfg.Mode, cfg.Workers, cfg.Timeout) } - if *domain != "" { - fmt.Fprintf(os.Stderr, "Tunnel domain: %s (verifies query reaches server)\n", *domain) + if cfg.Domain != "" { + fmt.Fprintf(os.Stderr, "Tunnel domain: %s (verifies query reaches server)\n", cfg.Domain) } else { fmt.Fprintf(os.Stderr, "WARNING: No --domain set. Finding generic DNS, not tunnel-compatible!\n") fmt.Fprintf(os.Stderr, " Use: --domain t.example.com for slipstream compatibility\n") @@ -238,19 +186,19 @@ func main() { signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) go func() { <-sigCh - if *progress { + if cfg.Progress { fmt.Fprintf(os.Stderr, "\nInterrupted, stopping...\n") } cancel() }() // Create scanner - prog := NewProgress(totalIPs, *progress) - scanner := NewScanner(*workers, *timeout, 53, prog, *domain) + prog := NewProgress(totalIPs, cfg.Progress) + scanner := NewScanner(cfg.Workers, cfg.Timeout, 53, prog, cfg.Domain) // Start progress ticker var progressDone chan struct{} - if *progress { + if cfg.Progress { progressDone = make(chan struct{}) go func() { ticker := time.NewTicker(500 * time.Millisecond) @@ -302,7 +250,7 @@ resultLoop: } // Print final stats - if *progress { + if cfg.Progress { scanned, found, _, elapsed := prog.Stats() fmt.Fprintf(os.Stderr, "\r \r") fmt.Fprintf(os.Stderr, "Completed: %d IPs in %v\n", scanned, elapsed.Round(time.Millisecond)) @@ -313,8 +261,8 @@ resultLoop: } // Phase 2: Verify with slipstream-client if requested - if *verify != "" && len(workingDNS) > 0 { - if *progress { + if cfg.VerifyBinary != "" && len(workingDNS) > 0 { + if cfg.Progress { fmt.Fprintf(os.Stderr, "\nVerifying %d candidates with slipstream-client...\n", len(workingDNS)) } @@ -325,32 +273,32 @@ resultLoop: // Check for interrupt select { case <-ctx.Done(): - if *progress { + if cfg.Progress { fmt.Fprintf(os.Stderr, "\nInterrupted during verification\n") } goto slipstreamDone default: } - if *progress { + if cfg.Progress { fmt.Fprintf(os.Stderr, "[%*d/%d] %-15s ", width, i+1, total, ip) } start := time.Now() - if verifyWithSlipstream(*verify, *domain, ip, *timeout) { + if verifyWithSlipstream(cfg.VerifyBinary, cfg.Domain, ip, cfg.Timeout) { elapsed := time.Since(start) verified = append(verified, ip) - if *progress { + if cfg.Progress { fmt.Fprintf(os.Stderr, "\033[32mOK (%.1fs)\033[0m\n", elapsed.Seconds()) } } else { - if *progress { + if cfg.Progress { fmt.Fprintf(os.Stderr, "FAIL\n") } } } slipstreamDone: - if *progress { + if cfg.Progress { fmt.Fprintf(os.Stderr, "---\n") fmt.Fprintf(os.Stderr, "Slipstream: %d/%d passed\n", len(verified), len(workingDNS)) } @@ -359,12 +307,12 @@ resultLoop: // Phase 3: Burst test to verify servers handle concurrent load var burstResults []*BurstResult - if *domain != "" && len(workingDNS) > 0 { + if cfg.Domain != "" && len(workingDNS) > 0 { total := len(workingDNS) if total <= 5 { // Sequential for small lists - nicer per-IP output - if *progress { + if cfg.Progress { fmt.Fprintf(os.Stderr, "\nBurst testing %d candidates (%d queries, %d%% required)...\n", total, BurstQueries, BurstMinSuccess) } @@ -373,22 +321,22 @@ resultLoop: for i, ip := range workingDNS { select { case <-ctx.Done(): - if *progress { + if cfg.Progress { fmt.Fprintf(os.Stderr, "\nInterrupted during burst test\n") } goto burstDone default: } - if *progress { + if cfg.Progress { fmt.Fprintf(os.Stderr, "[%*d/%d] %-15s ", width, i+1, total, ip) } - result := BurstTest(ctx, ip, *domain, 53, *timeout) + result := BurstTest(ctx, ip, cfg.Domain, 53, cfg.Timeout) if result.Passed() { burstResults = append(burstResults, result) - if *progress { + if cfg.Progress { color := "\033[33m" if result.SuccessRate() >= 85 { color = "\033[32m" @@ -397,7 +345,7 @@ resultLoop: color, result.SuccessRate(), result.QPS(), result.P50().Round(time.Millisecond)) } } else { - if *progress { + if cfg.Progress { fmt.Fprintf(os.Stderr, "FAIL %.0f%%\n", result.SuccessRate()) } } @@ -405,15 +353,15 @@ resultLoop: } else { // Parallel for larger lists burstWorkers := min(total, 10) - if *progress { + if cfg.Progress { fmt.Fprintf(os.Stderr, "\nBurst testing %d candidates in parallel (%d workers)...\n", total, burstWorkers) } - burstProg := NewBurstProgress(total, *progress) + burstProg := NewBurstProgress(total, cfg.Progress) var progressDone chan struct{} - if *progress { + if cfg.Progress { progressDone = make(chan struct{}) go func() { ticker := time.NewTicker(500 * time.Millisecond) @@ -432,7 +380,7 @@ resultLoop: }() } - resultChan := ParallelBurstTest(ctx, workingDNS, *domain, 53, *timeout, burstWorkers) + resultChan := ParallelBurstTest(ctx, workingDNS, cfg.Domain, 53, cfg.Timeout, burstWorkers) for result := range resultChan { burstProg.Tested() if result.Passed() { @@ -452,7 +400,7 @@ resultLoop: return burstResults[i].QPS() > burstResults[j].QPS() }) - if *progress { + if cfg.Progress { fmt.Fprintf(os.Stderr, "---\n") fmt.Fprintf(os.Stderr, "Burst test: %d/%d passed (sorted by throughput)\n", len(burstResults), len(workingDNS)) for _, r := range burstResults { @@ -475,7 +423,7 @@ resultLoop: scanned, _, _, elapsed := prog.Stats() // Write results - if *jsonOutput { + if cfg.JSONOutput { output := JSONOutput{ Servers: []JSONServer{}, Scan: JSONScan{ @@ -484,9 +432,9 @@ resultLoop: DurationMs: elapsed.Milliseconds(), }, } - if *inputFile == "" { - output.Scan.Country = *country - output.Scan.Mode = *mode + if cfg.InputFile == "" { + output.Scan.Country = cfg.Country + output.Scan.Mode = cfg.Mode } // Build server list with stats if available @@ -519,7 +467,7 @@ resultLoop: for _, ip := range workingDNS { fmt.Fprintln(outFile, ip) } - } else if !*progress { + } else if !cfg.Progress { for _, ip := range workingDNS { fmt.Println(ip) } @@ -527,8 +475,8 @@ resultLoop: } // Print usage hint - if *progress && len(workingDNS) > 0 { - showDomain := *domain + if cfg.Progress && len(workingDNS) > 0 { + showDomain := cfg.Domain if showDomain == "" { showDomain = "" } diff --git a/main_test.go b/main_test.go index 8f78153..cd544b0 100644 --- a/main_test.go +++ b/main_test.go @@ -47,8 +47,8 @@ func TestInvalidMode(t *testing.T) { t.Error("expected error for invalid mode") } - if !strings.Contains(string(out), "Invalid mode") { - t.Errorf("expected 'Invalid mode' error, got: %s", out) + if !strings.Contains(string(out), "invalid mode") { + t.Errorf("expected 'invalid mode' error, got: %s", out) } } From 1f17a23462ff865a05b281da8f156d74039008b1 Mon Sep 17 00:00:00 2001 From: Night Owl Nerd <256460992+nightowlnerd@users.noreply.github.com> Date: Mon, 2 Feb 2026 08:40:30 +0100 Subject: [PATCH 02/11] feat(output): add OutputWriter interface with PrintBanner and PrintUsageHint --- e2e_test.go | 24 +++------ main.go | 140 +++++++++++++-------------------------------------- main_test.go | 23 ++++++++- output.go | 140 +++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 205 insertions(+), 122 deletions(-) create mode 100644 output.go diff --git a/e2e_test.go b/e2e_test.go index 914075d..db56bf9 100644 --- a/e2e_test.go +++ b/e2e_test.go @@ -331,7 +331,7 @@ func TestBurstProgress(t *testing.T) { } } -func TestE2EJSONServerFromBurstResult(t *testing.T) { +func TestE2EBurstResultMetrics(t *testing.T) { mock, err := newMockDNSServer("93.184.216.34") if err != nil { t.Fatalf("Failed to start mock DNS: %v", err) @@ -340,24 +340,16 @@ func TestE2EJSONServerFromBurstResult(t *testing.T) { result := BurstTest(context.Background(), mock.ip, "test.example.com", mock.port, 2*time.Second) - server := JSONServer{ - IP: result.IP, - QPS: result.QPS(), - SuccessRate: result.SuccessRate(), - LatencyP50: result.P50().Milliseconds(), + if result.IP != mock.ip { + t.Errorf("IP mismatch: got %s, expected %s", result.IP, mock.ip) } - - if server.IP != mock.ip { - t.Errorf("IP mismatch: got %s, expected %s", server.IP, mock.ip) - } - if server.QPS <= 0 { + if result.QPS() <= 0 { t.Error("QPS should be positive") } - if server.SuccessRate < 70 { - t.Errorf("SuccessRate too low: %.1f", server.SuccessRate) + if result.SuccessRate() < 70 { + t.Errorf("SuccessRate too low: %.1f", result.SuccessRate()) } - // LatencyP50 can be 0ms for localhost mock (sub-millisecond response) - if server.LatencyP50 < 0 { - t.Error("LatencyP50 should not be negative") + if result.P50() < 0 { + t.Error("P50 latency should not be negative") } } diff --git a/main.go b/main.go index 8d85ed4..27f80ae 100644 --- a/main.go +++ b/main.go @@ -4,7 +4,6 @@ import ( "bufio" "bytes" "context" - "encoding/json" "fmt" "os" "os/exec" @@ -35,29 +34,6 @@ func readIPsFromFile(path string) ([]string, error) { return ips, scanner.Err() } -// JSONOutput is the structured output format for --json flag -type JSONOutput struct { - Servers []JSONServer `json:"servers"` - Scan JSONScan `json:"scan"` -} - -// JSONServer represents a single DNS server in JSON output -type JSONServer struct { - IP string `json:"ip"` - QPS float64 `json:"qps,omitempty"` - SuccessRate float64 `json:"success_rate,omitempty"` - LatencyP50 int64 `json:"latency_p50_ms,omitempty"` -} - -// JSONScan contains scan metadata -type JSONScan struct { - Country string `json:"country,omitempty"` - Mode string `json:"mode,omitempty"` - TotalScanned int64 `json:"total_scanned"` - Found int64 `json:"found"` - DurationMs int64 `json:"duration_ms"` -} - // verifyWithSlipstream tests if a DNS server actually works with slipstream-client func verifyWithSlipstream(clientPath, domain, ip string, timeout time.Duration) bool { ctx, cancel := context.WithTimeout(context.Background(), timeout*3) @@ -159,24 +135,7 @@ func main() { ips = ExpandRangesWithMode(ranges, cfg.Mode) } - if cfg.Progress { - fmt.Fprintf(os.Stderr, "dnscan %s\n", version) - if cfg.InputFile != "" { - fmt.Fprintf(os.Stderr, "Source: %s | Workers: %d | Timeout: %v\n", cfg.InputFile, cfg.Workers, cfg.Timeout) - } else { - fmt.Fprintf(os.Stderr, "Country: %s | Mode: %s | Workers: %d | Timeout: %v\n", cfg.Country, cfg.Mode, cfg.Workers, cfg.Timeout) - } - if cfg.Domain != "" { - fmt.Fprintf(os.Stderr, "Tunnel domain: %s (verifies query reaches server)\n", cfg.Domain) - } else { - fmt.Fprintf(os.Stderr, "WARNING: No --domain set. Finding generic DNS, not tunnel-compatible!\n") - fmt.Fprintf(os.Stderr, " Use: --domain t.example.com for slipstream compatibility\n") - } - fmt.Fprintf(os.Stderr, "IPs to scan: %d\n", totalIPs) - fmt.Fprintf(os.Stderr, "---\n") - } else { - fmt.Fprintf(os.Stderr, "Scanning %d IPs...\n", totalIPs) - } + PrintBanner(os.Stderr, cfg, totalIPs, version) // Setup context with signal handling ctx, cancel := context.WithCancel(context.Background()) @@ -419,77 +378,48 @@ resultLoop: } } - // Get final stats scanned, _, _, elapsed := prog.Stats() - // Write results - if cfg.JSONOutput { - output := JSONOutput{ - Servers: []JSONServer{}, - Scan: JSONScan{ - TotalScanned: scanned, - Found: int64(len(workingDNS)), - DurationMs: elapsed.Milliseconds(), - }, + var serverResults []ServerResult + if len(burstResults) > 0 { + for _, r := range burstResults { + serverResults = append(serverResults, ServerResult{ + IP: r.IP, + QPS: r.QPS(), + SuccessRate: r.SuccessRate(), + LatencyP50: r.P50(), + }) } - if cfg.InputFile == "" { - output.Scan.Country = cfg.Country - output.Scan.Mode = cfg.Mode + } else { + for _, ip := range workingDNS { + serverResults = append(serverResults, ServerResult{IP: ip}) } + } - // Build server list with stats if available - if len(burstResults) > 0 { - for _, r := range burstResults { - output.Servers = append(output.Servers, JSONServer{ - IP: r.IP, - QPS: r.QPS(), - SuccessRate: r.SuccessRate(), - LatencyP50: r.P50().Milliseconds(), - }) - } - } else { - for _, ip := range workingDNS { - output.Servers = append(output.Servers, JSONServer{IP: ip}) - } - } + stats := ScanStats{ + TotalScanned: scanned, + Found: int64(len(serverResults)), + Duration: elapsed, + } + if cfg.InputFile == "" { + stats.Country = cfg.Country + stats.Mode = cfg.Mode + } - enc := json.NewEncoder(os.Stdout) - enc.SetIndent("", " ") - if outFile != nil { - enc = json.NewEncoder(outFile) - enc.SetIndent("", " ") - } - enc.Encode(output) + // Write output + out := os.Stdout + if outFile != nil { + out = outFile + } + + if cfg.JSONOutput { + NewJSONWriter(out).Write(serverResults, stats) } else { - // Plain text output - skip stdout when progress shows colored stats - if len(workingDNS) > 0 { - if outFile != nil { - for _, ip := range workingDNS { - fmt.Fprintln(outFile, ip) - } - } else if !cfg.Progress { - for _, ip := range workingDNS { - fmt.Println(ip) - } - } + if outFile != nil || !cfg.Progress { + NewTextWriter(out).Write(serverResults, stats) } - - // Print usage hint - if cfg.Progress && len(workingDNS) > 0 { - showDomain := cfg.Domain - if showDomain == "" { - showDomain = "" - } - max := 10 - if len(workingDNS) < max { - max = len(workingDNS) - } - fmt.Fprintf(os.Stderr, "\nUsage:\n slipstream-client \\\n") - for i := 0; i < max; i++ { - fmt.Fprintf(os.Stderr, " --resolver %s:53 \\\n", workingDNS[i]) - } - fmt.Fprintf(os.Stderr, " --domain %s \\\n", showDomain) - fmt.Fprintf(os.Stderr, " --tcp-listen-port 7000\n") + if cfg.Progress { + PrintUsageHint(os.Stderr, workingDNS, cfg.Domain) } } } diff --git a/main_test.go b/main_test.go index cd544b0..8a503af 100644 --- a/main_test.go +++ b/main_test.go @@ -11,6 +11,27 @@ import ( var binaryPath string +// Test-local types for parsing JSON output +type testJSONOutput struct { + Servers []testJSONServer `json:"servers"` + Scan testJSONScan `json:"scan"` +} + +type testJSONServer struct { + IP string `json:"ip"` + QPS float64 `json:"qps,omitempty"` + SuccessRate float64 `json:"success_rate,omitempty"` + LatencyP50 int64 `json:"latency_p50_ms,omitempty"` +} + +type testJSONScan struct { + Country string `json:"country,omitempty"` + Mode string `json:"mode,omitempty"` + TotalScanned int64 `json:"total_scanned"` + Found int64 `json:"found"` + DurationMs int64 `json:"duration_ms"` +} + func TestMain(m *testing.M) { // Build binary once for all integration tests dir, _ := os.MkdirTemp("", "dnscan-test") @@ -181,7 +202,7 @@ func TestJSONOutputFlag(t *testing.T) { ) out, _ := cmd.Output() - var result JSONOutput + var result testJSONOutput if err := json.Unmarshal(out, &result); err != nil { t.Fatalf("Failed to parse JSON output: %v\nOutput: %s", err, out) } diff --git a/output.go b/output.go new file mode 100644 index 0000000..cdb3178 --- /dev/null +++ b/output.go @@ -0,0 +1,140 @@ +package main + +import ( + "encoding/json" + "fmt" + "io" + "time" +) + +const maxHintServers = 10 + +type ServerResult struct { + IP string + QPS float64 + SuccessRate float64 + LatencyP50 time.Duration +} + +type ScanStats struct { + Country string + Mode string + TotalScanned int64 + Found int64 + Duration time.Duration +} + +type OutputWriter interface { + Write(results []ServerResult, stats ScanStats) error +} + +// --- JSON --- + +type JSONWriter struct { + w io.Writer +} + +func NewJSONWriter(w io.Writer) *JSONWriter { + return &JSONWriter{w: w} +} + +type jsonOutput struct { + Servers []jsonServer `json:"servers"` + Scan jsonScan `json:"scan"` +} + +type jsonServer struct { + IP string `json:"ip"` + QPS float64 `json:"qps,omitempty"` + SuccessRate float64 `json:"success_rate,omitempty"` + LatencyP50 int64 `json:"latency_p50_ms,omitempty"` +} + +type jsonScan struct { + Country string `json:"country,omitempty"` + Mode string `json:"mode,omitempty"` + TotalScanned int64 `json:"total_scanned"` + Found int64 `json:"found"` + DurationMs int64 `json:"duration_ms"` +} + +func (j *JSONWriter) Write(results []ServerResult, stats ScanStats) error { + out := jsonOutput{ + Servers: make([]jsonServer, 0, len(results)), + Scan: jsonScan{ + Country: stats.Country, + Mode: stats.Mode, + TotalScanned: stats.TotalScanned, + Found: stats.Found, + DurationMs: stats.Duration.Milliseconds(), + }, + } + + for _, r := range results { + out.Servers = append(out.Servers, jsonServer{ + IP: r.IP, + QPS: r.QPS, + SuccessRate: r.SuccessRate, + LatencyP50: r.LatencyP50.Milliseconds(), + }) + } + + enc := json.NewEncoder(j.w) + enc.SetIndent("", " ") + return enc.Encode(out) +} + +// --- Text --- + +type TextWriter struct { + w io.Writer +} + +func NewTextWriter(w io.Writer) *TextWriter { + return &TextWriter{w: w} +} + +func (t *TextWriter) Write(results []ServerResult, stats ScanStats) error { + for _, r := range results { + if _, err := fmt.Fprintln(t.w, r.IP); err != nil { + return err + } + } + return nil +} + +func PrintBanner(w io.Writer, cfg *Config, totalIPs int, version string) { + if !cfg.Progress { + fmt.Fprintf(w, "Scanning %d IPs...\n", totalIPs) + return + } + fmt.Fprintf(w, "dnscan %s\n", version) + if cfg.InputFile != "" { + fmt.Fprintf(w, "Source: %s | Workers: %d | Timeout: %v\n", cfg.InputFile, cfg.Workers, cfg.Timeout) + } else { + fmt.Fprintf(w, "Country: %s | Mode: %s | Workers: %d | Timeout: %v\n", cfg.Country, cfg.Mode, cfg.Workers, cfg.Timeout) + } + if cfg.Domain != "" { + fmt.Fprintf(w, "Tunnel domain: %s (verifies query reaches server)\n", cfg.Domain) + } else { + fmt.Fprintf(w, "WARNING: No --domain set. Finding generic DNS, not tunnel-compatible!\n") + fmt.Fprintf(w, " Use: --domain t.example.com for slipstream compatibility\n") + } + fmt.Fprintf(w, "IPs to scan: %d\n", totalIPs) + fmt.Fprintf(w, "---\n") +} + +func PrintUsageHint(w io.Writer, ips []string, domain string) { + if len(ips) == 0 { + return + } + if domain == "" { + domain = "" + } + fmt.Fprintf(w, "\nUsage:\n slipstream-client \\\n") + for i := 0; i < min(len(ips), maxHintServers); i++ { + fmt.Fprintf(w, " --resolver %s:53 \\\n", ips[i]) + } + fmt.Fprintf(w, " --domain %s \\\n", domain) + fmt.Fprintf(w, " --tcp-listen-port 7000\n") +} From a7dd30867cedfb656246175d6029aecd670465c7 Mon Sep 17 00:00:00 2001 From: Night Owl Nerd <256460992+nightowlnerd@users.noreply.github.com> Date: Mon, 2 Feb 2026 08:49:13 +0100 Subject: [PATCH 03/11] feat(output): refactor(verify): extract Verifier interface with SlipstreamVerifier --- main.go | 63 ++++++---------------------------------------- verify.go | 74 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 55 deletions(-) create mode 100644 verify.go diff --git a/main.go b/main.go index 27f80ae..c87464a 100644 --- a/main.go +++ b/main.go @@ -2,11 +2,9 @@ package main import ( "bufio" - "bytes" "context" "fmt" "os" - "os/exec" "os/signal" "sort" "strings" @@ -34,52 +32,6 @@ func readIPsFromFile(path string) ([]string, error) { return ips, scanner.Err() } -// verifyWithSlipstream tests if a DNS server actually works with slipstream-client -func verifyWithSlipstream(clientPath, domain, ip string, timeout time.Duration) bool { - ctx, cancel := context.WithTimeout(context.Background(), timeout*3) - defer cancel() - - cmd := exec.CommandContext(ctx, clientPath, - "--resolver", ip+":53", - "--domain", domain, - "--tcp-listen-port", "0", // Random available port - ) - - var output bytes.Buffer - cmd.Stdout = &output - cmd.Stderr = &output - - if err := cmd.Start(); err != nil { - return false - } - - defer func() { - cmd.Process.Kill() - cmd.Wait() - }() - - // Poll for "Connection ready" (success) or errors (failure) - deadline := time.Now().Add(timeout) - - for time.Now().Before(deadline) { - result := output.String() - - // Success: tunnel connected - if strings.Contains(result, "Connection ready") { - return true - } - - // Failure: connection error - if strings.Contains(result, "Connection closed") || strings.Contains(result, "became unavailable") { - return false - } - - time.Sleep(200 * time.Millisecond) - } - - return false -} - func main() { cfg := ParseFlags() @@ -219,23 +171,24 @@ resultLoop: } } - // Phase 2: Verify with slipstream-client if requested + // Phase 2: Verify with tunnel client if requested if cfg.VerifyBinary != "" && len(workingDNS) > 0 { + verifier := NewSlipstreamVerifier(cfg.VerifyBinary, cfg.Domain, cfg.Timeout) + if cfg.Progress { - fmt.Fprintf(os.Stderr, "\nVerifying %d candidates with slipstream-client...\n", len(workingDNS)) + fmt.Fprintf(os.Stderr, "\nVerifying %d candidates with %s...\n", len(workingDNS), verifier.Name()) } var verified []string total := len(workingDNS) width := len(fmt.Sprintf("%d", total)) for i, ip := range workingDNS { - // Check for interrupt select { case <-ctx.Done(): if cfg.Progress { fmt.Fprintf(os.Stderr, "\nInterrupted during verification\n") } - goto slipstreamDone + goto verifyDone default: } @@ -243,7 +196,7 @@ resultLoop: fmt.Fprintf(os.Stderr, "[%*d/%d] %-15s ", width, i+1, total, ip) } start := time.Now() - if verifyWithSlipstream(cfg.VerifyBinary, cfg.Domain, ip, cfg.Timeout) { + if verifier.Verify(ip) { elapsed := time.Since(start) verified = append(verified, ip) if cfg.Progress { @@ -255,11 +208,11 @@ resultLoop: } } } - slipstreamDone: + verifyDone: if cfg.Progress { fmt.Fprintf(os.Stderr, "---\n") - fmt.Fprintf(os.Stderr, "Slipstream: %d/%d passed\n", len(verified), len(workingDNS)) + fmt.Fprintf(os.Stderr, "%s: %d/%d passed\n", verifier.Name(), len(verified), len(workingDNS)) } workingDNS = verified } diff --git a/verify.go b/verify.go new file mode 100644 index 0000000..99bf398 --- /dev/null +++ b/verify.go @@ -0,0 +1,74 @@ +package main + +import ( + "bytes" + "context" + "os/exec" + "strings" + "time" +) + +type Verifier interface { + Verify(ip string) bool + Name() string +} + +type SlipstreamVerifier struct { + clientPath string + domain string + timeout time.Duration +} + +func NewSlipstreamVerifier(clientPath, domain string, timeout time.Duration) *SlipstreamVerifier { + return &SlipstreamVerifier{ + clientPath: clientPath, + domain: domain, + timeout: timeout, + } +} + +func (v *SlipstreamVerifier) Name() string { + return "slipstream" +} + +func (v *SlipstreamVerifier) Verify(ip string) bool { + ctx, cancel := context.WithTimeout(context.Background(), v.timeout*3) + defer cancel() + + cmd := exec.CommandContext(ctx, v.clientPath, + "--resolver", ip+":53", + "--domain", v.domain, + "--tcp-listen-port", "0", + ) + + var output bytes.Buffer + cmd.Stdout = &output + cmd.Stderr = &output + + if err := cmd.Start(); err != nil { + return false + } + + defer func() { + cmd.Process.Kill() + cmd.Wait() + }() + + deadline := time.Now().Add(v.timeout) + + for time.Now().Before(deadline) { + result := output.String() + + if strings.Contains(result, "Connection ready") { + return true + } + + if strings.Contains(result, "Connection closed") || strings.Contains(result, "became unavailable") { + return false + } + + time.Sleep(200 * time.Millisecond) + } + + return false +} From 472a97388b8489df0aa59f27044e585f99b6a891 Mon Sep 17 00:00:00 2001 From: Night Owl Nerd <256460992+nightowlnerd@users.noreply.github.com> Date: Mon, 2 Feb 2026 09:46:06 +0100 Subject: [PATCH 04/11] test(scanner): fix flaky TestProgressStats timing assertion --- scanner_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scanner_test.go b/scanner_test.go index d36b000..ecaf5dd 100644 --- a/scanner_test.go +++ b/scanner_test.go @@ -185,7 +185,7 @@ func TestProgressStats(t *testing.T) { if total != 100 { t.Errorf("total = %d, expected 100", total) } - if elapsed <= 0 { - t.Error("elapsed should be positive") + if elapsed < 0 { + t.Error("elapsed should not be negative") } } From 7f8f451b55504cfa740e632ae608df56b9a43aa7 Mon Sep 17 00:00:00 2001 From: Night Owl Nerd <256460992+nightowlnerd@users.noreply.github.com> Date: Mon, 2 Feb 2026 09:46:47 +0100 Subject: [PATCH 05/11] refactor(source): add IPSource interface, consolidate loadLines --- app.go | 51 +++++++++++++++++++++++++++++++ e2e_test.go | 4 +-- ipgen.go | 13 -------- ipgen_test.go | 14 --------- main.go | 76 ++++++++-------------------------------------- main_test.go | 10 +++--- ranges.go | 19 ++++-------- ranges_test.go | 8 ++--- source.go | 82 ++++++++++++++++++++++++++++++++++++++++++++++++++ 9 files changed, 162 insertions(+), 115 deletions(-) create mode 100644 app.go create mode 100644 source.go diff --git a/app.go b/app.go new file mode 100644 index 0000000..d2291ca --- /dev/null +++ b/app.go @@ -0,0 +1,51 @@ +package main + +import ( + "fmt" + "io" + "os" +) + +type App struct { + cfg *Config + source IPSource + outFile io.Writer +} + +func NewApp(cfg *Config) (*App, error) { + source, err := newIPSource(cfg) + if err != nil { + return nil, err + } + + app := &App{ + cfg: cfg, + source: source, + } + + if cfg.OutputFile != "" { + f, err := os.Create(cfg.OutputFile) + if err != nil { + return nil, fmt.Errorf("failed to create output file: %w", err) + } + app.outFile = f + } + + return app, nil +} + +func newIPSource(cfg *Config) (IPSource, error) { + if cfg.InputFile != "" { + return NewFileSource(cfg.InputFile) + } + if cfg.Mode == "list" { + return NewDNSListSource(cfg.DataDir, cfg.Country) + } + return NewRangeSource(cfg.DataDir, cfg.Country, cfg.Mode) +} + +func (a *App) Close() { + if f, ok := a.outFile.(*os.File); ok && f != nil { + f.Close() + } +} diff --git a/e2e_test.go b/e2e_test.go index db56bf9..15a34c3 100644 --- a/e2e_test.go +++ b/e2e_test.go @@ -163,7 +163,7 @@ func TestE2EWorkerPool(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - results := scanner.Run(ctx, IPsFromList(ips)) + results := scanner.Run(ctx, sliceToChannel(ips)) var working, total int for r := range results { @@ -226,7 +226,7 @@ func TestE2EProgressTracking(t *testing.T) { progress := NewProgress(len(ips), false) scanner := NewScanner(1, 2*time.Second, mock.port, progress, "") - results := scanner.Run(context.Background(), IPsFromList(ips)) + results := scanner.Run(context.Background(), sliceToChannel(ips)) for range results { } diff --git a/ipgen.go b/ipgen.go index 910bce2..e398b14 100644 --- a/ipgen.go +++ b/ipgen.go @@ -87,19 +87,6 @@ func expandCIDRWithMode(cidr string, mode string) <-chan string { return out } -// IPsFromList converts a slice of IPs to a channel -func IPsFromList(ips []string) <-chan string { - out := make(chan string, len(ips)) - - go func() { - defer close(out) - for _, ip := range ips { - out <- ip - } - }() - - return out -} // CountIPsWithMode estimates total IPs based on ranges and mode func CountIPsWithMode(ranges []string, mode string) int { diff --git a/ipgen_test.go b/ipgen_test.go index a298c08..d2e0ff6 100644 --- a/ipgen_test.go +++ b/ipgen_test.go @@ -97,20 +97,6 @@ func TestCountIPsWithMode(t *testing.T) { } } -func TestIPsFromList(t *testing.T) { - input := []string{"1.1.1.1", "8.8.8.8", "9.9.9.9"} - ips := IPsFromList(input) - - var result []string - for ip := range ips { - result = append(result, ip) - } - - if len(result) != len(input) { - t.Errorf("Expected %d IPs, got %d", len(input), len(result)) - } -} - func TestInvalidCIDR(t *testing.T) { ranges := []string{"invalid-cidr"} ips := ExpandRangesWithMode(ranges, "fast") diff --git a/main.go b/main.go index c87464a..f8ceee4 100644 --- a/main.go +++ b/main.go @@ -1,37 +1,18 @@ package main import ( - "bufio" "context" "fmt" + "io" "os" "os/signal" "sort" - "strings" "syscall" "time" ) var version = "dev" -func readIPsFromFile(path string) ([]string, error) { - f, err := os.Open(path) - if err != nil { - return nil, err - } - defer f.Close() - - var ips []string - scanner := bufio.NewScanner(f) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line != "" && !strings.HasPrefix(line, "#") { - ips = append(ips, line) - } - } - return ips, scanner.Err() -} - func main() { cfg := ParseFlags() @@ -45,47 +26,15 @@ func main() { os.Exit(1) } - DataDir = cfg.DataDir - - var outFile *os.File - if cfg.OutputFile != "" { - f, err := os.Create(cfg.OutputFile) - if err != nil { - fmt.Fprintf(os.Stderr, "Failed to create output file: %v\n", err) - os.Exit(1) - } - defer f.Close() - outFile = f + app, err := NewApp(cfg) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) } + defer app.Close() - var ips <-chan string - var totalIPs int - - if cfg.InputFile != "" { - fileIPs, err := readIPsFromFile(cfg.InputFile) - if err != nil { - fmt.Fprintf(os.Stderr, "Failed to read file: %v\n", err) - os.Exit(1) - } - ips = IPsFromList(fileIPs) - totalIPs = len(fileIPs) - } else if cfg.Mode == "list" { - dnsList, err := LoadDNSList(cfg.Country) - if err != nil { - fmt.Fprintf(os.Stderr, "Failed to load DNS list for %s: %v\n", cfg.Country, err) - os.Exit(1) - } - ips = IPsFromList(dnsList) - totalIPs = len(dnsList) - } else { - ranges, err := LoadRanges(cfg.Country) - if err != nil { - fmt.Fprintf(os.Stderr, "Failed to load ranges for %s: %v\n", cfg.Country, err) - os.Exit(1) - } - totalIPs = CountIPsWithMode(ranges, cfg.Mode) - ips = ExpandRangesWithMode(ranges, cfg.Mode) - } + ips := app.source.IPs() + totalIPs := app.source.Count() PrintBanner(os.Stderr, cfg, totalIPs, version) @@ -359,16 +308,15 @@ resultLoop: stats.Mode = cfg.Mode } - // Write output - out := os.Stdout - if outFile != nil { - out = outFile + var out io.Writer = os.Stdout + if app.outFile != nil { + out = app.outFile } if cfg.JSONOutput { NewJSONWriter(out).Write(serverResults, stats) } else { - if outFile != nil || !cfg.Progress { + if app.outFile != nil || !cfg.Progress { NewTextWriter(out).Write(serverResults, stats) } if cfg.Progress { diff --git a/main_test.go b/main_test.go index 8a503af..9ffd3f2 100644 --- a/main_test.go +++ b/main_test.go @@ -81,8 +81,8 @@ func TestFileInputNotFound(t *testing.T) { t.Error("expected error for missing file") } - if !strings.Contains(string(out), "Failed to read file") { - t.Errorf("expected 'Failed to read file' error, got: %s", out) + if !strings.Contains(string(out), "failed to open") { + t.Errorf("expected 'failed to open' error, got: %s", out) } } @@ -165,7 +165,7 @@ func TestProgressFlagDisabled(t *testing.T) { } } -func TestReadIPsFromFile(t *testing.T) { +func TestLoadLines(t *testing.T) { // Create temp file f, _ := os.CreateTemp("", "ips-*.txt") f.WriteString("# comment line\n") @@ -175,9 +175,9 @@ func TestReadIPsFromFile(t *testing.T) { f.Close() defer os.Remove(f.Name()) - ips, err := readIPsFromFile(f.Name()) + ips, err := loadLines(f.Name()) if err != nil { - t.Fatalf("readIPsFromFile failed: %v", err) + t.Fatalf("loadLines failed: %v", err) } if len(ips) != 2 { diff --git a/ranges.go b/ranges.go index 740638e..c256445 100644 --- a/ranges.go +++ b/ranges.go @@ -10,15 +10,10 @@ import ( "strings" ) -// DataDir is the directory containing ranges and dns files -var DataDir = "data" - const ipDenyURL = "https://www.ipdeny.com/ipblocks/data/aggregated/%s-aggregated.zone" -// LoadRanges loads IP ranges from data/ranges/.zone -// Auto-downloads from ipdeny.com if not found locally -func LoadRanges(country string) ([]string, error) { - path := filepath.Join(DataDir, "ranges", country+".zone") +func LoadRanges(dataDir, country string) ([]string, error) { + path := filepath.Join(dataDir, "ranges", country+".zone") // Check if file exists, download if not if _, err := os.Stat(path); os.IsNotExist(err) { @@ -60,9 +55,8 @@ func downloadRanges(country, destPath string) error { return err } -// LoadDNSList loads known DNS servers from data/dns/.txt -func LoadDNSList(country string) ([]string, error) { - path := filepath.Join(DataDir, "dns", country+".txt") +func LoadDNSList(dataDir, country string) ([]string, error) { + path := filepath.Join(dataDir, "dns", country+".txt") return loadLines(path) } @@ -88,9 +82,8 @@ func loadLines(path string) ([]string, error) { return lines, nil } -// AvailableCountries returns list of countries with range files -func AvailableCountries() []string { - rangesDir := filepath.Join(DataDir, "ranges") +func AvailableCountries(dataDir string) []string { + rangesDir := filepath.Join(dataDir, "ranges") entries, err := os.ReadDir(rangesDir) if err != nil { return nil diff --git a/ranges_test.go b/ranges_test.go index 80072fb..fe93767 100644 --- a/ranges_test.go +++ b/ranges_test.go @@ -7,7 +7,7 @@ import ( ) func TestLoadRanges(t *testing.T) { - ranges, err := LoadRanges("ir") + ranges, err := LoadRanges("data", "ir") if err != nil { t.Fatalf("LoadRanges failed: %v", err) } @@ -21,7 +21,7 @@ func TestLoadRanges(t *testing.T) { } func TestLoadDNSList(t *testing.T) { - dns, err := LoadDNSList("ir") + dns, err := LoadDNSList("data", "ir") if err != nil { t.Fatalf("LoadDNSList failed: %v", err) } @@ -42,14 +42,14 @@ func TestLoadDNSList(t *testing.T) { } func TestLoadRangesInvalidCountry(t *testing.T) { - _, err := LoadRanges("xx") + _, err := LoadRanges("data", "xx") if err == nil { t.Error("Expected error for invalid country") } } func TestAvailableCountries(t *testing.T) { - countries := AvailableCountries() + countries := AvailableCountries("data") if len(countries) == 0 { t.Error("Expected at least one country") } diff --git a/source.go b/source.go new file mode 100644 index 0000000..a2f9c35 --- /dev/null +++ b/source.go @@ -0,0 +1,82 @@ +package main + +type IPSource interface { + IPs() <-chan string + Count() int +} + +type FileSource struct { + path string + ips []string +} + +func NewFileSource(path string) (*FileSource, error) { + ips, err := loadLines(path) + if err != nil { + return nil, err + } + return &FileSource{path: path, ips: ips}, nil +} + +func (s *FileSource) IPs() <-chan string { + return sliceToChannel(s.ips) +} + +func (s *FileSource) Count() int { + return len(s.ips) +} + +type DNSListSource struct { + country string + ips []string +} + +func NewDNSListSource(dataDir, country string) (*DNSListSource, error) { + ips, err := LoadDNSList(dataDir, country) + if err != nil { + return nil, err + } + return &DNSListSource{country: country, ips: ips}, nil +} + +func (s *DNSListSource) IPs() <-chan string { + return sliceToChannel(s.ips) +} + +func (s *DNSListSource) Count() int { + return len(s.ips) +} + +type RangeSource struct { + country string + mode string + ranges []string +} + +func NewRangeSource(dataDir, country, mode string) (*RangeSource, error) { + ranges, err := LoadRanges(dataDir, country) + if err != nil { + return nil, err + } + return &RangeSource{country: country, mode: mode, ranges: ranges}, nil +} + +func (s *RangeSource) IPs() <-chan string { + return ExpandRangesWithMode(s.ranges, s.mode) +} + +func (s *RangeSource) Count() int { + return CountIPsWithMode(s.ranges, s.mode) +} + +func sliceToChannel(ips []string) <-chan string { + ch := make(chan string, len(ips)) + go func() { + defer close(ch) + for _, ip := range ips { + ch <- ip + } + }() + return ch +} + From 0866ee8df0bc6f8c76357c7a665f94ca7c0d7fcc Mon Sep 17 00:00:00 2001 From: Night Owl Nerd <256460992+nightowlnerd@users.noreply.github.com> Date: Mon, 2 Feb 2026 10:28:01 +0100 Subject: [PATCH 06/11] refactor(geodata): rename ranges.go, decouple functions, consistent CIDR naming --- app.go | 2 +- ranges.go => geodata.go | 92 ++++++++++++++++++------------- ranges_test.go => geodata_test.go | 22 +++++--- ipgen.go | 15 ++--- ipgen_test.go | 38 ++++++------- source.go | 25 +++++---- 6 files changed, 106 insertions(+), 88 deletions(-) rename ranges.go => geodata.go (56%) rename ranges_test.go => geodata_test.go (73%) diff --git a/app.go b/app.go index d2291ca..a183c6c 100644 --- a/app.go +++ b/app.go @@ -41,7 +41,7 @@ func newIPSource(cfg *Config) (IPSource, error) { if cfg.Mode == "list" { return NewDNSListSource(cfg.DataDir, cfg.Country) } - return NewRangeSource(cfg.DataDir, cfg.Country, cfg.Mode) + return NewCIDRSource(cfg.DataDir, cfg.Country, cfg.Mode) } func (a *App) Close() { diff --git a/ranges.go b/geodata.go similarity index 56% rename from ranges.go rename to geodata.go index c256445..c6eb6b2 100644 --- a/ranges.go +++ b/geodata.go @@ -12,55 +12,86 @@ import ( const ipDenyURL = "https://www.ipdeny.com/ipblocks/data/aggregated/%s-aggregated.zone" -func LoadRanges(dataDir, country string) ([]string, error) { +func CIDRBlocksExist(dataDir, country string) bool { path := filepath.Join(dataDir, "ranges", country+".zone") + _, err := os.Stat(path) + return err == nil +} - // Check if file exists, download if not - if _, err := os.Stat(path); os.IsNotExist(err) { - fmt.Fprintf(os.Stderr, "Downloading IP ranges for %s...\n", country) - if err := downloadRanges(country, path); err != nil { - return nil, fmt.Errorf("failed to download ranges for %s: %w", country, err) - } +func DownloadCIDRBlocks(dataDir, country string) error { + fmt.Fprintf(os.Stderr, "Downloading IP ranges for %s...\n", country) + + data, err := fetchRanges(country) + if err != nil { + return fmt.Errorf("failed to download ranges for %s: %w", country, err) + } + defer data.Close() + + path := filepath.Join(dataDir, "ranges", country+".zone") + if err := saveToFile(path, data); err != nil { + return fmt.Errorf("failed to save ranges for %s: %w", country, err) } + return nil +} +func LoadCIDRBlocks(dataDir, country string) ([]string, error) { + path := filepath.Join(dataDir, "ranges", country+".zone") return loadLines(path) } -// downloadRanges fetches IP ranges from ipdeny.com -func downloadRanges(country, destPath string) error { +func LoadKnownDNS(dataDir, country string) ([]string, error) { + path := filepath.Join(dataDir, "dns", country+".txt") + return loadLines(path) +} + +func AvailableCountries(dataDir string) []string { + rangesDir := filepath.Join(dataDir, "ranges") + entries, err := os.ReadDir(rangesDir) + if err != nil { + return nil + } + + var countries []string + for _, e := range entries { + if !e.IsDir() && strings.HasSuffix(e.Name(), ".zone") { + country := strings.TrimSuffix(e.Name(), ".zone") + countries = append(countries, country) + } + } + return countries +} + +func fetchRanges(country string) (io.ReadCloser, error) { url := fmt.Sprintf(ipDenyURL, country) resp, err := http.Get(url) if err != nil { - return err + return nil, err } - defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return fmt.Errorf("HTTP %d - country '%s' may not exist", resp.StatusCode, country) + resp.Body.Close() + return nil, fmt.Errorf("HTTP %d - country '%s' may not exist", resp.StatusCode, country) } - // Ensure directory exists - if err := os.MkdirAll(filepath.Dir(destPath), 0755); err != nil { + return resp.Body, nil +} + +func saveToFile(path string, r io.Reader) error { + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { return err } - f, err := os.Create(destPath) + f, err := os.Create(path) if err != nil { return err } defer f.Close() - _, err = io.Copy(f, resp.Body) + _, err = io.Copy(f, r) return err } -func LoadDNSList(dataDir, country string) ([]string, error) { - path := filepath.Join(dataDir, "dns", country+".txt") - return loadLines(path) -} - -// loadLines reads non-empty, non-comment lines from a file func loadLines(path string) ([]string, error) { f, err := os.Open(path) if err != nil { @@ -81,20 +112,3 @@ func loadLines(path string) ([]string, error) { } return lines, nil } - -func AvailableCountries(dataDir string) []string { - rangesDir := filepath.Join(dataDir, "ranges") - entries, err := os.ReadDir(rangesDir) - if err != nil { - return nil - } - - var countries []string - for _, e := range entries { - if !e.IsDir() && strings.HasSuffix(e.Name(), ".zone") { - country := strings.TrimSuffix(e.Name(), ".zone") - countries = append(countries, country) - } - } - return countries -} diff --git a/ranges_test.go b/geodata_test.go similarity index 73% rename from ranges_test.go rename to geodata_test.go index fe93767..cf2b444 100644 --- a/ranges_test.go +++ b/geodata_test.go @@ -6,24 +6,28 @@ import ( "testing" ) -func TestLoadRanges(t *testing.T) { - ranges, err := LoadRanges("data", "ir") +func TestLoadCIDRBlocks(t *testing.T) { + if !CIDRBlocksExist("data", "ir") { + if err := DownloadCIDRBlocks("data", "ir"); err != nil { + t.Fatalf("DownloadCIDRBlocks failed: %v", err) + } + } + ranges, err := LoadCIDRBlocks("data", "ir") if err != nil { - t.Fatalf("LoadRanges failed: %v", err) + t.Fatalf("LoadCIDRBlocks failed: %v", err) } if len(ranges) == 0 { t.Error("Expected non-empty ranges") } - // Check first range is valid CIDR if ranges[0] == "" { t.Error("First range is empty") } } -func TestLoadDNSList(t *testing.T) { - dns, err := LoadDNSList("data", "ir") +func TestLoadKnownDNS(t *testing.T) { + dns, err := LoadKnownDNS("data", "ir") if err != nil { - t.Fatalf("LoadDNSList failed: %v", err) + t.Fatalf("LoadKnownDNS failed: %v", err) } if len(dns) == 0 { t.Error("Expected non-empty DNS list") @@ -41,8 +45,8 @@ func TestLoadDNSList(t *testing.T) { } } -func TestLoadRangesInvalidCountry(t *testing.T) { - _, err := LoadRanges("data", "xx") +func TestDownloadCIDRBlocksInvalidCountry(t *testing.T) { + err := DownloadCIDRBlocks("data", "xx") if err == nil { t.Error("Expected error for invalid country") } diff --git a/ipgen.go b/ipgen.go index e398b14..68e5be6 100644 --- a/ipgen.go +++ b/ipgen.go @@ -16,14 +16,13 @@ var ( // all: every usable IP (generated dynamically) ) -// ExpandRangesWithMode expands CIDR ranges using the specified sampling mode -func ExpandRangesWithMode(ranges []string, mode string) <-chan string { +func ExpandCIDR(blocks []string, mode string) <-chan string { out := make(chan string, 10000) go func() { defer close(out) - for _, cidr := range ranges { - for ip := range expandCIDRWithMode(cidr, mode) { + for _, block := range blocks { + for ip := range expandBlock(block, mode) { out <- ip } } @@ -32,8 +31,7 @@ func ExpandRangesWithMode(ranges []string, mode string) <-chan string { return out } -// expandCIDRWithMode generates IPs from a CIDR range based on mode -func expandCIDRWithMode(cidr string, mode string) <-chan string { +func expandBlock(cidr string, mode string) <-chan string { out := make(chan string, 1000) go func() { @@ -88,8 +86,7 @@ func expandCIDRWithMode(cidr string, mode string) <-chan string { } -// CountIPsWithMode estimates total IPs based on ranges and mode -func CountIPsWithMode(ranges []string, mode string) int { +func CountCIDRIPs(blocks []string, mode string) int { var octetsPerSubnet int switch mode { case "fast": @@ -103,7 +100,7 @@ func CountIPsWithMode(ranges []string, mode string) int { } total := 0 - for _, cidr := range ranges { + for _, cidr := range blocks { _, ipnet, err := net.ParseCIDR(cidr) if err != nil { continue diff --git a/ipgen_test.go b/ipgen_test.go index d2e0ff6..1ac5133 100644 --- a/ipgen_test.go +++ b/ipgen_test.go @@ -4,9 +4,9 @@ import ( "testing" ) -func TestExpandRangesWithModeFast(t *testing.T) { - ranges := []string{"192.168.1.0/24"} - ips := ExpandRangesWithMode(ranges, "fast") +func TestExpandCIDRFast(t *testing.T) { + blocks := []string{"192.168.1.0/24"} + ips := ExpandCIDR(blocks, "fast") var result []string for ip := range ips { @@ -17,7 +17,6 @@ func TestExpandRangesWithModeFast(t *testing.T) { t.Errorf("Fast mode: expected 3 IPs, got %d: %v", len(result), result) } - // Check expected IPs are present expected := map[string]bool{ "192.168.1.1": false, "192.168.1.53": false, @@ -33,9 +32,9 @@ func TestExpandRangesWithModeFast(t *testing.T) { } } -func TestExpandRangesWithModeMedium(t *testing.T) { - ranges := []string{"192.168.1.0/24"} - ips := ExpandRangesWithMode(ranges, "medium") +func TestExpandCIDRMedium(t *testing.T) { + blocks := []string{"192.168.1.0/24"} + ips := ExpandCIDR(blocks, "medium") var count int for range ips { @@ -47,9 +46,9 @@ func TestExpandRangesWithModeMedium(t *testing.T) { } } -func TestExpandRangesWithModeAll(t *testing.T) { - ranges := []string{"192.168.1.0/24"} - ips := ExpandRangesWithMode(ranges, "all") +func TestExpandCIDRAll(t *testing.T) { + blocks := []string{"192.168.1.0/24"} + ips := ExpandCIDR(blocks, "all") var count int for range ips { @@ -61,11 +60,10 @@ func TestExpandRangesWithModeAll(t *testing.T) { } } -func TestExpandRangesWithMode16(t *testing.T) { - ranges := []string{"10.0.0.0/16"} +func TestExpandCIDR16(t *testing.T) { + blocks := []string{"10.0.0.0/16"} - // Fast mode on /16 should give 3 IPs * 256 subnets = 768 - ips := ExpandRangesWithMode(ranges, "fast") + ips := ExpandCIDR(blocks, "fast") var count int for range ips { count++ @@ -77,8 +75,8 @@ func TestExpandRangesWithMode16(t *testing.T) { } } -func TestCountIPsWithMode(t *testing.T) { - ranges := []string{"192.168.1.0/24"} +func TestCountCIDRIPs(t *testing.T) { + blocks := []string{"192.168.1.0/24"} tests := []struct { mode string @@ -90,16 +88,16 @@ func TestCountIPsWithMode(t *testing.T) { } for _, tt := range tests { - count := CountIPsWithMode(ranges, tt.mode) + count := CountCIDRIPs(blocks, tt.mode) if count != tt.expected { - t.Errorf("CountIPsWithMode(%s): expected %d, got %d", tt.mode, tt.expected, count) + t.Errorf("CountCIDRIPs(%s): expected %d, got %d", tt.mode, tt.expected, count) } } } func TestInvalidCIDR(t *testing.T) { - ranges := []string{"invalid-cidr"} - ips := ExpandRangesWithMode(ranges, "fast") + blocks := []string{"invalid-cidr"} + ips := ExpandCIDR(blocks, "fast") var count int for range ips { diff --git a/source.go b/source.go index a2f9c35..dc3f86f 100644 --- a/source.go +++ b/source.go @@ -32,7 +32,7 @@ type DNSListSource struct { } func NewDNSListSource(dataDir, country string) (*DNSListSource, error) { - ips, err := LoadDNSList(dataDir, country) + ips, err := LoadKnownDNS(dataDir, country) if err != nil { return nil, err } @@ -47,26 +47,31 @@ func (s *DNSListSource) Count() int { return len(s.ips) } -type RangeSource struct { +type CIDRSource struct { country string mode string - ranges []string + blocks []string } -func NewRangeSource(dataDir, country, mode string) (*RangeSource, error) { - ranges, err := LoadRanges(dataDir, country) +func NewCIDRSource(dataDir, country, mode string) (*CIDRSource, error) { + if !CIDRBlocksExist(dataDir, country) { + if err := DownloadCIDRBlocks(dataDir, country); err != nil { + return nil, err + } + } + blocks, err := LoadCIDRBlocks(dataDir, country) if err != nil { return nil, err } - return &RangeSource{country: country, mode: mode, ranges: ranges}, nil + return &CIDRSource{country: country, mode: mode, blocks: blocks}, nil } -func (s *RangeSource) IPs() <-chan string { - return ExpandRangesWithMode(s.ranges, s.mode) +func (s *CIDRSource) IPs() <-chan string { + return ExpandCIDR(s.blocks, s.mode) } -func (s *RangeSource) Count() int { - return CountIPsWithMode(s.ranges, s.mode) +func (s *CIDRSource) Count() int { + return CountCIDRIPs(s.blocks, s.mode) } func sliceToChannel(ips []string) <-chan string { From 59064d77a2a042e57425995bf24adead90a5656b Mon Sep 17 00:00:00 2001 From: Night Owl Nerd <256460992+nightowlnerd@users.noreply.github.com> Date: Mon, 2 Feb 2026 11:25:14 +0100 Subject: [PATCH 07/11] refactor: extract benchmark.go and progress.go from scanner.go --- benchmark.go | 193 +++++++++++++++++++++++++++++ e2e_test.go | 56 ++++----- main.go | 82 ++++++------- progress.go | 48 ++++++++ scanner.go | 313 +++++------------------------------------------- scanner_test.go | 50 ++++---- 6 files changed, 364 insertions(+), 378 deletions(-) create mode 100644 benchmark.go create mode 100644 progress.go diff --git a/benchmark.go b/benchmark.go new file mode 100644 index 0000000..2f10414 --- /dev/null +++ b/benchmark.go @@ -0,0 +1,193 @@ +package main + +import ( + "context" + "crypto/rand" + "encoding/base32" + "fmt" + "sort" + "sync" + "time" + + "github.com/miekg/dns" +) + +const ( + BenchmarkQueries = 20 + BenchmarkConcurrency = 5 + BenchmarkThreshold = 70 + BenchmarkSubdomainLen = 32 + EDNSBufferSize = 1232 // matches slipstream UDP payload +) + +type BenchmarkResult struct { + IP string + Queries int + Successful int + Failed int + Latencies []time.Duration + Duration time.Duration +} + +func (r *BenchmarkResult) SuccessRate() float64 { + if r.Queries == 0 { + return 0 + } + return float64(r.Successful) / float64(r.Queries) * 100 +} + +func (r *BenchmarkResult) QPS() float64 { + if r.Duration == 0 { + return 0 + } + return float64(r.Successful) / r.Duration.Seconds() +} + +func (r *BenchmarkResult) P50() time.Duration { + return r.percentile(50) +} + +func (r *BenchmarkResult) percentile(p int) time.Duration { + if len(r.Latencies) == 0 { + return 0 + } + sorted := make([]time.Duration, len(r.Latencies)) + copy(sorted, r.Latencies) + sort.Slice(sorted, func(i, j int) bool { return sorted[i] < sorted[j] }) + idx := len(sorted) * p / 100 + if idx >= len(sorted) { + idx = len(sorted) - 1 + } + return sorted[idx] +} + +func (r *BenchmarkResult) Passed() bool { + return r.SuccessRate() >= BenchmarkThreshold +} + +func randomBenchmarkSubdomain() string { + b := make([]byte, BenchmarkSubdomainLen) + rand.Read(b) + return base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(b) +} + +// Benchmark runs concurrent DNS queries to test server reliability under load +func Benchmark(ctx context.Context, ip, domain string, port int, timeout time.Duration) *BenchmarkResult { + if port == 0 { + port = 53 + } + addr := fmt.Sprintf("%s:%d", ip, port) + + result := &BenchmarkResult{ + IP: ip, + Queries: BenchmarkQueries, + } + + client := &dns.Client{ + Net: "udp", + Timeout: timeout, + } + + var mu sync.Mutex + var wg sync.WaitGroup + sem := make(chan struct{}, BenchmarkConcurrency) + + start := time.Now() + + for i := 0; i < BenchmarkQueries; i++ { + select { + case <-ctx.Done(): + result.Duration = time.Since(start) + return result + default: + } + + wg.Add(1) + sem <- struct{}{} + + go func() { + defer wg.Done() + defer func() { <-sem }() + + select { + case <-ctx.Done(): + mu.Lock() + result.Failed++ + mu.Unlock() + return + default: + } + + subdomain := randomBenchmarkSubdomain() + m := new(dns.Msg) + m.SetQuestion(dns.Fqdn(subdomain+"."+domain), dns.TypeTXT) + m.RecursionDesired = true + m.SetEdns0(EDNSBufferSize, false) + + _, rtt, err := client.Exchange(m, addr) + + mu.Lock() + if err != nil { + result.Failed++ + } else { + result.Successful++ + result.Latencies = append(result.Latencies, rtt) + } + mu.Unlock() + }() + } + + wg.Wait() + result.Duration = time.Since(start) + return result +} + +// BenchmarkParallel runs benchmarks on multiple IPs concurrently +func BenchmarkParallel(ctx context.Context, ips []string, domain string, port int, + timeout time.Duration, workers int) <-chan *BenchmarkResult { + + results := make(chan *BenchmarkResult, workers) + ipChan := make(chan string, len(ips)) + + go func() { + defer close(ipChan) + for _, ip := range ips { + select { + case ipChan <- ip: + case <-ctx.Done(): + return + } + } + }() + + var wg sync.WaitGroup + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-ctx.Done(): + return + case ip, ok := <-ipChan: + if !ok { + return + } + result := Benchmark(ctx, ip, domain, port, timeout) + select { + case results <- result: + case <-ctx.Done(): + return + } + } + } + }() + } + + go func() { + wg.Wait() + close(results) + }() + + return results +} diff --git a/e2e_test.go b/e2e_test.go index 15a34c3..e6a990a 100644 --- a/e2e_test.go +++ b/e2e_test.go @@ -129,17 +129,17 @@ func TestE2EDomainVerification(t *testing.T) { } } -func TestE2EBurstTest(t *testing.T) { +func TestE2EBenchmark(t *testing.T) { mock, err := newMockDNSServer("93.184.216.34") if err != nil { t.Fatalf("Failed to start mock DNS: %v", err) } defer mock.Close() - result := BurstTest(context.Background(), mock.ip, "test.example.com", mock.port, 2*time.Second) + result := Benchmark(context.Background(), mock.ip, "test.example.com", mock.port, 2*time.Second) - if result.Queries != BurstQueries { - t.Errorf("Expected %d queries, got %d", BurstQueries, result.Queries) + if result.Queries != BenchmarkQueries { + t.Errorf("Expected %d queries, got %d", BenchmarkQueries, result.Queries) } if result.SuccessRate() < 90 { t.Errorf("Expected high success rate, got %.1f%%", result.SuccessRate()) @@ -230,20 +230,20 @@ func TestE2EProgressTracking(t *testing.T) { for range results { } - scanned, found, total, _ := progress.Stats() - if scanned != 3 || found != 3 || total != 3 { - t.Errorf("Progress mismatch: scanned=%d found=%d total=%d", scanned, found, total) + stats := progress.Stats() + if stats.Processed != 3 || stats.Success != 3 || stats.Total != 3 { + t.Errorf("Progress mismatch: processed=%d success=%d total=%d", stats.Processed, stats.Success, stats.Total) } } -func TestE2EBurstQPS(t *testing.T) { +func TestE2EBenchmarkQPS(t *testing.T) { mock, err := newMockDNSServer("93.184.216.34") if err != nil { t.Fatalf("Failed to start mock DNS: %v", err) } defer mock.Close() - result := BurstTest(context.Background(), mock.ip, "test.example.com", mock.port, 2*time.Second) + result := Benchmark(context.Background(), mock.ip, "test.example.com", mock.port, 2*time.Second) if result.QPS() <= 0 { t.Errorf("QPS should be positive, got %.2f", result.QPS()) @@ -253,7 +253,7 @@ func TestE2EBurstQPS(t *testing.T) { } } -func TestE2EParallelBurstTest(t *testing.T) { +func TestE2EBenchmarkParallel(t *testing.T) { // Start 3 mock servers var mocks []*mockDNSServer var ips []string @@ -275,7 +275,7 @@ func TestE2EParallelBurstTest(t *testing.T) { defer cancel() // All mocks use same port pattern, so we use first mock's port - resultChan := ParallelBurstTest(ctx, ips, "test.example.com", mocks[0].port, 2*time.Second, 3) + resultChan := BenchmarkParallel(ctx, ips, "test.example.com", mocks[0].port, 2*time.Second, 3) var count, passed int for r := range resultChan { @@ -293,7 +293,7 @@ func TestE2EParallelBurstTest(t *testing.T) { } } -func TestE2EBurstTestContextCancellation(t *testing.T) { +func TestE2EBenchmarkContextCancellation(t *testing.T) { mock, err := newMockDNSServer("93.184.216.34") if err != nil { t.Fatalf("Failed to start mock DNS: %v", err) @@ -304,41 +304,41 @@ func TestE2EBurstTestContextCancellation(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - result := BurstTest(ctx, mock.ip, "test.example.com", mock.port, 2*time.Second) + result := Benchmark(ctx, mock.ip, "test.example.com", mock.port, 2*time.Second) // Should return early with partial/no results - if result.Successful == BurstQueries { + if result.Successful == BenchmarkQueries { t.Error("Expected early termination with cancelled context") } } -func TestBurstProgress(t *testing.T) { - prog := NewBurstProgress(10, true) +func TestProgressUnified(t *testing.T) { + prog := NewProgress(10, true) - prog.Tested() - prog.Tested() - prog.Passed() + prog.Increment() + prog.Increment() + prog.Success() - tested, passed, total := prog.Stats() - if tested != 2 { - t.Errorf("Expected 2 tested, got %d", tested) + stats := prog.Stats() + if stats.Processed != 2 { + t.Errorf("Expected 2 processed, got %d", stats.Processed) } - if passed != 1 { - t.Errorf("Expected 1 passed, got %d", passed) + if stats.Success != 1 { + t.Errorf("Expected 1 success, got %d", stats.Success) } - if total != 10 { - t.Errorf("Expected total 10, got %d", total) + if stats.Total != 10 { + t.Errorf("Expected total 10, got %d", stats.Total) } } -func TestE2EBurstResultMetrics(t *testing.T) { +func TestE2EBenchmarkResultMetrics(t *testing.T) { mock, err := newMockDNSServer("93.184.216.34") if err != nil { t.Fatalf("Failed to start mock DNS: %v", err) } defer mock.Close() - result := BurstTest(context.Background(), mock.ip, "test.example.com", mock.port, 2*time.Second) + result := Benchmark(context.Background(), mock.ip, "test.example.com", mock.port, 2*time.Second) if result.IP != mock.ip { t.Errorf("IP mismatch: got %s, expected %s", result.IP, mock.ip) diff --git a/main.go b/main.go index f8ceee4..a6dada2 100644 --- a/main.go +++ b/main.go @@ -66,11 +66,11 @@ func main() { for { select { case <-ticker.C: - scanned, found, total, elapsed := prog.Stats() - rate := float64(scanned) / elapsed.Seconds() - pct := float64(scanned) / float64(total) * 100 + stats := prog.Stats() + rate := float64(stats.Processed) / stats.Elapsed.Seconds() + pct := float64(stats.Processed) / float64(stats.Total) * 100 fmt.Fprintf(os.Stderr, "\rScanned: %d/%d (%.1f%%) | Found: %d | %.0f IPs/sec ", - scanned, total, pct, found, rate) + stats.Processed, stats.Total, pct, stats.Success, rate) case <-ctx.Done(): return case <-progressDone: @@ -111,10 +111,10 @@ resultLoop: // Print final stats if cfg.Progress { - scanned, found, _, elapsed := prog.Stats() + stats := prog.Stats() fmt.Fprintf(os.Stderr, "\r \r") - fmt.Fprintf(os.Stderr, "Completed: %d IPs in %v\n", scanned, elapsed.Round(time.Millisecond)) - fmt.Fprintf(os.Stderr, "Found: %d DNS candidates\n", found) + fmt.Fprintf(os.Stderr, "Completed: %d IPs in %v\n", stats.Processed, stats.Elapsed.Round(time.Millisecond)) + fmt.Fprintf(os.Stderr, "Found: %d DNS candidates\n", stats.Success) if suspiciousCount > 0 { fmt.Fprintf(os.Stderr, "\033[33mWarning: %d servers returned private IPs (possible DNS hijacking)\033[0m\n", suspiciousCount) } @@ -166,16 +166,16 @@ resultLoop: workingDNS = verified } - // Phase 3: Burst test to verify servers handle concurrent load - var burstResults []*BurstResult + // Phase 3: Benchmark to verify servers handle concurrent load + var benchResults []*BenchmarkResult if cfg.Domain != "" && len(workingDNS) > 0 { total := len(workingDNS) if total <= 5 { // Sequential for small lists - nicer per-IP output if cfg.Progress { - fmt.Fprintf(os.Stderr, "\nBurst testing %d candidates (%d queries, %d%% required)...\n", - total, BurstQueries, BurstMinSuccess) + fmt.Fprintf(os.Stderr, "\nBenchmarking %d candidates (%d queries, %d%% required)...\n", + total, BenchmarkQueries, BenchmarkThreshold) } width := len(fmt.Sprintf("%d", total)) @@ -183,9 +183,9 @@ resultLoop: select { case <-ctx.Done(): if cfg.Progress { - fmt.Fprintf(os.Stderr, "\nInterrupted during burst test\n") + fmt.Fprintf(os.Stderr, "\nInterrupted during benchmark\n") } - goto burstDone + goto benchDone default: } @@ -193,10 +193,10 @@ resultLoop: fmt.Fprintf(os.Stderr, "[%*d/%d] %-15s ", width, i+1, total, ip) } - result := BurstTest(ctx, ip, cfg.Domain, 53, cfg.Timeout) + result := Benchmark(ctx, ip, cfg.Domain, 53, cfg.Timeout) if result.Passed() { - burstResults = append(burstResults, result) + benchResults = append(benchResults, result) if cfg.Progress { color := "\033[33m" if result.SuccessRate() >= 85 { @@ -213,13 +213,13 @@ resultLoop: } } else { // Parallel for larger lists - burstWorkers := min(total, 10) + benchWorkers := min(total, 10) if cfg.Progress { - fmt.Fprintf(os.Stderr, "\nBurst testing %d candidates in parallel (%d workers)...\n", - total, burstWorkers) + fmt.Fprintf(os.Stderr, "\nBenchmarking %d candidates in parallel (%d workers)...\n", + total, benchWorkers) } - burstProg := NewBurstProgress(total, cfg.Progress) + benchProg := NewProgress(total, cfg.Progress) var progressDone chan struct{} if cfg.Progress { @@ -230,8 +230,8 @@ resultLoop: for { select { case <-ticker.C: - tested, passed, tot := burstProg.Stats() - fmt.Fprintf(os.Stderr, "\rBurst testing: %d/%d tested, %d passed ", tested, tot, passed) + stats := benchProg.Stats() + fmt.Fprintf(os.Stderr, "\rBenchmarking: %d/%d tested, %d passed ", stats.Processed, stats.Total, stats.Success) case <-ctx.Done(): return case <-progressDone: @@ -241,12 +241,12 @@ resultLoop: }() } - resultChan := ParallelBurstTest(ctx, workingDNS, cfg.Domain, 53, cfg.Timeout, burstWorkers) + resultChan := BenchmarkParallel(ctx, workingDNS, cfg.Domain, 53, cfg.Timeout, benchWorkers) for result := range resultChan { - burstProg.Tested() + benchProg.Increment() if result.Passed() { - burstProg.Passed() - burstResults = append(burstResults, result) + benchProg.Success() + benchResults = append(benchResults, result) } } @@ -255,16 +255,16 @@ resultLoop: fmt.Fprintf(os.Stderr, "\r \r") } } - burstDone: + benchDone: - sort.Slice(burstResults, func(i, j int) bool { - return burstResults[i].QPS() > burstResults[j].QPS() + sort.Slice(benchResults, func(i, j int) bool { + return benchResults[i].QPS() > benchResults[j].QPS() }) if cfg.Progress { fmt.Fprintf(os.Stderr, "---\n") - fmt.Fprintf(os.Stderr, "Burst test: %d/%d passed (sorted by throughput)\n", len(burstResults), len(workingDNS)) - for _, r := range burstResults { + fmt.Fprintf(os.Stderr, "Benchmark: %d/%d passed (sorted by throughput)\n", len(benchResults), len(workingDNS)) + for _, r := range benchResults { color := "\033[33m" if r.SuccessRate() >= 85 { color = "\033[32m" @@ -275,16 +275,16 @@ resultLoop: } workingDNS = nil - for _, r := range burstResults { + for _, r := range benchResults { workingDNS = append(workingDNS, r.IP) } } - scanned, _, _, elapsed := prog.Stats() + finalStats := prog.Stats() var serverResults []ServerResult - if len(burstResults) > 0 { - for _, r := range burstResults { + if len(benchResults) > 0 { + for _, r := range benchResults { serverResults = append(serverResults, ServerResult{ IP: r.IP, QPS: r.QPS(), @@ -298,14 +298,14 @@ resultLoop: } } - stats := ScanStats{ - TotalScanned: scanned, + outputStats := ScanStats{ + TotalScanned: finalStats.Processed, Found: int64(len(serverResults)), - Duration: elapsed, + Duration: finalStats.Elapsed, } if cfg.InputFile == "" { - stats.Country = cfg.Country - stats.Mode = cfg.Mode + outputStats.Country = cfg.Country + outputStats.Mode = cfg.Mode } var out io.Writer = os.Stdout @@ -314,10 +314,10 @@ resultLoop: } if cfg.JSONOutput { - NewJSONWriter(out).Write(serverResults, stats) + NewJSONWriter(out).Write(serverResults, outputStats) } else { if app.outFile != nil || !cfg.Progress { - NewTextWriter(out).Write(serverResults, stats) + NewTextWriter(out).Write(serverResults, outputStats) } if cfg.Progress { PrintUsageHint(os.Stderr, workingDNS, cfg.Domain) diff --git a/progress.go b/progress.go new file mode 100644 index 0000000..b7782bd --- /dev/null +++ b/progress.go @@ -0,0 +1,48 @@ +package main + +import ( + "sync/atomic" + "time" +) + +// ProgressStats holds progress metrics +type ProgressStats struct { + Processed int64 + Success int64 + Total int64 + Elapsed time.Duration +} + +// Progress tracks scan/test progress with atomic counters +type Progress struct { + total int64 + processed int64 + success int64 + startTime time.Time + enabled bool +} + +func NewProgress(total int, enabled bool) *Progress { + return &Progress{ + total: int64(total), + startTime: time.Now(), + enabled: enabled, + } +} + +func (p *Progress) Increment() { + atomic.AddInt64(&p.processed, 1) +} + +func (p *Progress) Success() { + atomic.AddInt64(&p.success, 1) +} + +func (p *Progress) Stats() ProgressStats { + return ProgressStats{ + Processed: atomic.LoadInt64(&p.processed), + Success: atomic.LoadInt64(&p.success), + Total: p.total, + Elapsed: time.Since(p.startTime), + } +} diff --git a/scanner.go b/scanner.go index 8f40f49..38c9a63 100644 --- a/scanner.go +++ b/scanner.go @@ -3,73 +3,27 @@ package main import ( "context" "crypto/rand" - "encoding/base32" "encoding/hex" "fmt" "net" - "sort" "sync" - "sync/atomic" "time" "github.com/miekg/dns" ) -// Burst test parameters - tune these for accuracy vs speed tradeoff -const ( - BurstQueries = 20 // Number of queries per burst test - BurstConcurrency = 5 // Concurrent queries during burst - BurstMinSuccess = 70 // Minimum success rate % to pass - BurstSubdomainLen = 32 // Bytes for random subdomain (32 = ~52 base32 chars) -) - -// randomSubdomain generates a random subdomain prefix -func randomSubdomain() string { - b := make([]byte, 8) - rand.Read(b) - return hex.EncodeToString(b) -} - -// randomSlipstreamSubdomain generates a subdomain similar to slipstream's Base32 encoding -func randomSlipstreamSubdomain() string { - b := make([]byte, BurstSubdomainLen) - rand.Read(b) - return base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(b) -} +const probeDomain = "google.com" -// ScanResult holds the result of a DNS probe +// ScanResult holds the outcome of probing a single DNS server. type ScanResult struct { IP string Working bool - Suspicious bool + Suspicious bool // true if server returned private IP (possible hijacking) RTT time.Duration Error error } -// isPrivateIP detects DNS hijacking by checking if response IPs are in reserved ranges -func isPrivateIP(ip net.IP) bool { - if ip == nil { - return false - } - privateRanges := []string{ - "10.0.0.0/8", // Common in corporate/ISP hijacking - "172.16.0.0/12", // Often used by captive portals - "192.168.0.0/16", // Home routers sometimes hijack DNS - "127.0.0.0/8", // Loopback, used to block domains - "169.254.0.0/16", // Link-local, indicates broken resolution - "100.64.0.0/10", // CGNAT, ISP-level interception - "0.0.0.0/8", // Invalid, used to sink traffic - } - for _, cidr := range privateRanges { - _, network, _ := net.ParseCIDR(cidr) - if network.Contains(ip) { - return true - } - } - return false -} - -// Scanner manages the worker pool for DNS probing +// Scanner probes DNS servers for availability and hijacking detection. type Scanner struct { workers int timeout time.Duration @@ -78,44 +32,6 @@ type Scanner struct { verifyDomain string } -// Progress tracks scanning progress -type Progress struct { - total int64 - scanned int64 - found int64 - startTime time.Time - enabled bool - mu sync.Mutex -} - -// NewProgress creates a new progress tracker -func NewProgress(total int, enabled bool) *Progress { - return &Progress{ - total: int64(total), - startTime: time.Now(), - enabled: enabled, - } -} - -// Increment marks one IP as scanned -func (p *Progress) Increment() { - atomic.AddInt64(&p.scanned, 1) -} - -// Found marks a working DNS found -func (p *Progress) Found() { - atomic.AddInt64(&p.found, 1) -} - -// Stats returns current stats -func (p *Progress) Stats() (scanned, found, total int64, elapsed time.Duration) { - return atomic.LoadInt64(&p.scanned), - atomic.LoadInt64(&p.found), - p.total, - time.Since(p.startTime) -} - -// NewScanner creates a new scanner with given workers and timeout func NewScanner(workers int, timeout time.Duration, port int, progress *Progress, verifyDomain string) *Scanner { if port == 0 { port = 53 @@ -129,7 +45,7 @@ func NewScanner(workers int, timeout time.Duration, port int, progress *Progress } } -// Probe tests if an IP is a working DNS server +// Probe tests a single IP for DNS availability. func (s *Scanner) Probe(ip string) ScanResult { client := &dns.Client{ Net: "udp", @@ -137,9 +53,8 @@ func (s *Scanner) Probe(ip string) ScanResult { ReadTimeout: s.timeout, } - // First test: can it resolve google.com? m := new(dns.Msg) - m.SetQuestion(dns.Fqdn("google.com"), dns.TypeA) + m.SetQuestion(dns.Fqdn(probeDomain), dns.TypeA) m.RecursionDesired = true addr := fmt.Sprintf("%s:%d", ip, s.port) @@ -148,7 +63,6 @@ func (s *Scanner) Probe(ip string) ScanResult { return ScanResult{IP: ip, Working: false, Error: err} } - // Check for valid response with actual answer if reply == nil || reply.Rcode != dns.RcodeSuccess || len(reply.Answer) == 0 { return ScanResult{IP: ip, Working: false} } @@ -161,26 +75,19 @@ func (s *Scanner) Probe(ip string) ScanResult { } } - // If verify domain is set, check if query reaches our authoritative server - // Slipstream uses TXT records exclusively, so test with TXT - // Any response (NXDOMAIN, NOERROR, etc.) = query reached server - // Only timeout/error = didn't reach if s.verifyDomain != "" { - // Use random subdomain to avoid DNS caching testDomain := randomSubdomain() + "." + s.verifyDomain m2 := new(dns.Msg) m2.SetQuestion(dns.Fqdn(testDomain), dns.TypeTXT) m2.RecursionDesired = true - // Set EDNS0 with 1232 byte UDP payload (matches slipstream) - m2.SetEdns0(1232, false) + m2.SetEdns0(EDNSBufferSize, false) reply2, rtt2, err := client.Exchange(m2, addr) if err != nil { return ScanResult{IP: ip, Working: false, Error: err} } - // Any response = query reached our authoritative server if reply2 != nil { return ScanResult{IP: ip, Working: true, RTT: rtt2} } @@ -190,12 +97,11 @@ func (s *Scanner) Probe(ip string) ScanResult { return ScanResult{IP: ip, Working: true, RTT: rtt} } -// Run starts the scanner with a worker pool +// Run starts a worker pool that probes IPs concurrently. func (s *Scanner) Run(ctx context.Context, ips <-chan string) <-chan ScanResult { results := make(chan ScanResult, s.workers) var wg sync.WaitGroup - // Start workers for i := 0; i < s.workers; i++ { wg.Add(1) go func() { @@ -212,7 +118,7 @@ func (s *Scanner) Run(ctx context.Context, ips <-chan string) <-chan ScanResult if s.progress != nil { s.progress.Increment() if result.Working { - s.progress.Found() + s.progress.Success() } } select { @@ -225,7 +131,6 @@ func (s *Scanner) Run(ctx context.Context, ips <-chan string) <-chan ScanResult }() } - // Close results channel when all workers are done go func() { wg.Wait() close(results) @@ -234,191 +139,31 @@ func (s *Scanner) Run(ctx context.Context, ips <-chan string) <-chan ScanResult return results } -// BurstResult holds results from a burst test -type BurstResult struct { - IP string - Queries int - Successful int - Failed int - Latencies []time.Duration - Duration time.Duration -} - -// SuccessRate returns percentage of successful queries -func (r *BurstResult) SuccessRate() float64 { - if r.Queries == 0 { - return 0 - } - return float64(r.Successful) / float64(r.Queries) * 100 -} - -// QPS returns queries per second -func (r *BurstResult) QPS() float64 { - if r.Duration == 0 { - return 0 - } - return float64(r.Successful) / r.Duration.Seconds() -} - -// P50 returns median latency -func (r *BurstResult) P50() time.Duration { - return r.percentile(50) -} - -func (r *BurstResult) percentile(p int) time.Duration { - if len(r.Latencies) == 0 { - return 0 - } - sorted := make([]time.Duration, len(r.Latencies)) - copy(sorted, r.Latencies) - sort.Slice(sorted, func(i, j int) bool { return sorted[i] < sorted[j] }) - idx := len(sorted) * p / 100 - if idx >= len(sorted) { - idx = len(sorted) - 1 - } - return sorted[idx] +var privateRanges = []string{ + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "127.0.0.0/8", + "169.254.0.0/16", + "100.64.0.0/10", + "0.0.0.0/8", } -// Passed returns true if burst test meets minimum success rate -func (r *BurstResult) Passed() bool { - return r.SuccessRate() >= BurstMinSuccess -} - -// BurstTest runs concurrent DNS queries to test server reliability under load -func BurstTest(ctx context.Context, ip, domain string, port int, timeout time.Duration) *BurstResult { - if port == 0 { - port = 53 - } - addr := fmt.Sprintf("%s:%d", ip, port) - - result := &BurstResult{ - IP: ip, - Queries: BurstQueries, - } - - client := &dns.Client{ - Net: "udp", - Timeout: timeout, +func isPrivateIP(ip net.IP) bool { + if ip == nil { + return false } - - var mu sync.Mutex - var wg sync.WaitGroup - sem := make(chan struct{}, BurstConcurrency) - - start := time.Now() - - for i := 0; i < BurstQueries; i++ { - select { - case <-ctx.Done(): - result.Duration = time.Since(start) - return result - default: + for _, cidr := range privateRanges { + _, network, _ := net.ParseCIDR(cidr) + if network.Contains(ip) { + return true } - - wg.Add(1) - sem <- struct{}{} - - go func() { - defer wg.Done() - defer func() { <-sem }() - - select { - case <-ctx.Done(): - mu.Lock() - result.Failed++ - mu.Unlock() - return - default: - } - - subdomain := randomSlipstreamSubdomain() - m := new(dns.Msg) - m.SetQuestion(dns.Fqdn(subdomain+"."+domain), dns.TypeTXT) - m.RecursionDesired = true - m.SetEdns0(1232, false) - - _, rtt, err := client.Exchange(m, addr) - - mu.Lock() - if err != nil { - result.Failed++ - } else { - result.Successful++ - result.Latencies = append(result.Latencies, rtt) - } - mu.Unlock() - }() } - - wg.Wait() - result.Duration = time.Since(start) - return result -} - -// BurstProgress tracks parallel burst test progress with atomic counters -type BurstProgress struct { - total int64 - tested int64 - passed int64 - enabled bool -} - -func NewBurstProgress(total int, enabled bool) *BurstProgress { - return &BurstProgress{total: int64(total), enabled: enabled} -} - -func (p *BurstProgress) Tested() { atomic.AddInt64(&p.tested, 1) } -func (p *BurstProgress) Passed() { atomic.AddInt64(&p.passed, 1) } -func (p *BurstProgress) Stats() (tested, passed, total int64) { - return atomic.LoadInt64(&p.tested), atomic.LoadInt64(&p.passed), p.total + return false } -// ParallelBurstTest runs burst tests on multiple IPs concurrently -func ParallelBurstTest(ctx context.Context, ips []string, domain string, port int, - timeout time.Duration, workers int) <-chan *BurstResult { - - results := make(chan *BurstResult, workers) - ipChan := make(chan string, len(ips)) - - go func() { - defer close(ipChan) - for _, ip := range ips { - select { - case ipChan <- ip: - case <-ctx.Done(): - return - } - } - }() - - var wg sync.WaitGroup - for i := 0; i < workers; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for { - select { - case <-ctx.Done(): - return - case ip, ok := <-ipChan: - if !ok { - return - } - result := BurstTest(ctx, ip, domain, port, timeout) - select { - case results <- result: - case <-ctx.Done(): - return - } - } - } - }() - } - - go func() { - wg.Wait() - close(results) - }() - - return results +func randomSubdomain() string { + b := make([]byte, 8) + rand.Read(b) + return hex.EncodeToString(b) } diff --git a/scanner_test.go b/scanner_test.go index ecaf5dd..207dadd 100644 --- a/scanner_test.go +++ b/scanner_test.go @@ -66,22 +66,22 @@ func TestRandomSubdomain(t *testing.T) { } } -func TestRandomSlipstreamSubdomain(t *testing.T) { - s1 := randomSlipstreamSubdomain() - s2 := randomSlipstreamSubdomain() +func TestRandomBenchmarkSubdomain(t *testing.T) { + s1 := randomBenchmarkSubdomain() + s2 := randomBenchmarkSubdomain() // Base32 encoded 32 bytes = 52 chars if len(s1) != 52 { - t.Errorf("randomSlipstreamSubdomain length = %d, expected 52", len(s1)) + t.Errorf("randomBenchmarkSubdomain length = %d, expected 52", len(s1)) } // Should be different each time if s1 == s2 { - t.Error("randomSlipstreamSubdomain should generate unique values") + t.Error("randomBenchmarkSubdomain should generate unique values") } } -func TestBurstResultSuccessRate(t *testing.T) { +func TestBenchmarkResultSuccessRate(t *testing.T) { tests := []struct { queries int successful int @@ -94,7 +94,7 @@ func TestBurstResultSuccessRate(t *testing.T) { } for _, tt := range tests { - r := &BurstResult{ + r := &BenchmarkResult{ Queries: tt.queries, Successful: tt.successful, } @@ -105,8 +105,8 @@ func TestBurstResultSuccessRate(t *testing.T) { } } -func TestBurstResultQPS(t *testing.T) { - r := &BurstResult{ +func TestBenchmarkResultQPS(t *testing.T) { + r := &BenchmarkResult{ Successful: 10, Duration: time.Second, } @@ -115,14 +115,14 @@ func TestBurstResultQPS(t *testing.T) { } // Zero duration edge case - r2 := &BurstResult{Duration: 0} + r2 := &BenchmarkResult{Duration: 0} if r2.QPS() != 0.0 { t.Error("QPS with zero duration should be 0") } } -func TestBurstResultP50(t *testing.T) { - r := &BurstResult{ +func TestBenchmarkResultP50(t *testing.T) { + r := &BenchmarkResult{ Latencies: []time.Duration{ 10 * time.Millisecond, 20 * time.Millisecond, @@ -138,13 +138,13 @@ func TestBurstResultP50(t *testing.T) { } // Empty latencies - r2 := &BurstResult{} + r2 := &BenchmarkResult{} if r2.P50() != 0 { t.Error("P50 with no latencies should be 0") } } -func TestBurstResultPassed(t *testing.T) { +func TestBenchmarkResultPassed(t *testing.T) { tests := []struct { queries int successful int @@ -157,7 +157,7 @@ func TestBurstResultPassed(t *testing.T) { } for _, tt := range tests { - r := &BurstResult{ + r := &BenchmarkResult{ Queries: tt.queries, Successful: tt.successful, } @@ -173,19 +173,19 @@ func TestProgressStats(t *testing.T) { p.Increment() p.Increment() - p.Found() + p.Success() - scanned, found, total, elapsed := p.Stats() - if scanned != 2 { - t.Errorf("scanned = %d, expected 2", scanned) + stats := p.Stats() + if stats.Processed != 2 { + t.Errorf("Processed = %d, expected 2", stats.Processed) } - if found != 1 { - t.Errorf("found = %d, expected 1", found) + if stats.Success != 1 { + t.Errorf("Success = %d, expected 1", stats.Success) } - if total != 100 { - t.Errorf("total = %d, expected 100", total) + if stats.Total != 100 { + t.Errorf("Total = %d, expected 100", stats.Total) } - if elapsed < 0 { - t.Error("elapsed should not be negative") + if stats.Elapsed < 0 { + t.Error("Elapsed should not be negative") } } From 49927eae4cc503c478ee82d1c40002b1de435b2e Mon Sep 17 00:00:00 2001 From: Night Owl Nerd <256460992+nightowlnerd@users.noreply.github.com> Date: Mon, 2 Feb 2026 14:24:25 +0100 Subject: [PATCH 08/11] refactor: make components self-contained with progress output --- benchmark.go | 221 +++++++++++++++++++++++++++++++++++------------- e2e_test.go | 136 +++++++++++++----------------- main.go | 234 +++------------------------------------------------ scanner.go | 152 ++++++++++++++++++++++----------- verify.go | 80 ++++++++++++++++-- 5 files changed, 406 insertions(+), 417 deletions(-) diff --git a/benchmark.go b/benchmark.go index 2f10414..777ad07 100644 --- a/benchmark.go +++ b/benchmark.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "encoding/base32" "fmt" + "io" "sort" "sync" "time" @@ -15,9 +16,9 @@ import ( const ( BenchmarkQueries = 20 BenchmarkConcurrency = 5 - BenchmarkThreshold = 70 + BenchmarkThreshold = 70 BenchmarkSubdomainLen = 32 - EDNSBufferSize = 1232 // matches slipstream UDP payload + EDNSBufferSize = 1232 ) type BenchmarkResult struct { @@ -65,18 +66,164 @@ func (r *BenchmarkResult) Passed() bool { return r.SuccessRate() >= BenchmarkThreshold } -func randomBenchmarkSubdomain() string { - b := make([]byte, BenchmarkSubdomainLen) - rand.Read(b) - return base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(b) +type Benchmarker struct { + domain string + port int + timeout time.Duration + output io.Writer + showProgress bool } -// Benchmark runs concurrent DNS queries to test server reliability under load -func Benchmark(ctx context.Context, ip, domain string, port int, timeout time.Duration) *BenchmarkResult { +func NewBenchmarker(domain string, port int, timeout time.Duration, output io.Writer, showProgress bool) *Benchmarker { if port == 0 { port = 53 } - addr := fmt.Sprintf("%s:%d", ip, port) + return &Benchmarker{ + domain: domain, + port: port, + timeout: timeout, + output: output, + showProgress: showProgress, + } +} + +func (b *Benchmarker) Benchmark(ctx context.Context, ips []string) []*BenchmarkResult { + if len(ips) == 0 { + return nil + } + + prog := NewProgress(len(ips), b.showProgress) + + tickCtx, stopTick := context.WithCancel(ctx) + defer stopTick() + go b.tick(tickCtx, prog) + + var results []*BenchmarkResult + if len(ips) <= 5 { + results = b.sequential(ctx, ips, prog) + } else { + results = b.parallel(ctx, ips, prog) + } + + sort.Slice(results, func(i, j int) bool { + return results[i].QPS() > results[j].QPS() + }) + + b.summary(results, len(ips)) + return results +} + +func (b *Benchmarker) tick(ctx context.Context, prog *Progress) { + if !b.showProgress || b.output == nil { + return + } + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-ticker.C: + st := prog.Stats() + fmt.Fprintf(b.output, "\rBenchmarking: %d/%d tested, %d passed ", + st.Processed, st.Total, st.Success) + case <-ctx.Done(): + fmt.Fprint(b.output, "\r\033[K") + return + } + } +} + +func (b *Benchmarker) summary(results []*BenchmarkResult, total int) { + if !b.showProgress || b.output == nil { + return + } + fmt.Fprintf(b.output, "Benchmark: %d/%d passed (sorted by throughput)\n", len(results), total) + for _, r := range results { + color := "\033[33m" + if r.SuccessRate() >= 85 { + color = "\033[32m" + } + fmt.Fprintf(b.output, "%s%-15s %.0f%% (%.1f qps, p50=%v)\033[0m\n", + color, r.IP, r.SuccessRate(), r.QPS(), r.P50().Round(time.Millisecond)) + } +} + +func (b *Benchmarker) sequential(ctx context.Context, ips []string, prog *Progress) []*BenchmarkResult { + var results []*BenchmarkResult + for _, ip := range ips { + select { + case <-ctx.Done(): + return results + default: + } + result := b.benchmark(ctx, ip) + prog.Increment() + if result.Passed() { + prog.Success() + results = append(results, result) + } + } + return results +} + +func (b *Benchmarker) parallel(ctx context.Context, ips []string, prog *Progress) []*BenchmarkResult { + workers := min(len(ips), 10) + + resultChan := make(chan *BenchmarkResult, workers) + ipChan := make(chan string, len(ips)) + + go func() { + defer close(ipChan) + for _, ip := range ips { + select { + case ipChan <- ip: + case <-ctx.Done(): + return + } + } + }() + + var wg sync.WaitGroup + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-ctx.Done(): + return + case ip, ok := <-ipChan: + if !ok { + return + } + result := b.benchmark(ctx, ip) + select { + case resultChan <- result: + case <-ctx.Done(): + return + } + } + } + }() + } + + go func() { + wg.Wait() + close(resultChan) + }() + + var results []*BenchmarkResult + for result := range resultChan { + prog.Increment() + if result.Passed() { + prog.Success() + results = append(results, result) + } + } + return results +} + +func (b *Benchmarker) benchmark(ctx context.Context, ip string) *BenchmarkResult { + addr := fmt.Sprintf("%s:%d", ip, b.port) result := &BenchmarkResult{ IP: ip, @@ -85,7 +232,7 @@ func Benchmark(ctx context.Context, ip, domain string, port int, timeout time.Du client := &dns.Client{ Net: "udp", - Timeout: timeout, + Timeout: b.timeout, } var mu sync.Mutex @@ -120,7 +267,7 @@ func Benchmark(ctx context.Context, ip, domain string, port int, timeout time.Du subdomain := randomBenchmarkSubdomain() m := new(dns.Msg) - m.SetQuestion(dns.Fqdn(subdomain+"."+domain), dns.TypeTXT) + m.SetQuestion(dns.Fqdn(subdomain+"."+b.domain), dns.TypeTXT) m.RecursionDesired = true m.SetEdns0(EDNSBufferSize, false) @@ -142,52 +289,8 @@ func Benchmark(ctx context.Context, ip, domain string, port int, timeout time.Du return result } -// BenchmarkParallel runs benchmarks on multiple IPs concurrently -func BenchmarkParallel(ctx context.Context, ips []string, domain string, port int, - timeout time.Duration, workers int) <-chan *BenchmarkResult { - - results := make(chan *BenchmarkResult, workers) - ipChan := make(chan string, len(ips)) - - go func() { - defer close(ipChan) - for _, ip := range ips { - select { - case ipChan <- ip: - case <-ctx.Done(): - return - } - } - }() - - var wg sync.WaitGroup - for i := 0; i < workers; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for { - select { - case <-ctx.Done(): - return - case ip, ok := <-ipChan: - if !ok { - return - } - result := Benchmark(ctx, ip, domain, port, timeout) - select { - case results <- result: - case <-ctx.Done(): - return - } - } - } - }() - } - - go func() { - wg.Wait() - close(results) - }() - - return results +func randomBenchmarkSubdomain() string { + b := make([]byte, BenchmarkSubdomainLen) + rand.Read(b) + return base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(b) } diff --git a/e2e_test.go b/e2e_test.go index e6a990a..a85c0a1 100644 --- a/e2e_test.go +++ b/e2e_test.go @@ -85,14 +85,14 @@ func TestE2EBasicScan(t *testing.T) { } defer mock.Close() - scanner := NewScanner(1, 2*time.Second, mock.port, nil, "") - result := scanner.Probe(mock.ip) + scanner := NewScanner(1, 2*time.Second, mock.port, "", 1, nil, false) + working, suspicious := scanner.Scan(context.Background(), sliceToChannel([]string{mock.ip})) - if !result.Working { - t.Errorf("Expected working, got error: %v", result.Error) + if len(working) != 1 { + t.Errorf("Expected 1 working, got %d", len(working)) } - if result.Suspicious { - t.Error("Public IP should not be suspicious") + if suspicious != 0 { + t.Errorf("Public IP should not be suspicious, got %d", suspicious) } } @@ -103,14 +103,14 @@ func TestE2EHijackingDetection(t *testing.T) { } defer mock.Close() - scanner := NewScanner(1, 2*time.Second, mock.port, nil, "") - result := scanner.Probe(mock.ip) + scanner := NewScanner(1, 2*time.Second, mock.port, "", 1, nil, false) + working, suspicious := scanner.Scan(context.Background(), sliceToChannel([]string{mock.ip})) - if result.Working { - t.Error("Private IP response should not be working") + if len(working) != 0 { + t.Errorf("Private IP response should not be working, got %d", len(working)) } - if !result.Suspicious { - t.Error("Private IP response should be suspicious") + if suspicious != 1 { + t.Errorf("Private IP response should be suspicious, got %d", suspicious) } } @@ -121,11 +121,11 @@ func TestE2EDomainVerification(t *testing.T) { } defer mock.Close() - scanner := NewScanner(1, 2*time.Second, mock.port, nil, "test.example.com") - result := scanner.Probe(mock.ip) + scanner := NewScanner(1, 2*time.Second, mock.port, "test.example.com", 1, nil, false) + working, _ := scanner.Scan(context.Background(), sliceToChannel([]string{mock.ip})) - if !result.Working { - t.Errorf("Domain verification should pass, got: %v", result.Error) + if len(working) != 1 { + t.Errorf("Domain verification should pass, got %d working", len(working)) } } @@ -136,8 +136,14 @@ func TestE2EBenchmark(t *testing.T) { } defer mock.Close() - result := Benchmark(context.Background(), mock.ip, "test.example.com", mock.port, 2*time.Second) + benchmarker := NewBenchmarker("test.example.com", mock.port, 2*time.Second, nil, false) + results := benchmarker.Benchmark(context.Background(), []string{mock.ip}) + if len(results) != 1 { + t.Fatalf("Expected 1 result, got %d", len(results)) + } + + result := results[0] if result.Queries != BenchmarkQueries { t.Errorf("Expected %d queries, got %d", BenchmarkQueries, result.Queries) } @@ -156,28 +162,16 @@ func TestE2EWorkerPool(t *testing.T) { } defer mock.Close() - ips := []string{mock.ip, "192.0.2.1"} // second IP won't respond - progress := NewProgress(len(ips), false) - scanner := NewScanner(2, 500*time.Millisecond, mock.port, progress, "") + ips := []string{mock.ip, "192.0.2.1"} + scanner := NewScanner(2, 500*time.Millisecond, mock.port, "", len(ips), nil, false) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - results := scanner.Run(ctx, sliceToChannel(ips)) - - var working, total int - for r := range results { - total++ - if r.Working { - working++ - } - } + working, _ := scanner.Scan(ctx, sliceToChannel(ips)) - if total != 2 { - t.Errorf("Expected 2 results, got %d", total) - } - if working != 1 { - t.Errorf("Expected 1 working, got %d", working) + if len(working) != 1 { + t.Errorf("Expected 1 working, got %d", len(working)) } } @@ -190,29 +184,26 @@ func TestE2EAllPrivateRanges(t *testing.T) { t.Fatalf("Failed for %s: %v", ip, err) } - scanner := NewScanner(1, 2*time.Second, mock.port, nil, "") - result := scanner.Probe(mock.ip) + scanner := NewScanner(1, 2*time.Second, mock.port, "", 1, nil, false) + working, suspicious := scanner.Scan(context.Background(), sliceToChannel([]string{mock.ip})) mock.Close() - if result.Working { + if len(working) != 0 { t.Errorf("%s should not be working", ip) } - if !result.Suspicious { + if suspicious != 1 { t.Errorf("%s should be suspicious", ip) } } } func TestE2ETimeout(t *testing.T) { - scanner := NewScanner(1, 100*time.Millisecond, 65534, nil, "") - result := scanner.Probe("127.0.0.1") + scanner := NewScanner(1, 100*time.Millisecond, 65534, "", 1, nil, false) + working, _ := scanner.Scan(context.Background(), sliceToChannel([]string{"127.0.0.1"})) - if result.Working { + if len(working) != 0 { t.Error("Non-responsive should not be working") } - if result.Error == nil { - t.Error("Expected timeout error") - } } func TestE2EProgressTracking(t *testing.T) { @@ -223,17 +214,9 @@ func TestE2EProgressTracking(t *testing.T) { defer mock.Close() ips := []string{mock.ip, mock.ip, mock.ip} - progress := NewProgress(len(ips), false) - scanner := NewScanner(1, 2*time.Second, mock.port, progress, "") - - results := scanner.Run(context.Background(), sliceToChannel(ips)) - for range results { - } + scanner := NewScanner(1, 2*time.Second, mock.port, "", len(ips), nil, false) - stats := progress.Stats() - if stats.Processed != 3 || stats.Success != 3 || stats.Total != 3 { - t.Errorf("Progress mismatch: processed=%d success=%d total=%d", stats.Processed, stats.Success, stats.Total) - } + scanner.Scan(context.Background(), sliceToChannel(ips)) } func TestE2EBenchmarkQPS(t *testing.T) { @@ -243,8 +226,14 @@ func TestE2EBenchmarkQPS(t *testing.T) { } defer mock.Close() - result := Benchmark(context.Background(), mock.ip, "test.example.com", mock.port, 2*time.Second) + benchmarker := NewBenchmarker("test.example.com", mock.port, 2*time.Second, nil, false) + results := benchmarker.Benchmark(context.Background(), []string{mock.ip}) + + if len(results) != 1 { + t.Fatalf("Expected 1 result, got %d", len(results)) + } + result := results[0] if result.QPS() <= 0 { t.Errorf("QPS should be positive, got %.2f", result.QPS()) } @@ -253,8 +242,7 @@ func TestE2EBenchmarkQPS(t *testing.T) { } } -func TestE2EBenchmarkParallel(t *testing.T) { - // Start 3 mock servers +func TestE2EBenchmarkMultiple(t *testing.T) { var mocks []*mockDNSServer var ips []string for i := 0; i < 3; i++ { @@ -274,22 +262,11 @@ func TestE2EBenchmarkParallel(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - // All mocks use same port pattern, so we use first mock's port - resultChan := BenchmarkParallel(ctx, ips, "test.example.com", mocks[0].port, 2*time.Second, 3) + benchmarker := NewBenchmarker("test.example.com", mocks[0].port, 2*time.Second, nil, false) + results := benchmarker.Benchmark(ctx, ips) - var count, passed int - for r := range resultChan { - count++ - if r.Passed() { - passed++ - } - } - - if count != 3 { - t.Errorf("Expected 3 results, got %d", count) - } - if passed != 3 { - t.Errorf("Expected 3 passed, got %d", passed) + if len(results) != 3 { + t.Errorf("Expected 3 results, got %d", len(results)) } } @@ -300,14 +277,13 @@ func TestE2EBenchmarkContextCancellation(t *testing.T) { } defer mock.Close() - // Cancel immediately ctx, cancel := context.WithCancel(context.Background()) cancel() - result := Benchmark(ctx, mock.ip, "test.example.com", mock.port, 2*time.Second) + benchmarker := NewBenchmarker("test.example.com", mock.port, 2*time.Second, nil, false) + results := benchmarker.Benchmark(ctx, []string{mock.ip}) - // Should return early with partial/no results - if result.Successful == BenchmarkQueries { + if len(results) > 0 && results[0].Successful == BenchmarkQueries { t.Error("Expected early termination with cancelled context") } } @@ -338,8 +314,14 @@ func TestE2EBenchmarkResultMetrics(t *testing.T) { } defer mock.Close() - result := Benchmark(context.Background(), mock.ip, "test.example.com", mock.port, 2*time.Second) + benchmarker := NewBenchmarker("test.example.com", mock.port, 2*time.Second, nil, false) + results := benchmarker.Benchmark(context.Background(), []string{mock.ip}) + + if len(results) != 1 { + t.Fatalf("Expected 1 result, got %d", len(results)) + } + result := results[0] if result.IP != mock.ip { t.Errorf("IP mismatch: got %s, expected %s", result.IP, mock.ip) } diff --git a/main.go b/main.go index a6dada2..14d1498 100644 --- a/main.go +++ b/main.go @@ -6,9 +6,7 @@ import ( "io" "os" "os/signal" - "sort" "syscall" - "time" ) var version = "dev" @@ -38,7 +36,6 @@ func main() { PrintBanner(os.Stderr, cfg, totalIPs, version) - // Setup context with signal handling ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -52,227 +49,22 @@ func main() { cancel() }() - // Create scanner - prog := NewProgress(totalIPs, cfg.Progress) - scanner := NewScanner(cfg.Workers, cfg.Timeout, 53, prog, cfg.Domain) + // Phase 1: Scan + scanner := NewScanner(cfg.Workers, cfg.Timeout, 53, cfg.Domain, totalIPs, os.Stderr, cfg.Progress) + workingDNS, suspiciousCount := scanner.Scan(ctx, ips) + _ = suspiciousCount - // Start progress ticker - var progressDone chan struct{} - if cfg.Progress { - progressDone = make(chan struct{}) - go func() { - ticker := time.NewTicker(500 * time.Millisecond) - defer ticker.Stop() - for { - select { - case <-ticker.C: - stats := prog.Stats() - rate := float64(stats.Processed) / stats.Elapsed.Seconds() - pct := float64(stats.Processed) / float64(stats.Total) * 100 - fmt.Fprintf(os.Stderr, "\rScanned: %d/%d (%.1f%%) | Found: %d | %.0f IPs/sec ", - stats.Processed, stats.Total, pct, stats.Success, rate) - case <-ctx.Done(): - return - case <-progressDone: - return - } - } - }() - } - - // Run scanner - results := scanner.Run(ctx, ips) - - // Collect results - var workingDNS []string - var suspiciousCount int -resultLoop: - for { - select { - case <-ctx.Done(): - break resultLoop - case result, ok := <-results: - if !ok { - break resultLoop - } - if result.Suspicious { - suspiciousCount++ - } - if result.Working { - workingDNS = append(workingDNS, result.IP) - } - } - } - - // Stop progress ticker - if progressDone != nil { - close(progressDone) - } - - // Print final stats - if cfg.Progress { - stats := prog.Stats() - fmt.Fprintf(os.Stderr, "\r \r") - fmt.Fprintf(os.Stderr, "Completed: %d IPs in %v\n", stats.Processed, stats.Elapsed.Round(time.Millisecond)) - fmt.Fprintf(os.Stderr, "Found: %d DNS candidates\n", stats.Success) - if suspiciousCount > 0 { - fmt.Fprintf(os.Stderr, "\033[33mWarning: %d servers returned private IPs (possible DNS hijacking)\033[0m\n", suspiciousCount) - } - } - - // Phase 2: Verify with tunnel client if requested + // Phase 2: Verify (optional) if cfg.VerifyBinary != "" && len(workingDNS) > 0 { - verifier := NewSlipstreamVerifier(cfg.VerifyBinary, cfg.Domain, cfg.Timeout) - - if cfg.Progress { - fmt.Fprintf(os.Stderr, "\nVerifying %d candidates with %s...\n", len(workingDNS), verifier.Name()) - } - - var verified []string - total := len(workingDNS) - width := len(fmt.Sprintf("%d", total)) - for i, ip := range workingDNS { - select { - case <-ctx.Done(): - if cfg.Progress { - fmt.Fprintf(os.Stderr, "\nInterrupted during verification\n") - } - goto verifyDone - default: - } - - if cfg.Progress { - fmt.Fprintf(os.Stderr, "[%*d/%d] %-15s ", width, i+1, total, ip) - } - start := time.Now() - if verifier.Verify(ip) { - elapsed := time.Since(start) - verified = append(verified, ip) - if cfg.Progress { - fmt.Fprintf(os.Stderr, "\033[32mOK (%.1fs)\033[0m\n", elapsed.Seconds()) - } - } else { - if cfg.Progress { - fmt.Fprintf(os.Stderr, "FAIL\n") - } - } - } - verifyDone: - - if cfg.Progress { - fmt.Fprintf(os.Stderr, "---\n") - fmt.Fprintf(os.Stderr, "%s: %d/%d passed\n", verifier.Name(), len(verified), len(workingDNS)) - } - workingDNS = verified + verifier := NewSlipstreamVerifier(cfg.VerifyBinary, cfg.Domain, cfg.Timeout, os.Stderr, cfg.Progress) + workingDNS = verifier.Verify(ctx, workingDNS) } - // Phase 3: Benchmark to verify servers handle concurrent load + // Phase 3: Benchmark (optional) var benchResults []*BenchmarkResult if cfg.Domain != "" && len(workingDNS) > 0 { - total := len(workingDNS) - - if total <= 5 { - // Sequential for small lists - nicer per-IP output - if cfg.Progress { - fmt.Fprintf(os.Stderr, "\nBenchmarking %d candidates (%d queries, %d%% required)...\n", - total, BenchmarkQueries, BenchmarkThreshold) - } - - width := len(fmt.Sprintf("%d", total)) - for i, ip := range workingDNS { - select { - case <-ctx.Done(): - if cfg.Progress { - fmt.Fprintf(os.Stderr, "\nInterrupted during benchmark\n") - } - goto benchDone - default: - } - - if cfg.Progress { - fmt.Fprintf(os.Stderr, "[%*d/%d] %-15s ", width, i+1, total, ip) - } - - result := Benchmark(ctx, ip, cfg.Domain, 53, cfg.Timeout) - - if result.Passed() { - benchResults = append(benchResults, result) - if cfg.Progress { - color := "\033[33m" - if result.SuccessRate() >= 85 { - color = "\033[32m" - } - fmt.Fprintf(os.Stderr, "%sOK %.0f%% (%.1f qps, p50=%v)\033[0m\n", - color, result.SuccessRate(), result.QPS(), result.P50().Round(time.Millisecond)) - } - } else { - if cfg.Progress { - fmt.Fprintf(os.Stderr, "FAIL %.0f%%\n", result.SuccessRate()) - } - } - } - } else { - // Parallel for larger lists - benchWorkers := min(total, 10) - if cfg.Progress { - fmt.Fprintf(os.Stderr, "\nBenchmarking %d candidates in parallel (%d workers)...\n", - total, benchWorkers) - } - - benchProg := NewProgress(total, cfg.Progress) - var progressDone chan struct{} - - if cfg.Progress { - progressDone = make(chan struct{}) - go func() { - ticker := time.NewTicker(500 * time.Millisecond) - defer ticker.Stop() - for { - select { - case <-ticker.C: - stats := benchProg.Stats() - fmt.Fprintf(os.Stderr, "\rBenchmarking: %d/%d tested, %d passed ", stats.Processed, stats.Total, stats.Success) - case <-ctx.Done(): - return - case <-progressDone: - return - } - } - }() - } - - resultChan := BenchmarkParallel(ctx, workingDNS, cfg.Domain, 53, cfg.Timeout, benchWorkers) - for result := range resultChan { - benchProg.Increment() - if result.Passed() { - benchProg.Success() - benchResults = append(benchResults, result) - } - } - - if progressDone != nil { - close(progressDone) - fmt.Fprintf(os.Stderr, "\r \r") - } - } - benchDone: - - sort.Slice(benchResults, func(i, j int) bool { - return benchResults[i].QPS() > benchResults[j].QPS() - }) - - if cfg.Progress { - fmt.Fprintf(os.Stderr, "---\n") - fmt.Fprintf(os.Stderr, "Benchmark: %d/%d passed (sorted by throughput)\n", len(benchResults), len(workingDNS)) - for _, r := range benchResults { - color := "\033[33m" - if r.SuccessRate() >= 85 { - color = "\033[32m" - } - fmt.Fprintf(os.Stderr, "%s%-15s OK %.0f%% (%.1f qps, p50=%v)\033[0m\n", - color, r.IP, r.SuccessRate(), r.QPS(), r.P50().Round(time.Millisecond)) - } - } + benchmarker := NewBenchmarker(cfg.Domain, 53, cfg.Timeout, os.Stderr, cfg.Progress) + benchResults = benchmarker.Benchmark(ctx, workingDNS) workingDNS = nil for _, r := range benchResults { @@ -280,8 +72,7 @@ resultLoop: } } - finalStats := prog.Stats() - + // Output results var serverResults []ServerResult if len(benchResults) > 0 { for _, r := range benchResults { @@ -299,9 +90,8 @@ resultLoop: } outputStats := ScanStats{ - TotalScanned: finalStats.Processed, + TotalScanned: int64(totalIPs), Found: int64(len(serverResults)), - Duration: finalStats.Elapsed, } if cfg.InputFile == "" { outputStats.Country = cfg.Country diff --git a/scanner.go b/scanner.go index 38c9a63..3043f28 100644 --- a/scanner.go +++ b/scanner.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "encoding/hex" "fmt" + "io" "net" "sync" "time" @@ -14,25 +15,25 @@ import ( const probeDomain = "google.com" -// ScanResult holds the outcome of probing a single DNS server. type ScanResult struct { IP string Working bool - Suspicious bool // true if server returned private IP (possible hijacking) + Suspicious bool RTT time.Duration Error error } -// Scanner probes DNS servers for availability and hijacking detection. type Scanner struct { workers int timeout time.Duration port int - progress *Progress verifyDomain string + total int + output io.Writer + showProgress bool } -func NewScanner(workers int, timeout time.Duration, port int, progress *Progress, verifyDomain string) *Scanner { +func NewScanner(workers int, timeout time.Duration, port int, verifyDomain string, total int, output io.Writer, showProgress bool) *Scanner { if port == 0 { port = 53 } @@ -40,13 +41,106 @@ func NewScanner(workers int, timeout time.Duration, port int, progress *Progress workers: workers, timeout: timeout, port: port, - progress: progress, verifyDomain: verifyDomain, + total: total, + output: output, + showProgress: showProgress, } } -// Probe tests a single IP for DNS availability. -func (s *Scanner) Probe(ip string) ScanResult { +func (s *Scanner) Scan(ctx context.Context, ips <-chan string) (working []string, suspicious int) { + prog := NewProgress(s.total, s.showProgress) + + tickCtx, stopTick := context.WithCancel(ctx) + defer stopTick() + go s.tick(tickCtx, prog) + + for result := range s.run(ctx, ips, prog) { + if result.Suspicious { + suspicious++ + } + if result.Working { + working = append(working, result.IP) + } + } + + s.summary(prog, suspicious) + return +} + +func (s *Scanner) tick(ctx context.Context, prog *Progress) { + if !s.showProgress || s.output == nil { + return + } + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-ticker.C: + st := prog.Stats() + rate := float64(st.Processed) / st.Elapsed.Seconds() + pct := float64(st.Processed) / float64(st.Total) * 100 + fmt.Fprintf(s.output, "\rScanned: %d/%d (%.1f%%) | Found: %d | %.0f IPs/sec ", + st.Processed, st.Total, pct, st.Success, rate) + case <-ctx.Done(): + fmt.Fprint(s.output, "\r\033[K") + return + } + } +} + +func (s *Scanner) summary(prog *Progress, suspicious int) { + if !s.showProgress || s.output == nil { + return + } + st := prog.Stats() + fmt.Fprintf(s.output, "Completed: %d IPs in %v\n", st.Processed, st.Elapsed.Round(time.Millisecond)) + fmt.Fprintf(s.output, "Found: %d DNS candidates\n", st.Success) + if suspicious > 0 { + fmt.Fprintf(s.output, "\033[33mWarning: %d servers returned private IPs (possible DNS hijacking)\033[0m\n", suspicious) + } +} + +func (s *Scanner) run(ctx context.Context, ips <-chan string, prog *Progress) <-chan ScanResult { + results := make(chan ScanResult, s.workers) + var wg sync.WaitGroup + + for i := 0; i < s.workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-ctx.Done(): + return + case ip, ok := <-ips: + if !ok { + return + } + result := s.probe(ip) + prog.Increment() + if result.Working { + prog.Success() + } + select { + case results <- result: + case <-ctx.Done(): + return + } + } + } + }() + } + + go func() { + wg.Wait() + close(results) + }() + + return results +} + +func (s *Scanner) probe(ip string) ScanResult { client := &dns.Client{ Net: "udp", Timeout: s.timeout, @@ -97,48 +191,6 @@ func (s *Scanner) Probe(ip string) ScanResult { return ScanResult{IP: ip, Working: true, RTT: rtt} } -// Run starts a worker pool that probes IPs concurrently. -func (s *Scanner) Run(ctx context.Context, ips <-chan string) <-chan ScanResult { - results := make(chan ScanResult, s.workers) - var wg sync.WaitGroup - - for i := 0; i < s.workers; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for { - select { - case <-ctx.Done(): - return - case ip, ok := <-ips: - if !ok { - return - } - result := s.Probe(ip) - if s.progress != nil { - s.progress.Increment() - if result.Working { - s.progress.Success() - } - } - select { - case results <- result: - case <-ctx.Done(): - return - } - } - } - }() - } - - go func() { - wg.Wait() - close(results) - }() - - return results -} - var privateRanges = []string{ "10.0.0.0/8", "172.16.0.0/12", diff --git a/verify.go b/verify.go index 99bf398..60c3b9d 100644 --- a/verify.go +++ b/verify.go @@ -3,27 +3,33 @@ package main import ( "bytes" "context" + "fmt" + "io" "os/exec" "strings" "time" ) type Verifier interface { - Verify(ip string) bool + Verify(ctx context.Context, ips []string) []string Name() string } type SlipstreamVerifier struct { - clientPath string - domain string - timeout time.Duration + clientPath string + domain string + timeout time.Duration + output io.Writer + showProgress bool } -func NewSlipstreamVerifier(clientPath, domain string, timeout time.Duration) *SlipstreamVerifier { +func NewSlipstreamVerifier(clientPath, domain string, timeout time.Duration, output io.Writer, showProgress bool) *SlipstreamVerifier { return &SlipstreamVerifier{ - clientPath: clientPath, - domain: domain, - timeout: timeout, + clientPath: clientPath, + domain: domain, + timeout: timeout, + output: output, + showProgress: showProgress, } } @@ -31,7 +37,63 @@ func (v *SlipstreamVerifier) Name() string { return "slipstream" } -func (v *SlipstreamVerifier) Verify(ip string) bool { +func (v *SlipstreamVerifier) Verify(ctx context.Context, ips []string) []string { + if len(ips) == 0 { + return nil + } + + prog := NewProgress(len(ips), v.showProgress) + + tickCtx, stopTick := context.WithCancel(ctx) + defer stopTick() + go v.tick(tickCtx, prog) + + var verified []string + for _, ip := range ips { + select { + case <-ctx.Done(): + v.summary(prog, len(verified), len(ips)) + return verified + default: + } + prog.Increment() + if v.testIP(ip) { + prog.Success() + verified = append(verified, ip) + } + } + + v.summary(prog, len(verified), len(ips)) + return verified +} + +func (v *SlipstreamVerifier) tick(ctx context.Context, prog *Progress) { + if !v.showProgress || v.output == nil { + return + } + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-ticker.C: + st := prog.Stats() + fmt.Fprintf(v.output, "\rVerifying: %d/%d tested, %d passed ", + st.Processed, st.Total, st.Success) + case <-ctx.Done(): + fmt.Fprint(v.output, "\r\033[K") + return + } + } +} + +func (v *SlipstreamVerifier) summary(prog *Progress, passed, total int) { + if !v.showProgress || v.output == nil { + return + } + fmt.Fprintf(v.output, "%s: %d/%d passed\n", v.Name(), passed, total) +} + +func (v *SlipstreamVerifier) testIP(ip string) bool { ctx, cancel := context.WithTimeout(context.Background(), v.timeout*3) defer cancel() From 826ea60a3ad83104d22f1f229e1beaecc5475935 Mon Sep 17 00:00:00 2001 From: Night Owl Nerd <256460992+nightowlnerd@users.noreply.github.com> Date: Mon, 2 Feb 2026 14:37:52 +0100 Subject: [PATCH 09/11] refactor: inline app.go, simplify Scan return, unexport helpers --- app.go | 51 ----------------------------------------------- e2e_test.go | 21 ++++++-------------- main.go | 57 +++++++++++++++++++++++++++++++++++------------------ output.go | 4 ++-- scanner.go | 6 ++++-- 5 files changed, 50 insertions(+), 89 deletions(-) delete mode 100644 app.go diff --git a/app.go b/app.go deleted file mode 100644 index a183c6c..0000000 --- a/app.go +++ /dev/null @@ -1,51 +0,0 @@ -package main - -import ( - "fmt" - "io" - "os" -) - -type App struct { - cfg *Config - source IPSource - outFile io.Writer -} - -func NewApp(cfg *Config) (*App, error) { - source, err := newIPSource(cfg) - if err != nil { - return nil, err - } - - app := &App{ - cfg: cfg, - source: source, - } - - if cfg.OutputFile != "" { - f, err := os.Create(cfg.OutputFile) - if err != nil { - return nil, fmt.Errorf("failed to create output file: %w", err) - } - app.outFile = f - } - - return app, nil -} - -func newIPSource(cfg *Config) (IPSource, error) { - if cfg.InputFile != "" { - return NewFileSource(cfg.InputFile) - } - if cfg.Mode == "list" { - return NewDNSListSource(cfg.DataDir, cfg.Country) - } - return NewCIDRSource(cfg.DataDir, cfg.Country, cfg.Mode) -} - -func (a *App) Close() { - if f, ok := a.outFile.(*os.File); ok && f != nil { - f.Close() - } -} diff --git a/e2e_test.go b/e2e_test.go index a85c0a1..8ccd18c 100644 --- a/e2e_test.go +++ b/e2e_test.go @@ -86,14 +86,11 @@ func TestE2EBasicScan(t *testing.T) { defer mock.Close() scanner := NewScanner(1, 2*time.Second, mock.port, "", 1, nil, false) - working, suspicious := scanner.Scan(context.Background(), sliceToChannel([]string{mock.ip})) + working := scanner.Scan(context.Background(), sliceToChannel([]string{mock.ip})) if len(working) != 1 { t.Errorf("Expected 1 working, got %d", len(working)) } - if suspicious != 0 { - t.Errorf("Public IP should not be suspicious, got %d", suspicious) - } } func TestE2EHijackingDetection(t *testing.T) { @@ -104,14 +101,11 @@ func TestE2EHijackingDetection(t *testing.T) { defer mock.Close() scanner := NewScanner(1, 2*time.Second, mock.port, "", 1, nil, false) - working, suspicious := scanner.Scan(context.Background(), sliceToChannel([]string{mock.ip})) + working := scanner.Scan(context.Background(), sliceToChannel([]string{mock.ip})) if len(working) != 0 { t.Errorf("Private IP response should not be working, got %d", len(working)) } - if suspicious != 1 { - t.Errorf("Private IP response should be suspicious, got %d", suspicious) - } } func TestE2EDomainVerification(t *testing.T) { @@ -122,7 +116,7 @@ func TestE2EDomainVerification(t *testing.T) { defer mock.Close() scanner := NewScanner(1, 2*time.Second, mock.port, "test.example.com", 1, nil, false) - working, _ := scanner.Scan(context.Background(), sliceToChannel([]string{mock.ip})) + working := scanner.Scan(context.Background(), sliceToChannel([]string{mock.ip})) if len(working) != 1 { t.Errorf("Domain verification should pass, got %d working", len(working)) @@ -168,7 +162,7 @@ func TestE2EWorkerPool(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - working, _ := scanner.Scan(ctx, sliceToChannel(ips)) + working := scanner.Scan(ctx, sliceToChannel(ips)) if len(working) != 1 { t.Errorf("Expected 1 working, got %d", len(working)) @@ -185,21 +179,18 @@ func TestE2EAllPrivateRanges(t *testing.T) { } scanner := NewScanner(1, 2*time.Second, mock.port, "", 1, nil, false) - working, suspicious := scanner.Scan(context.Background(), sliceToChannel([]string{mock.ip})) + working := scanner.Scan(context.Background(), sliceToChannel([]string{mock.ip})) mock.Close() if len(working) != 0 { t.Errorf("%s should not be working", ip) } - if suspicious != 1 { - t.Errorf("%s should be suspicious", ip) - } } } func TestE2ETimeout(t *testing.T) { scanner := NewScanner(1, 100*time.Millisecond, 65534, "", 1, nil, false) - working, _ := scanner.Scan(context.Background(), sliceToChannel([]string{"127.0.0.1"})) + working := scanner.Scan(context.Background(), sliceToChannel([]string{"127.0.0.1"})) if len(working) != 0 { t.Error("Non-responsive should not be working") diff --git a/main.go b/main.go index 14d1498..55046ce 100644 --- a/main.go +++ b/main.go @@ -24,17 +24,26 @@ func main() { os.Exit(1) } - app, err := NewApp(cfg) + source, err := newIPSource(cfg) if err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } - defer app.Close() - ips := app.source.IPs() - totalIPs := app.source.Count() + var outFile *os.File + if cfg.OutputFile != "" { + outFile, err = os.Create(cfg.OutputFile) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + defer outFile.Close() + } + + ips := source.IPs() + totalIPs := source.Count() - PrintBanner(os.Stderr, cfg, totalIPs, version) + printBanner(os.Stderr, cfg, totalIPs, version) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -49,18 +58,14 @@ func main() { cancel() }() - // Phase 1: Scan scanner := NewScanner(cfg.Workers, cfg.Timeout, 53, cfg.Domain, totalIPs, os.Stderr, cfg.Progress) - workingDNS, suspiciousCount := scanner.Scan(ctx, ips) - _ = suspiciousCount + workingDNS := scanner.Scan(ctx, ips) - // Phase 2: Verify (optional) if cfg.VerifyBinary != "" && len(workingDNS) > 0 { verifier := NewSlipstreamVerifier(cfg.VerifyBinary, cfg.Domain, cfg.Timeout, os.Stderr, cfg.Progress) workingDNS = verifier.Verify(ctx, workingDNS) } - // Phase 3: Benchmark (optional) var benchResults []*BenchmarkResult if cfg.Domain != "" && len(workingDNS) > 0 { benchmarker := NewBenchmarker(cfg.Domain, 53, cfg.Timeout, os.Stderr, cfg.Progress) @@ -73,6 +78,20 @@ func main() { } // Output results + writeResults(cfg, outFile, workingDNS, benchResults, totalIPs) +} + +func newIPSource(cfg *Config) (IPSource, error) { + if cfg.InputFile != "" { + return NewFileSource(cfg.InputFile) + } + if cfg.Mode == "list" { + return NewDNSListSource(cfg.DataDir, cfg.Country) + } + return NewCIDRSource(cfg.DataDir, cfg.Country, cfg.Mode) +} + +func writeResults(cfg *Config, outFile *os.File, workingDNS []string, benchResults []*BenchmarkResult, totalIPs int) { var serverResults []ServerResult if len(benchResults) > 0 { for _, r := range benchResults { @@ -89,28 +108,28 @@ func main() { } } - outputStats := ScanStats{ + stats := ScanStats{ TotalScanned: int64(totalIPs), Found: int64(len(serverResults)), } if cfg.InputFile == "" { - outputStats.Country = cfg.Country - outputStats.Mode = cfg.Mode + stats.Country = cfg.Country + stats.Mode = cfg.Mode } var out io.Writer = os.Stdout - if app.outFile != nil { - out = app.outFile + if outFile != nil { + out = outFile } if cfg.JSONOutput { - NewJSONWriter(out).Write(serverResults, outputStats) + NewJSONWriter(out).Write(serverResults, stats) } else { - if app.outFile != nil || !cfg.Progress { - NewTextWriter(out).Write(serverResults, outputStats) + if outFile != nil || !cfg.Progress { + NewTextWriter(out).Write(serverResults, stats) } if cfg.Progress { - PrintUsageHint(os.Stderr, workingDNS, cfg.Domain) + printUsageHint(os.Stderr, workingDNS, cfg.Domain) } } } diff --git a/output.go b/output.go index cdb3178..464876a 100644 --- a/output.go +++ b/output.go @@ -103,7 +103,7 @@ func (t *TextWriter) Write(results []ServerResult, stats ScanStats) error { return nil } -func PrintBanner(w io.Writer, cfg *Config, totalIPs int, version string) { +func printBanner(w io.Writer, cfg *Config, totalIPs int, version string) { if !cfg.Progress { fmt.Fprintf(w, "Scanning %d IPs...\n", totalIPs) return @@ -124,7 +124,7 @@ func PrintBanner(w io.Writer, cfg *Config, totalIPs int, version string) { fmt.Fprintf(w, "---\n") } -func PrintUsageHint(w io.Writer, ips []string, domain string) { +func printUsageHint(w io.Writer, ips []string, domain string) { if len(ips) == 0 { return } diff --git a/scanner.go b/scanner.go index 3043f28..ba9eff0 100644 --- a/scanner.go +++ b/scanner.go @@ -48,13 +48,15 @@ func NewScanner(workers int, timeout time.Duration, port int, verifyDomain strin } } -func (s *Scanner) Scan(ctx context.Context, ips <-chan string) (working []string, suspicious int) { +func (s *Scanner) Scan(ctx context.Context, ips <-chan string) []string { prog := NewProgress(s.total, s.showProgress) tickCtx, stopTick := context.WithCancel(ctx) defer stopTick() go s.tick(tickCtx, prog) + var working []string + var suspicious int for result := range s.run(ctx, ips, prog) { if result.Suspicious { suspicious++ @@ -65,7 +67,7 @@ func (s *Scanner) Scan(ctx context.Context, ips <-chan string) (working []string } s.summary(prog, suspicious) - return + return working } func (s *Scanner) tick(ctx context.Context, prog *Progress) { From dd897726dfa3108ff93dafdd3f4ca748a490cdf8 Mon Sep 17 00:00:00 2001 From: Night Owl Nerd <256460992+nightowlnerd@users.noreply.github.com> Date: Mon, 2 Feb 2026 14:57:35 +0100 Subject: [PATCH 10/11] style: consistent bold green pipe-separated progress summaries --- benchmark.go | 6 +++--- scanner.go | 8 ++++---- verify.go | 6 +++--- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/benchmark.go b/benchmark.go index 777ad07..04e7811 100644 --- a/benchmark.go +++ b/benchmark.go @@ -95,7 +95,6 @@ func (b *Benchmarker) Benchmark(ctx context.Context, ips []string) []*BenchmarkR prog := NewProgress(len(ips), b.showProgress) tickCtx, stopTick := context.WithCancel(ctx) - defer stopTick() go b.tick(tickCtx, prog) var results []*BenchmarkResult @@ -109,6 +108,7 @@ func (b *Benchmarker) Benchmark(ctx context.Context, ips []string) []*BenchmarkR return results[i].QPS() > results[j].QPS() }) + stopTick() b.summary(results, len(ips)) return results } @@ -126,7 +126,6 @@ func (b *Benchmarker) tick(ctx context.Context, prog *Progress) { fmt.Fprintf(b.output, "\rBenchmarking: %d/%d tested, %d passed ", st.Processed, st.Total, st.Success) case <-ctx.Done(): - fmt.Fprint(b.output, "\r\033[K") return } } @@ -136,7 +135,8 @@ func (b *Benchmarker) summary(results []*BenchmarkResult, total int) { if !b.showProgress || b.output == nil { return } - fmt.Fprintf(b.output, "Benchmark: %d/%d passed (sorted by throughput)\n", len(results), total) + fmt.Fprintf(b.output, "\r\033[1;32mBenchmark: %d/%d | Passed: %d | sorted by QPS\033[0m \n", total, total, len(results)) + fmt.Fprintln(b.output, "---") for _, r := range results { color := "\033[33m" if r.SuccessRate() >= 85 { diff --git a/scanner.go b/scanner.go index ba9eff0..ac88d45 100644 --- a/scanner.go +++ b/scanner.go @@ -52,7 +52,6 @@ func (s *Scanner) Scan(ctx context.Context, ips <-chan string) []string { prog := NewProgress(s.total, s.showProgress) tickCtx, stopTick := context.WithCancel(ctx) - defer stopTick() go s.tick(tickCtx, prog) var working []string @@ -66,6 +65,7 @@ func (s *Scanner) Scan(ctx context.Context, ips <-chan string) []string { } } + stopTick() s.summary(prog, suspicious) return working } @@ -85,7 +85,6 @@ func (s *Scanner) tick(ctx context.Context, prog *Progress) { fmt.Fprintf(s.output, "\rScanned: %d/%d (%.1f%%) | Found: %d | %.0f IPs/sec ", st.Processed, st.Total, pct, st.Success, rate) case <-ctx.Done(): - fmt.Fprint(s.output, "\r\033[K") return } } @@ -96,8 +95,9 @@ func (s *Scanner) summary(prog *Progress, suspicious int) { return } st := prog.Stats() - fmt.Fprintf(s.output, "Completed: %d IPs in %v\n", st.Processed, st.Elapsed.Round(time.Millisecond)) - fmt.Fprintf(s.output, "Found: %d DNS candidates\n", st.Success) + rate := float64(st.Processed) / st.Elapsed.Seconds() + fmt.Fprintf(s.output, "\r\033[1;32mScan: %d/%d | Found: %d | %.0f IPs/sec | %v\033[0m \n", + st.Processed, st.Total, st.Success, rate, st.Elapsed.Round(time.Millisecond)) if suspicious > 0 { fmt.Fprintf(s.output, "\033[33mWarning: %d servers returned private IPs (possible DNS hijacking)\033[0m\n", suspicious) } diff --git a/verify.go b/verify.go index 60c3b9d..912b860 100644 --- a/verify.go +++ b/verify.go @@ -45,13 +45,13 @@ func (v *SlipstreamVerifier) Verify(ctx context.Context, ips []string) []string prog := NewProgress(len(ips), v.showProgress) tickCtx, stopTick := context.WithCancel(ctx) - defer stopTick() go v.tick(tickCtx, prog) var verified []string for _, ip := range ips { select { case <-ctx.Done(): + stopTick() v.summary(prog, len(verified), len(ips)) return verified default: @@ -63,6 +63,7 @@ func (v *SlipstreamVerifier) Verify(ctx context.Context, ips []string) []string } } + stopTick() v.summary(prog, len(verified), len(ips)) return verified } @@ -80,7 +81,6 @@ func (v *SlipstreamVerifier) tick(ctx context.Context, prog *Progress) { fmt.Fprintf(v.output, "\rVerifying: %d/%d tested, %d passed ", st.Processed, st.Total, st.Success) case <-ctx.Done(): - fmt.Fprint(v.output, "\r\033[K") return } } @@ -90,7 +90,7 @@ func (v *SlipstreamVerifier) summary(prog *Progress, passed, total int) { if !v.showProgress || v.output == nil { return } - fmt.Fprintf(v.output, "%s: %d/%d passed\n", v.Name(), passed, total) + fmt.Fprintf(v.output, "\r\033[1;32mVerify: %d/%d | Passed: %d | %s\033[0m \n", total, total, passed, v.Name()) } func (v *SlipstreamVerifier) testIP(ip string) bool { From a149ba7a1a360d0bd06c0d4a12833b1e6f76db66 Mon Sep 17 00:00:00 2001 From: Night Owl Nerd <256460992+nightowlnerd@users.noreply.github.com> Date: Mon, 2 Feb 2026 16:45:38 +0100 Subject: [PATCH 11/11] chore: remove redundant comment --- main.go | 1 - 1 file changed, 1 deletion(-) diff --git a/main.go b/main.go index 55046ce..b9ed6bd 100644 --- a/main.go +++ b/main.go @@ -77,7 +77,6 @@ func main() { } } - // Output results writeResults(cfg, outFile, workingDNS, benchResults, totalIPs) }