diff --git a/cmd/dmsgweb/commands/dmsgweb.go b/cmd/dmsgweb/commands/dmsgweb.go index c505065c..a1dcb6c1 100644 --- a/cmd/dmsgweb/commands/dmsgweb.go +++ b/cmd/dmsgweb/commands/dmsgweb.go @@ -11,7 +11,7 @@ import ( "regexp" "strconv" "strings" - "sync" + "time" "github.com/chen3feng/safecast" "github.com/confiant-inc/go-socks5" @@ -288,20 +288,20 @@ dmsgweb conf file detected: ` + dwcfg if len(resolveDmsgAddr) == 0 && len(webPort) == 1 { if len(rawTCP) > 0 && rawTCP[0] { dlog.Debug("proxyTCPConn(-1)") - proxyTCPConn(-1) + proxyTCPConn(ctx, -1) } else { dlog.Debug("proxyHTTPConn(-1)") - proxyHTTPConn(-1) + proxyHTTPConn(ctx, -1) } } else { for i := range resolveDmsgAddr { wg.Add(1) if rawTCP[i] { dlog.Debug("proxyTCPConn(" + fmt.Sprintf("%v", i) + ")") - go proxyTCPConn(i) + go proxyTCPConn(ctx, i) } else { dlog.Debug("proxyHTTPConn(" + fmt.Sprintf("%v", i) + ")") - go proxyHTTPConn(i) + go proxyHTTPConn(ctx, i) } } } @@ -309,7 +309,7 @@ dmsgweb conf file detected: ` + dwcfg }, } -func proxyTCPConn(n int) { +func proxyTCPConn(ctx context.Context, n int) { //nolint:unparam var thiswebport uint if n == -1 { thiswebport = webPort[0] @@ -337,42 +337,45 @@ func proxyTCPConn(n int) { defer ioutil.CloseQuietly(conn, dlog) dp, ok := safecast.To[uint16](dmsgPorts[n]) if !ok { - dlog.Fatal("uint16 overflow when converting dmsg port") + dlog.WithError(fmt.Errorf("uint16 overflow for port %v", dmsgPorts[n])). + Warn("Failed to convert dmsg port") + return } dlog.Debug(fmt.Sprintf("Dialing %v:%v", dialPK[n].String(), dp)) - dmsgConn, err := dmsgC.DialStream(context.Background(), dmsg.Addr{PK: dialPK[n], Port: dp}) //nolint + dmsgConn, err := dmsgC.DialStream(ctx, dmsg.Addr{PK: dialPK[n], Port: dp}) if err != nil { dlog.WithError(err).Warn(fmt.Sprintf("Failed to dial dmsg address %v port %v", dialPK[n].String(), dmsgPorts[n])) return } - defer ioutil.CloseQuietly(dmsgConn, dlog) - var wg sync.WaitGroup - wg.Add(2) - + done := make(chan struct{}) go func() { - defer wg.Done() + defer close(done) _, err := io.Copy(dmsgConn, conn) if err != nil { - dlog.WithError(err).Warn("Error on io.Copy(dmsgConn, conn)") + dlog.WithError(err).Debug("io.Copy(dmsgConn, conn) ended") } }() - go func() { - defer wg.Done() - _, err := io.Copy(conn, dmsgConn) - if err != nil { - dlog.WithError(err).Warn("Error on io.Copy(conn, dmsgConn)") - } - }() + _, err = io.Copy(conn, dmsgConn) + if err != nil { + dlog.WithError(err).Debug("io.Copy(conn, dmsgConn) ended") + } - wg.Wait() + // Close both to unblock the goroutine's io.Copy. + if err := conn.Close(); err != nil { + dlog.WithError(err).Debug("Error closing client conn") + } + if err := dmsgConn.Close(); err != nil { + dlog.WithError(err).Debug("Error closing dmsg conn") + } + <-done }(conn, n, dmsgC) } } -func proxyHTTPConn(n int) { +func proxyHTTPConn(ctx context.Context, n int) { //nolint:unparam r := gin.New() r.Use(gin.Recovery()) @@ -380,6 +383,10 @@ func proxyHTTPConn(n int) { r.Use(loggingMiddleware()) r.Any("/*path", func(c *gin.Context) { + // Limit request body to 10MB to prevent resource exhaustion. + const maxBodySize = 10 << 20 + c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxBodySize) + var urlStr string if n > -1 { urlStr = fmt.Sprintf("dmsg://%s%s", resolveDmsgAddr[n], c.Param("path")) @@ -401,7 +408,7 @@ func proxyHTTPConn(n int) { } dlog.Debug(fmt.Sprintf("Proxying request: %s %s", c.Request.Method, urlStr)) - req, err := http.NewRequest(c.Request.Method, urlStr, c.Request.Body) + req, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, urlStr, c.Request.Body) if err != nil { c.String(http.StatusInternalServerError, "Failed to create HTTP request") dlog.WithError(err).Warn("Failed to create HTTP request") @@ -430,23 +437,44 @@ func proxyHTTPConn(n int) { c.Status(resp.StatusCode) if _, err := io.Copy(c.Writer, resp.Body); err != nil { - c.String(http.StatusInternalServerError, "Failed to copy response body") + // Status header is already written; cannot override with 500. + // Just log the error. dlog.WithError(err).Warn("Failed to copy response body") } }) + + var thiswebport uint + if n == -1 { + thiswebport = webPort[0] + } else { + thiswebport = webPort[n] + } + + srv := &http.Server{ + Addr: fmt.Sprintf(":%v", thiswebport), + Handler: r, + ReadHeaderTimeout: 5 * time.Second, + } + wg.Add(1) go func() { defer wg.Done() - var thiswebport uint - if n == -1 { - thiswebport = webPort[0] - } else { - thiswebport = webPort[n] - } dlog.Debug(fmt.Sprintf("Serving http on: http://127.0.0.1:%v", thiswebport)) - r.Run(":" + fmt.Sprintf("%v", thiswebport)) //nolint + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + dlog.WithError(err).Error("HTTP server error") + } dlog.Debug(fmt.Sprintf("Stopped serving http on: http://127.0.0.1:%v", thiswebport)) }() + + // Graceful shutdown on context cancellation. + go func() { //nolint:gosec // G118: context.Background is intentional — shutdown must outlive parent ctx + <-ctx.Done() + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := srv.Shutdown(shutdownCtx); err != nil { + dlog.WithError(err).Warn("HTTP server shutdown error") + } + }() } const envfileLinux = //nolint unused diff --git a/cmd/dmsgweb/commands/dmsgwebsrv.go b/cmd/dmsgweb/commands/dmsgwebsrv.go index de4dd7eb..3622c1c6 100644 --- a/cmd/dmsgweb/commands/dmsgwebsrv.go +++ b/cmd/dmsgweb/commands/dmsgwebsrv.go @@ -177,12 +177,13 @@ func proxyHTTPConnections(ctx context.Context, localPort uint, listener net.List } authRoute.Any("/*path", func(c *gin.Context) { targetURL := fmt.Sprintf("http://127.0.0.1:%d%s?%s", localPort, c.Request.URL.Path, c.Request.URL.RawQuery) + parsed, err := url.Parse(targetURL) + if err != nil { + dlog.Errorf("failed to parse target URL %q: %v", targetURL, err) + c.String(http.StatusInternalServerError, "Bad target URL") + return + } proxy := httputil.ReverseProxy{Director: func(req *http.Request) { - parsed, err := url.Parse(targetURL) - if err != nil { - dlog.Errorf("failed to parse target URL %q: %v", targetURL, err) - return - } req.URL = parsed req.Host = req.URL.Host }} @@ -211,12 +212,16 @@ func proxyHTTPConnections(ctx context.Context, localPort uint, listener net.List } } +// maxTCPConns is the maximum number of concurrent TCP proxy connections. +const maxTCPConns = 256 + func proxyTCPConnections(ctx context.Context, localPort uint, listener net.Listener) { // To track active connections for cleanup var connWg sync.WaitGroup connChan := make(chan net.Conn) activeConns := make(map[net.Conn]struct{}) connMutex := &sync.Mutex{} // Protect access to activeConns + sem := make(chan struct{}, maxTCPConns) // Goroutine to accept new connections go func() { @@ -241,11 +246,15 @@ func proxyTCPConnections(ctx context.Context, localPort uint, listener net.Liste select { case <-ctx.Done(): dlog.Info("Shutting down TCP proxy connections...") - listener.Close() //nolint + if err := listener.Close(); err != nil { + dlog.WithError(err).Debug("Error closing TCP listener") + } connMutex.Lock() for conn := range activeConns { - conn.Close() //nolint + if err := conn.Close(); err != nil { + dlog.WithError(err).Debug("Error closing active connection") + } } connMutex.Unlock() @@ -257,14 +266,30 @@ func proxyTCPConnections(ctx context.Context, localPort uint, listener net.Liste return } + // Limit concurrent connections. + select { + case sem <- struct{}{}: + default: + dlog.Warn("Max TCP connections reached, rejecting connection") + if err := conn.Close(); err != nil { + dlog.WithError(err).Debug("Error closing rejected connection") + } + continue + } + connMutex.Lock() activeConns[conn] = struct{}{} connMutex.Unlock() connWg.Add(1) go func(dmsgConn net.Conn) { + defer func() { <-sem }() defer connWg.Done() - defer dmsgConn.Close() //nolint + defer func() { + if err := dmsgConn.Close(); err != nil { + dlog.WithError(err).Debug("Error closing dmsg connection") + } + }() localConn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", localPort)) if err != nil { @@ -276,21 +301,27 @@ func proxyTCPConnections(ctx context.Context, localPort uint, listener net.Liste return } - defer localConn.Close() //nolint + done := make(chan struct{}) go func() { + defer close(done) _, err1 := io.Copy(dmsgConn, localConn) if err1 != nil { - dlog.WithError(err1).Warn("Error on io.Copy(dmsgConn, localConn)") + dlog.WithError(err1).Debug("io.Copy(dmsgConn, localConn) ended") } }() _, err2 := io.Copy(localConn, dmsgConn) if err2 != nil { - dlog.WithError(err2).Warn("Error on io.Copy(localConn, dmsgConn)") + dlog.WithError(err2).Debug("io.Copy(localConn, dmsgConn) ended") } // Close both to unblock the goroutine - dmsgConn.Close() //nolint - localConn.Close() //nolint + if err := dmsgConn.Close(); err != nil { + dlog.WithError(err).Debug("Error closing dmsg conn") + } + if err := localConn.Close(); err != nil { + dlog.WithError(err).Debug("Error closing local conn") + } + <-done connMutex.Lock() delete(activeConns, dmsgConn) diff --git a/pkg/dmsg/client.go b/pkg/dmsg/client.go index 5e5a0f44..916aa830 100644 --- a/pkg/dmsg/client.go +++ b/pkg/dmsg/client.go @@ -16,6 +16,14 @@ import ( "github.com/skycoin/dmsg/pkg/disc" ) +// entryCacheEntry holds a cached discovery entry with a timestamp. +type entryCacheEntry struct { + entry *disc.Entry + fetchedAt time.Time +} + +const entryCacheTTL = 30 * time.Second + // SessionDialCallback is triggered BEFORE a session is dialed to. // If a non-nil error is returned, the session dial is instantly terminated. type SessionDialCallback func(network, addr string) (err error) @@ -79,6 +87,16 @@ type Client struct { maxBO time.Duration // maximum backoff duration factor float64 // multiplier for the backoff duration that is applied on every retry + // routeCache maps destination client PK → server PK that last successfully + // relayed to that destination. Evicted on failure. + routeCache map[cipher.PubKey]cipher.PubKey + routeCacheMx sync.RWMutex + + // entryCache caches discovery entry lookups with TTL to avoid + // re-querying HTTP discovery on every request. + entryCache map[cipher.PubKey]entryCacheEntry + entryCacheMx sync.RWMutex + errCh chan error done chan struct{} once sync.Once @@ -97,15 +115,17 @@ func NewClient(pk cipher.PubKey, sk cipher.SecKey, dc disc.APIClient, conf *Conf conf.Ensure() c := &Client{ - ready: make(chan struct{}), - porter: netutil.NewPorter(netutil.PorterMinEphemeral), - errCh: make(chan error, 10), - done: make(chan struct{}), - conf: conf, - initBO: time.Second * 5, - bo: time.Second * 5, - maxBO: time.Minute, - factor: netutil.DefaultFactor, + ready: make(chan struct{}), + porter: netutil.NewPorter(netutil.PorterMinEphemeral), + routeCache: make(map[cipher.PubKey]cipher.PubKey), + entryCache: make(map[cipher.PubKey]entryCacheEntry), + errCh: make(chan error, 10), + done: make(chan struct{}), + conf: conf, + initBO: time.Second * 5, + bo: time.Second * 5, + maxBO: time.Minute, + factor: netutil.DefaultFactor, } // Init common fields. @@ -163,6 +183,7 @@ func (ce *Client) Serve(ctx context.Context) { }(cancellabelCtx) updateEntryLoopOnce := new(sync.Once) + pingLoopOnce := new(sync.Once) needInitialPost := true @@ -297,6 +318,7 @@ func (ce *Client) Serve(ctx context.Context) { // Only start the update entry loop once we have at least one session established. updateEntryLoopOnce.Do(func() { go ce.updateClientEntryLoop(cancellabelCtx, ce.done, ce.conf.ClientType) }) + pingLoopOnce.Do(func() { go ce.pingSessionsLoop(cancellabelCtx) }) // We dial all servers and wait for error or done signal. select { @@ -399,3 +421,78 @@ func hasPK(pks []cipher.PubKey, pk cipher.PubKey) bool { } return false } + +// getCachedRoute returns the server PK that last successfully reached the given destination. +func (ce *Client) getCachedRoute(dst cipher.PubKey) (cipher.PubKey, bool) { + ce.routeCacheMx.RLock() + srvPK, ok := ce.routeCache[dst] + ce.routeCacheMx.RUnlock() + return srvPK, ok +} + +// setCachedRoute records a successful route to a destination via a server. +func (ce *Client) setCachedRoute(dst, srvPK cipher.PubKey) { + ce.routeCacheMx.Lock() + ce.routeCache[dst] = srvPK + ce.routeCacheMx.Unlock() +} + +// evictCachedRoute removes a cached route on failure. +func (ce *Client) evictCachedRoute(dst cipher.PubKey) { + ce.routeCacheMx.Lock() + delete(ce.routeCache, dst) + ce.routeCacheMx.Unlock() +} + +// getCachedEntry returns a cached discovery entry if it exists and hasn't expired. +func (ce *Client) getCachedEntry(pk cipher.PubKey) (*disc.Entry, bool) { + ce.entryCacheMx.RLock() + cached, ok := ce.entryCache[pk] + ce.entryCacheMx.RUnlock() + if !ok || time.Since(cached.fetchedAt) > entryCacheTTL { + return nil, false + } + return cached.entry, true +} + +// setCachedEntry stores a discovery entry in the cache. +func (ce *Client) setCachedEntry(pk cipher.PubKey, entry *disc.Entry) { + ce.entryCacheMx.Lock() + ce.entryCache[pk] = entryCacheEntry{entry: entry, fetchedAt: time.Now()} + ce.entryCacheMx.Unlock() +} + +// pingSessionsLoop periodically pings all sessions to measure latency. +func (ce *Client) pingSessionsLoop(ctx context.Context) { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + // Do an initial ping immediately. + ce.pingSessions() + + for { + select { + case <-ctx.Done(): + return + case <-ce.done: + return + case <-ticker.C: + ce.pingSessions() + } + } +} + +func (ce *Client) pingSessions() { + sessions := ce.allClientSessions(ce.porter) + for _, ses := range sessions { + rtt, err := ses.Ping() + if err != nil { + ce.log.WithError(err).WithField("server", ses.RemotePK()). + Debug("Ping failed, keeping previous latency measurement") + continue + } + ses.SetLastPing(rtt) + ce.log.WithField("server", ses.RemotePK()).WithField("rtt", rtt). + Debug("Session ping measured") + } +} diff --git a/pkg/dmsg/client_dial.go b/pkg/dmsg/client_dial.go index cf8fa608..88c437b3 100644 --- a/pkg/dmsg/client_dial.go +++ b/pkg/dmsg/client_dial.go @@ -3,10 +3,14 @@ package dmsg import ( "context" + "math" "net" + "sort" "github.com/skycoin/skywire/pkg/skywire-utilities/pkg/cipher" "github.com/skycoin/skywire/pkg/skywire-utilities/pkg/netutil" + + "github.com/skycoin/dmsg/pkg/disc" ) // Listen listens on a given dmsg port. @@ -28,40 +32,58 @@ func (ce *Client) Dial(ctx context.Context, addr Addr) (net.Conn, error) { // DialStream dials to a remote client entity with the given address. func (ce *Client) DialStream(ctx context.Context, addr Addr) (*Stream, error) { - entry, err := getClientEntry(ctx, ce.dc, addr.PK) + entry, err := ce.getClientEntryCached(ctx, addr.PK) if err != nil { return nil, err } - // 1. Try existing sessions to the target's delegated servers (direct path, cheapest). - for _, srvPK := range entry.Client.DelegatedServers { - if dSes, ok := ce.clientSession(ce.porter, srvPK); ok { + // Phase 0: Try cached route first (server that last successfully reached this destination). + if cachedSrvPK, ok := ce.getCachedRoute(addr.PK); ok { + if dSes, ok := ce.clientSession(ce.porter, cachedSrvPK); ok { stream, err := dSes.DialStream(addr) if err != nil { - ce.log.WithError(err).WithField("server", srvPK). - Debug("DialStream failed via existing session, trying next server") - continue + ce.log.WithError(err).WithField("server", cachedSrvPK). + Debug("DialStream failed via cached route, evicting") + ce.evictCachedRoute(addr.PK) + } else { + return stream, nil } - return stream, nil + } else { + // Session no longer exists, evict stale route. + ce.evictCachedRoute(addr.PK) } } - // 2. Try all other existing sessions (mesh path — already connected, no new handshake). - // If servers are meshed, our server forwards the request to the target's server. - for _, ses := range ce.allClientSessions(ce.porter) { - if hasPK(entry.Client.DelegatedServers, ses.RemotePK()) { - continue // already tried above + // Phase 1: Try existing sessions to the target's delegated servers (direct path, cheapest). + // Sort by latency so the lowest-latency server is tried first. + delegatedSessions := ce.sortedDelegatedSessions(entry.Client.DelegatedServers) + for _, dSes := range delegatedSessions { + stream, err := dSes.DialStream(addr) + if err != nil { + ce.log.WithError(err).WithField("server", dSes.RemotePK()). + Debug("DialStream failed via existing session, trying next server") + continue } + ce.setCachedRoute(addr.PK, dSes.RemotePK()) + return stream, nil + } + + // Phase 2: Try all other existing sessions (mesh path — already connected, no new handshake). + // If servers are meshed, our server forwards the request to the target's server. + // Sorted by latency. + meshSessions := ce.sortedMeshSessions(entry.Client.DelegatedServers) + for _, ses := range meshSessions { stream, err := ses.DialStream(addr) if err != nil { ce.log.WithError(err).WithField("server", ses.RemotePK()). Debug("DialStream failed via mesh, trying next server") continue } + ce.setCachedRoute(addr.PK, ses.RemotePK()) return stream, nil } - // 3. Last resort: establish new sessions to the target's delegated servers. + // Phase 3: Last resort: establish new sessions to the target's delegated servers. for _, srvPK := range entry.Client.DelegatedServers { dSes, err := ce.EnsureAndObtainSession(ctx, srvPK) if err != nil { @@ -73,12 +95,70 @@ func (ce *Client) DialStream(ctx context.Context, addr Addr) (*Stream, error) { Debug("DialStream failed via new session, trying next server") continue } + ce.setCachedRoute(addr.PK, srvPK) return stream, nil } return nil, ErrCannotConnectToDelegated } +// getClientEntryCached returns a client entry, using the entry cache when possible. +func (ce *Client) getClientEntryCached(ctx context.Context, clientPK cipher.PubKey) (*disc.Entry, error) { + if entry, ok := ce.getCachedEntry(clientPK); ok { + return entry, nil + } + entry, err := getClientEntry(ctx, ce.dc, clientPK) + if err != nil { + return nil, err + } + ce.setCachedEntry(clientPK, entry) + return entry, nil +} + +// sortedDelegatedSessions returns existing sessions to the given delegated servers, +// sorted by ascending latency (lowest ping first). +func (ce *Client) sortedDelegatedSessions(delegatedServers []cipher.PubKey) []ClientSession { + var sessions []ClientSession + for _, srvPK := range delegatedServers { + if dSes, ok := ce.clientSession(ce.porter, srvPK); ok { + sessions = append(sessions, dSes) + } + } + sortSessionsByLatency(sessions) + return sessions +} + +// sortedMeshSessions returns all sessions NOT in the delegated list, +// sorted by ascending latency. +func (ce *Client) sortedMeshSessions(delegatedServers []cipher.PubKey) []ClientSession { + var sessions []ClientSession + for _, ses := range ce.allClientSessions(ce.porter) { + if hasPK(delegatedServers, ses.RemotePK()) { + continue + } + sessions = append(sessions, ses) + } + sortSessionsByLatency(sessions) + return sessions +} + +// sortSessionsByLatency sorts sessions by last measured ping latency (ascending). +// Sessions with no measurement (0) are sorted last. +func sortSessionsByLatency(sessions []ClientSession) { + sort.Slice(sessions, func(i, j int) bool { + pi := sessions[i].LastPing() + pj := sessions[j].LastPing() + // Treat 0 (unmeasured) as maximum latency. + if pi == 0 { + pi = math.MaxInt64 + } + if pj == 0 { + pj = math.MaxInt64 + } + return pi < pj + }) +} + // LookupIP dails to dmsg servers for public IP of the client. func (ce *Client) LookupIP(ctx context.Context, servers []cipher.PubKey) (myIP net.IP, err error) { diff --git a/pkg/dmsg/session_common.go b/pkg/dmsg/session_common.go index 9413f042..fa0f2bb7 100644 --- a/pkg/dmsg/session_common.go +++ b/pkg/dmsg/session_common.go @@ -7,6 +7,7 @@ import ( "io" "net" "sync" + "sync/atomic" "time" "github.com/chen3feng/safecast" @@ -34,6 +35,11 @@ type SessionCommon struct { rMx sync.Mutex wMx sync.Mutex + // lastPingNs stores the last measured round-trip latency in nanoseconds. + // Updated by background ping goroutine; read by DialStream for sorting. + // A value of 0 means no measurement yet (treated as max latency for sorting). + lastPingNs atomic.Int64 + log logrus.FieldLogger } @@ -210,6 +216,17 @@ func (sc *SessionCommon) smuxPing() (time.Duration, error) { return time.Since(start), nil } +// LastPing returns the last measured round-trip latency. +// Returns 0 if no measurement has been taken yet. +func (sc *SessionCommon) LastPing() time.Duration { + return time.Duration(sc.lastPingNs.Load()) +} + +// SetLastPing records a latency measurement. +func (sc *SessionCommon) SetLastPing(d time.Duration) { + sc.lastPingNs.Store(int64(d)) +} + // Close closes the session. func (sc *SessionCommon) Close() error { if sc == nil { diff --git a/pkg/dmsghttp/http_transport.go b/pkg/dmsghttp/http_transport.go index 873a58b9..f7efdc48 100644 --- a/pkg/dmsghttp/http_transport.go +++ b/pkg/dmsghttp/http_transport.go @@ -31,6 +31,12 @@ func MakeHTTPTransport(ctx context.Context, dmsgC *dmsg.Client) HTTPTransport { // RoundTrip implements golang's http package support for alternative HTTP transport protocols. // In this case dmsg is used instead of TCP to initiate the communication with the server. func (t HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Normalize scheme: callers may use "dmsg://" URLs. + if req.URL.Scheme == "dmsg" { + req = req.Clone(req.Context()) + req.URL.Scheme = "http" + } + var hostAddr dmsg.Addr if err := hostAddr.Set(req.Host); err != nil { return nil, fmt.Errorf("invalid host address: %w", err) @@ -44,23 +50,22 @@ func (t HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { return nil, err } - // Ensure stream is closed if we return an error before wrapping the response body + // Ensure stream is closed if we return an error before wrapping the response body. defer func() { if err != nil { - _ = stream.Close() //nolint:errcheck // best-effort cleanup on error path + stream.Close() //nolint:errcheck,gosec } }() if err = req.Write(stream); err != nil { return nil, err } - bufR := bufio.NewReader(stream) - resp, err := http.ReadResponse(bufR, req) + resp, err := http.ReadResponse(bufio.NewReader(stream), req) if err != nil { return nil, err } - // Wrap resp.Body to ensure the stream is closed when the body is closed + // Wrap resp.Body to ensure the stream is closed when the body is closed. resp.Body = &wrappedBody{ ReadCloser: resp.Body, stream: stream, @@ -76,10 +81,6 @@ type wrappedBody struct { } func (wb *wrappedBody) Close() error { - // Drain the response body up to a limit (e.g., 512KB). - const maxDrainBytes = 512 * 1024 - _, _ = io.CopyN(io.Discard, wb.ReadCloser, maxDrainBytes) //nolint - err1 := wb.ReadCloser.Close() err2 := wb.stream.Close() diff --git a/pkg/dmsghttp/util.go b/pkg/dmsghttp/util.go index 7f7508fa..f22d91a8 100644 --- a/pkg/dmsghttp/util.go +++ b/pkg/dmsghttp/util.go @@ -14,7 +14,7 @@ import ( // GetServers is used to get all the available servers from the dmsg-discovery. func GetServers(ctx context.Context, dmsgDisc string, dmsgServerType string, log *logging.Logger) (entries []*disc.Entry) { - dmsgclient := disc.NewHTTP(dmsgDisc, &http.Client{}, log) + dmsgclient := disc.NewHTTP(dmsgDisc, &http.Client{Timeout: 30 * time.Second}, log) ticker := time.NewTicker(time.Second * 10) defer ticker.Stop() for { @@ -52,7 +52,7 @@ func GetServers(ctx context.Context, dmsgDisc string, dmsgServerType string, log // UpdateServers is used to update the servers in the direct client. func UpdateServers(ctx context.Context, dClient disc.APIClient, dmsgDisc string, dmsgC *dmsg.Client, dmsgServerType string, log *logging.Logger) (entries []*disc.Entry) { - dmsgclient := disc.NewHTTP(dmsgDisc, &http.Client{}, log) + dmsgclient := disc.NewHTTP(dmsgDisc, &http.Client{Timeout: 30 * time.Second}, log) ticker := time.NewTicker(time.Minute * 10) defer ticker.Stop() for {