diff --git a/agent.go b/agent.go index 665c26aa..44257dea 100644 --- a/agent.go +++ b/agent.go @@ -174,6 +174,9 @@ type Agent struct { renominationInterval time.Duration lastRenominationTime time.Time + // Port mapping support for container + mapPort func(candidate Candidate) int + turnClientFactory func(*turn.ClientConfig) (turnClient, error) } diff --git a/agent_options.go b/agent_options.go index 420d2d1d..19002c59 100644 --- a/agent_options.go +++ b/agent_options.go @@ -960,3 +960,13 @@ func WithLoggerFactory(loggerFactory logging.LoggerFactory) AgentOption { return nil } } + +func WithMapPortHandler(handler func(cand Candidate) int) AgentOption { + return func(a *Agent) error { + a.mapPort = func(candidate Candidate) int { + return handler(candidate) + } + + return nil + } +} diff --git a/candidate_base.go b/candidate_base.go index a0a493bc..f7626586 100644 --- a/candidate_base.go +++ b/candidate_base.go @@ -440,7 +440,7 @@ func (c *candidateBase) Priority() uint32 { } // Equal is used to compare two candidateBases. -func (c *candidateBase) Equal(other Candidate) bool { +func (c *candidateBase) Equal(other Candidate) bool { //nolint:cyclop if c.addr() != other.addr() { if c.addr() == nil || other.addr() == nil { return false diff --git a/gather.go b/gather.go index 1fc25e22..2d38b0b5 100644 --- a/gather.go +++ b/gather.go @@ -414,6 +414,12 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ continue } + if a.mapPort != nil { + mappedPort := a.mapPort(candidateHost) + if mappedPort != 0 { + candidateHost.port = mappedPort + } + } if err := a.addCandidate(ctx, candidateHost, connAndPort.conn); err != nil { if closeErr := candidateHost.close(); closeErr != nil { @@ -512,15 +518,22 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolin return err } - c, err := NewCandidateHost(&hostConfig) + cand, err := NewCandidateHost(&hostConfig) + if a.mapPort != nil { + mappedPort := a.mapPort(cand) + if mappedPort != 0 { + cand.port = mappedPort + } + } + if err != nil { closeConnAndLog(conn, a.log, "failed to create host mux candidate: %s %d: %v", candidateIP, udpAddr.Port, err) continue } - if err := a.addCandidate(ctx, c, conn); err != nil { - if closeErr := c.close(); closeErr != nil { + if err := a.addCandidate(ctx, cand, conn); err != nil { + if closeErr := cand.close(); closeErr != nil { a.log.Warnf("Failed to close candidate: %v", closeErr) } @@ -623,7 +636,7 @@ func (a *Agent) gatherCandidatesSrflxMapped(ctx context.Context, networkTypes [] RelAddr: currentAddr.IP.String(), RelPort: currentAddr.Port, } - c, err := NewCandidateServerReflexive(&srflxConfig) + cand, err := NewCandidateServerReflexive(&srflxConfig) if err != nil { closeConnAndLog(currentConn, a.log, "failed to create server reflexive candidate: %s %s %d: %v", network, @@ -633,9 +646,15 @@ func (a *Agent) gatherCandidatesSrflxMapped(ctx context.Context, networkTypes [] continue } + if a.mapPort != nil { + mappedPort := a.mapPort(cand) + if mappedPort != 0 { + cand.port = mappedPort + } + } - if err := a.addCandidate(ctx, c, currentConn); err != nil { - if closeErr := c.close(); closeErr != nil { + if err := a.addCandidate(ctx, cand, currentConn); err != nil { + if closeErr := cand.close(); closeErr != nil { a.log.Warnf("Failed to close candidate: %v", closeErr) } a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v", err) @@ -712,15 +731,21 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*stun.UR RelAddr: localAddr.IP.String(), RelPort: localAddr.Port, } - c, err := NewCandidateServerReflexive(&srflxConfig) + cand, err := NewCandidateServerReflexive(&srflxConfig) if err != nil { closeConnAndLog(conn, a.log, "failed to create server reflexive candidate: %s %s %d: %v", network, ip, port, err) return } + if a.mapPort != nil { + mappedPort := a.mapPort(cand) + if mappedPort != 0 { + cand.port = mappedPort + } + } - if err := a.addCandidate(ctx, c, conn); err != nil { - if closeErr := c.close(); closeErr != nil { + if err := a.addCandidate(ctx, cand, conn); err != nil { + if closeErr := cand.close(); closeErr != nil { a.log.Warnf("Failed to close candidate: %v", closeErr) } a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v", err) @@ -805,15 +830,21 @@ func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, net RelAddr: lAddr.IP.String(), RelPort: lAddr.Port, } - c, err := NewCandidateServerReflexive(&srflxConfig) + cand, err := NewCandidateServerReflexive(&srflxConfig) if err != nil { closeConnAndLog(conn, a.log, "failed to create server reflexive candidate: %s %s %d: %v", network, ip, port, err) return } + if a.mapPort != nil { + mappedPort := a.mapPort(cand) + if mappedPort != 0 { + cand.port = mappedPort + } + } - if err := a.addCandidate(ctx, c, conn); err != nil { - if closeErr := c.close(); closeErr != nil { + if err := a.addCandidate(ctx, cand, conn); err != nil { + if closeErr := cand.close(); closeErr != nil { a.log.Warnf("Failed to close candidate: %v", closeErr) } a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v", err) @@ -1158,6 +1189,12 @@ func (a *Agent) createRelayCandidate(ctx context.Context, ep relayEndpoint, ip n return err } + if a.mapPort != nil { + mappedPort := a.mapPort(candidate) + if mappedPort != 0 { + candidate.port = mappedPort + } + } if err := a.addCandidate(ctx, candidate, ep.conn); err != nil { if closeErr := candidate.close(); closeErr != nil { diff --git a/gather_test.go b/gather_test.go index 342cdc29..4dbf1a40 100644 --- a/gather_test.go +++ b/gather_test.go @@ -3628,3 +3628,279 @@ func (m *mockUniversalUDPMux) GetRelayedAddr(net.Addr, time.Duration) (*net.Addr func (m *mockUniversalUDPMux) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) { return m.mockUDPMux.GetConn(ufrag+url, addr) } + +func TestMapPort(t *testing.T) { + listener, err := net.ListenPacket("udp4", "127.0.0.1:0") // nolint: noctx + skipOnPermission(t, err, "listening for TURN server") + require.NoError(t, err) + defer func() { + _ = listener.Close() + }() + relayPort := uint16(40000) + server, err := turn.NewServer(turn.ServerConfig{ + Realm: "pion.ly", + AuthHandler: optimisticAuthHandler, + PacketConnConfigs: []turn.PacketConnConfig{ + { + PacketConn: listener, + RelayAddressGenerator: &turn.RelayAddressGeneratorPortRange{ + RelayAddress: net.ParseIP("127.0.0.1"), + MinPort: relayPort, + MaxPort: relayPort, + MaxRetries: 1, + Address: "127.0.0.1", + }, + }, + }, + }) + require.NoError(t, err) + defer func() { + require.NoError(t, server.Close()) + }() + + serverPort := listener.LocalAddr().(*net.UDPAddr).Port //nolint:forcetypeassert + turnURL := &stun.URI{ + Scheme: stun.SchemeTypeTURN, + Host: "127.0.0.1", + Port: serverPort, + Username: "username", + Password: "password", + Proto: stun.ProtoTypeUDP, + } + agent, err := NewAgentWithOptions( + WithCandidateTypes([]CandidateType{CandidateTypeHost, CandidateTypeRelay, CandidateTypeServerReflexive}), + WithNetworkTypes([]NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}), + WithUrls([]*stun.URI{turnURL}), + WithMapPortHandler(func(cand Candidate) int { + if cand.Type() != CandidateTypeHost { + return cand.Port() + } + + return 50000 + }), + ) + require.NoError(t, err) + defer func() { + require.NoError(t, agent.Close()) + }() + require.NoError(t, err) + + gathered := make(chan (struct{})) + + var cands []Candidate + var mu sync.Mutex + require.NoError(t, agent.OnCandidate(func(c Candidate) { + if c == nil { + close(gathered) + + return + } + mu.Lock() + cands = append(cands, c) + mu.Unlock() + })) + + require.NoError(t, agent.GatherCandidates()) + + <-gathered + var ( + sawHost bool + sawRelay bool + sawSrflx bool + ) + for _, cand := range cands { + switch cand.Type() { + case CandidateTypeHost: + sawHost = true + require.Equal(t, 50000, cand.Port()) + case CandidateTypeRelay: + sawRelay = true + require.Equal(t, int(relayPort), cand.Port()) + case CandidateTypeServerReflexive: + sawSrflx = true + default: + require.Failf(t, "unexpected cand type", "got: %v", cand.Type()) + } + } + + require.True(t, sawHost) + require.True(t, sawRelay) + require.True(t, sawSrflx) +} + +func TestMapPortSrflx(t *testing.T) { + stunURI := &stun.URI{ + Scheme: stun.SchemeTypeSTUN, + Host: "127.0.0.1", + Port: 3478, + } + relatedAddr := &net.UDPAddr{IP: net.IP{10, 0, 0, 1}, Port: 49000} + srflxAddr := &stun.XORMappedAddress{ + IP: net.IP{203, 0, 113, 5}, + Port: 50000, + } + + udpMuxSrflx := newMockUniversalUDPMux([]net.Addr{relatedAddr}, srflxAddr) + + agent, err := NewAgentWithOptions( + WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), + WithCandidateTypes([]CandidateType{CandidateTypeServerReflexive}), + WithUDPMuxSrflx(udpMuxSrflx), + WithMapPortHandler(func(cand Candidate) int { + if cand.Type() != CandidateTypeServerReflexive { + return cand.Port() + } + + return 50001 + }), + ) + require.NoError(t, err) + defer func() { + require.NoError(t, agent.Close()) + }() + + require.NoError(t, agent.OnCandidate(func(Candidate) {})) + + agent.gatherCandidatesSrflxUDPMux(context.Background(), []*stun.URI{stunURI}, []NetworkType{NetworkTypeUDP4}) + + agent.gatherCandidatesSrflxMapped(context.Background(), []NetworkType{NetworkTypeUDP4}) + + candidates, err := agent.GetLocalCandidates() + require.NoError(t, err) + require.Len(t, candidates, 2) + + for _, cand := range candidates { + srflx, ok := cand.(*CandidateServerReflexive) + require.True(t, ok) + require.Equal(t, 50001, srflx.Port()) + } +} + +func TestRewriteAndMapPort(t *testing.T) { //nolint:cyclop + t.Run("replace host via UDPMux", func(t *testing.T) { + mux := newMockUDPMux([]net.Addr{&net.UDPAddr{IP: net.IP{10, 0, 0, 1}, Port: 1234}}) + + agent, err := NewAgentWithOptions( + WithNet(newStubNet(t)), + WithCandidateTypes([]CandidateType{CandidateTypeHost}), + WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), + WithUDPMux(mux), + WithMulticastDNSMode(MulticastDNSModeDisabled), + WithAddressRewriteRules(AddressRewriteRule{ + External: []string{"203.0.113.1"}, + Local: "10.0.0.1", + AsCandidateType: CandidateTypeHost, + Mode: AddressRewriteReplace, + }), + WithMapPortHandler(func(c Candidate) int { + if c.Type() != CandidateTypeHost { + return c.Port() + } + if c.Port() == 1234 { + return 4321 + } else { + return 12345 + } + }), + ) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, agent.Close()) + }) + + var ( + mu sync.Mutex + addresses []Candidate + done = make(chan struct{}) + ) + require.NoError(t, agent.OnCandidate(func(c Candidate) { + if c == nil { + close(done) + + return + } + mu.Lock() + addresses = append(addresses, c) + mu.Unlock() + })) + + require.NoError(t, agent.GatherCandidates()) + select { + case <-done: + case <-time.After(2 * time.Second): + require.FailNow(t, "gather did not complete") + } + + mu.Lock() + defer mu.Unlock() + require.Len(t, addresses, 1) + assert.Equal(t, "203.0.113.1", addresses[0].Address()) + assert.Equal(t, 4321, addresses[0].Port()) + assert.Equal(t, CandidateTypeHost, addresses[0].Type()) + }) + + t.Run("append host via UDPMux", func(t *testing.T) { + mux := newMockUDPMux([]net.Addr{&net.UDPAddr{IP: net.IP{10, 0, 0, 1}, Port: 1234}}) + + agent, err := NewAgentWithOptions( + WithNet(newStubNet(t)), + WithCandidateTypes([]CandidateType{CandidateTypeHost}), + WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), + WithUDPMux(mux), + WithMulticastDNSMode(MulticastDNSModeDisabled), + WithAddressRewriteRules(AddressRewriteRule{ + External: []string{"203.0.113.2"}, + Local: "10.0.0.1", + AsCandidateType: CandidateTypeHost, + Mode: AddressRewriteAppend, + }), + WithMapPortHandler(func(c Candidate) int { + if c.Type() != CandidateTypeHost { + return c.Port() + } + if c.Port() == 1234 { + return 4321 + } else { + return 12345 + } + }), + ) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, agent.Close()) + }) + + var ( + mu sync.Mutex + addresses []Candidate + done = make(chan struct{}) + ) + require.NoError(t, agent.OnCandidate(func(c Candidate) { + if c == nil { + close(done) + + return + } + mu.Lock() + addresses = append(addresses, c) + mu.Unlock() + })) + + require.NoError(t, agent.GatherCandidates()) + select { + case <-done: + case <-time.After(2 * time.Second): + require.FailNow(t, "gather did not complete") + } + + mu.Lock() + defer mu.Unlock() + require.Len(t, addresses, 2) + seenAddrs := []string{addresses[0].Address(), addresses[1].Address()} + assert.ElementsMatch(t, []string{"10.0.0.1", "203.0.113.2"}, seenAddrs) + for _, cand := range addresses { + assert.Equal(t, CandidateTypeHost, cand.Type()) + assert.Equal(t, 4321, cand.Port()) + } + }) +} diff --git a/go.mod b/go.mod index 62e5599c..d70ec2fb 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/kr/pretty v0.1.0 // indirect + github.com/kr/text v0.2.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/wlynxg/anet v0.0.5 // indirect golang.org/x/crypto v0.33.0 // indirect diff --git a/go.sum b/go.sum index 69d7cdf2..34e06f7d 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,4 @@ +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -5,8 +6,9 @@ github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pion/dtls/v3 v3.0.10 h1:k9ekkq1kaZoxnNEbyLKI8DI37j/Nbk1HWmMuywpQJgg= github.com/pion/dtls/v3 v3.0.10/go.mod h1:YEmmBYIoBsY3jmG56dsziTv/Lca9y4Om83370CXfqJ8= github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=