diff --git a/transport.go b/transport.go index 40ac049b..19b30eff 100644 --- a/transport.go +++ b/transport.go @@ -12,16 +12,67 @@ import ( "github.com/pion/stun/v3" ) -// Dial connects to the remote agent, acting as the controlling ice agent. +// AwaitConnect waits until a pair is selected. +func (a *Agent) AwaitConnect(ctx context.Context) error { + select { + case <-a.loop.Done(): + return a.loop.Err() + case <-ctx.Done(): + return ErrCanceledByCaller + case <-a.onConnected: + } + + return nil +} + +// StartDial sets the agent up for connecting to the remote agent, acting as the +// controlling agent and returns immediately. +func (a *Agent) StartDial(remoteUfrag, remotePwd string) (*Conn, error) { + conn, err := a.startConnect(true, remoteUfrag, remotePwd) + if err != nil { + return nil, err + } + + return conn, nil +} + // Dial blocks until at least one ice candidate pair has successfully connected. func (a *Agent) Dial(ctx context.Context, remoteUfrag, remotePwd string) (*Conn, error) { - return a.connect(ctx, true, remoteUfrag, remotePwd) + conn, err := a.StartDial(remoteUfrag, remotePwd) //nolint:contextcheck + if err != nil { + return nil, err + } + err = a.AwaitConnect(ctx) + if err != nil { + return nil, err + } + + return conn, nil +} + +// StartAccept sets the agent up for connecting to the remote agent, acting as the +// controlled agent and returns immediately. +func (a *Agent) StartAccept(remoteUfrag, remotePwd string) (*Conn, error) { + conn, err := a.startConnect(false, remoteUfrag, remotePwd) + if err != nil { + return nil, err + } + + return conn, nil } -// Accept connects to the remote agent, acting as the controlled ice agent. // Accept blocks until at least one ice candidate pair has successfully connected. func (a *Agent) Accept(ctx context.Context, remoteUfrag, remotePwd string) (*Conn, error) { - return a.connect(ctx, false, remoteUfrag, remotePwd) + conn, err := a.StartAccept(remoteUfrag, remotePwd) //nolint:contextcheck + if err != nil { + return nil, err + } + err = a.AwaitConnect(ctx) + if err != nil { + return nil, err + } + + return conn, nil } // Conn represents the ICE connection. @@ -42,7 +93,7 @@ func (c *Conn) BytesReceived() uint64 { return c.bytesReceived.Load() } -func (a *Agent) connect(ctx context.Context, isControlling bool, remoteUfrag, remotePwd string) (*Conn, error) { +func (a *Agent) startConnect(isControlling bool, remoteUfrag, remotePwd string) (*Conn, error) { err := a.loop.Err() if err != nil { return nil, err @@ -52,15 +103,6 @@ func (a *Agent) connect(ctx context.Context, isControlling bool, remoteUfrag, re return nil, err } - // Block until pair selected - select { - case <-a.loop.Done(): - return nil, a.loop.Err() - case <-ctx.Done(): - return nil, ErrCanceledByCaller - case <-a.onConnected: - } - return &Conn{ agent: a, }, nil diff --git a/transport_test.go b/transport_test.go index 5886f90d..f379c81a 100644 --- a/transport_test.go +++ b/transport_test.go @@ -401,18 +401,21 @@ func TestAgent_connect_ErrEarly(t *testing.T) { cfg := &AgentConfig{ NetworkTypes: supportedNetworkTypes(), } - a, err := NewAgent(cfg) + agent, err := NewAgent(cfg) require.NoError(t, err) - require.NoError(t, a.Close()) + require.NoError(t, agent.Close()) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() // isControlling = true - conn, cerr := a.connect(ctx, true, "ufragX", "pwdX") + conn, cerr := agent.startConnect(true, "ufragX", "pwdX") require.Nil(t, conn) require.Error(t, cerr, "expected error from a.loop.Err() short-circuit") + + err2 := agent.AwaitConnect(ctx) + require.Error(t, err2, "the agent is closed") } func TestConn_Write_RejectsSTUN(t *testing.T) {