diff --git a/pkg/dmsg/server_session.go b/pkg/dmsg/server_session.go index 861fa36f..3ff0c90b 100644 --- a/pkg/dmsg/server_session.go +++ b/pkg/dmsg/server_session.go @@ -2,6 +2,7 @@ package dmsg import ( + "bytes" "fmt" "io" "net" @@ -144,8 +145,25 @@ func (ss *ServerSession) serveStream(log logrus.FieldLogger, yStr io.ReadWriteCl } } + // Check for ping marker: a 2-byte zero-length prefix [0x00, 0x00]. + // This cannot occur in normal traffic since valid objects always have length > 0. + // Read the first 2 bytes to check before passing to readRequest. + header := make([]byte, 2) + if _, err := io.ReadFull(yStr, header); err != nil { + return err + } + if header[0] == 0 && header[1] == 0 { + // Ping: echo back the marker and close. + _, err := yStr.Write(pingMarker) + return err + } + + // Not a ping — the 2 bytes are the length prefix of a normal object. + // Pass them through to readRequest via a prefixed reader. + prefixedReader := io.MultiReader(bytes.NewReader(header), yStr) + readRequest := func() (StreamRequest, error) { - obj, err := ss.readObject(yStr) + obj, err := ss.readObject(prefixedReader) if err != nil { return StreamRequest{}, err } diff --git a/pkg/dmsg/session_common.go b/pkg/dmsg/session_common.go index 25e5b737..9413f042 100644 --- a/pkg/dmsg/session_common.go +++ b/pkg/dmsg/session_common.go @@ -168,6 +168,11 @@ func (sc *SessionCommon) LocalTCPAddr() net.Addr { return sc.netConn.LocalAddr() // RemoteTCPAddr returns the remote address of the underlying TCP connection. func (sc *SessionCommon) RemoteTCPAddr() net.Addr { return sc.netConn.RemoteAddr() } +// pingMarker is a 2-byte zero-length prefix that cannot occur in normal +// session traffic (valid SignedObjects always have length > 0). Used to +// implement ping over smux streams. +var pingMarker = []byte{0x00, 0x00} + // Ping obtains the round trip latency of the session. func (sc *SessionCommon) Ping() (time.Duration, error) { sc.sm.mutx.RLock() @@ -175,7 +180,34 @@ func (sc *SessionCommon) Ping() (time.Duration, error) { if sc.sm.yamux != nil { return sc.sm.yamux.Ping() } - return 0, fmt.Errorf("Ping not available on SMUX protocol") + if sc.sm.smux != nil { + return sc.smuxPing() + } + return 0, fmt.Errorf("no mux session available for ping") +} + +// smuxPing implements ping over smux by opening a temporary stream, +// writing a ping marker, and waiting for the echo. +func (sc *SessionCommon) smuxPing() (time.Duration, error) { + str, err := sc.sm.smux.OpenStream() + if err != nil { + return 0, fmt.Errorf("smux ping: open stream: %w", err) + } + defer str.Close() //nolint:errcheck + + if err := str.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { + return 0, fmt.Errorf("smux ping: set deadline: %w", err) + } + + start := time.Now() + if _, err := str.Write(pingMarker); err != nil { + return 0, fmt.Errorf("smux ping: write: %w", err) + } + resp := make([]byte, 2) + if _, err := io.ReadFull(str, resp); err != nil { + return 0, fmt.Errorf("smux ping: read: %w", err) + } + return time.Since(start), nil } // Close closes the session.