From 057fa01b0cb5a69d7707a5bf03b2cf2f709dc24d Mon Sep 17 00:00:00 2001 From: Moses Narrow <36607567+0pcom@users.noreply.github.com> Date: Fri, 20 Mar 2026 12:28:55 -0500 Subject: [PATCH] Add comprehensive tests for dmsg packages e2e-style tests (pkg/dmsgtest/e2e_test.go): - TestBidirectionalStreams: bidirectional data transfer at 32B/4KB/64KB - TestMultiServerStreams: streams across multiple servers and clients - TestConcurrentStreams: 20 simultaneous streams with data integrity - TestSessionReconnect: client reconnects after server shutdown - TestListenerAcceptAll: listener accepts multiple connections - TestPortOccupied: duplicate listen returns ErrPortOccupied - TestDialNonexistentClient: dial unknown PK returns ErrDiscEntryNotFound direct client tests (pkg/direct/client_test.go): - Entry lookup, post, delete, put operations - AvailableServers/AllServers filtering - AllEntries enumeration - ClientsByServer/AllClientsByServer grouping - GetClientEntry and GetAllEntries utility functions ioutil tests (pkg/ioutil/buf_read_test.go): - BufRead with exact fit, short buffer, empty data, large data noise nonce tests (pkg/noise/nonce_test.go): - DecryptWithNonceMap replay prevention - Out-of-order decryption with nonce map - Encrypt/decrypt roundtrip - Large payload (64KB) roundtrip --- pkg/direct/client_test.go | 273 ++++++++++++++++++++ pkg/dmsgtest/e2e_test.go | 502 ++++++++++++++++++++++++++++++++++++ pkg/ioutil/buf_read_test.go | 71 +++++ pkg/noise/nonce_test.go | 123 +++++++++ 4 files changed, 969 insertions(+) create mode 100644 pkg/direct/client_test.go create mode 100644 pkg/dmsgtest/e2e_test.go create mode 100644 pkg/ioutil/buf_read_test.go create mode 100644 pkg/noise/nonce_test.go diff --git a/pkg/direct/client_test.go b/pkg/direct/client_test.go new file mode 100644 index 000000000..ddccdaf5d --- /dev/null +++ b/pkg/direct/client_test.go @@ -0,0 +1,273 @@ +package direct + +import ( + "context" + "testing" + + "github.com/skycoin/skywire/pkg/skywire-utilities/pkg/cipher" + "github.com/skycoin/skywire/pkg/skywire-utilities/pkg/logging" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/skycoin/dmsg/pkg/disc" +) + +func TestDirectClient_Entry(t *testing.T) { + log := logging.MustGetLogger("direct_test") + + pk1, _ := cipher.GenerateKeyPair() + pk2, _ := cipher.GenerateKeyPair() + pkMissing, _ := cipher.GenerateKeyPair() + + entries := []*disc.Entry{ + {Static: pk1, Server: &disc.Server{Address: "addr1"}}, + {Static: pk2, Server: &disc.Server{Address: "addr2"}}, + } + + client := NewClient(entries, log) + ctx := context.Background() + + // Existing entries should be found. + e1, err := client.Entry(ctx, pk1) + require.NoError(t, err) + assert.Equal(t, pk1, e1.Static) + + e2, err := client.Entry(ctx, pk2) + require.NoError(t, err) + assert.Equal(t, pk2, e2.Static) + + // Non-existent entry should return ErrKeyNotFound. + _, err = client.Entry(ctx, pkMissing) + assert.ErrorIs(t, err, disc.ErrKeyNotFound) +} + +func TestDirectClient_PostAndDelete(t *testing.T) { + log := logging.MustGetLogger("direct_test") + + client := NewClient(nil, log) + ctx := context.Background() + + pk, _ := cipher.GenerateKeyPair() + entry := &disc.Entry{ + Static: pk, + Server: &disc.Server{Address: "addr1"}, + } + + // Entry should not exist yet. + _, err := client.Entry(ctx, pk) + assert.ErrorIs(t, err, disc.ErrKeyNotFound) + + // Post entry. + require.NoError(t, client.PostEntry(ctx, entry)) + + // Entry should now be retrievable. + got, err := client.Entry(ctx, pk) + require.NoError(t, err) + assert.Equal(t, pk, got.Static) + + // Delete entry. + require.NoError(t, client.DelEntry(ctx, entry)) + + // Entry should be gone. + _, err = client.Entry(ctx, pk) + assert.ErrorIs(t, err, disc.ErrKeyNotFound) +} + +func TestDirectClient_PutEntry(t *testing.T) { + log := logging.MustGetLogger("direct_test") + + pk, sk := cipher.GenerateKeyPair() + original := &disc.Entry{ + Static: pk, + Server: &disc.Server{Address: "old-addr"}, + } + + client := NewClient([]*disc.Entry{original}, log) + ctx := context.Background() + + // Verify original. + got, err := client.Entry(ctx, pk) + require.NoError(t, err) + assert.Equal(t, "old-addr", got.Server.Address) + + // Update via PutEntry. + updated := &disc.Entry{ + Static: pk, + Server: &disc.Server{Address: "new-addr"}, + } + require.NoError(t, client.PutEntry(ctx, sk, updated)) + + // Verify update persists. + got, err = client.Entry(ctx, pk) + require.NoError(t, err) + assert.Equal(t, "new-addr", got.Server.Address) +} + +func TestDirectClient_AvailableServers(t *testing.T) { + log := logging.MustGetLogger("direct_test") + + pkSrv1, _ := cipher.GenerateKeyPair() + pkSrv2, _ := cipher.GenerateKeyPair() + pkClient, _ := cipher.GenerateKeyPair() + + entries := []*disc.Entry{ + {Static: pkSrv1, Server: &disc.Server{Address: "srv1"}}, + {Static: pkSrv2, Server: &disc.Server{Address: "srv2"}}, + {Static: pkClient, Client: &disc.Client{DelegatedServers: []cipher.PubKey{pkSrv1}}}, + } + + client := NewClient(entries, log) + ctx := context.Background() + + // AvailableServers should return only server entries. + servers, err := client.AvailableServers(ctx) + require.NoError(t, err) + assert.Len(t, servers, 2) + for _, s := range servers { + assert.NotNil(t, s.Server) + } + + // AllServers should return the same set. + allSrv, err := client.AllServers(ctx) + require.NoError(t, err) + assert.Len(t, allSrv, 2) + for _, s := range allSrv { + assert.NotNil(t, s.Server) + } +} + +func TestDirectClient_AllEntries(t *testing.T) { + log := logging.MustGetLogger("direct_test") + + pk1, _ := cipher.GenerateKeyPair() + pk2, _ := cipher.GenerateKeyPair() + pk3, _ := cipher.GenerateKeyPair() + + entries := []*disc.Entry{ + {Static: pk1, Server: &disc.Server{Address: "srv1"}}, + {Static: pk2, Client: &disc.Client{}}, + {Static: pk3, Server: &disc.Server{Address: "srv2"}}, + } + + client := NewClient(entries, log) + ctx := context.Background() + + all, err := client.AllEntries(ctx) + require.NoError(t, err) + assert.Len(t, all, 3) + + // Verify all PKs are present. + hexSet := make(map[string]bool) + for _, h := range all { + hexSet[h] = true + } + assert.True(t, hexSet[pk1.Hex()]) + assert.True(t, hexSet[pk2.Hex()]) + assert.True(t, hexSet[pk3.Hex()]) +} + +func TestDirectClient_ClientsByServer(t *testing.T) { + log := logging.MustGetLogger("direct_test") + + pkSrv1, _ := cipher.GenerateKeyPair() + pkSrv2, _ := cipher.GenerateKeyPair() + pkC1, _ := cipher.GenerateKeyPair() + pkC2, _ := cipher.GenerateKeyPair() + pkC3, _ := cipher.GenerateKeyPair() + + entries := []*disc.Entry{ + {Static: pkSrv1, Server: &disc.Server{Address: "srv1"}}, + {Static: pkSrv2, Server: &disc.Server{Address: "srv2"}}, + {Static: pkC1, Client: &disc.Client{DelegatedServers: []cipher.PubKey{pkSrv1}}}, + {Static: pkC2, Client: &disc.Client{DelegatedServers: []cipher.PubKey{pkSrv1, pkSrv2}}}, + {Static: pkC3, Client: &disc.Client{DelegatedServers: []cipher.PubKey{pkSrv2}}}, + } + + client := NewClient(entries, log) + ctx := context.Background() + + // ClientsByServer for srv1 should return pkC1 and pkC2. + srv1Clients, err := client.ClientsByServer(ctx, pkSrv1) + require.NoError(t, err) + assert.Len(t, srv1Clients, 2) + pks := make(map[cipher.PubKey]bool) + for _, e := range srv1Clients { + pks[e.Static] = true + } + assert.True(t, pks[pkC1]) + assert.True(t, pks[pkC2]) + + // ClientsByServer for srv2 should return pkC2 and pkC3. + srv2Clients, err := client.ClientsByServer(ctx, pkSrv2) + require.NoError(t, err) + assert.Len(t, srv2Clients, 2) + pks = make(map[cipher.PubKey]bool) + for _, e := range srv2Clients { + pks[e.Static] = true + } + assert.True(t, pks[pkC2]) + assert.True(t, pks[pkC3]) + + // AllClientsByServer should group correctly. + allByServer, err := client.AllClientsByServer(ctx) + require.NoError(t, err) + assert.Len(t, allByServer[pkSrv1.Hex()], 2) + assert.Len(t, allByServer[pkSrv2.Hex()], 2) +} + +func TestGetClientEntry(t *testing.T) { + pkSrv1, _ := cipher.GenerateKeyPair() + pkSrv2, _ := cipher.GenerateKeyPair() + + servers := []*disc.Entry{ + {Static: pkSrv1, Server: &disc.Server{Address: "srv1"}}, + {Static: pkSrv2, Server: &disc.Server{Address: "srv2"}}, + } + + pkC1, _ := cipher.GenerateKeyPair() + pkC2, _ := cipher.GenerateKeyPair() + pks := cipher.PubKeys{pkC1, pkC2} + + clients := GetClientEntry(pks, servers) + require.Len(t, clients, 2) + + for _, c := range clients { + require.NotNil(t, c.Client) + assert.Len(t, c.Client.DelegatedServers, 2) + assert.Contains(t, c.Client.DelegatedServers, pkSrv1) + assert.Contains(t, c.Client.DelegatedServers, pkSrv2) + assert.Equal(t, "0.0.1", c.Version) + } + + // Verify the correct static keys are assigned. + staticSet := make(map[cipher.PubKey]bool) + for _, c := range clients { + staticSet[c.Static] = true + } + assert.True(t, staticSet[pkC1]) + assert.True(t, staticSet[pkC2]) +} + +func TestGetAllEntries(t *testing.T) { + pkSrv, _ := cipher.GenerateKeyPair() + servers := []*disc.Entry{ + {Static: pkSrv, Server: &disc.Server{Address: "srv1"}}, + } + + pkC1, _ := cipher.GenerateKeyPair() + pkC2, _ := cipher.GenerateKeyPair() + pks := cipher.PubKeys{pkC1, pkC2} + + all := GetAllEntries(pks, servers) + + // Should contain 2 clients + 1 server = 3 entries. + require.Len(t, all, 3) + + staticSet := make(map[cipher.PubKey]bool) + for _, e := range all { + staticSet[e.Static] = true + } + assert.True(t, staticSet[pkC1]) + assert.True(t, staticSet[pkC2]) + assert.True(t, staticSet[pkSrv]) +} diff --git a/pkg/dmsgtest/e2e_test.go b/pkg/dmsgtest/e2e_test.go new file mode 100644 index 000000000..4656d95b2 --- /dev/null +++ b/pkg/dmsgtest/e2e_test.go @@ -0,0 +1,502 @@ +// Package dmsgtest pkg/dmsgtest/e2e_test.go +// +//nolint:errcheck +package dmsgtest + +import ( + "bytes" + "context" + "fmt" + "io" + "sync" + "testing" + "time" + + "github.com/skycoin/skywire/pkg/skywire-utilities/pkg/cipher" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + dmsg "github.com/skycoin/dmsg/pkg/dmsg" +) + +func TestBidirectionalStreams(t *testing.T) { + const timeout = time.Second * 30 + + payloads := []struct { + name string + size int + }{ + {"small_32B", 32}, + {"medium_4KB", 4 * 1024}, + {"large_64KB", 64 * 1024}, + } + + for _, p := range payloads { + p := p + t.Run(p.name, func(t *testing.T) { + env := NewEnv(t, timeout) + require.NoError(t, env.Startup(0, 1, 2, &dmsg.Config{MinSessions: 1})) + t.Cleanup(env.Shutdown) + + clients := env.AllClients() + require.Len(t, clients, 2) + clientA := clients[0] + clientB := clients[1] + + const port = uint16(100) + + lisA, err := clientA.Listen(port) + require.NoError(t, err) + t.Cleanup(func() { _ = lisA.Close() }) + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + // Client B dials Client A. + streamB, err := clientB.DialStream(ctx, dmsg.Addr{PK: clientA.LocalPK(), Port: port}) + require.NoError(t, err) + t.Cleanup(func() { _ = streamB.Close() }) + + connA, err := lisA.AcceptStream() + require.NoError(t, err) + t.Cleanup(func() { _ = connA.Close() }) + + // Test A -> B + dataAtoB := cipher.RandByte(p.size) + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + _, wErr := connA.Write(dataAtoB) + assert.NoError(t, wErr) + }() + + recvAtoB := make([]byte, p.size) + _, err = io.ReadFull(streamB, recvAtoB) + require.NoError(t, err) + wg.Wait() + require.True(t, bytes.Equal(dataAtoB, recvAtoB), "A->B data mismatch") + + // Test B -> A + dataBtoA := cipher.RandByte(p.size) + + wg.Add(1) + go func() { + defer wg.Done() + _, wErr := streamB.Write(dataBtoA) + assert.NoError(t, wErr) + }() + + recvBtoA := make([]byte, p.size) + _, err = io.ReadFull(connA, recvBtoA) + require.NoError(t, err) + wg.Wait() + require.True(t, bytes.Equal(dataBtoA, recvBtoA), "B->A data mismatch") + }) + } +} + +func TestMultiServerStreams(t *testing.T) { + const timeout = time.Second * 30 + + env := NewEnv(t, timeout) + require.NoError(t, env.Startup(0, 2, 3, &dmsg.Config{MinSessions: 2})) + t.Cleanup(env.Shutdown) + + clients := env.AllClients() + require.Len(t, clients, 3) + clientA := clients[0] + clientB := clients[1] + clientC := clients[2] + + const ( + portB = uint16(100) + portC = uint16(101) + ) + + // Client B and C listen. + lisB, err := clientB.Listen(portB) + require.NoError(t, err) + t.Cleanup(func() { _ = lisB.Close() }) + + lisC, err := clientC.Listen(portC) + require.NoError(t, err) + t.Cleanup(func() { _ = lisC.Close() }) + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + // Client A dials Client B. + streamAB, err := clientA.DialStream(ctx, dmsg.Addr{PK: clientB.LocalPK(), Port: portB}) + require.NoError(t, err) + t.Cleanup(func() { _ = streamAB.Close() }) + + connBA, err := lisB.AcceptStream() + require.NoError(t, err) + t.Cleanup(func() { _ = connBA.Close() }) + + // Client B dials Client C. + streamBC, err := clientB.DialStream(ctx, dmsg.Addr{PK: clientC.LocalPK(), Port: portC}) + require.NoError(t, err) + t.Cleanup(func() { _ = streamBC.Close() }) + + connCB, err := lisC.AcceptStream() + require.NoError(t, err) + t.Cleanup(func() { _ = connCB.Close() }) + + // Verify data flows on A->B stream. + dataAB := cipher.RandByte(512) + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + _, wErr := streamAB.Write(dataAB) + assert.NoError(t, wErr) + }() + + recvAB := make([]byte, 512) + _, err = io.ReadFull(connBA, recvAB) + require.NoError(t, err) + wg.Wait() + require.True(t, bytes.Equal(dataAB, recvAB), "A->B data mismatch") + + // Verify data flows on B->C stream. + dataBC := cipher.RandByte(512) + + wg.Add(1) + go func() { + defer wg.Done() + _, wErr := streamBC.Write(dataBC) + assert.NoError(t, wErr) + }() + + recvBC := make([]byte, 512) + _, err = io.ReadFull(connCB, recvBC) + require.NoError(t, err) + wg.Wait() + require.True(t, bytes.Equal(dataBC, recvBC), "B->C data mismatch") + + // Verify each client can maintain multiple simultaneous streams. + // Open a second stream from A to B. + streamAB2, err := clientA.DialStream(ctx, dmsg.Addr{PK: clientB.LocalPK(), Port: portB}) + require.NoError(t, err) + t.Cleanup(func() { _ = streamAB2.Close() }) + + connBA2, err := lisB.AcceptStream() + require.NoError(t, err) + t.Cleanup(func() { _ = connBA2.Close() }) + + dataAB2 := cipher.RandByte(256) + + wg.Add(1) + go func() { + defer wg.Done() + _, wErr := streamAB2.Write(dataAB2) + assert.NoError(t, wErr) + }() + + recvAB2 := make([]byte, 256) + _, err = io.ReadFull(connBA2, recvAB2) + require.NoError(t, err) + wg.Wait() + require.True(t, bytes.Equal(dataAB2, recvAB2), "A->B second stream data mismatch") +} + +func TestConcurrentStreams(t *testing.T) { + const timeout = time.Second * 30 + const numStreams = 20 + const payloadSize = 256 + + env := NewEnv(t, timeout) + require.NoError(t, env.Startup(0, 1, 2, &dmsg.Config{MinSessions: 1})) + t.Cleanup(env.Shutdown) + + clients := env.AllClients() + require.Len(t, clients, 2) + clientA := clients[0] + clientB := clients[1] + + const port = uint16(100) + + lisA, err := clientA.Listen(port) + require.NoError(t, err) + t.Cleanup(func() { _ = lisA.Close() }) + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + // Accept connections in the background, collecting accepted streams. + type acceptedStream struct { + stream *dmsg.Stream + err error + } + acceptCh := make(chan acceptedStream, numStreams) + + go func() { + for i := 0; i < numStreams; i++ { + s, aErr := lisA.AcceptStream() + acceptCh <- acceptedStream{stream: s, err: aErr} + } + }() + + // Dial numStreams connections concurrently, each with unique data. + type streamResult struct { + idx int + sent []byte + stream *dmsg.Stream + err error + } + + results := make([]streamResult, numStreams) + var wg sync.WaitGroup + + for i := 0; i < numStreams; i++ { + i := i + wg.Add(1) + go func() { + defer wg.Done() + s, dErr := clientB.DialStream(ctx, dmsg.Addr{PK: clientA.LocalPK(), Port: port}) + if dErr != nil { + results[i] = streamResult{idx: i, err: dErr} + return + } + results[i] = streamResult{idx: i, stream: s, sent: cipher.RandByte(payloadSize)} + }() + } + wg.Wait() + + // Collect accepted streams. + accepted := make([]*dmsg.Stream, 0, numStreams) + for i := 0; i < numStreams; i++ { + a := <-acceptCh + require.NoError(t, a.err) + accepted = append(accepted, a.stream) + } + + // Send data on each dialed stream and read on the accepted side. + var sendWg sync.WaitGroup + for i := 0; i < numStreams; i++ { + require.NoError(t, results[i].err, "stream %d dial failed", i) + i := i + sendWg.Add(1) + go func() { + defer sendWg.Done() + _, wErr := results[i].stream.Write(results[i].sent) + assert.NoError(t, wErr, "stream %d write failed", i) + }() + } + + // Read from each accepted stream. Since we do not know which accepted stream + // corresponds to which dialed stream, we collect all received data and match. + recvData := make([][]byte, numStreams) + var recvWg sync.WaitGroup + for i := 0; i < numStreams; i++ { + i := i + recvWg.Add(1) + go func() { + defer recvWg.Done() + buf := make([]byte, payloadSize) + _, rErr := io.ReadFull(accepted[i], buf) + assert.NoError(t, rErr, "stream %d read failed", i) + recvData[i] = buf + }() + } + + sendWg.Wait() + recvWg.Wait() + + // Verify each sent payload appears exactly once in received data. + for i := 0; i < numStreams; i++ { + found := false + for j := 0; j < numStreams; j++ { + if bytes.Equal(results[i].sent, recvData[j]) { + found = true + break + } + } + require.True(t, found, "stream %d: sent data not found in any received data", i) + } + + // Cleanup streams. + for i := 0; i < numStreams; i++ { + _ = results[i].stream.Close() + _ = accepted[i].Close() + } +} + +func TestSessionReconnect(t *testing.T) { + const timeout = time.Second * 30 + + env := NewEnv(t, timeout) + // Start with 2 servers, 0 clients initially so we can track servers. + require.NoError(t, env.Startup(0, 2, 0, nil)) + t.Cleanup(env.Shutdown) + + servers := env.AllServers() + require.Len(t, servers, 2) + + // Create client that connects to both servers. + clientA, err := env.NewClient(&dmsg.Config{MinSessions: 2}) + require.NoError(t, err) + + clientB, err := env.NewClient(&dmsg.Config{MinSessions: 2}) + require.NoError(t, err) + + const port = uint16(100) + + lisB, err := clientB.Listen(port) + require.NoError(t, err) + t.Cleanup(func() { _ = lisB.Close() }) + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + // Verify streams work before closing a server. + stream1, err := clientA.DialStream(ctx, dmsg.Addr{PK: clientB.LocalPK(), Port: port}) + require.NoError(t, err) + + conn1, err := lisB.AcceptStream() + require.NoError(t, err) + + data1 := cipher.RandByte(64) + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + _, wErr := stream1.Write(data1) + assert.NoError(t, wErr) + }() + + recv1 := make([]byte, 64) + _, err = io.ReadFull(conn1, recv1) + require.NoError(t, err) + wg.Wait() + require.True(t, bytes.Equal(data1, recv1)) + + _ = stream1.Close() + _ = conn1.Close() + + // Close one server. + err = servers[0].Close() + require.NoError(t, err) + + // Wait a moment for the client to handle the disconnection. + time.Sleep(time.Second * 2) + + // Verify streams still work via the second server. + stream2, err := clientA.DialStream(ctx, dmsg.Addr{PK: clientB.LocalPK(), Port: port}) + require.NoError(t, err) + t.Cleanup(func() { _ = stream2.Close() }) + + conn2, err := lisB.AcceptStream() + require.NoError(t, err) + t.Cleanup(func() { _ = conn2.Close() }) + + data2 := cipher.RandByte(64) + + wg.Add(1) + go func() { + defer wg.Done() + _, wErr := stream2.Write(data2) + assert.NoError(t, wErr) + }() + + recv2 := make([]byte, 64) + _, err = io.ReadFull(conn2, recv2) + require.NoError(t, err) + wg.Wait() + require.True(t, bytes.Equal(data2, recv2), "data mismatch after server closure") +} + +func TestListenerAcceptAll(t *testing.T) { + const timeout = time.Second * 30 + + env := NewEnv(t, timeout) + require.NoError(t, env.Startup(0, 1, 2, &dmsg.Config{MinSessions: 1})) + t.Cleanup(env.Shutdown) + + clients := env.AllClients() + require.Len(t, clients, 2) + clientA := clients[0] + clientB := clients[1] + + const port = uint16(100) + const numConns = 10 + + lisA, err := clientA.Listen(port) + require.NoError(t, err) + t.Cleanup(func() { _ = lisA.Close() }) + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + // Dial numConns connections and accept them all. + dialedStreams := make([]*dmsg.Stream, 0, numConns) + for i := 0; i < numConns; i++ { + s, dErr := clientB.DialStream(ctx, dmsg.Addr{PK: clientA.LocalPK(), Port: port}) + require.NoError(t, dErr, "dial %d should succeed", i) + dialedStreams = append(dialedStreams, s) + } + + // Accept all connections. + for i := 0; i < numConns; i++ { + conn, aErr := lisA.AcceptStream() + require.NoError(t, aErr, "accept %d should succeed", i) + require.NotNil(t, conn) + _ = conn.Close() + } + + // Cleanup dialed streams. + for _, s := range dialedStreams { + _ = s.Close() + } +} + +func TestPortOccupied(t *testing.T) { + const timeout = time.Second * 30 + + env := NewEnv(t, timeout) + require.NoError(t, env.Startup(0, 1, 1, &dmsg.Config{MinSessions: 1})) + t.Cleanup(env.Shutdown) + + clients := env.AllClients() + require.Len(t, clients, 1) + client := clients[0] + + const port = uint16(100) + + lis1, err := client.Listen(port) + require.NoError(t, err) + t.Cleanup(func() { _ = lis1.Close() }) + + // Second listen on same port should return ErrPortOccupied. + _, err = client.Listen(port) + require.Error(t, err) + require.ErrorIs(t, err, dmsg.ErrPortOccupied) +} + +func TestDialNonexistentClient(t *testing.T) { + const timeout = time.Second * 30 + + env := NewEnv(t, timeout) + require.NoError(t, env.Startup(0, 1, 1, &dmsg.Config{MinSessions: 1})) + t.Cleanup(env.Shutdown) + + clients := env.AllClients() + require.Len(t, clients, 1) + client := clients[0] + + // Generate a random PK that does not exist in discovery. + randomPK, _ := cipher.GenerateKeyPair() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + _, err := client.DialStream(ctx, dmsg.Addr{PK: randomPK, Port: 100}) + require.Error(t, err) + require.ErrorIs(t, err, dmsg.ErrDiscEntryNotFound, + fmt.Sprintf("expected ErrDiscEntryNotFound, got: %v", err)) +} diff --git a/pkg/ioutil/buf_read_test.go b/pkg/ioutil/buf_read_test.go new file mode 100644 index 000000000..867291b3f --- /dev/null +++ b/pkg/ioutil/buf_read_test.go @@ -0,0 +1,71 @@ +package ioutil + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBufRead_ExactFit(t *testing.T) { + buf := new(bytes.Buffer) + data := []byte("hello") + p := make([]byte, 5) + + n, err := BufRead(buf, data, p) + require.NoError(t, err) + assert.Equal(t, 5, n) + assert.Equal(t, []byte("hello"), p) + assert.Equal(t, 0, buf.Len()) +} + +func TestBufRead_ShortP(t *testing.T) { + buf := new(bytes.Buffer) + data := []byte("hello world") + p := make([]byte, 5) + + n, err := BufRead(buf, data, p) + require.NoError(t, err) + assert.Equal(t, 5, n) + assert.Equal(t, []byte("hello"), p) + assert.Equal(t, " world", buf.String()) +} + +func TestBufRead_EmptyData(t *testing.T) { + buf := new(bytes.Buffer) + data := []byte{} + p := make([]byte, 10) + + n, err := BufRead(buf, data, p) + require.NoError(t, err) + assert.Equal(t, 0, n) + assert.Equal(t, 0, buf.Len()) +} + +func TestBufRead_EmptyP(t *testing.T) { + buf := new(bytes.Buffer) + data := []byte("hello") + p := make([]byte, 0) + + n, err := BufRead(buf, data, p) + require.NoError(t, err) + assert.Equal(t, 0, n) + assert.Equal(t, "hello", buf.String()) +} + +func TestBufRead_LargeData(t *testing.T) { + buf := new(bytes.Buffer) + data := make([]byte, 10000) + for i := range data { + data[i] = byte(i % 256) + } + p := make([]byte, 100) + + n, err := BufRead(buf, data, p) + require.NoError(t, err) + assert.Equal(t, 100, n) + assert.Equal(t, data[:100], p) + assert.Equal(t, 9900, buf.Len()) + assert.Equal(t, data[100:], buf.Bytes()) +} diff --git a/pkg/noise/nonce_test.go b/pkg/noise/nonce_test.go new file mode 100644 index 000000000..fa416512c --- /dev/null +++ b/pkg/noise/nonce_test.go @@ -0,0 +1,123 @@ +package noise + +import ( + "testing" + + "github.com/skycoin/skywire/pkg/skywire-utilities/pkg/cipher" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// newHandshakedPair creates an initiator/responder Noise pair that have +// completed the KK handshake and are ready for encrypt/decrypt. +func newHandshakedPair(t *testing.T) (initiator, responder *Noise) { + t.Helper() + + pkI, skI := cipher.GenerateKeyPair() + pkR, skR := cipher.GenerateKeyPair() + + var err error + initiator, err = KKAndSecp256k1(Config{ + LocalPK: pkI, + LocalSK: skI, + RemotePK: pkR, + Initiator: true, + }) + require.NoError(t, err) + + responder, err = KKAndSecp256k1(Config{ + LocalPK: pkR, + LocalSK: skR, + RemotePK: pkI, + Initiator: false, + }) + require.NoError(t, err) + + // -> e, es + msg, err := initiator.MakeHandshakeMessage() + require.NoError(t, err) + require.NoError(t, responder.ProcessHandshakeMessage(msg)) + + // <- e, ee + msg, err = responder.MakeHandshakeMessage() + require.NoError(t, err) + require.NoError(t, initiator.ProcessHandshakeMessage(msg)) + + require.True(t, initiator.HandshakeFinished()) + require.True(t, responder.HandshakeFinished()) + + return initiator, responder +} + +func TestDecryptWithNonceMap_ReplayPrevention(t *testing.T) { + nI, nR := newHandshakedPair(t) + + nm := make(NonceMap) + ciphertext := nI.EncryptUnsafe([]byte("secret")) + + // First decryption should succeed. + plaintext, err := nR.DecryptWithNonceMap(nm, ciphertext) + require.NoError(t, err) + assert.Equal(t, []byte("secret"), plaintext) + + // Replaying the same ciphertext should fail with "repeated" error. + _, err = nR.DecryptWithNonceMap(nm, ciphertext) + require.Error(t, err) + assert.Contains(t, err.Error(), "repeated") +} + +func TestDecryptWithNonceMap_OutOfOrder(t *testing.T) { + nI, nR := newHandshakedPair(t) + + nm := make(NonceMap) + + ct1 := nI.EncryptUnsafe([]byte("msg1")) + ct2 := nI.EncryptUnsafe([]byte("msg2")) + ct3 := nI.EncryptUnsafe([]byte("msg3")) + + // Decrypt out of order: 3, 1, 2. + pt, err := nR.DecryptWithNonceMap(nm, ct3) + require.NoError(t, err) + assert.Equal(t, []byte("msg3"), pt) + + pt, err = nR.DecryptWithNonceMap(nm, ct1) + require.NoError(t, err) + assert.Equal(t, []byte("msg1"), pt) + + pt, err = nR.DecryptWithNonceMap(nm, ct2) + require.NoError(t, err) + assert.Equal(t, []byte("msg2"), pt) +} + +func TestEncryptDecrypt_Roundtrip(t *testing.T) { + nI, nR := newHandshakedPair(t) + + messages := []string{"hello", "world", "a]b[c{d}e"} + for _, msg := range messages { + ct := nI.EncryptUnsafe([]byte(msg)) + pt, err := nR.DecryptUnsafe(ct) + require.NoError(t, err) + assert.Equal(t, []byte(msg), pt) + } + + // Also test responder -> initiator direction. + ct := nR.EncryptUnsafe([]byte("reverse")) + pt, err := nI.DecryptUnsafe(ct) + require.NoError(t, err) + assert.Equal(t, []byte("reverse"), pt) +} + +func TestEncryptDecrypt_LargePayload(t *testing.T) { + nI, nR := newHandshakedPair(t) + + // 64 KiB payload. + payload := make([]byte, 64*1024) + for i := range payload { + payload[i] = byte(i % 251) + } + + ct := nI.EncryptUnsafe(payload) + pt, err := nR.DecryptUnsafe(ct) + require.NoError(t, err) + assert.Equal(t, payload, pt) +}