diff --git a/pkg/dmsg/server_session.go b/pkg/dmsg/server_session.go index b8ff1586..861fa36f 100644 --- a/pkg/dmsg/server_session.go +++ b/pkg/dmsg/server_session.go @@ -13,7 +13,6 @@ import ( "github.com/xtaci/smux" "github.com/skycoin/dmsg/pkg/dmsg/metrics" - "github.com/skycoin/dmsg/pkg/noise" ) const ( @@ -35,7 +34,6 @@ type ServerSession struct { func makeServerSession(m metrics.Metrics, entity *EntityCommon, conn net.Conn) (ServerSession, error) { var sSes ServerSession sSes.SessionCommon = new(SessionCommon) - sSes.nMap = make(noise.NonceMap) if err := sSes.SessionCommon.initServer(entity, conn); err != nil { m.RecordSession(metrics.DeltaFailed) // record failed connection return sSes, err diff --git a/pkg/dmsg/session_common.go b/pkg/dmsg/session_common.go index 1c47a506..25e5b737 100644 --- a/pkg/dmsg/session_common.go +++ b/pkg/dmsg/session_common.go @@ -28,11 +28,11 @@ type SessionCommon struct { netConn net.Conn // underlying net.Conn (TCP connection to the dmsg server) // ys *yamux.Session // ss *smux.Session - sm SessionManager - ns *noise.Noise - nMap noise.NonceMap - rMx sync.Mutex - wMx sync.Mutex + sm SessionManager + ns *noise.Noise + nw *noise.NonceWindow + rMx sync.Mutex + wMx sync.Mutex log logrus.FieldLogger } @@ -86,7 +86,7 @@ func (sc *SessionCommon) initClient(entity *EntityCommon, conn net.Conn, rPK cip sc.rPK = rPK sc.netConn = conn sc.ns = ns - sc.nMap = make(noise.NonceMap) + sc.nw = noise.NewNonceWindow() sc.log = entity.log.WithField("session", ns.RemoteStatic()) return nil } @@ -113,7 +113,7 @@ func (sc *SessionCommon) initServer(entity *EntityCommon, conn net.Conn) error { sc.rPK = ns.RemoteStatic() sc.netConn = conn sc.ns = ns - sc.nMap = make(noise.NonceMap) + sc.nw = noise.NewNonceWindow() sc.log = entity.log.WithField("session", ns.RemoteStatic()) return nil } @@ -144,11 +144,11 @@ func (sc *SessionCommon) readObject(r io.Reader) (SignedObject, error) { } sc.rMx.Lock() - if sc.nMap == nil { + if sc.nw == nil { sc.rMx.Unlock() return nil, ErrSessionClosed } - obj, err := sc.ns.DecryptWithNonceMap(sc.nMap, pb) + obj, err := sc.ns.DecryptWithNonceWindow(sc.nw, pb) sc.rMx.Unlock() return obj, err @@ -192,7 +192,7 @@ func (sc *SessionCommon) Close() error { } sc.sm.mutx.Unlock() sc.rMx.Lock() - sc.nMap = nil + sc.nw = nil sc.rMx.Unlock() return err } diff --git a/pkg/noise/noise.go b/pkg/noise/noise.go index 759131d3..3750aaec 100644 --- a/pkg/noise/noise.go +++ b/pkg/noise/noise.go @@ -172,9 +172,11 @@ func (ns *Noise) DecryptUnsafe(ciphertext []byte) ([]byte, error) { } // NonceMap is a map of used nonces. +// Deprecated: Use NonceWindow instead for bounded memory usage. type NonceMap map[uint64]struct{} // DecryptWithNonceMap is equivalent to DecryptNonce, instead it uses NonceMap to track nonces instead of a counter. +// Deprecated: Use DecryptWithNonceWindow instead. func (ns *Noise) DecryptWithNonceMap(nm NonceMap, ciphertext []byte) ([]byte, error) { if len(ciphertext) < nonceSize { return nil, ErrInvalidCipherText @@ -190,3 +192,21 @@ func (ns *Noise) DecryptWithNonceMap(nm NonceMap, ciphertext []byte) ([]byte, er nm[recvSeq] = struct{}{} return plaintext, nil } + +// DecryptWithNonceWindow decrypts ciphertext using a sliding window for nonce tracking. +// Unlike DecryptWithNonceMap, memory usage is bounded to O(NonceWindowSize) regardless +// of how many messages are decrypted over the session's lifetime. +func (ns *Noise) DecryptWithNonceWindow(nw *NonceWindow, ciphertext []byte) ([]byte, error) { + if len(ciphertext) < nonceSize { + return nil, ErrInvalidCipherText + } + recvSeq := binary.BigEndian.Uint64(ciphertext[:nonceSize]) + if err := nw.Check(recvSeq); err != nil { + return nil, err + } + plaintext, err := ns.dec.Cipher().Decrypt(nil, recvSeq, nil, ciphertext[nonceSize:]) + if err != nil { + return nil, err + } + return plaintext, nil +} diff --git a/pkg/noise/nonce_window.go b/pkg/noise/nonce_window.go new file mode 100644 index 00000000..da035618 --- /dev/null +++ b/pkg/noise/nonce_window.go @@ -0,0 +1,85 @@ +// Package noise pkg/noise/nonce_window.go +package noise + +import "fmt" + +// NonceWindowSize is the number of recent nonces tracked for replay prevention. +// Must be a multiple of 64 for the bitmap implementation. +const NonceWindowSize = 1024 + +// NonceWindow tracks nonces using a sliding window to prevent replay attacks +// without unbounded memory growth. It stores the highest nonce seen and a +// bitmap of the last NonceWindowSize nonces. Nonces older than the window +// are assumed to be replays. +// +// This replaces the unbounded NonceMap which grew forever on long-lived +// sessions (e.g., setup-node handling thousands of streams). +type NonceWindow struct { + maxNonce uint64 + bitmap [NonceWindowSize / 64]uint64 +} + +// NewNonceWindow creates a new NonceWindow. +func NewNonceWindow() *NonceWindow { + return &NonceWindow{} +} + +// Check returns nil if the nonce is valid (not a replay, not too old). +// If valid, the nonce is recorded in the window. +func (nw *NonceWindow) Check(nonce uint64) error { + if nonce == 0 { + return fmt.Errorf("nonce cannot be zero") + } + + if nw.maxNonce == 0 { + // First nonce seen. + nw.maxNonce = nonce + nw.setBit(nonce) + return nil + } + + if nonce > nw.maxNonce { + // New highest nonce — advance the window. + diff := nonce - nw.maxNonce + if diff >= NonceWindowSize { + // Nonce jumped far ahead — clear entire bitmap. + nw.bitmap = [NonceWindowSize / 64]uint64{} + } else { + // Clear the bits that are now outside the window. + for i := nw.maxNonce + 1; i <= nonce; i++ { + nw.clearBit(i) + } + } + nw.maxNonce = nonce + nw.setBit(nonce) + return nil + } + + // Nonce is <= maxNonce. Check if it's within the window. + if nw.maxNonce-nonce >= NonceWindowSize { + return fmt.Errorf("nonce (%d) is too old (window starts at %d)", nonce, nw.maxNonce-NonceWindowSize+1) + } + + // Check for replay. + if nw.hasBit(nonce) { + return fmt.Errorf("received decryption nonce (%d) is repeated", nonce) + } + + nw.setBit(nonce) + return nil +} + +func (nw *NonceWindow) setBit(nonce uint64) { + idx := nonce % NonceWindowSize + nw.bitmap[idx/64] |= 1 << (idx % 64) +} + +func (nw *NonceWindow) clearBit(nonce uint64) { + idx := nonce % NonceWindowSize + nw.bitmap[idx/64] &^= 1 << (idx % 64) +} + +func (nw *NonceWindow) hasBit(nonce uint64) bool { + idx := nonce % NonceWindowSize + return nw.bitmap[idx/64]&(1<<(idx%64)) != 0 +}