diff --git a/conn.go b/conn.go index ef4a78bd..9990050f 100644 --- a/conn.go +++ b/conn.go @@ -451,12 +451,36 @@ func NewDialer(value *Dialer) *Dialer { return value.SetDefaults() } +// filterLocalAddrsByFamily filters localIPs to only include addresses matching the address family of targetIP. +// If targetIP is nil or no local IPs are provided, we return the original list. +// If no local IPs match the target's address family, we return an empty list. +func filterLocalAddrsByFamily(localIPs []net.IP, targetIP net.IP) []net.IP { + if targetIP == nil || len(localIPs) == 0 { + return localIPs + } + targetIsIPv4 := targetIP.To4() != nil + filtered := make([]net.IP, 0, len(localIPs)) + for _, ip := range localIPs { + ipIsIPv4 := ip.To4() != nil + if ipIsIPv4 == targetIsIPv4 { + filtered = append(filtered, ip) + } + } + return filtered +} + // SetRandomLocalAddr sets a random local address and port for the dialer. If either localIPs or localPorts are empty, // the IP or port, respectively, will be un-set and the system will choose. -func (d *Dialer) SetRandomLocalAddr(network string, localIPs []net.IP, localPorts []uint16) error { +// If targetIP is non-nil, localIPs are filtered to match the target's address family (IPv4 or IPv6) to prevent +// protocol mismatch errors when both IPv4 and IPv6 local addresses are configured. +func (d *Dialer) SetRandomLocalAddr(network string, localIPs []net.IP, localPorts []uint16, targetIP net.IP) error { var localIP net.IP if len(localIPs) != 0 { - localIP = localIPs[rand.Intn(len(localIPs))] + candidates := filterLocalAddrsByFamily(localIPs, targetIP) + if len(candidates) == 0 { + return fmt.Errorf("no selected local IPs %v match the address family of the target IP %s, so target would not be reachable", localIPs, targetIP.String()) + } + localIP = candidates[rand.Intn(len(candidates))] } var localPort int if len(localPorts) != 0 { diff --git a/conn_localaddr_test.go b/conn_localaddr_test.go new file mode 100644 index 00000000..4440f2c1 --- /dev/null +++ b/conn_localaddr_test.go @@ -0,0 +1,72 @@ +package zgrab2 + +import ( + "net" + "testing" +) + +var ( + testIPv4a = net.ParseIP("192.168.1.1") + testIPv4b = net.ParseIP("192.168.1.2") + testIPv6a = net.ParseIP("2001:db8::1") + testIPv6b = net.ParseIP("2001:db8::2") + testTargetIPv4 = net.ParseIP("192.168.1.100") + testTargetIPv6 = net.ParseIP("2001:db8::100") +) + +func TestFilterLocalAddrsByFamily_IPv4Target(t *testing.T) { + result := filterLocalAddrsByFamily([]net.IP{testIPv4a, testIPv6a, testIPv4b, testIPv6b}, testTargetIPv4) + if len(result) != 2 { + t.Fatalf("expected 2 IPv4 addresses, got %d: %v", len(result), result) + } + for _, ip := range result { + if ip.To4() == nil { + t.Errorf("expected only IPv4 addresses, got %s", ip) + } + } +} + +func TestFilterLocalAddrsByFamily_IPv6Target(t *testing.T) { + result := filterLocalAddrsByFamily([]net.IP{testIPv4a, testIPv6a, testIPv6b}, testTargetIPv6) + if len(result) != 2 { + t.Fatalf("expected 2 IPv6 addresses, got %d: %v", len(result), result) + } + for _, ip := range result { + if ip.To4() != nil { + t.Errorf("expected only IPv6 addresses, got %s", ip) + } + } +} + +func TestFilterLocalAddrsByFamily_NilTarget(t *testing.T) { + result := filterLocalAddrsByFamily([]net.IP{testIPv4a, testIPv6a}, nil) + if len(result) != 2 { + t.Fatalf("expected all addresses returned when target is nil, got %d", len(result)) + } +} + +func TestFilterLocalAddrsByFamily_NoMatchFallsBack(t *testing.T) { + result := filterLocalAddrsByFamily([]net.IP{testIPv6a, testIPv6b}, testTargetIPv4) + if len(result) != 0 { + t.Fatalf("expected empty list, got %d: %v", len(result), result) + } +} + +func TestFilterLocalAddrsByFamily_EmptyInput(t *testing.T) { + result := filterLocalAddrsByFamily(nil, testTargetIPv4) + if len(result) != 0 { + t.Fatalf("expected empty result for nil input, got %d", len(result)) + } + + result = filterLocalAddrsByFamily([]net.IP{}, testTargetIPv4) + if len(result) != 0 { + t.Fatalf("expected empty result for empty input, got %d", len(result)) + } +} + +func TestFilterLocalAddrsByFamily_AllSameFamily(t *testing.T) { + result := filterLocalAddrsByFamily([]net.IP{testIPv4a, testIPv4b}, testTargetIPv4) + if len(result) != 2 { + t.Fatalf("expected 2 addresses when all match, got %d", len(result)) + } +} diff --git a/processing.go b/processing.go index ca62f0b2..7fa5445f 100644 --- a/processing.go +++ b/processing.go @@ -90,7 +90,7 @@ func GetDefaultTCPDialer(flags *BaseFlags) func(ctx context.Context, t *ScanTarg } } } - err := dialer.SetRandomLocalAddr("tcp", config.localAddrs, config.localPorts) + err := dialer.SetRandomLocalAddr("tcp", config.localAddrs, config.localPorts, t.IP) if err != nil { return nil, fmt.Errorf("could not set random local address: %w", err) } @@ -155,7 +155,7 @@ func GetDefaultUDPDialer(flags *BaseFlags) func(ctx context.Context, t *ScanTarg // create dialer once and reuse it return func(ctx context.Context, t *ScanTarget, addr string) (net.Conn, error) { dialer := GetTimeoutConnectionDialer(flags.ConnectTimeout, flags.TargetTimeout) - err := dialer.SetRandomLocalAddr("udp", config.localAddrs, config.localPorts) + err := dialer.SetRandomLocalAddr("udp", config.localAddrs, config.localPorts, t.IP) if err != nil { return nil, fmt.Errorf("could not set random local address: %w", err) } @@ -216,7 +216,7 @@ func grabTarget(ctx context.Context, input ScanTarget, m *Monitor) *Grab { } // resolve the target's IP here once, so it doesn't need to be resolved in each module dialer := NewDialer(nil) - err := dialer.SetRandomLocalAddr("udp", config.localAddrs, config.localPorts) + err := dialer.SetRandomLocalAddr("udp", config.localAddrs, config.localPorts, nil) if err != nil { return onResolutionFailure(input, m, fmt.Errorf("could not set random local address: %w", err)) }