Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pkg/dmsg/client_dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (ce *Client) DialStream(ctx context.Context, addr Addr) (*Stream, error) {
// Phase 0: Try cached route first (server that last successfully reached this destination).
if cachedSrvPK, ok := ce.getCachedRoute(addr.PK); ok {
if dSes, ok := ce.clientSession(ce.porter, cachedSrvPK); ok {
stream, err := dSes.DialStream(addr)
stream, err := dSes.DialStream(ctx, addr)
if err != nil {
ce.log.WithError(err).WithField("server", cachedSrvPK).
Debug("DialStream failed via cached route, evicting")
Expand All @@ -58,7 +58,7 @@ func (ce *Client) DialStream(ctx context.Context, addr Addr) (*Stream, error) {
// Sort by latency so the lowest-latency server is tried first.
delegatedSessions := ce.sortedDelegatedSessions(entry.Client.DelegatedServers)
for _, dSes := range delegatedSessions {
stream, err := dSes.DialStream(addr)
stream, err := dSes.DialStream(ctx, addr)
if err != nil {
ce.log.WithError(err).WithField("server", dSes.RemotePK()).
Debug("DialStream failed via existing session, trying next server")
Expand All @@ -73,7 +73,7 @@ func (ce *Client) DialStream(ctx context.Context, addr Addr) (*Stream, error) {
// Sorted by latency.
meshSessions := ce.sortedMeshSessions(entry.Client.DelegatedServers)
for _, ses := range meshSessions {
stream, err := ses.DialStream(addr)
stream, err := ses.DialStream(ctx, addr)
if err != nil {
ce.log.WithError(err).WithField("server", ses.RemotePK()).
Debug("DialStream failed via mesh, trying next server")
Expand All @@ -89,7 +89,7 @@ func (ce *Client) DialStream(ctx context.Context, addr Addr) (*Stream, error) {
if err != nil {
continue
}
stream, err := dSes.DialStream(addr)
stream, err := dSes.DialStream(ctx, addr)
if err != nil {
ce.log.WithError(err).WithField("server", srvPK).
Debug("DialStream failed via new session, trying next server")
Expand Down
24 changes: 22 additions & 2 deletions pkg/dmsg/client_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package dmsg

import (
"context"
"errors"
"net"
"time"
Expand All @@ -28,7 +29,9 @@ func makeClientSession(entity *EntityCommon, porter *netutil.Porter, conn net.Co
}

// DialStream attempts to dial a stream to a remote client via the dmsg server that this session is connected to.
func (cs *ClientSession) DialStream(dst Addr) (dStr *Stream, err error) {
// The context is used to cancel the dial if the caller's deadline expires — this prevents ephemeral port
// leaks when many dials are attempted and the caller gives up before the handshake completes.
func (cs *ClientSession) DialStream(ctx context.Context, dst Addr) (dStr *Stream, err error) {
log := cs.log.
WithField("func", "ClientSession.DialStream").
WithField("dst_addr", dst)
Expand All @@ -37,7 +40,7 @@ func (cs *ClientSession) DialStream(dst Addr) (dStr *Stream, err error) {
return nil, err
}

// Close stream on failure.
// Close stream on failure — this frees the reserved ephemeral port.
defer func() {
if err != nil {
log.WithError(err).
Expand All @@ -46,6 +49,23 @@ func (cs *ClientSession) DialStream(dst Addr) (dStr *Stream, err error) {
}
}()

// If the caller's context is canceled, close the stream to interrupt
// any blocked read/write and free the ephemeral port immediately.
ctxDone := make(chan struct{})
go func() {
select {
case <-ctx.Done():
dStr.Close() //nolint:errcheck,gosec
case <-ctxDone:
}
}()
defer close(ctxDone)

// Check context before starting.
if ctx.Err() != nil {
return nil, ctx.Err()
}

// Prepare deadline.
if err = dStr.SetDeadline(time.Now().Add(HandshakeTimeout)); err != nil {
return nil, err
Expand Down
12 changes: 7 additions & 5 deletions pkg/dmsghttp/http_transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,21 @@ func TestHTTPTransport_RoundTrip(t *testing.T) {
// Configure timeouts to prevent hanging on errors.
httpC1 := http.Client{
Transport: MakeHTTPTransport(ctx, newDmsgClient(t, dc, minSessions, "client1")),
Timeout: 10 * time.Second,
Timeout: 30 * time.Second,
}
httpC2 := http.Client{
Transport: MakeHTTPTransport(ctx, newDmsgClient(t, dc, minSessions, "client2")),
Timeout: 10 * time.Second,
Timeout: 30 * time.Second,
}
httpC3 := http.Client{
Transport: MakeHTTPTransport(ctx, newDmsgClient(t, dc, minSessions, "client3")),
Timeout: 10 * time.Second,
Timeout: 30 * time.Second,
}

// Allow time for dmsg sessions to stabilize on macOS
time.Sleep(200 * time.Millisecond)
// Allow time for dmsg sessions to stabilize across all platforms.
// CI runners are slower; 200ms was insufficient for noise handshakes
// to complete across 5 servers × 4 clients.
time.Sleep(2 * time.Second)

// Act: http clients send requests concurrently.
// - client1 sends "/index.html" requests.
Expand Down