Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 234 additions & 3 deletions agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
import (
"context"
"fmt"
"hash/crc32"
"math"
"net"
"net/netip"
"slices"
"strings"
"sync"
"sync/atomic"
Expand All @@ -35,6 +37,34 @@
destination net.Addr
isUseCandidate bool
nominationValue *uint32 // Tracks nomination value for renomination requests
// TODO: having a callback or the original request would be useful for SPED

Check failure on line 40 in agent.go

View workflow job for this annotation

GitHub Actions / lint / Go

Line contains TODO/BUG/FIXME: "TODO: having a callback or the original ..." (godox)
// so that the response can do the "implicit" ack of the packet in the request
}

type packetWithCrc struct {
data []byte
crc uint32
}

type piggybackingState int

const (
PiggybackingStateTentative = iota
PiggybackingStateConfirmed
PiggybackingStatePending
PiggybackingStateComplete
PiggybackingStateOff
)

// DTLS-in-STUN controller.
type piggybackingController struct {
mu sync.Mutex
state piggybackingState
packets []packetWithCrc
packetsIndex int
acks []uint32
dtlsCallback func(packet []byte, rAddr net.Addr)
newFlight bool
}

// Agent represents the ICE agent.
Expand Down Expand Up @@ -175,6 +205,8 @@
lastRenominationTime time.Time

turnClientFactory func(*turn.ClientConfig) (turnClient, error)

piggyback piggybackingController
}

// NewAgent creates a new Agent.
Expand Down Expand Up @@ -218,6 +250,13 @@
agent.addressRewriteRules = rules
}

// Embedding DTLS in STUN. This is off by default and enabled
// by the use of `SetDtlsCallback`.
agent.piggyback.mu.Lock()
agent.piggyback.acks = []uint32{}
agent.piggyback.state = PiggybackingStateOff
agent.piggyback.mu.Unlock()

return newAgentWithConfig(agent, opts...)
}

Expand Down Expand Up @@ -676,6 +715,22 @@
a.deleteAllCandidates()
}

var packetsToFlush []packetWithCrc
a.piggyback.mu.Lock()
if newState == ConnectionStateConnected && a.piggyback.state == PiggybackingStateOff {
// Piggybacking was discovered as not supported.
// Flush any pending DTLS packets.
packetsToFlush = a.piggyback.packets
a.piggyback.packets = []packetWithCrc{}
}
a.piggyback.mu.Unlock()

if pair := a.getSelectedPair(); pair != nil && len(packetsToFlush) > 0 {
for _, p := range packetsToFlush {
_, _ = pair.Write(p.data)
}
}

a.log.Infof("Setting new connection state: %s", newState)
a.connectionState = newState
a.connectionStateNotifier.EnqueueConnectionState(newState)
Expand Down Expand Up @@ -1304,14 +1359,27 @@
return
}

if out, err := stun.Build(m, stun.BindingSuccess,
attributes := []stun.Setter{
m,
stun.BindingSuccess,
&stun.XORMappedAddress{
IP: ip.AsSlice(),
Port: port,
},
}
if packet, acks := a.GetPiggybackDataAndAcks(); acks != nil {
if acks != nil {
attributes = append(attributes, DtlsInStunAckAttribute(acks))
}
if packet != nil {
attributes = append(attributes, DtlsInStunAttribute(packet))
}
}
attributes = append(attributes,
stun.NewShortTermIntegrity(a.localPwd),
stun.Fingerprint,
); err != nil {
stun.Fingerprint)

if out, err := stun.Build(attributes...); err != nil {
a.log.Warnf("Failed to handle inbound ICE from: %s to: %s error: %s", local, remote, err)
} else {
if pair := a.findPair(local, remote); pair != nil {
Expand Down Expand Up @@ -1557,6 +1625,161 @@
return nil
}

// SetDtlsCallback sets the callback for DTLS packets. Setting this callback
// initializes state of the piggybacking state machine to "tentative", i.e.
// expecting embedded packets.
func (a *Agent) SetDtlsCallback(cb func(packet []byte, rAddr net.Addr)) {
a.piggyback.mu.Lock()
defer a.piggyback.mu.Unlock()
a.piggyback.dtlsCallback = cb
if cb != nil {
a.piggyback.state = PiggybackingStateTentative
}
}

// Piggyback stores a packet to be picked in a round-robin fashion.
// Returns `true` if packet is to be consumed.
func (a *Agent) Piggyback(packet []byte, end bool) bool {
a.piggyback.mu.Lock()
defer a.piggyback.mu.Unlock()
if a.piggyback.state == PiggybackingStateOff {
// TODO: ѕhould we store the packet for later so we

Check failure on line 1646 in agent.go

View workflow job for this annotation

GitHub Actions / lint / Go

Line contains TODO/BUG/FIXME: "TODO: ѕhould we store the packet for lat..." (godox)
// can send it when the connection gets established?
return a.connectionState != ConnectionStateConnected
}

if packet != nil {
// If we receive a packet after the end of a flight we need
// to clear the outgoing list.
if a.piggyback.newFlight {
a.piggyback.packets = []packetWithCrc{}
}
a.piggyback.newFlight = end
crc := crc32.ChecksumIEEE(packet)
a.piggyback.packets = append(a.piggyback.packets, packetWithCrc{packet, crc})
} else {
a.piggyback.state = PiggybackingStatePending
}
// If we are connected we could send DTLS plain.
return true // a.connectionState == ConnectionStateConnected
}

// GetPiggybackDataAndAcks returns a packet from the stored list in a round-robin fashion and a list of acks.
func (a *Agent) GetPiggybackDataAndAcks() ([]byte, []uint32) {
a.piggyback.mu.Lock()
defer a.piggyback.mu.Unlock()

if a.piggyback.state == PiggybackingStateOff || a.piggyback.state == PiggybackingStateComplete {
return nil, nil
}
if len(a.piggyback.packets) == 0 {
return nil, a.piggyback.acks
}

packet := a.piggyback.packets[a.piggyback.packetsIndex]
a.piggyback.packetsIndex = (a.piggyback.packetsIndex + 1) % len(a.piggyback.packets)

// Return a copy to prevent external modification of the internal buffer
result := make([]byte, len(packet.data))
copy(result, packet.data)

return result, a.piggyback.acks
}

func (a *Agent) ReportPiggybacking(packet []byte, acks []uint32, rAddr net.Addr) { //nolint:cyclop
a.piggyback.mu.Lock()

if a.piggyback.state == PiggybackingStateComplete || a.piggyback.state == PiggybackingStateOff {
a.piggyback.mu.Unlock()

return
}
if packet == nil && acks == nil && a.piggyback.state == PiggybackingStateTentative {
// Any pending packets will be flushed later when the ICE connection gets established.
a.log.Infof("Piggybacking discovered as not supported, falling back to normal state")
a.piggyback.dtlsCallback = nil
a.piggyback.state = PiggybackingStateOff
a.piggyback.mu.Unlock()

return
}
if packet == nil && acks == nil && a.piggyback.acks != nil {
a.log.Infof("Done with the SPED handshake", a.piggyback.state)
// TODO: check that we are in pending state?

Check failure on line 1708 in agent.go

View workflow job for this annotation

GitHub Actions / lint / Go

Line contains TODO/BUG/FIXME: "TODO: check that we are in pending state..." (godox)
a.piggyback.acks = nil
a.piggyback.state = PiggybackingStateComplete
a.piggyback.mu.Unlock()

return
}
if a.piggyback.state == PiggybackingStateTentative {
a.piggyback.state = PiggybackingStateConfirmed
}
// Handle incoming acks.
if size := len(acks); size > 0 {
beforeLen := len(a.piggyback.packets)
a.piggyback.packets = slices.DeleteFunc(a.piggyback.packets, func(p packetWithCrc) bool {
for _, ackCrc := range acks {
if p.crc == ackCrc {
return true // This packet is acknowledged, so remove it.
}
}

return false // This packet is not acknowledged, so keep it.
})
removed := beforeLen - len(a.piggyback.packets)

// Adjust the index if it's out of bounds after deletion
// TODO: for fairness one should only adjust if the index was affected?

Check failure on line 1733 in agent.go

View workflow job for this annotation

GitHub Actions / lint / Go

Line contains TODO/BUG/FIXME: "TODO: for fairness one should only adjus..." (godox)
if a.piggyback.packetsIndex >= removed {
a.piggyback.packetsIndex -= removed
} else {
a.piggyback.packetsIndex = 0
}
}
if len(packet) == 0 {
a.piggyback.acks = []uint32{}
}

var dtlsCallback func(packet []byte, rAddr net.Addr)
// Handle the incoming packet. Calculate and store the crc32 of the packet
// for acks, then notify the DTLS packet.
if a.piggyback.dtlsCallback != nil && len(packet) > 0 {
crc := crc32.ChecksumIEEE(packet)
if !slices.Contains(a.piggyback.acks, crc) {
a.piggyback.acks = append(a.piggyback.acks, crc)
if len(a.piggyback.acks) > 4 {
a.piggyback.acks = a.piggyback.acks[1:]
}
}
dtlsCallback = a.piggyback.dtlsCallback
}

a.piggyback.mu.Unlock()

if dtlsCallback != nil {
dtlsCallback(packet, rAddr)
}
}

func (a *Agent) ReportDtlsPacket(packet []byte) {
a.piggyback.mu.Lock()

if a.piggyback.state == PiggybackingStateComplete || a.piggyback.state == PiggybackingStateOff {
a.piggyback.mu.Unlock()

return
}
crc := crc32.ChecksumIEEE(packet)
if !slices.Contains(a.piggyback.acks, crc) {
a.piggyback.acks = append(a.piggyback.acks, crc)
if len(a.piggyback.acks) > 4 {
a.piggyback.acks = a.piggyback.acks[1:]
}
}
a.piggyback.mu.Unlock()
}

func (a *Agent) closeMulticastConn() {
if a.mDNSConn != nil {
if err := a.mDNSConn.Close(); err != nil {
Expand Down Expand Up @@ -1770,6 +1993,14 @@
a.log.Tracef("Sending renomination request from %s to %s with nomination value %d",
pair.Local, pair.Remote, nominationValue)
}
if packet, acks := a.GetPiggybackDataAndAcks(); acks != nil {
if acks != nil {
attributes = append(attributes, DtlsInStunAckAttribute(acks))
}
if packet != nil {
attributes = append(attributes, DtlsInStunAttribute(packet))
}
}

msg, err := stun.Build(append([]stun.Setter{stun.BindingRequest}, attributes...)...)
if err != nil {
Expand Down
57 changes: 57 additions & 0 deletions agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2593,3 +2593,60 @@ func TestAgentUpdateOptions(t *testing.T) {
}
})
}

func TestSnap(t *testing.T) {
defer test.CheckRoutines(t)()

t.Run("Basic embedding", func(t *testing.T) {
aNotifier, aConnected := onConnected()
aAgent, err := NewAgent(&AgentConfig{
NetworkTypes: supportedNetworkTypes(),
})
require.NoError(t, err)
require.NoError(t, aAgent.OnConnectionStateChange(aNotifier))

var toA string
fromA := "Hello from A"
aAgent.SetDtlsCallback(func(packet []byte, rAddr net.Addr) {
toA = string(packet)
})
require.True(t, aAgent.Piggyback([]byte(fromA), true))

bNotifier, bConnected := onConnected()
bAgent, err := NewAgent(&AgentConfig{
NetworkTypes: supportedNetworkTypes(),
})
require.NoError(t, err)
require.NoError(t, bAgent.OnConnectionStateChange(bNotifier))

var toB string
fromB := "Hello from B"
bAgent.SetDtlsCallback(func(packet []byte, rAddr net.Addr) {
toB = string(packet)
})
require.True(t, bAgent.Piggyback([]byte(fromB), true))

gatherAndExchangeCandidates(t, aAgent, bAgent)
go func() {
bUfrag, bPwd, err := bAgent.GetLocalUserCredentials()
require.NoError(t, err)
_, err = aAgent.Accept(context.TODO(), bUfrag, bPwd)
require.NoError(t, err)
}()

go func() {
aUfrag, aPwd, err := aAgent.GetLocalUserCredentials()
require.NoError(t, err)
_, err = bAgent.Dial(context.TODO(), aUfrag, aPwd)
require.NoError(t, err)
}()

<-aConnected
<-bConnected
require.NoError(t, aAgent.Close())
require.NoError(t, bAgent.Close())

require.Equal(t, toA, fromB)
require.Equal(t, toB, fromA)
})
}
Loading
Loading