diff --git a/pkg/dmsg/client_dial.go b/pkg/dmsg/client_dial.go index 88c437b3..5d54c57b 100644 --- a/pkg/dmsg/client_dial.go +++ b/pkg/dmsg/client_dial.go @@ -40,7 +40,7 @@ func (ce *Client) DialStream(ctx context.Context, addr Addr) (*Stream, error) { // Phase 0: Try cached route first (server that last successfully reached this destination). if cachedSrvPK, ok := ce.getCachedRoute(addr.PK); ok { if dSes, ok := ce.clientSession(ce.porter, cachedSrvPK); ok { - stream, err := dSes.DialStream(addr) + stream, err := dSes.DialStream(ctx, addr) if err != nil { ce.log.WithError(err).WithField("server", cachedSrvPK). Debug("DialStream failed via cached route, evicting") @@ -58,7 +58,7 @@ func (ce *Client) DialStream(ctx context.Context, addr Addr) (*Stream, error) { // Sort by latency so the lowest-latency server is tried first. delegatedSessions := ce.sortedDelegatedSessions(entry.Client.DelegatedServers) for _, dSes := range delegatedSessions { - stream, err := dSes.DialStream(addr) + stream, err := dSes.DialStream(ctx, addr) if err != nil { ce.log.WithError(err).WithField("server", dSes.RemotePK()). Debug("DialStream failed via existing session, trying next server") @@ -73,7 +73,7 @@ func (ce *Client) DialStream(ctx context.Context, addr Addr) (*Stream, error) { // Sorted by latency. meshSessions := ce.sortedMeshSessions(entry.Client.DelegatedServers) for _, ses := range meshSessions { - stream, err := ses.DialStream(addr) + stream, err := ses.DialStream(ctx, addr) if err != nil { ce.log.WithError(err).WithField("server", ses.RemotePK()). Debug("DialStream failed via mesh, trying next server") @@ -89,7 +89,7 @@ func (ce *Client) DialStream(ctx context.Context, addr Addr) (*Stream, error) { if err != nil { continue } - stream, err := dSes.DialStream(addr) + stream, err := dSes.DialStream(ctx, addr) if err != nil { ce.log.WithError(err).WithField("server", srvPK). Debug("DialStream failed via new session, trying next server") diff --git a/pkg/dmsg/client_session.go b/pkg/dmsg/client_session.go index bfe8f24d..7e961f75 100644 --- a/pkg/dmsg/client_session.go +++ b/pkg/dmsg/client_session.go @@ -2,6 +2,7 @@ package dmsg import ( + "context" "errors" "net" "time" @@ -28,7 +29,9 @@ func makeClientSession(entity *EntityCommon, porter *netutil.Porter, conn net.Co } // DialStream attempts to dial a stream to a remote client via the dmsg server that this session is connected to. -func (cs *ClientSession) DialStream(dst Addr) (dStr *Stream, err error) { +// The context is used to cancel the dial if the caller's deadline expires — this prevents ephemeral port +// leaks when many dials are attempted and the caller gives up before the handshake completes. +func (cs *ClientSession) DialStream(ctx context.Context, dst Addr) (dStr *Stream, err error) { log := cs.log. WithField("func", "ClientSession.DialStream"). WithField("dst_addr", dst) @@ -37,7 +40,7 @@ func (cs *ClientSession) DialStream(dst Addr) (dStr *Stream, err error) { return nil, err } - // Close stream on failure. + // Close stream on failure — this frees the reserved ephemeral port. defer func() { if err != nil { log.WithError(err). @@ -46,6 +49,23 @@ func (cs *ClientSession) DialStream(dst Addr) (dStr *Stream, err error) { } }() + // If the caller's context is canceled, close the stream to interrupt + // any blocked read/write and free the ephemeral port immediately. + ctxDone := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + dStr.Close() //nolint:errcheck,gosec + case <-ctxDone: + } + }() + defer close(ctxDone) + + // Check context before starting. + if ctx.Err() != nil { + return nil, ctx.Err() + } + // Prepare deadline. if err = dStr.SetDeadline(time.Now().Add(HandshakeTimeout)); err != nil { return nil, err diff --git a/pkg/dmsghttp/http_transport_test.go b/pkg/dmsghttp/http_transport_test.go index d0b40158..3401473b 100644 --- a/pkg/dmsghttp/http_transport_test.go +++ b/pkg/dmsghttp/http_transport_test.go @@ -70,19 +70,21 @@ func TestHTTPTransport_RoundTrip(t *testing.T) { // Configure timeouts to prevent hanging on errors. httpC1 := http.Client{ Transport: MakeHTTPTransport(ctx, newDmsgClient(t, dc, minSessions, "client1")), - Timeout: 10 * time.Second, + Timeout: 30 * time.Second, } httpC2 := http.Client{ Transport: MakeHTTPTransport(ctx, newDmsgClient(t, dc, minSessions, "client2")), - Timeout: 10 * time.Second, + Timeout: 30 * time.Second, } httpC3 := http.Client{ Transport: MakeHTTPTransport(ctx, newDmsgClient(t, dc, minSessions, "client3")), - Timeout: 10 * time.Second, + Timeout: 30 * time.Second, } - // Allow time for dmsg sessions to stabilize on macOS - time.Sleep(200 * time.Millisecond) + // Allow time for dmsg sessions to stabilize across all platforms. + // CI runners are slower; 200ms was insufficient for noise handshakes + // to complete across 5 servers × 4 clients. + time.Sleep(2 * time.Second) // Act: http clients send requests concurrently. // - client1 sends "/index.html" requests.