diff --git a/sped.go b/sped.go new file mode 100644 index 00000000..6143cab3 --- /dev/null +++ b/sped.go @@ -0,0 +1,71 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "encoding/binary" + + "github.com/pion/stun/v3" +) + +// DtlsInStunAttribute is a STUN attribute for carrying DTLS embedded in STUN. +type DtlsInStunAttribute []byte + +// AddTo adds DTLS-in-STUN attribute to message. +func (d DtlsInStunAttribute) AddTo(m *stun.Message) error { + m.Add(stun.AttrDtlsInStun, d) + + return nil +} + +// GetFrom decodes DTLS-in-STUN attribute from message. +func (d *DtlsInStunAttribute) GetFrom(m *stun.Message) error { + v, err := m.Get(stun.AttrDtlsInStun) + if err != nil { + return err + } + + *d = v + + return nil +} + +// DtlsInStunAckAttribute is a STUN attribute for acknowledging the receipt +// of DTLS packets (embedded in STUN or without embedding). +type DtlsInStunAckAttribute []uint32 + +// Acks are 32 bit values, the attribute can carry up to four of these. +const ackSizeBytes, ackSizeValues = 32, 4 + +// AddTo adds DTLS-in-STUN-ACK attribute to message. +func (a DtlsInStunAckAttribute) AddTo(m *stun.Message) error { + if len(a) > ackSizeValues { + return stun.ErrAttributeSizeInvalid + } + v := make([]byte, len(a)*4) + for i, ack := range a { + binary.BigEndian.PutUint32(v[i*4:], ack) + } + m.Add(stun.AttrDtlsInStunAck, v) + + return nil +} + +// GetFrom decodes DTLS-in-STUN-ACK attribute from message. +func (a *DtlsInStunAckAttribute) GetFrom(m *stun.Message) error { + v, err := m.Get(stun.AttrDtlsInStunAck) + if err != nil { + return err + } + if len(v) > ackSizeBytes || len(v)%4 != 0 { + return stun.ErrAttributeSizeInvalid + } + u := make([]uint32, len(v)/4) + for i := range u { + u[i] = binary.BigEndian.Uint32(v[i*4 : (i+1)*4]) + } + *a = DtlsInStunAckAttribute(u) + + return nil +} diff --git a/sped_test.go b/sped_test.go new file mode 100644 index 00000000..ae5e2a9c --- /dev/null +++ b/sped_test.go @@ -0,0 +1,87 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "encoding/binary" + "testing" + + "github.com/pion/stun/v3" + "github.com/stretchr/testify/require" +) + +func TestDtlsInStunAttribute_GetFrom(t *testing.T) { + m := new(stun.Message) + var dtlsInStun DtlsInStunAttribute + require.ErrorIs(t, stun.ErrAttributeNotFound, dtlsInStun.GetFrom(m)) + + expectedValue := []byte{0x01, 0x02, 0x03, 0x04} + m.Add(stun.AttrDtlsInStun, expectedValue) + + var dtlsInStun1 DtlsInStunAttribute + require.NoError(t, dtlsInStun1.GetFrom(m)) + require.Equal(t, expectedValue, []byte(dtlsInStun1)) +} + +func TestDtlsInStunAttribute_AddTo(t *testing.T) { + m := new(stun.Message) + dtlsInStun := DtlsInStunAttribute([]byte{0x05, 0x06, 0x07, 0x08}) + require.NoError(t, dtlsInStun.AddTo(m)) + + v, err := m.Get(stun.AttrDtlsInStun) + require.NoError(t, err) + require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, v) +} + +func TestDtlsInStunAckAttribute_GetFrom(t *testing.T) { + m := new(stun.Message) + var dtlsInStunAck DtlsInStunAckAttribute + require.ErrorIs(t, stun.ErrAttributeNotFound, dtlsInStunAck.GetFrom(m)) + + // Test with valid data + expectedValue := []uint32{0x01020304, 0x05060708} + byteValue := make([]byte, 8) + binary.BigEndian.PutUint32(byteValue[0:4], expectedValue[0]) + binary.BigEndian.PutUint32(byteValue[4:8], expectedValue[1]) + m.Add(stun.AttrDtlsInStunAck, byteValue) + + var dtlsInStunAck1 DtlsInStunAckAttribute + require.NoError(t, dtlsInStunAck1.GetFrom(m)) + require.Equal(t, expectedValue, []uint32(dtlsInStunAck1)) + + // Test with invalid size (not multiple of 4) + m2 := new(stun.Message) + m2.Add(stun.AttrDtlsInStunAck, []byte{0x01, 0x02, 0x03}) + var dtlsInStunAck2 DtlsInStunAckAttribute + require.ErrorIs(t, stun.ErrAttributeSizeInvalid, dtlsInStunAck2.GetFrom(m2)) + require.Empty(t, dtlsInStunAck2) + + // Test with invalid size (greater than ackSize) + m3 := new(stun.Message) + m3.Add(stun.AttrDtlsInStunAck, make([]byte, ackSizeBytes+4)) + var dtlsInStunAck3 DtlsInStunAckAttribute + require.ErrorIs(t, stun.ErrAttributeSizeInvalid, dtlsInStunAck3.GetFrom(m3)) + require.Empty(t, dtlsInStunAck3) +} + +func TestDtlsInStunAckAttribute_AddTo(t *testing.T) { + m := new(stun.Message) + dtlsInStunAck := DtlsInStunAckAttribute([]uint32{0x090a0b0c, 0x0d0e0f10}) + require.NoError(t, dtlsInStunAck.AddTo(m)) + + v, err := m.Get(stun.AttrDtlsInStunAck) + require.NoError(t, err) + + expectedByteValue := make([]byte, 8) + binary.BigEndian.PutUint32(expectedByteValue[0:4], 0x090a0b0c) + binary.BigEndian.PutUint32(expectedByteValue[4:8], 0x0d0e0f10) + require.Equal(t, expectedByteValue, v) + + // Test with more than 4 elements (should not add to message) + m2 := new(stun.Message) + dtlsInStunAck2 := DtlsInStunAckAttribute([]uint32{1, 2, 3, 4, 5}) + require.ErrorIs(t, stun.ErrAttributeSizeInvalid, dtlsInStunAck2.AddTo(m2)) + _, err = m2.Get(stun.AttrDtlsInStunAck) + require.ErrorIs(t, err, stun.ErrAttributeNotFound) +}