diff --git a/pkg/dmsg/client.go b/pkg/dmsg/client.go index 28236d5c..12054cbd 100644 --- a/pkg/dmsg/client.go +++ b/pkg/dmsg/client.go @@ -184,6 +184,7 @@ func (ce *Client) Serve(ctx context.Context) { updateEntryLoopOnce := new(sync.Once) pingLoopOnce := new(sync.Once) + reconnectLoopOnce := new(sync.Once) needInitialPost := true @@ -306,10 +307,13 @@ func (ce *Client) Serve(ctx context.Context) { } ce.sesMx.Unlock() } + // Only backoff after all servers have been tried + ce.log.WithField("current_backoff", ce.bo.String()). + Warn("All servers failed, backing off.") + ce.serveWait() } - ce.log.WithField("remote_pk", entry.Static).WithError(err).WithField("current_backoff", ce.bo.String()). + ce.log.WithField("remote_pk", entry.Static).WithError(err). Warn("Failed to establish session.") - ce.serveWait() } else { // Reset backoff on successful session establishment. ce.bo = ce.initBO @@ -319,6 +323,11 @@ func (ce *Client) Serve(ctx context.Context) { // Only start the update entry loop once we have at least one session established. updateEntryLoopOnce.Do(func() { go ce.updateClientEntryLoop(cancellabelCtx, ce.done, ce.conf.ClientType) }) pingLoopOnce.Do(func() { go ce.pingSessionsLoop(cancellabelCtx) }) + // When MinSessions is 0 (connect to all), start a reconnect loop that + // aggressively retries connecting to servers we failed to reach on the first pass. + if ce.conf.MinSessions == 0 { + reconnectLoopOnce.Do(func() { go ce.reconnectLoop(cancellabelCtx) }) + } // We dial all servers and wait for error or done signal. select { @@ -467,6 +476,56 @@ func (ce *Client) setCachedEntry(pk cipher.PubKey, entry *disc.Entry) { ce.entryCacheMx.Unlock() } +// reconnectLoop periodically discovers all available servers and attempts to +// connect to any that don't have an active session. This ensures services using +// MinSessions=0 (connect to all) maintain sessions to all servers, even if some +// were unavailable during initial startup. +func (ce *Client) reconnectLoop(ctx context.Context) { + ticker := time.NewTicker(15 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ce.done: + return + case <-ticker.C: + ce.reconnectMissing(ctx) + } + } +} + +func (ce *Client) reconnectMissing(ctx context.Context) { + entries, err := ce.discoverServers(ctx, false) + if err != nil { + return + } + for _, entry := range entries { + if isClosed(ce.done) { + return + } + // Skip servers we already have sessions with + if _, ok := ce.session(entry.Static); ok { + continue + } + // Filter by server type if configured + if ce.conf.ConnectedServersType == "official" && entry.Server.ServerType != "official" { + continue + } + if ce.conf.ConnectedServersType == "community" && entry.Server.ServerType != "community" { + continue + } + ce.log.WithField("remote_pk", entry.Static).Debug("Reconnecting to missing server...") + if err := ce.EnsureSession(ctx, entry); err != nil { + ce.log.WithField("remote_pk", entry.Static).WithError(err). + Debug("Reconnect failed, will retry next cycle.") + } else { + ce.log.WithField("remote_pk", entry.Static).Info("Reconnected to server.") + } + } +} + // pingSessionsLoop periodically pings all sessions to measure latency. func (ce *Client) pingSessionsLoop(ctx context.Context) { ticker := time.NewTicker(30 * time.Second) diff --git a/pkg/dmsg/client_sessions.go b/pkg/dmsg/client_sessions.go index 1c8ae358..d6b6ccd7 100644 --- a/pkg/dmsg/client_sessions.go +++ b/pkg/dmsg/client_sessions.go @@ -81,7 +81,7 @@ func (ce *Client) dialSession(ctx context.Context, entry *disc.Entry) (cs Client return ClientSession{}, fmt.Errorf("failed to dial through SOCKS5 proxy: %w", err) } } else { - conn, err = net.Dial(network, entry.Server.Address) + conn, err = net.DialTimeout(network, entry.Server.Address, DialTimeout) if err != nil { return ClientSession{}, fmt.Errorf("failed to dial: %w", err) } diff --git a/pkg/dmsg/const.go b/pkg/dmsg/const.go index 04163e5e..8311bb69 100644 --- a/pkg/dmsg/const.go +++ b/pkg/dmsg/const.go @@ -90,17 +90,21 @@ func InitConfig() error { if err != nil { return err } - err = json.Unmarshal(envServices.Prod, &Prod) - if err != nil { - return err - } - Prod.DmsgServers, err = shuffleServers(Prod.DmsgServers) - if err != nil { - return err + if envServices.Prod != nil { + err = json.Unmarshal(envServices.Prod, &Prod) + if err != nil { + return err + } + Prod.DmsgServers, err = shuffleServers(Prod.DmsgServers) + if err != nil { + return err + } } - err = json.Unmarshal(envServices.Test, &Test) - if err != nil { - return err + if envServices.Test != nil { + err = json.Unmarshal(envServices.Test, &Test) + if err != nil { + return err + } } return nil } diff --git a/pkg/dmsg/types.go b/pkg/dmsg/types.go index f57a9337..dfde5ade 100644 --- a/pkg/dmsg/types.go +++ b/pkg/dmsg/types.go @@ -21,6 +21,10 @@ const ( ) var ( + // DialTimeout defines the duration a TCP dial to a dmsg server should take. + // This prevents blocking for minutes on unresponsive/overloaded servers. + DialTimeout = 10 * time.Second + // HandshakeTimeout defines the duration a stream handshake should take. HandshakeTimeout = time.Second * 20 diff --git a/pkg/dmsgclient/cli.go b/pkg/dmsgclient/cli.go index f9ebfb9b..52e4a9c3 100644 --- a/pkg/dmsgclient/cli.go +++ b/pkg/dmsgclient/cli.go @@ -63,6 +63,13 @@ Default mode of operation is dmsghttp: // InitDmsgWithFlags starts dmsg with flags from the flags package func InitDmsgWithFlags(ctx context.Context, dlog *logging.Logger, pk cipher.PubKey, sk cipher.SecKey, httpClient *http.Client, destination string) (dmsgC *dmsg.Client, stop func(), err error) { + if DmsgServerAddr != "" { + srvEntry, err := ParseServerAddr(DmsgServerAddr) + if err != nil { + return nil, nil, err + } + return StartDmsgDirectWithServers(ctx, dlog, pk, sk, "", []*disc.Entry{srvEntry}, 1, dmsg.ExtractPKFromDmsgAddr(destination)) + } if UseDC { return StartDmsgDirect(ctx, dlog, pk, sk, "", DmsgSessions, dmsg.ExtractPKFromDmsgAddr(destination)) } diff --git a/pkg/dmsgclient/flags.go b/pkg/dmsgclient/flags.go index 004764fb..71ae9ca4 100644 --- a/pkg/dmsgclient/flags.go +++ b/pkg/dmsgclient/flags.go @@ -2,10 +2,14 @@ package dmsgclient import ( + "fmt" "os" + "strings" + "github.com/skycoin/skywire/pkg/skywire-utilities/pkg/cipher" "github.com/spf13/cobra" + "github.com/skycoin/dmsg/pkg/disc" "github.com/skycoin/dmsg/pkg/dmsg" ) @@ -27,6 +31,10 @@ var ( // UseDC use dmsg direct client with embedded dmsg server configuration and don't connect to discovery server UseDC = false + + // DmsgServerAddr specifies a specific dmsg server to connect through. + // Format: pk@ip:port (e.g., 02a2d4c3...@139.162.173.101:30082) + DmsgServerAddr string ) // InitFlags is used to set command flags for the above variables @@ -37,6 +45,25 @@ func InitFlags(cmd *cobra.Command) { cmd.Flags().StringVarP(&DmsgDiscAddr, "disc-addr", "A", DmsgDiscAddr, "DMSG Discovery dmsg address\033[0m\n\r") cmd.Flags().StringVarP(&DmsgHTTPPath, "dmsgconf", "D", "", "dmsghttp-config path") cmd.Flags().IntVarP(&DmsgSessions, "sess", "e", DmsgSessions, "number of DMSG Servers to connect to\033[0m\n\r") + cmd.Flags().StringVarP(&DmsgServerAddr, "srv", "S", "", "connect via specific dmsg server `pk@ip:port`\033[0m\n\r") +} + +// ParseServerAddr parses the --srv flag value into a disc.Entry. +// Format: pk@ip:port +func ParseServerAddr(s string) (*disc.Entry, error) { + parts := strings.SplitN(s, "@", 2) + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return nil, fmt.Errorf("invalid server address %q, expected pk@ip:port", s) + } + var pk cipher.PubKey + if err := pk.Set(parts[0]); err != nil { + return nil, fmt.Errorf("invalid server public key: %w", err) + } + return &disc.Entry{ + Version: "0.0.1", + Static: pk, + Server: &disc.Server{Address: parts[1], AvailableSessions: 2048}, + }, nil } // InitConfig is used to set command flags for the above variables