From 71f32cb46dfd679d3506c38c61a4974a213475fe Mon Sep 17 00:00:00 2001 From: Moses Narrow <36607567+0pcom@users.noreply.github.com> Date: Mon, 30 Mar 2026 12:16:44 -0500 Subject: [PATCH] Implement ping for smux sessions smux (unlike yamux) has no built-in ping. Implement it using a lightweight stream-level ping protocol: Client side (SessionCommon.Ping): - Opens a temporary smux stream - Writes a 2-byte zero marker [0x00, 0x00] (ping) - Reads 2-byte echo, measures RTT - Closes stream (5s deadline) Server side (serveStream): - Reads first 2 bytes of each new stream - If [0x00, 0x00]: echoes the marker back and closes (ping response) - Otherwise: passes the bytes through to readRequest via MultiReader The [0x00, 0x00] marker is safe because it represents a zero-length object, which cannot occur in normal session traffic (valid SignedObjects always have length > 0). Yamux sessions continue to use the built-in yamux.Ping(). --- pkg/dmsg/server_session.go | 20 +++++++++++++++++++- pkg/dmsg/session_common.go | 34 +++++++++++++++++++++++++++++++++- 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/pkg/dmsg/server_session.go b/pkg/dmsg/server_session.go index 861fa36f..3ff0c90b 100644 --- a/pkg/dmsg/server_session.go +++ b/pkg/dmsg/server_session.go @@ -2,6 +2,7 @@ package dmsg import ( + "bytes" "fmt" "io" "net" @@ -144,8 +145,25 @@ func (ss *ServerSession) serveStream(log logrus.FieldLogger, yStr io.ReadWriteCl } } + // Check for ping marker: a 2-byte zero-length prefix [0x00, 0x00]. + // This cannot occur in normal traffic since valid objects always have length > 0. + // Read the first 2 bytes to check before passing to readRequest. + header := make([]byte, 2) + if _, err := io.ReadFull(yStr, header); err != nil { + return err + } + if header[0] == 0 && header[1] == 0 { + // Ping: echo back the marker and close. + _, err := yStr.Write(pingMarker) + return err + } + + // Not a ping — the 2 bytes are the length prefix of a normal object. + // Pass them through to readRequest via a prefixed reader. + prefixedReader := io.MultiReader(bytes.NewReader(header), yStr) + readRequest := func() (StreamRequest, error) { - obj, err := ss.readObject(yStr) + obj, err := ss.readObject(prefixedReader) if err != nil { return StreamRequest{}, err } diff --git a/pkg/dmsg/session_common.go b/pkg/dmsg/session_common.go index 25e5b737..9413f042 100644 --- a/pkg/dmsg/session_common.go +++ b/pkg/dmsg/session_common.go @@ -168,6 +168,11 @@ func (sc *SessionCommon) LocalTCPAddr() net.Addr { return sc.netConn.LocalAddr() // RemoteTCPAddr returns the remote address of the underlying TCP connection. func (sc *SessionCommon) RemoteTCPAddr() net.Addr { return sc.netConn.RemoteAddr() } +// pingMarker is a 2-byte zero-length prefix that cannot occur in normal +// session traffic (valid SignedObjects always have length > 0). Used to +// implement ping over smux streams. +var pingMarker = []byte{0x00, 0x00} + // Ping obtains the round trip latency of the session. func (sc *SessionCommon) Ping() (time.Duration, error) { sc.sm.mutx.RLock() @@ -175,7 +180,34 @@ func (sc *SessionCommon) Ping() (time.Duration, error) { if sc.sm.yamux != nil { return sc.sm.yamux.Ping() } - return 0, fmt.Errorf("Ping not available on SMUX protocol") + if sc.sm.smux != nil { + return sc.smuxPing() + } + return 0, fmt.Errorf("no mux session available for ping") +} + +// smuxPing implements ping over smux by opening a temporary stream, +// writing a ping marker, and waiting for the echo. +func (sc *SessionCommon) smuxPing() (time.Duration, error) { + str, err := sc.sm.smux.OpenStream() + if err != nil { + return 0, fmt.Errorf("smux ping: open stream: %w", err) + } + defer str.Close() //nolint:errcheck + + if err := str.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { + return 0, fmt.Errorf("smux ping: set deadline: %w", err) + } + + start := time.Now() + if _, err := str.Write(pingMarker); err != nil { + return 0, fmt.Errorf("smux ping: write: %w", err) + } + resp := make([]byte, 2) + if _, err := io.ReadFull(str, resp); err != nil { + return 0, fmt.Errorf("smux ping: read: %w", err) + } + return time.Since(start), nil } // Close closes the session.