Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 56 additions & 14 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,67 @@ import (
"github.com/pion/stun/v3"
)

// Dial connects to the remote agent, acting as the controlling ice agent.
// AwaitConnect waits until a pair is selected.
func (a *Agent) AwaitConnect(ctx context.Context) error {
select {
case <-a.loop.Done():
return a.loop.Err()
case <-ctx.Done():
return ErrCanceledByCaller
case <-a.onConnected:
}

return nil
}

// StartDial sets the agent up for connecting to the remote agent, acting as the
// controlling agent and returns immediately.
func (a *Agent) StartDial(remoteUfrag, remotePwd string) (*Conn, error) {
conn, err := a.startConnect(true, remoteUfrag, remotePwd)
if err != nil {
return nil, err
}

return conn, nil
}

// Dial blocks until at least one ice candidate pair has successfully connected.
func (a *Agent) Dial(ctx context.Context, remoteUfrag, remotePwd string) (*Conn, error) {
return a.connect(ctx, true, remoteUfrag, remotePwd)
conn, err := a.StartDial(remoteUfrag, remotePwd) //nolint:contextcheck
if err != nil {
return nil, err
}
err = a.AwaitConnect(ctx)
if err != nil {
return nil, err
}

return conn, nil
}

// StartAccept sets the agent up for connecting to the remote agent, acting as the
// controlled agent and returns immediately.
func (a *Agent) StartAccept(remoteUfrag, remotePwd string) (*Conn, error) {
conn, err := a.startConnect(false, remoteUfrag, remotePwd)
if err != nil {
return nil, err
}

return conn, nil
}

// Accept connects to the remote agent, acting as the controlled ice agent.
// Accept blocks until at least one ice candidate pair has successfully connected.
func (a *Agent) Accept(ctx context.Context, remoteUfrag, remotePwd string) (*Conn, error) {
return a.connect(ctx, false, remoteUfrag, remotePwd)
conn, err := a.StartAccept(remoteUfrag, remotePwd) //nolint:contextcheck
if err != nil {
return nil, err
}
err = a.AwaitConnect(ctx)
if err != nil {
return nil, err
}

return conn, nil
}

// Conn represents the ICE connection.
Expand All @@ -42,7 +93,7 @@ func (c *Conn) BytesReceived() uint64 {
return c.bytesReceived.Load()
}

func (a *Agent) connect(ctx context.Context, isControlling bool, remoteUfrag, remotePwd string) (*Conn, error) {
func (a *Agent) startConnect(isControlling bool, remoteUfrag, remotePwd string) (*Conn, error) {
err := a.loop.Err()
if err != nil {
return nil, err
Expand All @@ -52,15 +103,6 @@ func (a *Agent) connect(ctx context.Context, isControlling bool, remoteUfrag, re
return nil, err
}

// Block until pair selected
select {
case <-a.loop.Done():
return nil, a.loop.Err()
case <-ctx.Done():
return nil, ErrCanceledByCaller
case <-a.onConnected:
}

return &Conn{
agent: a,
}, nil
Expand Down
9 changes: 6 additions & 3 deletions transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,18 +401,21 @@ func TestAgent_connect_ErrEarly(t *testing.T) {
cfg := &AgentConfig{
NetworkTypes: supportedNetworkTypes(),
}
a, err := NewAgent(cfg)
agent, err := NewAgent(cfg)
require.NoError(t, err)

require.NoError(t, a.Close())
require.NoError(t, agent.Close())

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()

// isControlling = true
conn, cerr := a.connect(ctx, true, "ufragX", "pwdX")
conn, cerr := agent.startConnect(true, "ufragX", "pwdX")
require.Nil(t, conn)
require.Error(t, cerr, "expected error from a.loop.Err() short-circuit")

err2 := agent.AwaitConnect(ctx)
require.Error(t, err2, "the agent is closed")
}

func TestConn_Write_RejectsSTUN(t *testing.T) {
Expand Down
Loading