Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions pkg/dmsg/server_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"github.com/xtaci/smux"

"github.com/skycoin/dmsg/pkg/dmsg/metrics"
"github.com/skycoin/dmsg/pkg/noise"
)

const (
Expand All @@ -35,7 +34,6 @@ type ServerSession struct {
func makeServerSession(m metrics.Metrics, entity *EntityCommon, conn net.Conn) (ServerSession, error) {
var sSes ServerSession
sSes.SessionCommon = new(SessionCommon)
sSes.nMap = make(noise.NonceMap)
if err := sSes.SessionCommon.initServer(entity, conn); err != nil {
m.RecordSession(metrics.DeltaFailed) // record failed connection
return sSes, err
Expand Down
20 changes: 10 additions & 10 deletions pkg/dmsg/session_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ type SessionCommon struct {
netConn net.Conn // underlying net.Conn (TCP connection to the dmsg server)
// ys *yamux.Session
// ss *smux.Session
sm SessionManager
ns *noise.Noise
nMap noise.NonceMap
rMx sync.Mutex
wMx sync.Mutex
sm SessionManager
ns *noise.Noise
nw *noise.NonceWindow
rMx sync.Mutex
wMx sync.Mutex

log logrus.FieldLogger
}
Expand Down Expand Up @@ -86,7 +86,7 @@ func (sc *SessionCommon) initClient(entity *EntityCommon, conn net.Conn, rPK cip
sc.rPK = rPK
sc.netConn = conn
sc.ns = ns
sc.nMap = make(noise.NonceMap)
sc.nw = noise.NewNonceWindow()
sc.log = entity.log.WithField("session", ns.RemoteStatic())
return nil
}
Expand All @@ -113,7 +113,7 @@ func (sc *SessionCommon) initServer(entity *EntityCommon, conn net.Conn) error {
sc.rPK = ns.RemoteStatic()
sc.netConn = conn
sc.ns = ns
sc.nMap = make(noise.NonceMap)
sc.nw = noise.NewNonceWindow()
sc.log = entity.log.WithField("session", ns.RemoteStatic())
return nil
}
Expand Down Expand Up @@ -144,11 +144,11 @@ func (sc *SessionCommon) readObject(r io.Reader) (SignedObject, error) {
}

sc.rMx.Lock()
if sc.nMap == nil {
if sc.nw == nil {
sc.rMx.Unlock()
return nil, ErrSessionClosed
}
obj, err := sc.ns.DecryptWithNonceMap(sc.nMap, pb)
obj, err := sc.ns.DecryptWithNonceWindow(sc.nw, pb)
sc.rMx.Unlock()

return obj, err
Expand Down Expand Up @@ -192,7 +192,7 @@ func (sc *SessionCommon) Close() error {
}
sc.sm.mutx.Unlock()
sc.rMx.Lock()
sc.nMap = nil
sc.nw = nil
sc.rMx.Unlock()
return err
}
20 changes: 20 additions & 0 deletions pkg/noise/noise.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,11 @@ func (ns *Noise) DecryptUnsafe(ciphertext []byte) ([]byte, error) {
}

// NonceMap is a map of used nonces.
// Deprecated: Use NonceWindow instead for bounded memory usage.
type NonceMap map[uint64]struct{}

// DecryptWithNonceMap is equivalent to DecryptNonce, instead it uses NonceMap to track nonces instead of a counter.
// Deprecated: Use DecryptWithNonceWindow instead.
func (ns *Noise) DecryptWithNonceMap(nm NonceMap, ciphertext []byte) ([]byte, error) {
if len(ciphertext) < nonceSize {
return nil, ErrInvalidCipherText
Expand All @@ -190,3 +192,21 @@ func (ns *Noise) DecryptWithNonceMap(nm NonceMap, ciphertext []byte) ([]byte, er
nm[recvSeq] = struct{}{}
return plaintext, nil
}

// DecryptWithNonceWindow decrypts ciphertext using a sliding window for nonce tracking.
// Unlike DecryptWithNonceMap, memory usage is bounded to O(NonceWindowSize) regardless
// of how many messages are decrypted over the session's lifetime.
func (ns *Noise) DecryptWithNonceWindow(nw *NonceWindow, ciphertext []byte) ([]byte, error) {
if len(ciphertext) < nonceSize {
return nil, ErrInvalidCipherText
}
recvSeq := binary.BigEndian.Uint64(ciphertext[:nonceSize])
if err := nw.Check(recvSeq); err != nil {
return nil, err
}
plaintext, err := ns.dec.Cipher().Decrypt(nil, recvSeq, nil, ciphertext[nonceSize:])
if err != nil {
return nil, err
}
return plaintext, nil
}
85 changes: 85 additions & 0 deletions pkg/noise/nonce_window.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Package noise pkg/noise/nonce_window.go
package noise

import "fmt"

// NonceWindowSize is the number of recent nonces tracked for replay prevention.
// Must be a multiple of 64 for the bitmap implementation.
const NonceWindowSize = 1024

// NonceWindow tracks nonces using a sliding window to prevent replay attacks
// without unbounded memory growth. It stores the highest nonce seen and a
// bitmap of the last NonceWindowSize nonces. Nonces older than the window
// are assumed to be replays.
//
// This replaces the unbounded NonceMap which grew forever on long-lived
// sessions (e.g., setup-node handling thousands of streams).
type NonceWindow struct {
maxNonce uint64
bitmap [NonceWindowSize / 64]uint64
}

// NewNonceWindow creates a new NonceWindow.
func NewNonceWindow() *NonceWindow {
return &NonceWindow{}
}

// Check returns nil if the nonce is valid (not a replay, not too old).
// If valid, the nonce is recorded in the window.
func (nw *NonceWindow) Check(nonce uint64) error {
if nonce == 0 {
return fmt.Errorf("nonce cannot be zero")
}

if nw.maxNonce == 0 {
// First nonce seen.
nw.maxNonce = nonce
nw.setBit(nonce)
return nil
}

if nonce > nw.maxNonce {
// New highest nonce — advance the window.
diff := nonce - nw.maxNonce
if diff >= NonceWindowSize {
// Nonce jumped far ahead — clear entire bitmap.
nw.bitmap = [NonceWindowSize / 64]uint64{}
} else {
// Clear the bits that are now outside the window.
for i := nw.maxNonce + 1; i <= nonce; i++ {
nw.clearBit(i)
}
}
nw.maxNonce = nonce
nw.setBit(nonce)
return nil
}

// Nonce is <= maxNonce. Check if it's within the window.
if nw.maxNonce-nonce >= NonceWindowSize {
return fmt.Errorf("nonce (%d) is too old (window starts at %d)", nonce, nw.maxNonce-NonceWindowSize+1)
}

// Check for replay.
if nw.hasBit(nonce) {
return fmt.Errorf("received decryption nonce (%d) is repeated", nonce)
}

nw.setBit(nonce)
return nil
}

func (nw *NonceWindow) setBit(nonce uint64) {
idx := nonce % NonceWindowSize
nw.bitmap[idx/64] |= 1 << (idx % 64)
}

func (nw *NonceWindow) clearBit(nonce uint64) {
idx := nonce % NonceWindowSize
nw.bitmap[idx/64] &^= 1 << (idx % 64)
}

func (nw *NonceWindow) hasBit(nonce uint64) bool {
idx := nonce % NonceWindowSize
return nw.bitmap[idx/64]&(1<<(idx%64)) != 0
}
Loading