diff --git a/agent.go b/agent.go index df83c330..1f651579 100644 --- a/agent.go +++ b/agent.go @@ -8,9 +8,11 @@ package ice import ( "context" "fmt" + "hash/crc32" "math" "net" "net/netip" + "slices" "strings" "sync" "sync/atomic" @@ -35,6 +37,34 @@ type bindingRequest struct { destination net.Addr isUseCandidate bool nominationValue *uint32 // Tracks nomination value for renomination requests + // TODO: having a callback or the original request would be useful for SPED + // so that the response can do the "implicit" ack of the packet in the request +} + +type packetWithCrc struct { + data []byte + crc uint32 +} + +type piggybackingState int + +const ( + PiggybackingStateTentative = iota + PiggybackingStateConfirmed + PiggybackingStatePending + PiggybackingStateComplete + PiggybackingStateOff +) + +// DTLS-in-STUN controller. +type piggybackingController struct { + mu sync.Mutex + state piggybackingState + packets []packetWithCrc + packetsIndex int + acks []uint32 + dtlsCallback func(packet []byte, rAddr net.Addr) + newFlight bool } // Agent represents the ICE agent. @@ -175,6 +205,8 @@ type Agent struct { lastRenominationTime time.Time turnClientFactory func(*turn.ClientConfig) (turnClient, error) + + piggyback piggybackingController } // NewAgent creates a new Agent. @@ -218,6 +250,13 @@ func newAgentFromConfig(config *AgentConfig, opts ...AgentOption) (*Agent, error agent.addressRewriteRules = rules } + // Embedding DTLS in STUN. This is off by default and enabled + // by the use of `SetDtlsCallback`. + agent.piggyback.mu.Lock() + agent.piggyback.acks = []uint32{} + agent.piggyback.state = PiggybackingStateOff + agent.piggyback.mu.Unlock() + return newAgentWithConfig(agent, opts...) } @@ -676,6 +715,22 @@ func (a *Agent) updateConnectionState(newState ConnectionState) { a.deleteAllCandidates() } + var packetsToFlush []packetWithCrc + a.piggyback.mu.Lock() + if newState == ConnectionStateConnected && a.piggyback.state == PiggybackingStateOff { + // Piggybacking was discovered as not supported. + // Flush any pending DTLS packets. + packetsToFlush = a.piggyback.packets + a.piggyback.packets = []packetWithCrc{} + } + a.piggyback.mu.Unlock() + + if pair := a.getSelectedPair(); pair != nil && len(packetsToFlush) > 0 { + for _, p := range packetsToFlush { + _, _ = pair.Write(p.data) + } + } + a.log.Infof("Setting new connection state: %s", newState) a.connectionState = newState a.connectionStateNotifier.EnqueueConnectionState(newState) @@ -1304,14 +1359,27 @@ func (a *Agent) sendBindingSuccess(m *stun.Message, local, remote Candidate) { return } - if out, err := stun.Build(m, stun.BindingSuccess, + attributes := []stun.Setter{ + m, + stun.BindingSuccess, &stun.XORMappedAddress{ IP: ip.AsSlice(), Port: port, }, + } + if packet, acks := a.GetPiggybackDataAndAcks(); acks != nil { + if acks != nil { + attributes = append(attributes, DtlsInStunAckAttribute(acks)) + } + if packet != nil { + attributes = append(attributes, DtlsInStunAttribute(packet)) + } + } + attributes = append(attributes, stun.NewShortTermIntegrity(a.localPwd), - stun.Fingerprint, - ); err != nil { + stun.Fingerprint) + + if out, err := stun.Build(attributes...); err != nil { a.log.Warnf("Failed to handle inbound ICE from: %s to: %s error: %s", local, remote, err) } else { if pair := a.findPair(local, remote); pair != nil { @@ -1557,6 +1625,161 @@ func (a *Agent) getSelectedPair() *CandidatePair { return nil } +// SetDtlsCallback sets the callback for DTLS packets. Setting this callback +// initializes state of the piggybacking state machine to "tentative", i.e. +// expecting embedded packets. +func (a *Agent) SetDtlsCallback(cb func(packet []byte, rAddr net.Addr)) { + a.piggyback.mu.Lock() + defer a.piggyback.mu.Unlock() + a.piggyback.dtlsCallback = cb + if cb != nil { + a.piggyback.state = PiggybackingStateTentative + } +} + +// Piggyback stores a packet to be picked in a round-robin fashion. +// Returns `true` if packet is to be consumed. +func (a *Agent) Piggyback(packet []byte, end bool) bool { + a.piggyback.mu.Lock() + defer a.piggyback.mu.Unlock() + if a.piggyback.state == PiggybackingStateOff { + // TODO: ѕhould we store the packet for later so we + // can send it when the connection gets established? + return a.connectionState != ConnectionStateConnected + } + + if packet != nil { + // If we receive a packet after the end of a flight we need + // to clear the outgoing list. + if a.piggyback.newFlight { + a.piggyback.packets = []packetWithCrc{} + } + a.piggyback.newFlight = end + crc := crc32.ChecksumIEEE(packet) + a.piggyback.packets = append(a.piggyback.packets, packetWithCrc{packet, crc}) + } else { + a.piggyback.state = PiggybackingStatePending + } + // If we are connected we could send DTLS plain. + return true // a.connectionState == ConnectionStateConnected +} + +// GetPiggybackDataAndAcks returns a packet from the stored list in a round-robin fashion and a list of acks. +func (a *Agent) GetPiggybackDataAndAcks() ([]byte, []uint32) { + a.piggyback.mu.Lock() + defer a.piggyback.mu.Unlock() + + if a.piggyback.state == PiggybackingStateOff || a.piggyback.state == PiggybackingStateComplete { + return nil, nil + } + if len(a.piggyback.packets) == 0 { + return nil, a.piggyback.acks + } + + packet := a.piggyback.packets[a.piggyback.packetsIndex] + a.piggyback.packetsIndex = (a.piggyback.packetsIndex + 1) % len(a.piggyback.packets) + + // Return a copy to prevent external modification of the internal buffer + result := make([]byte, len(packet.data)) + copy(result, packet.data) + + return result, a.piggyback.acks +} + +func (a *Agent) ReportPiggybacking(packet []byte, acks []uint32, rAddr net.Addr) { //nolint:cyclop + a.piggyback.mu.Lock() + + if a.piggyback.state == PiggybackingStateComplete || a.piggyback.state == PiggybackingStateOff { + a.piggyback.mu.Unlock() + + return + } + if packet == nil && acks == nil && a.piggyback.state == PiggybackingStateTentative { + // Any pending packets will be flushed later when the ICE connection gets established. + a.log.Infof("Piggybacking discovered as not supported, falling back to normal state") + a.piggyback.dtlsCallback = nil + a.piggyback.state = PiggybackingStateOff + a.piggyback.mu.Unlock() + + return + } + if packet == nil && acks == nil && a.piggyback.acks != nil { + a.log.Infof("Done with the SPED handshake", a.piggyback.state) + // TODO: check that we are in pending state? + a.piggyback.acks = nil + a.piggyback.state = PiggybackingStateComplete + a.piggyback.mu.Unlock() + + return + } + if a.piggyback.state == PiggybackingStateTentative { + a.piggyback.state = PiggybackingStateConfirmed + } + // Handle incoming acks. + if size := len(acks); size > 0 { + beforeLen := len(a.piggyback.packets) + a.piggyback.packets = slices.DeleteFunc(a.piggyback.packets, func(p packetWithCrc) bool { + for _, ackCrc := range acks { + if p.crc == ackCrc { + return true // This packet is acknowledged, so remove it. + } + } + + return false // This packet is not acknowledged, so keep it. + }) + removed := beforeLen - len(a.piggyback.packets) + + // Adjust the index if it's out of bounds after deletion + // TODO: for fairness one should only adjust if the index was affected? + if a.piggyback.packetsIndex >= removed { + a.piggyback.packetsIndex -= removed + } else { + a.piggyback.packetsIndex = 0 + } + } + if len(packet) == 0 { + a.piggyback.acks = []uint32{} + } + + var dtlsCallback func(packet []byte, rAddr net.Addr) + // Handle the incoming packet. Calculate and store the crc32 of the packet + // for acks, then notify the DTLS packet. + if a.piggyback.dtlsCallback != nil && len(packet) > 0 { + crc := crc32.ChecksumIEEE(packet) + if !slices.Contains(a.piggyback.acks, crc) { + a.piggyback.acks = append(a.piggyback.acks, crc) + if len(a.piggyback.acks) > 4 { + a.piggyback.acks = a.piggyback.acks[1:] + } + } + dtlsCallback = a.piggyback.dtlsCallback + } + + a.piggyback.mu.Unlock() + + if dtlsCallback != nil { + dtlsCallback(packet, rAddr) + } +} + +func (a *Agent) ReportDtlsPacket(packet []byte) { + a.piggyback.mu.Lock() + + if a.piggyback.state == PiggybackingStateComplete || a.piggyback.state == PiggybackingStateOff { + a.piggyback.mu.Unlock() + + return + } + crc := crc32.ChecksumIEEE(packet) + if !slices.Contains(a.piggyback.acks, crc) { + a.piggyback.acks = append(a.piggyback.acks, crc) + if len(a.piggyback.acks) > 4 { + a.piggyback.acks = a.piggyback.acks[1:] + } + } + a.piggyback.mu.Unlock() +} + func (a *Agent) closeMulticastConn() { if a.mDNSConn != nil { if err := a.mDNSConn.Close(); err != nil { @@ -1770,6 +1993,14 @@ func (a *Agent) sendNominationRequest(pair *CandidatePair, nominationValue uint3 a.log.Tracef("Sending renomination request from %s to %s with nomination value %d", pair.Local, pair.Remote, nominationValue) } + if packet, acks := a.GetPiggybackDataAndAcks(); acks != nil { + if acks != nil { + attributes = append(attributes, DtlsInStunAckAttribute(acks)) + } + if packet != nil { + attributes = append(attributes, DtlsInStunAttribute(packet)) + } + } msg, err := stun.Build(append([]stun.Setter{stun.BindingRequest}, attributes...)...) if err != nil { diff --git a/agent_test.go b/agent_test.go index 0dd922af..4c2491b0 100644 --- a/agent_test.go +++ b/agent_test.go @@ -2593,3 +2593,60 @@ func TestAgentUpdateOptions(t *testing.T) { } }) } + +func TestSnap(t *testing.T) { + defer test.CheckRoutines(t)() + + t.Run("Basic embedding", func(t *testing.T) { + aNotifier, aConnected := onConnected() + aAgent, err := NewAgent(&AgentConfig{ + NetworkTypes: supportedNetworkTypes(), + }) + require.NoError(t, err) + require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) + + var toA string + fromA := "Hello from A" + aAgent.SetDtlsCallback(func(packet []byte, rAddr net.Addr) { + toA = string(packet) + }) + require.True(t, aAgent.Piggyback([]byte(fromA), true)) + + bNotifier, bConnected := onConnected() + bAgent, err := NewAgent(&AgentConfig{ + NetworkTypes: supportedNetworkTypes(), + }) + require.NoError(t, err) + require.NoError(t, bAgent.OnConnectionStateChange(bNotifier)) + + var toB string + fromB := "Hello from B" + bAgent.SetDtlsCallback(func(packet []byte, rAddr net.Addr) { + toB = string(packet) + }) + require.True(t, bAgent.Piggyback([]byte(fromB), true)) + + gatherAndExchangeCandidates(t, aAgent, bAgent) + go func() { + bUfrag, bPwd, err := bAgent.GetLocalUserCredentials() + require.NoError(t, err) + _, err = aAgent.Accept(context.TODO(), bUfrag, bPwd) + require.NoError(t, err) + }() + + go func() { + aUfrag, aPwd, err := aAgent.GetLocalUserCredentials() + require.NoError(t, err) + _, err = bAgent.Dial(context.TODO(), aUfrag, aPwd) + require.NoError(t, err) + }() + + <-aConnected + <-bConnected + require.NoError(t, aAgent.Close()) + require.NoError(t, bAgent.Close()) + + require.Equal(t, toA, fromB) + require.Equal(t, toB, fromA) + }) +} diff --git a/selection.go b/selection.go index 78e86d7b..f51090f1 100644 --- a/selection.go +++ b/selection.go @@ -26,6 +26,14 @@ type controllingSelector struct { log logging.LeveledLogger } +func reportPiggybacking(agent *Agent, message *stun.Message, remote Candidate) { + var dtls DtlsInStunAttribute + _ = dtls.GetFrom(message) + var ack DtlsInStunAckAttribute + _ = ack.GetFrom(message) + agent.ReportPiggybacking(dtls, ack, remote.addr()) +} + func (s *controllingSelector) Start() { s.startTime = time.Now() s.nominatedPair = nil @@ -84,14 +92,26 @@ func (s *controllingSelector) nominatePair(pair *CandidatePair) { // order to nominate a candidate pair (Section 8.1.1). The controlled // agent MUST NOT include the USE-CANDIDATE attribute in a Binding // request. - msg, err := stun.Build(stun.BindingRequest, stun.TransactionID, - stun.NewUsername(s.agent.remoteUfrag+":"+s.agent.localUfrag), + attributes := []stun.Setter{ + stun.BindingRequest, + stun.TransactionID, + stun.NewUsername(s.agent.remoteUfrag + ":" + s.agent.localUfrag), UseCandidate(), AttrControlling(s.agent.tieBreaker), PriorityAttr(pair.Local.Priority()), + } + if packet, acks := s.agent.GetPiggybackDataAndAcks(); acks != nil { + if acks != nil { + attributes = append(attributes, DtlsInStunAckAttribute(acks)) + } + if packet != nil { + attributes = append(attributes, DtlsInStunAttribute(packet)) + } + } + attributes = append(attributes, stun.NewShortTermIntegrity(s.agent.remotePwd), - stun.Fingerprint, - ) + stun.Fingerprint) + msg, err := stun.Build(attributes...) if err != nil { s.log.Error(err.Error()) @@ -103,6 +123,8 @@ func (s *controllingSelector) nominatePair(pair *CandidatePair) { } func (s *controllingSelector) HandleBindingRequest(message *stun.Message, local, remote Candidate) { //nolint:cyclop + reportPiggybacking(s.agent, message, remote) + s.agent.sendBindingSuccess(message, local, remote) pair := s.agent.findPair(local, remote) @@ -137,10 +159,11 @@ func (s *controllingSelector) HandleBindingRequest(message *stun.Message, local, } } -func (s *controllingSelector) HandleSuccessResponse(m *stun.Message, local, remote Candidate, remoteAddr net.Addr) { - ok, pendingRequest, rtt := s.agent.handleInboundBindingSuccess(m.TransactionID) +func (s *controllingSelector) HandleSuccessResponse(message *stun.Message, local, remote Candidate, + remoteAddr net.Addr) { + ok, pendingRequest, rtt := s.agent.handleInboundBindingSuccess(message.TransactionID) if !ok { - s.log.Warnf("Discard success response from (%s), unknown TransactionID 0x%x", remote, m.TransactionID) + s.log.Warnf("Discard success response from (%s), unknown TransactionID 0x%x", remote, message.TransactionID) return } @@ -159,6 +182,9 @@ func (s *controllingSelector) HandleSuccessResponse(m *stun.Message, local, remo return } + // TODO: get the implicit ack from the pendingRequest. + reportPiggybacking(s.agent, message, remote) + s.log.Tracef("Inbound STUN (SuccessResponse) from %s to %s", remote, local) pair := s.agent.findPair(local, remote) @@ -191,13 +217,26 @@ func (s *controllingSelector) HandleSuccessResponse(m *stun.Message, local, remo } func (s *controllingSelector) PingCandidate(local, remote Candidate) { - msg, err := stun.Build(stun.BindingRequest, stun.TransactionID, - stun.NewUsername(s.agent.remoteUfrag+":"+s.agent.localUfrag), + attributes := []stun.Setter{ + stun.BindingRequest, + stun.TransactionID, + stun.NewUsername(s.agent.remoteUfrag + ":" + s.agent.localUfrag), AttrControlling(s.agent.tieBreaker), PriorityAttr(local.Priority()), + } + if packet, acks := s.agent.GetPiggybackDataAndAcks(); acks != nil { + if acks != nil { + attributes = append(attributes, DtlsInStunAckAttribute(acks)) + } + if packet != nil { + attributes = append(attributes, DtlsInStunAttribute(packet)) + } + } + attributes = append(attributes, stun.NewShortTermIntegrity(s.agent.remotePwd), - stun.Fingerprint, - ) + stun.Fingerprint) + + msg, err := stun.Build(attributes...) if err != nil { s.log.Error(err.Error()) @@ -338,13 +377,26 @@ func (s *controlledSelector) ContactCandidates() { } func (s *controlledSelector) PingCandidate(local, remote Candidate) { - msg, err := stun.Build(stun.BindingRequest, stun.TransactionID, - stun.NewUsername(s.agent.remoteUfrag+":"+s.agent.localUfrag), + attributes := []stun.Setter{ + stun.BindingRequest, + stun.TransactionID, + stun.NewUsername(s.agent.remoteUfrag + ":" + s.agent.localUfrag), AttrControlled(s.agent.tieBreaker), PriorityAttr(local.Priority()), + } + if packet, acks := s.agent.GetPiggybackDataAndAcks(); acks != nil { + if acks != nil { + attributes = append(attributes, DtlsInStunAckAttribute(acks)) + } + if packet != nil { + attributes = append(attributes, DtlsInStunAttribute(packet)) + } + } + attributes = append(attributes, stun.NewShortTermIntegrity(s.agent.remotePwd), - stun.Fingerprint, - ) + stun.Fingerprint) + + msg, err := stun.Build(attributes...) if err != nil { s.log.Error(err.Error()) @@ -354,7 +406,8 @@ func (s *controlledSelector) PingCandidate(local, remote Candidate) { s.agent.sendBindingRequest(msg, local, remote) } -func (s *controlledSelector) HandleSuccessResponse(m *stun.Message, local, remote Candidate, remoteAddr net.Addr) { +func (s *controlledSelector) HandleSuccessResponse(message *stun.Message, local, remote Candidate, + remoteAddr net.Addr) { //nolint:godox // TODO according to the standard we should specifically answer a failed nomination: // https://tools.ietf.org/html/rfc8445#section-7.3.1.5 @@ -363,9 +416,9 @@ func (s *controlledSelector) HandleSuccessResponse(m *stun.Message, local, remot // request with an appropriate error code response (e.g., 400) // [RFC5389]. - ok, pendingRequest, rtt := s.agent.handleInboundBindingSuccess(m.TransactionID) + ok, pendingRequest, rtt := s.agent.handleInboundBindingSuccess(message.TransactionID) if !ok { - s.log.Warnf("Discard message from (%s), unknown TransactionID 0x%x", remote, m.TransactionID) + s.log.Warnf("Discard message from (%s), unknown TransactionID 0x%x", remote, message.TransactionID) return } @@ -407,9 +460,14 @@ func (s *controlledSelector) HandleSuccessResponse(m *stun.Message, local, remot } pair.UpdateRoundTripTime(rtt) + + // TODO: get the implicit ack from the pendingRequest. + reportPiggybacking(s.agent, message, remote) } func (s *controlledSelector) HandleBindingRequest(message *stun.Message, local, remote Candidate) { //nolint:cyclop + reportPiggybacking(s.agent, message, remote) + pair := s.agent.findPair(local, remote) if pair == nil { pair = s.agent.addPair(local, remote) diff --git a/transport.go b/transport.go index 40ac049b..19b30eff 100644 --- a/transport.go +++ b/transport.go @@ -12,16 +12,67 @@ import ( "github.com/pion/stun/v3" ) -// Dial connects to the remote agent, acting as the controlling ice agent. +// AwaitConnect waits until a pair is selected. +func (a *Agent) AwaitConnect(ctx context.Context) error { + select { + case <-a.loop.Done(): + return a.loop.Err() + case <-ctx.Done(): + return ErrCanceledByCaller + case <-a.onConnected: + } + + return nil +} + +// StartDial sets the agent up for connecting to the remote agent, acting as the +// controlling agent and returns immediately. +func (a *Agent) StartDial(remoteUfrag, remotePwd string) (*Conn, error) { + conn, err := a.startConnect(true, remoteUfrag, remotePwd) + if err != nil { + return nil, err + } + + return conn, nil +} + // Dial blocks until at least one ice candidate pair has successfully connected. func (a *Agent) Dial(ctx context.Context, remoteUfrag, remotePwd string) (*Conn, error) { - return a.connect(ctx, true, remoteUfrag, remotePwd) + conn, err := a.StartDial(remoteUfrag, remotePwd) //nolint:contextcheck + if err != nil { + return nil, err + } + err = a.AwaitConnect(ctx) + if err != nil { + return nil, err + } + + return conn, nil +} + +// StartAccept sets the agent up for connecting to the remote agent, acting as the +// controlled agent and returns immediately. +func (a *Agent) StartAccept(remoteUfrag, remotePwd string) (*Conn, error) { + conn, err := a.startConnect(false, remoteUfrag, remotePwd) + if err != nil { + return nil, err + } + + return conn, nil } -// Accept connects to the remote agent, acting as the controlled ice agent. // Accept blocks until at least one ice candidate pair has successfully connected. func (a *Agent) Accept(ctx context.Context, remoteUfrag, remotePwd string) (*Conn, error) { - return a.connect(ctx, false, remoteUfrag, remotePwd) + conn, err := a.StartAccept(remoteUfrag, remotePwd) //nolint:contextcheck + if err != nil { + return nil, err + } + err = a.AwaitConnect(ctx) + if err != nil { + return nil, err + } + + return conn, nil } // Conn represents the ICE connection. @@ -42,7 +93,7 @@ func (c *Conn) BytesReceived() uint64 { return c.bytesReceived.Load() } -func (a *Agent) connect(ctx context.Context, isControlling bool, remoteUfrag, remotePwd string) (*Conn, error) { +func (a *Agent) startConnect(isControlling bool, remoteUfrag, remotePwd string) (*Conn, error) { err := a.loop.Err() if err != nil { return nil, err @@ -52,15 +103,6 @@ func (a *Agent) connect(ctx context.Context, isControlling bool, remoteUfrag, re return nil, err } - // Block until pair selected - select { - case <-a.loop.Done(): - return nil, a.loop.Err() - case <-ctx.Done(): - return nil, ErrCanceledByCaller - case <-a.onConnected: - } - return &Conn{ agent: a, }, nil diff --git a/transport_test.go b/transport_test.go index 5886f90d..f379c81a 100644 --- a/transport_test.go +++ b/transport_test.go @@ -401,18 +401,21 @@ func TestAgent_connect_ErrEarly(t *testing.T) { cfg := &AgentConfig{ NetworkTypes: supportedNetworkTypes(), } - a, err := NewAgent(cfg) + agent, err := NewAgent(cfg) require.NoError(t, err) - require.NoError(t, a.Close()) + require.NoError(t, agent.Close()) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() // isControlling = true - conn, cerr := a.connect(ctx, true, "ufragX", "pwdX") + conn, cerr := agent.startConnect(true, "ufragX", "pwdX") require.Nil(t, conn) require.Error(t, cerr, "expected error from a.loop.Err() short-circuit") + + err2 := agent.AwaitConnect(ctx) + require.Error(t, err2, "the agent is closed") } func TestConn_Write_RejectsSTUN(t *testing.T) {