diff --git a/cmd/dmsg-discovery/commands/dmsg-discovery.go b/cmd/dmsg-discovery/commands/dmsg-discovery.go index af17c93e4..59feede0d 100644 --- a/cmd/dmsg-discovery/commands/dmsg-discovery.go +++ b/cmd/dmsg-discovery/commands/dmsg-discovery.go @@ -375,8 +375,8 @@ Example: go a.RunBackgroundTasks(ctx, log) log.WithField("addr", addr).Info("Serving discovery API...") go func() { - if err = listenAndServe(addr, a); err != nil { - log.Errorf("ListenAndServe: %v", err) + if listenErr := listenAndServe(addr, a); listenErr != nil { + log.Errorf("ListenAndServe: %v", listenErr) cancel() } }() @@ -408,8 +408,8 @@ Example: go updateServers(ctx, a, dClient, dmsgDC, dmsgServerType, log) go func() { - if err = dmsghttp.ListenAndServe(ctx, sk, a, dClient, dmsg.DefaultDmsgHTTPPort, dmsgDC, log); err != nil { - log.Errorf("dmsghttp.ListenAndServe: %v", err) + if dmsgErr := dmsghttp.ListenAndServe(ctx, sk, a, dClient, dmsg.DefaultDmsgHTTPPort, dmsgDC, log); dmsgErr != nil { + log.Errorf("dmsghttp.ListenAndServe: %v", dmsgErr) cancel() } }() @@ -467,7 +467,7 @@ func getServers(ctx context.Context, a *api.API, dmsgServerType string, log logr case <-ctx.Done(): return []*disc.Entry{} case <-ticker.C: - getServers(ctx, a, dmsgServerType, log) + return getServers(ctx, a, dmsgServerType, log) } } } diff --git a/cmd/dmsgcurl/commands/dmsgcurl.go b/cmd/dmsgcurl/commands/dmsgcurl.go index ef51f68f9..a2f6d9dfe 100644 --- a/cmd/dmsgcurl/commands/dmsgcurl.go +++ b/cmd/dmsgcurl/commands/dmsgcurl.go @@ -142,7 +142,7 @@ var RootCmd = &cobra.Command{ httpClient = &http.Client{ Transport: transport, } - ctx = context.WithValue(context.Background(), "socks5_proxy", proxyAddr) //nolint + ctx = context.WithValue(ctx, "socks5_proxy", proxyAddr) //nolint } cErr = handleRequest(ctx, pk, sk, httpClient, parsedURL, dmsgcurlData) @@ -166,7 +166,7 @@ func handleRequest(ctx context.Context, pk cipher.PubKey, sk cipher.SecKey, http Code: errorCode["WRITE_INIT"], } } - defer closeAndCleanFile(file, err) + defer func() { closeAndCleanFile(file, err) }() var httpC http.Client if flags.UseDC { @@ -256,9 +256,8 @@ func handleRequest(ctx context.Context, pk cipher.PubKey, sk cipher.SecKey, http dlog.WithError(err).Debug("Failed to perform HTTP request after maximum retries") continue // Retry outer attempt } - defer closeResponseBody(resp) - n, err := cancellableCopy(ctx, file, resp.Body, resp.ContentLength) + closeResponseBody(resp) if err != nil { dlog.WithError(err).Errorf("Download failed at %d/%dB", n, resp.ContentLength) select { @@ -373,7 +372,12 @@ func (pw *progressWriter) Write(p []byte) (int, error) { n := len(p) current := atomic.AddInt64(&pw.Current, int64(n)) total := atomic.LoadInt64(&pw.Total) - pc := fmt.Sprintf("%d%%", current*100/total) + var pc string + if total > 0 { + pc = fmt.Sprintf("%d%%", current*100/total) + } else { + pc = "unknown" + } if dmsgcurlOutput != "" { fmt.Printf("Downloading: %d/%dB (%s)", current, total, pc) if current != total { diff --git a/cmd/dmsgweb/commands/dmsgweb.go b/cmd/dmsgweb/commands/dmsgweb.go index fb33cb15a..c4605ac7a 100644 --- a/cmd/dmsgweb/commands/dmsgweb.go +++ b/cmd/dmsgweb/commands/dmsgweb.go @@ -219,7 +219,7 @@ dmsgweb conf file detected: ` + dwcfg httpClient = &http.Client{ Transport: transport, } - ctx = context.WithValue(context.Background(), "socks5_proxy", proxyAddr) //nolint + ctx = context.WithValue(ctx, "socks5_proxy", proxyAddr) //nolint } dmsgC, closeDmsg, err = cli.InitDmsgWithFlags(ctx, dlog, pk, sk, httpClient, "") @@ -390,7 +390,7 @@ func proxyHTTPConn(n int) { } else { dmsgp = "80" } - urlStr = fmt.Sprintf("dmsg://%s:%s%s", strings.TrimRight(hostParts[0], filterDomainSuffix), dmsgp, c.Param("path")) + urlStr = fmt.Sprintf("dmsg://%s:%s%s", strings.TrimSuffix(hostParts[0], filterDomainSuffix), dmsgp, c.Param("path")) if c.Request.URL.RawQuery != "" { urlStr = fmt.Sprintf("%s?%s", urlStr, c.Request.URL.RawQuery) } diff --git a/cmd/dmsgweb/commands/dmsgwebsrv.go b/cmd/dmsgweb/commands/dmsgwebsrv.go index 3bb56fa9d..4599f1ba1 100644 --- a/cmd/dmsgweb/commands/dmsgwebsrv.go +++ b/cmd/dmsgweb/commands/dmsgwebsrv.go @@ -130,7 +130,7 @@ func server() { httpClient = &http.Client{ Transport: transport, } - ctx = context.WithValue(context.Background(), "socks5_proxy", proxyAddr) //nolint + ctx = context.WithValue(ctx, "socks5_proxy", proxyAddr) //nolint } dmsgC, closeDmsg, err = cli.InitDmsgWithFlags(ctx, dlog, pk, sk, httpClient, "") diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 963f9d3f4..d10310e5b 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -2,6 +2,7 @@ package cli import ( + "bytes" "context" "fmt" "io" @@ -519,8 +520,28 @@ func NewFallbackRoundTripper(ctx context.Context, clients []*dmsg.Client) http.R // RoundTrip tries each DMSG client in order until a successful response is received. func (f *FallbackRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + // Buffer the request body so it can be replayed on retry. + // Without this, the first failed transport consumes the body + // and subsequent transports receive an empty body. + var bodyBytes []byte + if req.Body != nil { + var err error + bodyBytes, err = io.ReadAll(req.Body) + if err != nil { + return nil, fmt.Errorf("failed to read request body for retry: %w", err) + } + req.Body.Close() //nolint:errcheck,gosec + } + var lastErr error for _, client := range f.clients { + // Reset the body for each attempt + if bodyBytes != nil { + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } else { + req.Body = nil + } + rt := dmsghttp.MakeHTTPTransport(f.ctx, client) resp, err := rt.RoundTrip(req) if err != nil { diff --git a/internal/dmsg-discovery/api/api.go b/internal/dmsg-discovery/api/api.go index 4902691e5..c1d69a673 100644 --- a/internal/dmsg-discovery/api/api.go +++ b/internal/dmsg-discovery/api/api.go @@ -51,7 +51,7 @@ type API struct { // New returns a new API object, which can be started as a server func New(log logrus.FieldLogger, db store.Storer, m discmetrics.Metrics, testMode, enableLoadTesting, enableMetrics bool, dmsgAddr, authPassphrase string) *API { - if log != nil { + if log == nil { log = logging.MustGetLogger("dmsg_disc") } @@ -358,6 +358,7 @@ func (a *API) setEntry() func(w http.ResponseWriter, r *http.Request) { // json serialized entry object func (a *API) delEntry() func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() //nolint:errcheck entry := new(disc.Entry) if err := json.NewDecoder(r.Body).Decode(entry); err != nil { a.handleError(w, r, disc.ErrUnexpected) @@ -526,7 +527,11 @@ func isLoopbackAddr(addr string) (bool, error) { return true, nil } - return net.ParseIP(host).IsLoopback(), nil + ip := net.ParseIP(host) + if ip == nil { + return false, nil + } + return ip.IsLoopback(), nil } // writeJSON writes a json object on a http.ResponseWriter with the given code. diff --git a/internal/dmsg-discovery/api/error_handler.go b/internal/dmsg-discovery/api/error_handler.go index ab458ef07..6ab202230 100644 --- a/internal/dmsg-discovery/api/error_handler.go +++ b/internal/dmsg-discovery/api/error_handler.go @@ -2,6 +2,7 @@ package api import ( + "errors" "net/http" "github.com/skycoin/dmsg/pkg/disc" @@ -36,9 +37,12 @@ func (a *API) handleError(w http.ResponseWriter, r *http.Request, e error) { code = http.StatusUnprocessableEntity msg = e.Error() } else { - f, ok := apiErrors[e] - if !ok { - f = func() (int, string) { return http.StatusInternalServerError, disc.ErrUnexpected.Error() } + f := func() (int, string) { return http.StatusInternalServerError, disc.ErrUnexpected.Error() } + for target, handler := range apiErrors { + if errors.Is(e, target) { + f = handler + break + } } code, msg = f() } diff --git a/internal/dmsg-server/api/api.go b/internal/dmsg-server/api/api.go index 1276f431b..4249e497e 100644 --- a/internal/dmsg-server/api/api.go +++ b/internal/dmsg-server/api/api.go @@ -82,26 +82,24 @@ func (a *API) SetDmsgServer(srv *dmsg.Server) { // ListenAndServe runs dmsg Serve function alongside health endpoint func (a *API) ListenAndServe(lAddr, pAddr, httpAddr string) error { - errCh := make(chan error) + errCh := make(chan error, 2) dmsgLn, err := net.Listen("tcp", lAddr) if err != nil { return err } dmsgLis := &proxyproto.Listener{Listener: dmsgLn} - defer dmsgLis.Close() // nolint:errcheck go func(l net.Listener, address string) { - if err := a.dmsgServer.Serve(l, address); err != nil { - errCh <- err - } + errCh <- a.dmsgServer.Serve(l, address) + l.Close() //nolint:errcheck,gosec }(dmsgLis, pAddr) ln, err := net.Listen("tcp", httpAddr) if err != nil { + dmsgLis.Close() //nolint:errcheck,gosec return err } lis := &proxyproto.Listener{Listener: ln} - defer lis.Close() // nolint:errcheck srv := &http.Server{ ReadTimeout: 3 * time.Second, WriteTimeout: 3 * time.Second, @@ -110,9 +108,8 @@ func (a *API) ListenAndServe(lAddr, pAddr, httpAddr string) error { //Addr: lis, Handler: a.router, } - if err := srv.Serve(lis); err != nil { - errCh <- err - } + errCh <- srv.Serve(lis) + lis.Close() //nolint:errcheck,gosec return <-errCh } @@ -161,6 +158,9 @@ func (a *API) updateInternalState() { // UpdateAverageNumberOfPacketsPerMinute is function which needs to called every minute. func (a *API) updateAverageNumberOfPacketsPerMinute() { if a.dmsgServer != nil { + a.sMu.Lock() + defer a.sMu.Unlock() + newDecValues, newEncValues, average := calculateThroughput( a.dmsgServer.GetSessions(), a.minuteDecValues, @@ -169,8 +169,6 @@ func (a *API) updateAverageNumberOfPacketsPerMinute() { a.metrics.SetPacketsPerMinute(average) - a.sMu.Lock() - defer a.sMu.Unlock() a.minuteDecValues = newDecValues a.minuteEncValues = newEncValues } diff --git a/pkg/direct/client.go b/pkg/direct/client.go index 92fca6c12..e2240e8ec 100644 --- a/pkg/direct/client.go +++ b/pkg/direct/client.go @@ -36,10 +36,8 @@ func NewClient(entries []*disc.Entry, log *logging.Logger) disc.APIClient { func (c *directClient) Entry(_ context.Context, pubKey cipher.PubKey) (*disc.Entry, error) { c.mx.RLock() defer c.mx.RUnlock() - for _, entry := range c.entries { - if entry.Static == pubKey { - return entry, nil - } + if entry, ok := c.entries[pubKey]; ok { + return entry, nil } return nil, disc.ErrKeyNotFound } diff --git a/pkg/disc/client.go b/pkg/disc/client.go index 4317f5898..a8160d221 100644 --- a/pkg/disc/client.go +++ b/pkg/disc/client.go @@ -144,6 +144,7 @@ func (c *httpClient) PostEntry(ctx context.Context, entry *Entry) error { Error() return errFromString(httpResponse.Message) } + _, _ = io.Copy(io.Discard, resp.Body) //nolint:errcheck return nil } @@ -198,6 +199,7 @@ func (c *httpClient) DelEntry(ctx context.Context, entry *Entry) error { Error() return errFromString(httpResponse.Message) } + _, _ = io.Copy(io.Discard, resp.Body) //nolint:errcheck return nil } @@ -206,10 +208,11 @@ func (c *httpClient) PutEntry(ctx context.Context, sk cipher.SecKey, entry *Entr c.updateMux.Lock() defer c.updateMux.Unlock() - entry.Sequence++ + sequence := entry.Sequence + 1 entry.Timestamp = time.Now().UnixNano() for { + entry.Sequence = sequence err := entry.Sign(sk) if err != nil { return err @@ -219,18 +222,17 @@ func (c *httpClient) PutEntry(ctx context.Context, sk cipher.SecKey, entry *Entr return nil } if err != ErrValidationWrongSequence { - entry.Sequence-- return err } rE, entryErr := c.Entry(ctx, entry.Static) if entryErr != nil { - return err + return entryErr } if rE.Timestamp > entry.Timestamp { // If there is a more up to date entry drop update entry.Sequence = rE.Sequence return nil } - entry.Sequence = rE.Sequence + 1 + sequence = rE.Sequence + 1 } } diff --git a/pkg/disc/entry.go b/pkg/disc/entry.go index dbb286a35..9d3f73d2f 100644 --- a/pkg/disc/entry.go +++ b/pkg/disc/entry.go @@ -326,6 +326,10 @@ func Copy(dst, src *Entry) { dst.Client = nil } else { *dst.Client = *src.Client + if src.Client.DelegatedServers != nil { + dst.Client.DelegatedServers = make([]cipher.PubKey, len(src.Client.DelegatedServers)) + copy(dst.Client.DelegatedServers, src.Client.DelegatedServers) + } } dst.Static = src.Static diff --git a/pkg/dmsg/client.go b/pkg/dmsg/client.go index 97fcb6fca..18396ba0d 100644 --- a/pkg/dmsg/client.go +++ b/pkg/dmsg/client.go @@ -80,7 +80,8 @@ type Client struct { conf *Config porter *netutil.Porter - bo time.Duration // initial backoff duration + initBO time.Duration // initial backoff duration (constant) + bo time.Duration // current backoff duration maxBO time.Duration // maximum backoff duration factor float64 // multiplier for the backoff duration that is applied on every retry @@ -106,6 +107,7 @@ func NewClient(pk cipher.PubKey, sk cipher.SecKey, dc disc.APIClient, conf *Conf errCh: make(chan error, 10), done: make(chan struct{}), conf: conf, + initBO: time.Second * 5, bo: time.Second * 5, maxBO: time.Minute, factor: netutil.DefaultFactor, @@ -210,6 +212,7 @@ func (ce *Client) Serve(ctx context.Context) { if len(entries) == 0 { ce.log.Warnf("No entries found. Retrying after %s...", ce.bo.String()) ce.serveWait() + continue } // randomize dmsg servers list using crypto/rand seed for true randomization // This ensures each client connects to servers in a different order, @@ -280,6 +283,9 @@ func (ce *Client) Serve(ctx context.Context) { ce.log.WithField("remote_pk", entry.Static).WithError(err).WithField("current_backoff", ce.bo.String()). Warn("Failed to establish session.") ce.serveWait() + } else { + // Reset backoff on successful session establishment. + ce.bo = ce.initBO } } @@ -373,10 +379,16 @@ func (ce *Client) DialStream(ctx context.Context, addr Addr) (*Stream, error) { } // Range client's delegated servers. - // See if we are already connected to a delegated server. + // Try existing sessions first, falling back to next server on failure. for _, srvPK := range entry.Client.DelegatedServers { if dSes, ok := ce.clientSession(ce.porter, srvPK); ok { - return dSes.DialStream(addr) + stream, err := dSes.DialStream(addr) + if err != nil { + ce.log.WithError(err).WithField("server", srvPK). + Debug("DialStream failed via existing session, trying next server") + continue + } + return stream, nil } } @@ -387,7 +399,13 @@ func (ce *Client) DialStream(ctx context.Context, addr Addr) (*Stream, error) { if err != nil { continue } - return dSes.DialStream(addr) + stream, err := dSes.DialStream(addr) + if err != nil { + ce.log.WithError(err).WithField("server", srvPK). + Debug("DialStream failed via new session, trying next server") + continue + } + return stream, nil } return nil, ErrCannotConnectToDelegated diff --git a/pkg/dmsg/entity_common.go b/pkg/dmsg/entity_common.go index 2e6ce82c4..4c40c43a1 100644 --- a/pkg/dmsg/entity_common.go +++ b/pkg/dmsg/entity_common.go @@ -143,6 +143,7 @@ func (c *EntityCommon) delSession(ctx context.Context, pk cipher.PubKey) { // updateServerEntry updates the dmsg server's entry within dmsg discovery. // If 'addr' is an empty string, the Entry.addr field will not be updated in discovery. +// Caller must hold c.sessionsMx. func (c *EntityCommon) updateServerEntry(ctx context.Context, addr string, maxSessions int, authPassphrase string) (err error) { if addr == "" { panic("updateServerEntry cannot accept empty 'addr' input") // this should never happen @@ -233,7 +234,9 @@ func (c *EntityCommon) initilizeClientEntry(ctx context.Context, clientType stri } }() + c.sessionsMx.Lock() srvPKs := make([]cipher.PubKey, 0, len(c.sessions)) + c.sessionsMx.Unlock() _, err = c.dc.Entry(ctx, c.pk) if err != nil { diff --git a/pkg/dmsg/server.go b/pkg/dmsg/server.go index e9dbb82dd..f69f56381 100644 --- a/pkg/dmsg/server.go +++ b/pkg/dmsg/server.go @@ -100,14 +100,15 @@ func (s *Server) Close() error { if s == nil { return nil } + var err error s.once.Do(func() { close(s.done) s.wg.Wait() + err = s.delEntry(context.Background()) + if err != nil { + s.log.Warn("Cannot delete entry from db.") + } }) - err := s.delEntry(context.Background()) - if err != nil { - s.log.Warn("Cannot delete entry from db.") - } return nil } @@ -159,20 +160,21 @@ func (s *Server) Serve(lis net.Listener, addr string) error { s.wg.Add(1) go func(conn net.Conn) { + defer s.wg.Done() defer func() { - err := recover() - if err != nil { + if err := recover(); err != nil { log.Warnf("panic in handleSession: %+v", err) } }() s.handleSession(conn) - s.wg.Done() }(conn) } } func (s *Server) startUpdateEntryLoop(ctx context.Context) error { err := netutil.NewDefaultRetrier(s.log).Do(ctx, func() error { + s.sessionsMx.Lock() + defer s.sessionsMx.Unlock() return s.updateServerEntry(ctx, s.AdvertisedAddr(), s.maxSessions, s.authPassphrase) }) if err != nil { diff --git a/pkg/dmsgctrl/control.go b/pkg/dmsgctrl/control.go index 5825bea9f..0120fd8ce 100644 --- a/pkg/dmsgctrl/control.go +++ b/pkg/dmsgctrl/control.go @@ -28,9 +28,11 @@ const ( // Control wraps and takes over a dmsg.Stream and provides control features. type Control struct { conn net.Conn + wMu sync.Mutex // protects concurrent writes to conn pongCh chan time.Time doneCh chan struct{} - err error // the resultant error after control stops serving + errMu sync.Mutex // protects c.err + err error // the resultant error after control stops serving errOnce sync.Once } @@ -60,7 +62,10 @@ func (c *Control) serve() { switch pt := PacketType(rawType[0]); pt { case PingType: - if _, err := c.conn.Write([]byte{byte(PongType)}); err != nil { + c.wMu.Lock() + _, err := c.conn.Write([]byte{byte(PongType)}) + c.wMu.Unlock() + if err != nil { c.reportErr(fmt.Errorf("failed to write pong: %w", err)) return } @@ -84,7 +89,10 @@ func (c *Control) serve() { func (c *Control) Ping(ctx context.Context) (time.Duration, error) { start := time.Now() - if _, err := c.conn.Write([]byte{byte(PingType)}); err != nil { + c.wMu.Lock() + _, err := c.conn.Write([]byte{byte(PingType)}) + c.wMu.Unlock() + if err != nil { return 0, err } @@ -94,7 +102,10 @@ func (c *Control) Ping(ctx context.Context) (time.Duration, error) { case t, ok := <-c.pongCh: if !ok { - return 0, c.err + c.errMu.Lock() + err = c.err + c.errMu.Unlock() + return 0, err } return t.Sub(start), nil } @@ -108,7 +119,10 @@ func (c *Control) Conn() net.Conn { // Close implements io.Closer func (c *Control) Close() error { if isDone(c.doneCh) { - return c.err + c.errMu.Lock() + err := c.err + c.errMu.Unlock() + return err } c.reportErr(ErrClosed) @@ -126,12 +140,17 @@ func (c *Control) Err() error { if !isDone(c.doneCh) { return nil } - return c.err + c.errMu.Lock() + err := c.err + c.errMu.Unlock() + return err } func (c *Control) reportErr(err error) { c.errOnce.Do(func() { + c.errMu.Lock() c.err = err + c.errMu.Unlock() close(c.doneCh) }) } diff --git a/pkg/dmsgctrl/serve_listener.go b/pkg/dmsgctrl/serve_listener.go index c1659e42f..4fb2bd6d4 100644 --- a/pkg/dmsgctrl/serve_listener.go +++ b/pkg/dmsgctrl/serve_listener.go @@ -29,8 +29,14 @@ func ServeListener(l net.Listener, chanLen int) <-chan *Control { log.Warnf("Failed to accept dmsgctrl conn, continuing: %v", err) continue } - if ctrl := ControlStream(conn); ch != nil && len(ch) < cap(ch) { - ch <- ctrl + ctrl := ControlStream(conn) + select { + case ch <- ctrl: + default: + log.Warnf("Control channel full, dropping and closing control") + if err := ctrl.Close(); err != nil { + log.Warnf("Failed to close dropped control: %v", err) + } } } }() diff --git a/pkg/dmsgcurl/dmsgcurl.go b/pkg/dmsgcurl/dmsgcurl.go index 785c471d1..2d4618100 100644 --- a/pkg/dmsgcurl/dmsgcurl.go +++ b/pkg/dmsgcurl/dmsgcurl.go @@ -113,7 +113,7 @@ func (dg *DmsgCurl) Run(ctx context.Context, log *logging.Logger, skStr string, httpC := http.Client{Transport: dmsghttp.MakeHTTPTransport(ctx, dmsgC)} - for i := 0; i < dg.dlF.Tries; i++ { + for i := 0; dg.dlF.Tries == 0 || i < dg.dlF.Tries; i++ { log.Infof("Download attempt %d/%d ...", i, dg.dlF.Tries) if _, err := file.Seek(0, 0); err != nil { @@ -220,12 +220,17 @@ func (dg *DmsgCurl) StartDmsg(ctx context.Context, log *logging.Logger, pk ciphe func Download(ctx context.Context, log logrus.FieldLogger, httpC *http.Client, w io.Writer, urlStr string, maxSize int64) error { req, err := http.NewRequest(http.MethodGet, urlStr, nil) if err != nil { - log.WithError(err).Fatal("Failed to formulate HTTP request.") + return fmt.Errorf("failed to formulate HTTP request: %w", err) } resp, err := httpC.Do(req) if err != nil { return fmt.Errorf("failed to connect to HTTP server: %w", err) } + defer func() { + if err := resp.Body.Close(); err != nil { + log.WithError(err).Warn("HTTP Response body closed with non-nil error.") + } + }() if maxSize > 0 { if resp.ContentLength > maxSize*1024 { return fmt.Errorf("requested file size is more than allowed size: %d KB > %d KB", (resp.ContentLength / 1024), maxSize) @@ -235,11 +240,6 @@ func Download(ctx context.Context, log logrus.FieldLogger, httpC *http.Client, w if err != nil { return fmt.Errorf("download failed at %d/%dB: %w", n, resp.ContentLength, err) } - defer func() { - if err := resp.Body.Close(); err != nil { - log.WithError(err).Warn("HTTP Response body closed with non-nil error.") - } - }() return nil } diff --git a/pkg/dmsgcurl/progress_writer.go b/pkg/dmsgcurl/progress_writer.go index e104337a5..34da4b05a 100644 --- a/pkg/dmsgcurl/progress_writer.go +++ b/pkg/dmsgcurl/progress_writer.go @@ -19,13 +19,18 @@ func (pw *ProgressWriter) Write(p []byte) (int, error) { current := atomic.AddInt64(&pw.Current, int64(n)) total := atomic.LoadInt64(&pw.Total) - pc := fmt.Sprintf("%d%%", current*100/total) - fmt.Printf("Downloading: %d/%dB (%s)", current, total, pc) - if current != total { + if total <= 0 { + fmt.Printf("Downloading: %dB", current) fmt.Print("\r") } else { - fmt.Print("\n") + pc := fmt.Sprintf("%d%%", current*100/total) + fmt.Printf("Downloading: %d/%dB (%s)", current, total, pc) + if current != total { + fmt.Print("\r") + } else { + fmt.Print("\n") + } } return n, nil diff --git a/pkg/dmsghttp/http.go b/pkg/dmsghttp/http.go index 9f45927df..616fca4dc 100644 --- a/pkg/dmsghttp/http.go +++ b/pkg/dmsghttp/http.go @@ -22,12 +22,6 @@ func ListenAndServe(ctx context.Context, _ cipher.SecKey, a http.Handler, _ disc if err != nil { return fmt.Errorf("dmsg listen on port %d: %w", dmsgPort, err) } - go func() { - <-ctx.Done() - if err := lis.Close(); err != nil { - log.WithError(err).Error() - } - }() log.WithField("dmsg_addr", fmt.Sprintf("dmsg://%v", lis.Addr().String())). Debug("Serving...") @@ -39,5 +33,18 @@ func ListenAndServe(ctx context.Context, _ cipher.SecKey, a http.Handler, _ disc Handler: a, } - return srv.Serve(lis) + done := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + if err := srv.Shutdown(context.Background()); err != nil { + log.WithError(err).Error() + } + case <-done: + } + }() + + err = srv.Serve(lis) + close(done) + return err } diff --git a/pkg/dmsghttp/util.go b/pkg/dmsghttp/util.go index fdc6fb47f..7f7508fa5 100644 --- a/pkg/dmsghttp/util.go +++ b/pkg/dmsghttp/util.go @@ -31,7 +31,7 @@ func GetServers(ctx context.Context, dmsgDisc string, dmsgServerType string, log if dmsgServerType != "" { var filteredServers []*disc.Entry for _, server := range servers { - if server.Server.ServerType == dmsgServerType { + if server.Server != nil && server.Server.ServerType == dmsgServerType { filteredServers = append(filteredServers, server) } } @@ -69,7 +69,7 @@ func UpdateServers(ctx context.Context, dClient disc.APIClient, dmsgDisc string, if dmsgServerType != "" { var filteredServers []*disc.Entry for _, server := range servers { - if server.Server.ServerType == dmsgServerType { + if server.Server != nil && server.Server.ServerType == dmsgServerType { filteredServers = append(filteredServers, server) } } diff --git a/pkg/dmsgpty/conf.go b/pkg/dmsgpty/conf.go index 61a1a3c58..1d0089a29 100644 --- a/pkg/dmsgpty/conf.go +++ b/pkg/dmsgpty/conf.go @@ -40,12 +40,13 @@ func WriteConfig(conf Config, path string) error { if err != nil { return fmt.Errorf("failed to open config file: %w", err) } + defer f.Close() //nolint:errcheck enc := json.NewEncoder(f) enc.SetIndent("", " ") if err = enc.Encode(&conf); err != nil { return err } - return f.Close() + return nil } func findStringsEnclosedBy(str string, sep string, result []string, lastIndex int) ([]string, int) { @@ -88,7 +89,7 @@ func ParseWindowsEnv(cliAddr string) string { } paths[len(paths)-1] = strings.Replace(cliAddr[lastIndex:], string(filepath.Separator), "", 1) cliAddr = filepath.Join(paths...) - _ = strings.ReplaceAll(cliAddr, `\`, `\\`) + cliAddr = strings.ReplaceAll(cliAddr, `\`, `\\`) return cliAddr } } diff --git a/pkg/dmsgpty/pty_unix.go b/pkg/dmsgpty/pty_unix.go index 02f12142b..1d83485be 100644 --- a/pkg/dmsgpty/pty_unix.go +++ b/pkg/dmsgpty/pty_unix.go @@ -23,6 +23,7 @@ var ( // Pty runs a local pty. type Pty struct { pty *os.File + cmd *exec.Cmd mx sync.RWMutex } @@ -42,6 +43,11 @@ func (s *Pty) Stop() error { err := s.pty.Close() s.pty = nil + // Reap the child process to avoid zombies. + if s.cmd != nil { + _ = s.cmd.Wait() //nolint:errcheck + s.cmd = nil + } return err } @@ -96,6 +102,7 @@ func (s *Pty) Start(name string, args []string, size *WinSize, env []string) err } s.pty = f + s.cmd = cmd return nil } diff --git a/pkg/dmsgpty/pty_windows.go b/pkg/dmsgpty/pty_windows.go index fe2e262a2..e4b889ab7 100644 --- a/pkg/dmsgpty/pty_windows.go +++ b/pkg/dmsgpty/pty_windows.go @@ -106,6 +106,7 @@ func (s *Pty) Start(name string, args []string, size *WinSize, env []string) err ) if err != nil { + pty.Close() //nolint:errcheck return err } diff --git a/pkg/dmsgpty/ui.go b/pkg/dmsgpty/ui.go index 4dcd8bdf6..09741d070 100644 --- a/pkg/dmsgpty/ui.go +++ b/pkg/dmsgpty/ui.go @@ -230,12 +230,6 @@ func writeWSError(log logrus.FieldLogger, wsConn net.Conn, err error) { log.WithError(err).Error("Failed to write error msg to ws conn.") } logWS(wsConn, "Stopped!") - for { - if _, err := wsConn.Write([]byte("\x00")); err != nil { - return - } - time.Sleep(10 * time.Second) - } } func writeError(log logrus.FieldLogger, w http.ResponseWriter, r *http.Request, err error, code int) { @@ -284,6 +278,7 @@ type wsReader struct { ctx *http.Request closed bool mu sync.Mutex + buf []byte // buffered remainder from previous read } func newWSReader(ws *websocket.Conn, ptyC *PtyClient, log logrus.FieldLogger, r *http.Request) *wsReader { @@ -302,6 +297,16 @@ func (wr *wsReader) Read(p []byte) (int, error) { wr.mu.Unlock() return 0, io.EOF } + // Return buffered remainder from a previous read first. + if len(wr.buf) > 0 { + n := copy(p, wr.buf) + wr.buf = wr.buf[n:] + if len(wr.buf) == 0 { + wr.buf = nil + } + wr.mu.Unlock() + return n, nil + } wr.mu.Unlock() msgType, data, err := wr.ws.Read(wr.ctx.Context()) @@ -331,8 +336,13 @@ func (wr *wsReader) Read(p []byte) (int, error) { } } - // Regular data - copy to output buffer + // Regular data - copy to output buffer, save remainder n := copy(p, data) + if n < len(data) { + wr.mu.Lock() + wr.buf = append([]byte(nil), data[n:]...) + wr.mu.Unlock() + } return n, nil } } diff --git a/pkg/dmsgpty/whitelist.go b/pkg/dmsgpty/whitelist.go index a8bcfb4fb..db4af0205 100644 --- a/pkg/dmsgpty/whitelist.go +++ b/pkg/dmsgpty/whitelist.go @@ -18,6 +18,8 @@ import ( var ( json = jsoniter.ConfigFastest wl cipher.PubKeys + // wlMu protects the global wl and conf variables from concurrent access. + wlMu sync.Mutex ) // Whitelist represents a whitelist of public keys. @@ -49,6 +51,9 @@ type configWhitelist struct { } func (w *configWhitelist) Get(pk cipher.PubKey) (bool, error) { + wlMu.Lock() + defer wlMu.Unlock() + var ok bool err := w.open() if err != nil { @@ -63,6 +68,9 @@ func (w *configWhitelist) Get(pk cipher.PubKey) (bool, error) { } func (w *configWhitelist) All() (map[cipher.PubKey]bool, error) { + wlMu.Lock() + defer wlMu.Unlock() + err := w.open() if err != nil { return nil, err @@ -75,6 +83,9 @@ func (w *configWhitelist) All() (map[cipher.PubKey]bool, error) { } func (w *configWhitelist) Add(pks ...cipher.PubKey) error { + wlMu.Lock() + defer wlMu.Unlock() + err := w.open() if err != nil { return err @@ -123,6 +134,9 @@ func (w *configWhitelist) Add(pks ...cipher.PubKey) error { } func (w *configWhitelist) Remove(pks ...cipher.PubKey) error { + wlMu.Lock() + defer wlMu.Unlock() + err := w.open() if err != nil { return err @@ -156,12 +170,19 @@ func (w *configWhitelist) open() error { info, err := os.Stat(w.confPath) if err != nil { if errors.Is(err, fs.ErrNotExist) { - _, err = os.Create(w.confPath) + f, createErr := os.Create(w.confPath) + if createErr != nil { + return createErr + } + f.Close() //nolint:errcheck,gosec + // Re-stat to get the info for the newly created file. + info, err = os.Stat(w.confPath) if err != nil { return err } + } else { + return err } - return err } if info.Size() == 0 { diff --git a/pkg/noise/dh.go b/pkg/noise/dh.go index e627ae68a..6c208f079 100644 --- a/pkg/noise/dh.go +++ b/pkg/noise/dh.go @@ -2,6 +2,7 @@ package noise import ( + "fmt" "io" "github.com/skycoin/noise" @@ -22,22 +23,17 @@ func (Secp256k1) GenerateKeypair(_ io.Reader) (noise.DHKey, error) { // DH helps to implement `noise.DHFunc`. func (Secp256k1) DH(sk, pk []byte) []byte { - // Use non-panic versions to handle invalid keys gracefully pubKey, err := cipher.NewPubKey(pk) if err != nil { - // Return empty key on error to prevent panic - // The handshake will fail with this invalid key - return make([]byte, 33) + panic(fmt.Sprintf("noise DH: invalid public key: %v", err)) } secKey, err := cipher.NewSecKey(sk) if err != nil { - // Return empty key on error to prevent panic - return make([]byte, 33) + panic(fmt.Sprintf("noise DH: invalid secret key: %v", err)) } ecdh, err := cipher.ECDH(pubKey, secKey) if err != nil { - // Return empty key on error to prevent panic - return make([]byte, 33) + panic(fmt.Sprintf("noise DH: ECDH failed: %v", err)) } return append(ecdh, byte(0)) } diff --git a/pkg/noise/net.go b/pkg/noise/net.go index 4080fa0e1..94a0c8a20 100644 --- a/pkg/noise/net.go +++ b/pkg/noise/net.go @@ -110,13 +110,15 @@ func (d *RPCClientDialer) establishConn() error { } ns, err := New(d.pattern, d.config) if err != nil { + conn.Close() //nolint:errcheck,gosec return err } - conn, err = WrapConn(conn, ns, time.Second*5) + wrappedConn, err := WrapConn(conn, ns, time.Second*5) if err != nil { + conn.Close() //nolint:errcheck,gosec return err } - d.conn = conn + d.conn = wrappedConn return nil } @@ -231,6 +233,7 @@ func (ml *Listener) Accept() (net.Conn, error) { rw := NewReadWriter(conn, ns) if err := rw.Handshake(AcceptHandshakeTimeout); err != nil { noiseLogger.WithError(err).Warn("accept: noise handshake failed.") + conn.Close() //nolint:errcheck,gosec continue } noiseLogger.Infoln("accepted:", rw.RemoteStatic()) diff --git a/pkg/noise/noise.go b/pkg/noise/noise.go index eb41da75e..54ba9d29e 100644 --- a/pkg/noise/noise.go +++ b/pkg/noise/noise.go @@ -177,5 +177,10 @@ func (ns *Noise) DecryptWithNonceMap(nm NonceMap, ciphertext []byte) ([]byte, er if _, ok := nm[recvSeq]; ok { return nil, fmt.Errorf("received decryption nonce (%d) is repeated", recvSeq) } - return ns.dec.Cipher().Decrypt(nil, recvSeq, nil, ciphertext[nonceSize:]) + plaintext, err := ns.dec.Cipher().Decrypt(nil, recvSeq, nil, ciphertext[nonceSize:]) + if err != nil { + return nil, err + } + nm[recvSeq] = struct{}{} + return plaintext, nil } diff --git a/pkg/noise/read_writer.go b/pkg/noise/read_writer.go index 397045d7e..c2c82d307 100644 --- a/pkg/noise/read_writer.go +++ b/pkg/noise/read_writer.go @@ -135,8 +135,6 @@ func (rw *ReadWriter) Write(p []byte) (n int, err error) { return 0, err } - p = p[:] - for len(p) > 0 { // Enforce max frame size. wn := len(p) @@ -192,6 +190,13 @@ func (rw *ReadWriter) Handshake(hsTimeout time.Duration) error { case err := <-errCh: return err case <-time.After(hsTimeout): + // Set a past deadline on the underlying connection to unblock the + // handshake goroutine which may be stuck in a Read or Write call. + if conn, ok := rw.origin.(net.Conn); ok { + conn.SetDeadline(time.Now()) //nolint:errcheck,gosec + } + // Drain the goroutine result to avoid a leak. + <-errCh return timeoutError{} } } @@ -302,11 +307,13 @@ func ReadRawFrame(r *bufio.Reader) (p []byte, err error) { if err != nil { return nil, err } + out := make([]byte, prefix) + copy(out, b[prefixSize:]) if _, err := r.Discard(prefixSize + prefix); err != nil { return nil, fmt.Errorf("unexpected error when discarding %d bytes: %w", prefixSize+prefix, err) } - return b[prefixSize:], nil + return out, nil } func isTemp(err error) bool {