Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,12 +451,36 @@ func NewDialer(value *Dialer) *Dialer {
return value.SetDefaults()
}

// filterLocalAddrsByFamily filters localIPs to only include addresses matching the address family of targetIP.
// If targetIP is nil or no local IPs are provided, we return the original list.
// If no local IPs match the target's address family, we return an empty list.
func filterLocalAddrsByFamily(localIPs []net.IP, targetIP net.IP) []net.IP {
if targetIP == nil || len(localIPs) == 0 {
return localIPs
}
targetIsIPv4 := targetIP.To4() != nil
filtered := make([]net.IP, 0, len(localIPs))
for _, ip := range localIPs {
ipIsIPv4 := ip.To4() != nil
if ipIsIPv4 == targetIsIPv4 {
filtered = append(filtered, ip)
}
}
return filtered
}

// SetRandomLocalAddr sets a random local address and port for the dialer. If either localIPs or localPorts are empty,
// the IP or port, respectively, will be un-set and the system will choose.
func (d *Dialer) SetRandomLocalAddr(network string, localIPs []net.IP, localPorts []uint16) error {
// If targetIP is non-nil, localIPs are filtered to match the target's address family (IPv4 or IPv6) to prevent
// protocol mismatch errors when both IPv4 and IPv6 local addresses are configured.
func (d *Dialer) SetRandomLocalAddr(network string, localIPs []net.IP, localPorts []uint16, targetIP net.IP) error {
var localIP net.IP
if len(localIPs) != 0 {
localIP = localIPs[rand.Intn(len(localIPs))]
candidates := filterLocalAddrsByFamily(localIPs, targetIP)
if len(candidates) == 0 {
return fmt.Errorf("no selected local IPs %v match the address family of the target IP %s, so target would not be reachable", localIPs, targetIP.String())
}
localIP = candidates[rand.Intn(len(candidates))]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rand.Intn will panic of called with n <= 0. I'd suggest we add a check to be sure the list's size is > 0. (Note - this is really only an issue when combined with my above suggestion, since as you've written it I don't think we'd hit this case)

}
var localPort int
if len(localPorts) != 0 {
Expand Down
72 changes: 72 additions & 0 deletions conn_localaddr_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package zgrab2

import (
"net"
"testing"
)

var (
testIPv4a = net.ParseIP("192.168.1.1")
testIPv4b = net.ParseIP("192.168.1.2")
testIPv6a = net.ParseIP("2001:db8::1")
testIPv6b = net.ParseIP("2001:db8::2")
testTargetIPv4 = net.ParseIP("192.168.1.100")
testTargetIPv6 = net.ParseIP("2001:db8::100")
)

func TestFilterLocalAddrsByFamily_IPv4Target(t *testing.T) {
result := filterLocalAddrsByFamily([]net.IP{testIPv4a, testIPv6a, testIPv4b, testIPv6b}, testTargetIPv4)
if len(result) != 2 {
t.Fatalf("expected 2 IPv4 addresses, got %d: %v", len(result), result)
}
for _, ip := range result {
if ip.To4() == nil {
t.Errorf("expected only IPv4 addresses, got %s", ip)
}
}
}

func TestFilterLocalAddrsByFamily_IPv6Target(t *testing.T) {
result := filterLocalAddrsByFamily([]net.IP{testIPv4a, testIPv6a, testIPv6b}, testTargetIPv6)
if len(result) != 2 {
t.Fatalf("expected 2 IPv6 addresses, got %d: %v", len(result), result)
}
for _, ip := range result {
if ip.To4() != nil {
t.Errorf("expected only IPv6 addresses, got %s", ip)
}
}
}

func TestFilterLocalAddrsByFamily_NilTarget(t *testing.T) {
result := filterLocalAddrsByFamily([]net.IP{testIPv4a, testIPv6a}, nil)
if len(result) != 2 {
t.Fatalf("expected all addresses returned when target is nil, got %d", len(result))
}
}

func TestFilterLocalAddrsByFamily_NoMatchFallsBack(t *testing.T) {
result := filterLocalAddrsByFamily([]net.IP{testIPv6a, testIPv6b}, testTargetIPv4)
if len(result) != 0 {
t.Fatalf("expected empty list, got %d: %v", len(result), result)
}
}

func TestFilterLocalAddrsByFamily_EmptyInput(t *testing.T) {
result := filterLocalAddrsByFamily(nil, testTargetIPv4)
if len(result) != 0 {
t.Fatalf("expected empty result for nil input, got %d", len(result))
}

result = filterLocalAddrsByFamily([]net.IP{}, testTargetIPv4)
if len(result) != 0 {
t.Fatalf("expected empty result for empty input, got %d", len(result))
}
}

func TestFilterLocalAddrsByFamily_AllSameFamily(t *testing.T) {
result := filterLocalAddrsByFamily([]net.IP{testIPv4a, testIPv4b}, testTargetIPv4)
if len(result) != 2 {
t.Fatalf("expected 2 addresses when all match, got %d", len(result))
}
}
6 changes: 3 additions & 3 deletions processing.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func GetDefaultTCPDialer(flags *BaseFlags) func(ctx context.Context, t *ScanTarg
}
}
}
err := dialer.SetRandomLocalAddr("tcp", config.localAddrs, config.localPorts)
err := dialer.SetRandomLocalAddr("tcp", config.localAddrs, config.localPorts, t.IP)
if err != nil {
return nil, fmt.Errorf("could not set random local address: %w", err)
}
Expand Down Expand Up @@ -155,7 +155,7 @@ func GetDefaultUDPDialer(flags *BaseFlags) func(ctx context.Context, t *ScanTarg
// create dialer once and reuse it
return func(ctx context.Context, t *ScanTarget, addr string) (net.Conn, error) {
dialer := GetTimeoutConnectionDialer(flags.ConnectTimeout, flags.TargetTimeout)
err := dialer.SetRandomLocalAddr("udp", config.localAddrs, config.localPorts)
err := dialer.SetRandomLocalAddr("udp", config.localAddrs, config.localPorts, t.IP)
if err != nil {
return nil, fmt.Errorf("could not set random local address: %w", err)
}
Expand Down Expand Up @@ -216,7 +216,7 @@ func grabTarget(ctx context.Context, input ScanTarget, m *Monitor) *Grab {
}
// resolve the target's IP here once, so it doesn't need to be resolved in each module
dialer := NewDialer(nil)
err := dialer.SetRandomLocalAddr("udp", config.localAddrs, config.localPorts)
err := dialer.SetRandomLocalAddr("udp", config.localAddrs, config.localPorts, nil)
if err != nil {
return onResolutionFailure(input, m, fmt.Errorf("could not set random local address: %w", err))
}
Expand Down