From e10e15fc5134b29a9ce412069090cdcfb6819fe9 Mon Sep 17 00:00:00 2001 From: sirzooro Date: Tue, 20 Jan 2026 07:41:24 +0100 Subject: [PATCH] Updated received trickled candidates handling Implements RFC 8838 section 11.4 (Receiving Trickled Candidates) behavior so that when a remote peer-reflexive (prflx) candidate is discovered first, a later signaled equivalent candidate (host/srflx/relay) replaces it in-place while preserving checklist/pair behavior. Set priority of prflx candidate to value received in STUN PRIORITY attribute. --- agent.go | 56 +++++++- agent_handlers.go | 24 ++-- agent_handlers_test.go | 4 +- agent_test.go | 311 ++++++++++++++++++++++++++++++++++++++++- candidate.go | 6 + candidate_base.go | 22 ++- candidatepair.go | 11 ++ selection.go | 8 +- 8 files changed, 420 insertions(+), 22 deletions(-) diff --git a/agent.go b/agent.go index df83c330..cc51e5a2 100644 --- a/agent.go +++ b/agent.go @@ -1014,6 +1014,45 @@ func (a *Agent) addRemotePassiveTCPCandidate(remoteCandidate Candidate) { } } +// replaceRedundantPeerReflexiveCandidates removes any peer-reflexive candidates +// from the given set that have the same transport address as cand. +// It also updates any candidate pairs and local candidate caches that +// referenced the removed peer-reflexive candidates to reference cand instead. +// It is implemented according to RFC 8838 §11.4. +// It returns the updated set of candidates. +func (a *Agent) replaceRedundantPeerReflexiveCandidates(set []Candidate, cand Candidate) []Candidate { + if cand.Type() != CandidateTypePeerReflexive { + var replacedPrflx []Candidate + + for i := 0; i < len(set); i++ { + existing := set[i] + if existing.Type() == CandidateTypePeerReflexive && existing.transportAddressEqual(cand) { + replacedPrflx = append(replacedPrflx, existing) + set = append(set[:i], set[i+1:]...) + i-- + } + } + + for _, oldRemote := range replacedPrflx { + for _, pair := range a.checklist { + if pair.Remote == oldRemote { + oldPriority := pair.priority() + pair.Remote = cand + pair.setPriorityOverride(oldPriority) + } + } + + for _, locals := range a.localCandidates { + for _, local := range locals { + local.replaceRemoteCandidateCacheValues(oldRemote, cand) + } + } + } + } + + return set +} + // addRemoteCandidate assumes you are holding the lock (must be execute using a.run). func (a *Agent) addRemoteCandidate(cand Candidate) { //nolint:cyclop set := a.remoteCandidates[cand.NetworkType()] @@ -1024,6 +1063,11 @@ func (a *Agent) addRemoteCandidate(cand Candidate) { //nolint:cyclop } } + // RFC 8838 §11.4: If a trickled candidate is redundant with an existing + // peer-reflexive candidate (same transport address), prefer the signaled + // candidate and replace the peer-reflexive one. + set = a.replaceRedundantPeerReflexiveCandidates(set, cand) + acceptRemotePassiveTCPCandidate := false // Assert that TCP4 or TCP6 is a enabled NetworkType locally if !a.disableActiveTCP && cand.TCPType() == TCPTypePassive { @@ -1044,7 +1088,9 @@ func (a *Agent) addRemoteCandidate(cand Candidate) { //nolint:cyclop if cand.TCPType() != TCPTypePassive { if localCandidates, ok := a.localCandidates[cand.NetworkType()]; ok { for _, localCandidate := range localCandidates { - a.addPair(localCandidate, cand) + if a.findPair(localCandidate, cand) == nil { + a.addPair(localCandidate, cand) + } } } } @@ -1487,6 +1533,14 @@ func (a *Agent) handleInboundRequest( RelPort: 0, } + // A peer-reflexive candidate SHOULD take its priority from the PRIORITY + // attribute in the Binding Request that discovered it. + var prio PriorityAttr + err = prio.GetFrom(msg) + if err == nil { + prflxCandidateConfig.Priority = uint32(prio) + } + prflxCandidate, err := NewCandidatePeerReflexive(&prflxCandidateConfig) if err != nil { a.log.Errorf("Failed to create new remote prflx candidate (%s)", err) diff --git a/agent_handlers.go b/agent_handlers.go index 95ee9256..0f96cd77 100644 --- a/agent_handlers.go +++ b/agent_handlers.go @@ -48,8 +48,10 @@ func (a *Agent) onConnectionStateChange(s ConnectionState) { type handlerNotifier struct { sync.Mutex - running bool - notifiers sync.WaitGroup + runningConnectionStates bool + runningCandidates bool + runningCandidatePairs bool + notifiers sync.WaitGroup connectionStates []ConnectionState connectionStateFunc func(ConnectionState) @@ -99,7 +101,7 @@ func (h *handlerNotifier) EnqueueConnectionState(state ConnectionState) { for { h.Lock() if len(h.connectionStates) == 0 { - h.running = false + h.runningConnectionStates = false h.Unlock() return @@ -112,8 +114,8 @@ func (h *handlerNotifier) EnqueueConnectionState(state ConnectionState) { } h.connectionStates = append(h.connectionStates, state) - if !h.running { - h.running = true + if !h.runningConnectionStates { + h.runningConnectionStates = true h.notifiers.Add(1) go notify() } @@ -134,7 +136,7 @@ func (h *handlerNotifier) EnqueueCandidate(cand Candidate) { for { h.Lock() if len(h.candidates) == 0 { - h.running = false + h.runningCandidates = false h.Unlock() return @@ -147,8 +149,8 @@ func (h *handlerNotifier) EnqueueCandidate(cand Candidate) { } h.candidates = append(h.candidates, cand) - if !h.running { - h.running = true + if !h.runningCandidates { + h.runningCandidates = true h.notifiers.Add(1) go notify() } @@ -169,7 +171,7 @@ func (h *handlerNotifier) EnqueueSelectedCandidatePair(pair *CandidatePair) { for { h.Lock() if len(h.selectedCandidatePairs) == 0 { - h.running = false + h.runningCandidatePairs = false h.Unlock() return @@ -182,8 +184,8 @@ func (h *handlerNotifier) EnqueueSelectedCandidatePair(pair *CandidatePair) { } h.selectedCandidatePairs = append(h.selectedCandidatePairs, pair) - if !h.running { - h.running = true + if !h.runningCandidatePairs { + h.runningCandidatePairs = true h.notifiers.Add(1) go notify() } diff --git a/agent_handlers_test.go b/agent_handlers_test.go index 0b731597..28b66721 100644 --- a/agent_handlers_test.go +++ b/agent_handlers_test.go @@ -116,7 +116,9 @@ func TestHandlerNotifier_Close_AlreadyClosed(t *testing.T) { assert.True(t, isClosed(notifier.done), "expected h.done to remain closed after second Close") // sanity: no enqueues should start after close. - require.False(t, notifier.running) + require.False(t, notifier.runningConnectionStates) + require.False(t, notifier.runningCandidates) + require.False(t, notifier.runningCandidatePairs) require.Zero(t, len(notifier.connectionStates)) require.Zero(t, len(notifier.candidates)) require.Zero(t, len(notifier.selectedCandidatePairs)) diff --git a/agent_test.go b/agent_test.go index 0dd922af..deaa3519 100644 --- a/agent_test.go +++ b/agent_test.go @@ -51,7 +51,7 @@ func (r *recordingSelector) HandleBindingRequest(*stun.Message, Candidate, Candi r.handledBindingRequest = true } -func TestHandlePeerReflexive(t *testing.T) { //nolint:cyclop +func TestHandlePeerReflexive(t *testing.T) { //nolint:cyclop,maintidx defer test.CheckRoutines(t)() // Limit runtime in case of deadlocks @@ -107,6 +107,315 @@ func TestHandlePeerReflexive(t *testing.T) { //nolint:cyclop })) }) + t.Run("prflx candidate priority comes from inbound PRIORITY", func(t *testing.T) { + agent, err := NewAgent(&AgentConfig{}) + require.NoError(t, err) + defer func() { + require.NoError(t, agent.Close()) + }() + + require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) { + agent.selector = &controllingSelector{agent: agent, log: agent.log} + + local, err := NewCandidateHost(&CandidateHostConfig{ + Network: "udp", + Address: "192.168.0.2", + Port: 777, + Component: 1, + }) + require.NoError(t, err) + local.conn = &fakenet.MockPacketConn{} + + remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999} + remotePriority := uint32(123456) + + msg, err := stun.Build(stun.BindingRequest, stun.TransactionID, + stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag), + UseCandidate(), + AttrControlling(agent.tieBreaker), + PriorityAttr(remotePriority), + stun.NewShortTermIntegrity(agent.localPwd), + stun.Fingerprint, + ) + require.NoError(t, err) + + // nolint: contextcheck + agent.handleInbound(msg, local, remote) + + set := agent.remoteCandidates[local.NetworkType()] + require.Len(t, set, 1) + + c := set[0] + require.Equal(t, CandidateTypePeerReflexive, c.Type()) + require.Equal(t, remotePriority, c.Priority()) + })) + }) + + t.Run("Signaled host candidate replaces existing remote prflx candidate", func(t *testing.T) { + agent, err := NewAgent(&AgentConfig{}) + require.NoError(t, err) + defer func() { + require.NoError(t, agent.Close()) + }() + + require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) { + agent.selector = &controllingSelector{agent: agent, log: agent.log} + + local, err := NewCandidateHost(&CandidateHostConfig{ + Network: "udp", + Address: "192.168.0.2", + Port: 777, + Component: 1, + }) + require.NoError(t, err) + local.conn = &fakenet.MockPacketConn{} + agent.localCandidates[local.NetworkType()] = []Candidate{local} + + remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999} + msg, err := stun.Build(stun.BindingRequest, stun.TransactionID, + stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag), + UseCandidate(), + AttrControlling(agent.tieBreaker), + PriorityAttr(uint32(99999)), + stun.NewShortTermIntegrity(agent.localPwd), + stun.Fingerprint, + ) + require.NoError(t, err) + + // nolint: contextcheck + agent.handleInbound(msg, local, remote) + + set := agent.remoteCandidates[local.NetworkType()] + require.Len(t, set, 1) + prflx := set[0] + require.Equal(t, CandidateTypePeerReflexive, prflx.Type()) + require.Len(t, agent.checklist, 1) + pair := agent.checklist[0] + require.Equal(t, prflx, pair.Remote) + + local.addRemoteCandidateCache(prflx, remote) + oldPriority := pair.priority() + + host, err := NewCandidateHost(&CandidateHostConfig{ + Network: "udp", + Address: "172.17.0.3", + Port: 999, + Component: 1, + }) + require.NoError(t, err) + agent.addRemoteCandidate(host) // nolint:contextcheck + + set = agent.remoteCandidates[local.NetworkType()] + require.Len(t, set, 1) + require.Equal(t, CandidateTypeHost, set[0].Type()) + require.Equal(t, host, set[0]) + require.Equal(t, host, pair.Remote) + require.Equal(t, oldPriority, pair.priority()) + require.Equal(t, host, local.remoteCandidateCaches[toAddrPort(remote)]) + })) + }) + + t.Run("Signaled srflx candidate replaces existing remote prflx candidate", func(t *testing.T) { // nolint:dupl + agent, err := NewAgent(&AgentConfig{}) + require.NoError(t, err) + defer func() { + require.NoError(t, agent.Close()) + }() + + require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) { + agent.selector = &controllingSelector{agent: agent, log: agent.log} + + local, err := NewCandidateHost(&CandidateHostConfig{ + Network: "udp", + Address: "192.168.0.2", + Port: 777, + Component: 1, + }) + require.NoError(t, err) + local.conn = &fakenet.MockPacketConn{} + agent.localCandidates[local.NetworkType()] = []Candidate{local} + + remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999} + msg, err := stun.Build(stun.BindingRequest, stun.TransactionID, + stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag), + UseCandidate(), + AttrControlling(agent.tieBreaker), + PriorityAttr(uint32(99999)), + stun.NewShortTermIntegrity(agent.localPwd), + stun.Fingerprint, + ) + require.NoError(t, err) + + // nolint: contextcheck + agent.handleInbound(msg, local, remote) + + set := agent.remoteCandidates[local.NetworkType()] + require.Len(t, set, 1) + prflx := set[0] + require.Equal(t, CandidateTypePeerReflexive, prflx.Type()) + require.Len(t, agent.checklist, 1) + pair := agent.checklist[0] + require.Equal(t, prflx, pair.Remote) + + local.addRemoteCandidateCache(prflx, remote) + oldPriority := pair.priority() + + srflx, err := NewCandidateServerReflexive(&CandidateServerReflexiveConfig{ + Network: "udp", + Address: "172.17.0.3", + Port: 999, + Component: 1, + RelAddr: "0.0.0.0", + RelPort: 0, + }) + require.NoError(t, err) + agent.addRemoteCandidate(srflx) // nolint:contextcheck + + set = agent.remoteCandidates[local.NetworkType()] + require.Len(t, set, 1) + require.Equal(t, CandidateTypeServerReflexive, set[0].Type()) + require.Equal(t, srflx, set[0]) + require.Equal(t, srflx, pair.Remote) + require.Equal(t, oldPriority, pair.priority()) + require.Equal(t, srflx, local.remoteCandidateCaches[toAddrPort(remote)]) + })) + }) + + t.Run("Signaled relay candidate replaces existing remote prflx candidate", func(t *testing.T) { // nolint:dupl + agent, err := NewAgent(&AgentConfig{}) + require.NoError(t, err) + defer func() { + require.NoError(t, agent.Close()) + }() + + require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) { + agent.selector = &controllingSelector{agent: agent, log: agent.log} + + local, err := NewCandidateHost(&CandidateHostConfig{ + Network: "udp", + Address: "192.168.0.2", + Port: 777, + Component: 1, + }) + require.NoError(t, err) + local.conn = &fakenet.MockPacketConn{} + agent.localCandidates[local.NetworkType()] = []Candidate{local} + + remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999} + msg, err := stun.Build(stun.BindingRequest, stun.TransactionID, + stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag), + UseCandidate(), + AttrControlling(agent.tieBreaker), + PriorityAttr(uint32(99999)), + stun.NewShortTermIntegrity(agent.localPwd), + stun.Fingerprint, + ) + require.NoError(t, err) + + // nolint: contextcheck + agent.handleInbound(msg, local, remote) + + set := agent.remoteCandidates[local.NetworkType()] + require.Len(t, set, 1) + prflx := set[0] + require.Equal(t, CandidateTypePeerReflexive, prflx.Type()) + require.Len(t, agent.checklist, 1) + pair := agent.checklist[0] + require.Equal(t, prflx, pair.Remote) + + local.addRemoteCandidateCache(prflx, remote) + oldPriority := pair.priority() + + relay, err := NewCandidateRelay(&CandidateRelayConfig{ + Network: "udp", + Address: "172.17.0.3", + Port: 999, + Component: 1, + RelAddr: "0.0.0.0", + RelPort: 0, + }) + require.NoError(t, err) + agent.addRemoteCandidate(relay) // nolint:contextcheck + + set = agent.remoteCandidates[local.NetworkType()] + require.Len(t, set, 1) + require.Equal(t, CandidateTypeRelay, set[0].Type()) + require.Equal(t, relay, set[0]) + require.Equal(t, relay, pair.Remote) + require.Equal(t, oldPriority, pair.priority()) + require.Equal(t, relay, local.remoteCandidateCaches[toAddrPort(remote)]) + })) + }) + + t.Run("AcceptanceMinWait: prflx not accepted until replaced", func(t *testing.T) { + agent, err := NewAgent(&AgentConfig{}) + require.NoError(t, err) + defer func() { + require.NoError(t, agent.Close()) + }() + + require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) { + agent.isControlling.Store(true) + agent.remoteUfrag = "remoteUfrag" + agent.remotePwd = "remotePwd" + agent.hostAcceptanceMinWait = 0 + agent.prflxAcceptanceMinWait = time.Hour + + local, err := NewCandidateHost(&CandidateHostConfig{ + Network: "udp", + Address: "192.168.0.2", + Port: 777, + Component: 1, + }) + require.NoError(t, err) + local.conn = &fakenet.MockPacketConn{} + agent.localCandidates[local.NetworkType()] = []Candidate{local} + + prflx, err := NewCandidatePeerReflexive(&CandidatePeerReflexiveConfig{ + Network: "udp", + Address: "1.2.3.4", + Port: 999, + Component: 1, + }) + require.NoError(t, err) + agent.addRemoteCandidate(prflx) // nolint:contextcheck + + pair := agent.findPair(local, prflx) + require.NotNil(t, pair) + require.Equal(t, CandidateTypeHost, pair.Local.Type()) + require.Equal(t, CandidateTypePeerReflexive, pair.Remote.Type()) + pair.state = CandidatePairStateSucceeded + + sel := &controllingSelector{agent: agent, log: agent.log} + sel.Start() + + // With prflxAcceptanceMinWait set high, remote prflx candidate should not be nominatable. + sel.ContactCandidates() + require.Nil(t, sel.nominatedPair) + require.False(t, pair.nominated) + + // Trickle the signaled candidate for the same transport address. + signaled, err := NewCandidateHost(&CandidateHostConfig{ + Network: "udp", + Address: "1.2.3.4", + Port: 999, + Component: 1, + }) + require.NoError(t, err) + agent.addRemoteCandidate(signaled) // nolint:contextcheck + require.Equal(t, signaled, pair.Remote) + require.Equal(t, CandidateTypeHost, pair.Remote.Type()) + + // Now the (updated) pair should be nominatable and become nominated. + sel.ContactCandidates() + require.NotNil(t, sel.nominatedPair) + require.Same(t, pair, sel.nominatedPair) + require.Equal(t, CandidateTypeHost, sel.nominatedPair.Local.Type()) + require.Equal(t, CandidateTypeHost, sel.nominatedPair.Remote.Type()) + require.True(t, pair.nominated) + })) + }) + t.Run("Bad network type with handleInbound()", func(t *testing.T) { agent, err := NewAgent(&AgentConfig{}) require.NoError(t, err) diff --git a/candidate.go b/candidate.go index dafe07ca..362d925f 100644 --- a/candidate.go +++ b/candidate.go @@ -82,6 +82,10 @@ type Candidate interface { Marshal() string + // transportAddressEqual checks if the transport address (IP, Port, NetworkType, TCPType) is equal to another + // candidate. + transportAddressEqual(other Candidate) bool + addr() net.Addr filterForLocationTracking() bool agent() *Agent @@ -92,4 +96,6 @@ type Candidate interface { seen(outbound bool) start(a *Agent, conn net.PacketConn, initializedCh <-chan struct{}) writeTo(raw []byte, dst Candidate) (int, error) + + replaceRemoteCandidateCacheValues(oldRemote, newRemote Candidate) } diff --git a/candidate_base.go b/candidate_base.go index 6ddffed7..f41f8ee1 100644 --- a/candidate_base.go +++ b/candidate_base.go @@ -301,6 +301,14 @@ func (c *candidateBase) addRemoteCandidateCache(candidate Candidate, srcAddr net c.remoteCandidateCaches[toAddrPort(srcAddr)] = candidate } +func (c *candidateBase) replaceRemoteCandidateCacheValues(oldRemote, newRemote Candidate) { + for k, v := range c.remoteCandidateCaches { + if v == oldRemote { + c.remoteCandidateCaches[k] = newRemote + } + } +} + func (c *candidateBase) handleInboundPacket(buf []byte, srcAddr net.Addr) { agent := c.agent() @@ -445,8 +453,9 @@ func (c *candidateBase) Priority() uint32 { (1<<0)*uint32(256-c.Component()) } -// Equal is used to compare two candidateBases. -func (c *candidateBase) Equal(other Candidate) bool { +// transportAddressEqual checks if the transport address (IP, Port, NetworkType, TCPType) is equal to another +// candidate. +func (c *candidateBase) transportAddressEqual(other Candidate) bool { if c.addr() != other.addr() { if c.addr() == nil || other.addr() == nil { return false @@ -457,10 +466,15 @@ func (c *candidateBase) Equal(other Candidate) bool { } return c.NetworkType() == other.NetworkType() && - c.Type() == other.Type() && c.Address() == other.Address() && c.Port() == other.Port() && - c.TCPType() == other.TCPType() && + c.TCPType() == other.TCPType() +} + +// Equal is used to compare two candidateBases. +func (c *candidateBase) Equal(other Candidate) bool { + return c.transportAddressEqual(other) && + c.Type() == other.Type() && c.RelatedAddress().Equal(other.RelatedAddress()) } diff --git a/candidatepair.go b/candidatepair.go index c0a7dafa..65378b0f 100644 --- a/candidatepair.go +++ b/candidatepair.go @@ -26,6 +26,8 @@ type CandidatePair struct { iceRoleControlling bool Remote Candidate Local Candidate + priorityOverride uint64 + hasPriorityOverride bool bindingRequestCount uint16 state CandidatePairState nominated bool @@ -84,12 +86,21 @@ func (p *CandidatePair) equal(other *CandidatePair) bool { return p.Local.Equal(other.Local) && p.Remote.Equal(other.Remote) } +func (p *CandidatePair) setPriorityOverride(prio uint64) { + p.priorityOverride = prio + p.hasPriorityOverride = true +} + // RFC 5245 - 5.7.2. Computing Pair Priority and Ordering Pairs // Let G be the priority for the candidate provided by the controlling // agent. Let D be the priority for the candidate provided by the // controlled agent. // pair priority = 2^32*MIN(G,D) + 2*MAX(G,D) + (G>D?1:0). func (p *CandidatePair) priority() uint64 { + if p.hasPriorityOverride { + return p.priorityOverride + } + var g, d uint32 //nolint:varnamelen // clearer to use g and d here if p.iceRoleControlling { g = p.Local.Priority() diff --git a/selection.go b/selection.go index 78e86d7b..2f783a4d 100644 --- a/selection.go +++ b/selection.go @@ -34,13 +34,13 @@ func (s *controllingSelector) Start() { func (s *controllingSelector) isNominatable(c Candidate) bool { switch { case c.Type() == CandidateTypeHost: - return time.Since(s.startTime).Nanoseconds() > s.agent.hostAcceptanceMinWait.Nanoseconds() + return time.Since(s.startTime).Nanoseconds() >= s.agent.hostAcceptanceMinWait.Nanoseconds() case c.Type() == CandidateTypeServerReflexive: - return time.Since(s.startTime).Nanoseconds() > s.agent.srflxAcceptanceMinWait.Nanoseconds() + return time.Since(s.startTime).Nanoseconds() >= s.agent.srflxAcceptanceMinWait.Nanoseconds() case c.Type() == CandidateTypePeerReflexive: - return time.Since(s.startTime).Nanoseconds() > s.agent.prflxAcceptanceMinWait.Nanoseconds() + return time.Since(s.startTime).Nanoseconds() >= s.agent.prflxAcceptanceMinWait.Nanoseconds() case c.Type() == CandidateTypeRelay: - return time.Since(s.startTime).Nanoseconds() > s.agent.relayAcceptanceMinWait.Nanoseconds() + return time.Since(s.startTime).Nanoseconds() >= s.agent.relayAcceptanceMinWait.Nanoseconds() } s.log.Errorf("Invalid candidate type: %s", c.Type())