diff --git a/gather.go b/gather.go index a1e02471..50656683 100644 --- a/gather.go +++ b/gather.go @@ -836,7 +836,17 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { return } + allocDoneCh := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + client.Close() + case <-allocDoneCh: + } + }() + relayConn, err := client.Allocate() + close(allocDoneCh) if err != nil { client.Close() closeConnAndLog(locConn, a.log, "failed to allocate on TURN client %s %s", turnServerAddr, err) diff --git a/gather_vnet_test.go b/gather_vnet_test.go index 3020367a..da8982c0 100644 --- a/gather_vnet_test.go +++ b/gather_vnet_test.go @@ -11,6 +11,7 @@ import ( "fmt" "net" "testing" + "time" "github.com/pion/logging" "github.com/pion/stun/v3" @@ -435,3 +436,50 @@ func TestVNetGather_TURNConnectionLeak(t *testing.T) { aAgent.gatherCandidatesRelay(context.Background(), []*stun.URI{turnServerURL}) } + +func TestVNetGather_TURNAllocationAbort(t *testing.T) { + defer test.CheckRoutines(t)() + + // configure unreachable TURN server + turnServerURL := &stun.URI{ + Scheme: stun.SchemeTypeTURN, + Host: vnetSTUNServerIP, + Port: vnetSTUNServerPort, + Username: "user", + Password: "pass", + Proto: stun.ProtoTypeUDP, + } + + loggerFactory := logging.NewDefaultLoggerFactory() + router, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "10.0.0.0/24", + LoggerFactory: loggerFactory, + }) + require.NoError(t, err) + + nw, err := vnet.NewNet(&vnet.NetConfig{}) + require.NoError(t, err) + require.NoError(t, router.AddNet(nw)) + + cfg0 := &AgentConfig{ + Urls: []*stun.URI{ + turnServerURL, + }, + NetworkTypes: supportedNetworkTypes(), + MulticastDNSMode: MulticastDNSModeDisabled, + Net: nw, + } + aAgent, err := NewAgent(cfg0) + require.NoError(t, err, "should succeed") + defer func() { + require.NoError(t, aAgent.Close()) + }() + + // if not canceled, gatherCandidatesRelay() will block for ~7.8s + defer test.TimeOut(time.Second * 1).Stop() + + ctx, cancelFunc := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancelFunc() + + aAgent.gatherCandidatesRelay(ctx, []*stun.URI{turnServerURL}) +}