diff --git a/server-packet.go b/server-packet.go index 1347d43..6d4f65b 100644 --- a/server-packet.go +++ b/server-packet.go @@ -46,6 +46,9 @@ type PacketServer struct { // This should only be set to true for debugging purposes. InsecureSkipVerify bool + // Don't block dupplicated requests (retransmissions). + AllowRetransmission bool + // ErrorLog specifies an optional logger for errors // around packet accepting, processing, and validation. // If nil, logging is done via the log package's standard logger. @@ -167,30 +170,31 @@ func (s *PacketServer) Serve(conn net.PacketConn) error { return } - key := requestKey{ - IP: remoteAddr.String(), - Identifier: packet.Identifier, - } - - requestsLock.Lock() - if _, ok := requests[key]; ok { + if !s.AllowRetransmission { + key := requestKey{ + IP: remoteAddr.String(), + Identifier: packet.Identifier, + } + requestsLock.Lock() + if _, ok := requests[key]; ok { + requestsLock.Unlock() + return + } + requests[key] = struct{}{} requestsLock.Unlock() - return - } - requests[key] = struct{}{} - requestsLock.Unlock() + //clean up afterwards + defer func() { + requestsLock.Lock() + delete(requests, key) + requestsLock.Unlock() + }() + } response := packetResponseWriter{ conn: conn, addr: remoteAddr, } - defer func() { - requestsLock.Lock() - delete(requests, key) - requestsLock.Unlock() - }() - request := Request{ LocalAddr: conn.LocalAddr(), RemoteAddr: remoteAddr, diff --git a/server-packet_test.go b/server-packet_test.go index 7497aaa..440234c 100644 --- a/server-packet_test.go +++ b/server-packet_test.go @@ -186,3 +186,116 @@ func TestPacketServer_singleUse(t *testing.T) { t.Fatalf("got err %v; expecting ErrServerShutdown", err) } } + +func TestPacketServer_AllowRetransmission(t *testing.T) { + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + if err != nil { + t.Fatal(err) + } + pc, err := net.ListenUDP("udp", addr) + if err != nil { + t.Fatal(err) + } + + secret := []byte("123456790") + var receivedRequests = 0 + var identifiers = make(map[byte]struct{}) + server := PacketServer{ + SecretSource: StaticSecretSource(secret), + AllowRetransmission: true, + Handler: HandlerFunc(func(w ResponseWriter, r *Request) { + receivedRequests++ + if _, ok := identifiers[r.Identifier]; ok { + return + } + identifiers[r.Identifier] = struct{}{} + time.Sleep(time.Millisecond * 200) + w.Write(r.Response(CodeAccessReject)) + }), + } + + var clientErr error + go func(rr int) { + defer server.Shutdown(context.Background()) + + packet := New(CodeAccessRequest, secret) + client := Client{ + Retry: time.Millisecond * 10, + } + response, err := client.Exchange(context.Background(), packet, pc.LocalAddr().String()) + if err != nil { + clientErr = err + return + } + if response.Code != CodeAccessReject { + clientErr = fmt.Errorf("got response code %v; expecting CodeAccessReject", response.Code) + } + if receivedRequests < 2 { + clientErr = fmt.Errorf("got %d requests; expecting at least 2", receivedRequests) + } + }(receivedRequests) + + if err := server.Serve(pc); err != ErrServerShutdown { + t.Fatal(err) + } + + server.Shutdown(context.Background()) + if clientErr != nil { + t.Fatal(clientErr) + } +} + +func TestPacketServer_BlockRetransmission(t *testing.T) { + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + if err != nil { + t.Fatal(err) + } + pc, err := net.ListenUDP("udp", addr) + if err != nil { + t.Fatal(err) + } + var receivedRequests = 0 + var identifiers = make(map[byte]struct{}) + secret := []byte("123456790") + server := PacketServer{ + SecretSource: StaticSecretSource(secret), + Handler: HandlerFunc(func(w ResponseWriter, r *Request) { + receivedRequests++ + if _, ok := identifiers[r.Identifier]; ok { + return + } + time.Sleep(time.Millisecond * 500) + w.Write(r.Response(CodeAccessReject)) + }), + } + + var clientErr error + go func(rr int) { + defer server.Shutdown(context.Background()) + + packet := New(CodeAccessRequest, secret) + client := Client{ + Retry: time.Millisecond * 10, + } + response, err := client.Exchange(context.Background(), packet, pc.LocalAddr().String()) + if err != nil { + clientErr = err + return + } + if response.Code != CodeAccessReject { + clientErr = fmt.Errorf("got response code %v; expecting CodeAccessReject", response.Code) + } + if receivedRequests != 1 { + clientErr = fmt.Errorf("got %d requests; expecting only 1", receivedRequests) + } + }(receivedRequests) + + if err := server.Serve(pc); err != ErrServerShutdown { + t.Fatal(err) + } + + server.Shutdown(context.Background()) + if clientErr != nil { + t.Fatal(clientErr) + } +}