From 5f0a26acdedaf0cf857fbfab1c52ba2044c72f55 Mon Sep 17 00:00:00 2001 From: 33TU Date: Sat, 18 Oct 2025 03:38:54 +0300 Subject: [PATCH 1/6] feat(socks5): implement SOCKS5 protocol support with tests --- socks5/consts.go | 48 ++++++ socks5/handshake_reply.go | 78 ++++++++++ socks5/handshake_reply_test.go | 78 ++++++++++ socks5/handshake_request.go | 93 +++++++++++ socks5/handshake_request_test.go | 98 ++++++++++++ socks5/reply.go | 244 +++++++++++++++++++++++++++++ socks5/reply_test.go | 171 +++++++++++++++++++++ socks5/request.go | 252 ++++++++++++++++++++++++++++++ socks5/request_test.go | 140 +++++++++++++++++ socks5/udp_packet.go | 254 +++++++++++++++++++++++++++++++ socks5/udp_packet_test.go | 196 ++++++++++++++++++++++++ socks5/user_pass_reply.go | 77 ++++++++++ socks5/user_pass_reply_test.go | 98 ++++++++++++ socks5/user_pass_request.go | 146 ++++++++++++++++++ socks5/user_pass_request_test.go | 115 ++++++++++++++ 15 files changed, 2088 insertions(+) create mode 100644 socks5/handshake_reply.go create mode 100644 socks5/handshake_reply_test.go create mode 100644 socks5/handshake_request.go create mode 100644 socks5/handshake_request_test.go create mode 100644 socks5/reply.go create mode 100644 socks5/reply_test.go create mode 100644 socks5/request.go create mode 100644 socks5/request_test.go create mode 100644 socks5/udp_packet.go create mode 100644 socks5/udp_packet_test.go create mode 100644 socks5/user_pass_reply.go create mode 100644 socks5/user_pass_reply_test.go create mode 100644 socks5/user_pass_request.go create mode 100644 socks5/user_pass_request_test.go diff --git a/socks5/consts.go b/socks5/consts.go index e69de29..2d4d01b 100644 --- a/socks5/consts.go +++ b/socks5/consts.go @@ -0,0 +1,48 @@ +package socks5 + +// Protocol version. +const ( + SocksVersion = 5 +) + +// Command codes (CMD) for client requests. +const ( + CmdConnect = 1 + CmdBind = 2 + CmdUDPAssociate = 3 + CmdResolve = 0xF0 + CmdResolvePTR = 0xF1 +) + +// Address types (ATYP) used in requests and responses. +const ( + AddrTypeIPv4 = 1 + AddrTypeDomain = 3 + AddrTypeIPv6 = 4 +) + +// Reply codes (REP) for server responses. +const ( + RepSuccess = 0 + RepGeneralFailure = 1 + RepConnectionNotAllowed = 2 + RepNetworkUnreachable = 3 + RepHostUnreachable = 4 + RepConnectionRefused = 5 + RepTTLExpired = 6 + RepCommandNotSupported = 7 + RepAddrTypeNotSupported = 8 +) + +// Authentication methods (METHOD) for initial greeting. +const ( + MethodNoAuth = 0x00 + MethodGSSAPI = 0x01 + MethodUserPass = 0x02 + MethodNoAcceptable = 0xFF +) + +// Authentication sub-negotiation versions. +const ( + AuthVersionUserPass = 1 +) diff --git a/socks5/handshake_reply.go b/socks5/handshake_reply.go new file mode 100644 index 0000000..df3bb80 --- /dev/null +++ b/socks5/handshake_reply.go @@ -0,0 +1,78 @@ +package socks5 + +import ( + "errors" + "fmt" + "io" +) + +// Errors for SOCKS5 handshake replies. +var ( + ErrInvalidHandshakeReplyVersion = errors.New("invalid SOCKS version in handshake reply (must be 5)") +) + +// HandshakeReply represents the server’s response to a SOCKS5 handshake request. +type HandshakeReply struct { + Version byte // VER (should always be 0x05) + Method byte // METHOD; selected authentication method +} + +// Init initializes a handshake reply with the given method. +func (h *HandshakeReply) Init(method byte) { + h.Version = SocksVersion + h.Method = method +} + +// Validate ensures the handshake reply is valid. +func (h *HandshakeReply) Validate() error { + if h.Version != SocksVersion { + return ErrInvalidHandshakeReplyVersion + } + return nil +} + +// ReadFrom reads a SOCKS5 handshake reply from an io.Reader. +// Implements io.ReaderFrom. +func (h *HandshakeReply) ReadFrom(src io.Reader) (int64, error) { + var buf [2]byte + + n, err := io.ReadFull(src, buf[:]) + if err != nil { + return int64(n), err + } + + h.Version = buf[0] + h.Method = buf[1] + + return int64(n), h.Validate() +} + +// WriteTo writes the handshake reply to an io.Writer. +// Implements io.WriterTo. +func (h *HandshakeReply) WriteTo(dst io.Writer) (int64, error) { + buf := [2]byte{h.Version, h.Method} + n, err := dst.Write(buf[:]) + return int64(n), err +} + +// String returns a human-readable representation of the handshake reply. +func (h *HandshakeReply) String() string { + var method string + switch h.Method { + case MethodNoAuth: + method = "NoAuth" + case MethodGSSAPI: + method = "GSSAPI" + case MethodUserPass: + method = "UserPass" + case MethodNoAcceptable: + method = "NoAcceptable" + default: + method = fmt.Sprintf("Unknown(0x%02x)", h.Method) + } + + return fmt.Sprintf( + "SOCKS5 HandshakeReply{Version=%d, Method=%s}", + h.Version, method, + ) +} diff --git a/socks5/handshake_reply_test.go b/socks5/handshake_reply_test.go new file mode 100644 index 0000000..96e7b22 --- /dev/null +++ b/socks5/handshake_reply_test.go @@ -0,0 +1,78 @@ +package socks5_test + +import ( + "bytes" + "errors" + "io" + "testing" + + "github.com/33TU/socks/socks5" +) + +func Test_HandshakeReply_Init_And_Validate(t *testing.T) { + h := &socks5.HandshakeReply{} + h.Init(socks5.MethodUserPass) + + if err := h.Validate(); err != nil { + t.Fatalf("expected valid reply, got %v", err) + } + + h.Version = 4 + if err := h.Validate(); !errors.Is(err, socks5.ErrInvalidHandshakeReplyVersion) { + t.Errorf("expected ErrInvalidHandshakeReplyVersion, got %v", err) + } +} + +func Test_HandshakeReply_WriteTo_ReadFrom_RoundTrip(t *testing.T) { + orig := &socks5.HandshakeReply{} + orig.Init(socks5.MethodNoAuth) + + var buf bytes.Buffer + n1, err := orig.WriteTo(&buf) + if err != nil { + t.Fatalf("WriteTo failed: %v", err) + } + + var parsed socks5.HandshakeReply + n2, err := parsed.ReadFrom(&buf) + if err != nil { + t.Fatalf("ReadFrom failed: %v", err) + } + + if n1 != n2 { + t.Errorf("expected %d bytes, got %d", n1, n2) + } + if parsed.Method != orig.Method { + t.Errorf("expected method 0x%02x, got 0x%02x", orig.Method, parsed.Method) + } +} + +func Test_HandshakeReply_ReadFrom_Truncated(t *testing.T) { + data := []byte{5} // incomplete + var h socks5.HandshakeReply + if _, err := h.ReadFrom(bytes.NewReader(data)); err == nil { + t.Errorf("expected EOF for truncated reply") + } +} + +func Test_HandshakeReply_WriteTo_ErrorPropagation(t *testing.T) { + h := &socks5.HandshakeReply{} + h.Init(socks5.MethodUserPass) + + failWriter := writerFunc(func(p []byte) (int, error) { + return 0, io.ErrClosedPipe + }) + + if _, err := h.WriteTo(failWriter); err == nil { + t.Errorf("expected write error") + } +} + +func Test_HandshakeReply_String(t *testing.T) { + h := &socks5.HandshakeReply{} + h.Init(socks5.MethodNoAuth) + + if s := h.String(); s == "" { + t.Errorf("expected non-empty String() output") + } +} diff --git a/socks5/handshake_request.go b/socks5/handshake_request.go new file mode 100644 index 0000000..bef2d5d --- /dev/null +++ b/socks5/handshake_request.go @@ -0,0 +1,93 @@ +package socks5 + +import ( + "errors" + "fmt" + "io" +) + +// Errors for SOCKS5 handshake requests. +var ( + ErrInvalidHandshakeVersion = errors.New("invalid SOCKS version (must be 5)") + ErrTooManyMethods = errors.New("too many authentication methods") + ErrNoMethodsProvided = errors.New("no authentication methods provided") +) + +// HandshakeRequest represents the initial SOCKS5 client handshake (method negotiation). +type HandshakeRequest struct { + Version byte // VER (should always be 0x05) + NMethods byte // NMETHODS; number of methods + Methods []byte // METHODS; list of supported methods +} + +// Init initializes a handshake request with the given methods. +func (h *HandshakeRequest) Init(methods ...byte) { + h.Version = SocksVersion + h.NMethods = byte(len(methods)) + h.Methods = append([]byte(nil), methods...) // copy +} + +// Validate ensures the handshake request is structurally valid. +func (h *HandshakeRequest) Validate() error { + if h.Version != SocksVersion { + return ErrInvalidHandshakeVersion + } + if h.NMethods == 0 { + return ErrNoMethodsProvided + } + if len(h.Methods) != int(h.NMethods) { + return ErrTooManyMethods + } + return nil +} + +// ReadFrom reads a SOCKS5 handshake request from an io.Reader. +// Implements io.ReaderFrom. +func (h *HandshakeRequest) ReadFrom(src io.Reader) (int64, error) { + var hdr [2]byte + + n, err := io.ReadFull(src, hdr[:]) + if err != nil { + return int64(n), err + } + + h.Version = hdr[0] + h.NMethods = hdr[1] + + if h.NMethods == 0 { + return int64(n), ErrNoMethodsProvided + } + + methods := make([]byte, h.NMethods) + n2, err := io.ReadFull(src, methods) + total := int64(n + n2) + if err != nil { + return total, err + } + + h.Methods = methods + return total, h.Validate() +} + +// WriteTo writes the handshake request to an io.Writer. +// Implements io.WriterTo. +func (h *HandshakeRequest) WriteTo(dst io.Writer) (int64, error) { + buf := []byte{h.Version, h.NMethods} + n, err := dst.Write(buf) + total := int64(n) + if err != nil { + return total, err + } + + n2, err := dst.Write(h.Methods) + total += int64(n2) + return total, err +} + +// String returns a human-readable representation of the handshake request. +func (h *HandshakeRequest) String() string { + return fmt.Sprintf( + "SOCKS5 HandshakeRequest{Version=%d, Methods=%v}", + h.Version, h.Methods, + ) +} diff --git a/socks5/handshake_request_test.go b/socks5/handshake_request_test.go new file mode 100644 index 0000000..7e2f629 --- /dev/null +++ b/socks5/handshake_request_test.go @@ -0,0 +1,98 @@ +package socks5_test + +import ( + "bytes" + "errors" + "io" + "testing" + + "github.com/33TU/socks/socks5" +) + +func Test_HandshakeRequest_Init_And_Validate(t *testing.T) { + r := &socks5.HandshakeRequest{} + r.Init(socks5.MethodNoAuth, socks5.MethodUserPass) + + if err := r.Validate(); err != nil { + t.Fatalf("expected valid request, got %v", err) + } + + r.Version = 4 + if err := r.Validate(); !errors.Is(err, socks5.ErrInvalidHandshakeVersion) { + t.Errorf("expected ErrInvalidHandshakeVersion, got %v", err) + } + + r.Version = socks5.SocksVersion + r.NMethods = 0 + if err := r.Validate(); !errors.Is(err, socks5.ErrNoMethodsProvided) { + t.Errorf("expected ErrNoMethodsProvided, got %v", err) + } +} + +func Test_HandshakeRequest_WriteTo_ReadFrom_RoundTrip(t *testing.T) { + orig := &socks5.HandshakeRequest{} + orig.Init(socks5.MethodNoAuth, socks5.MethodUserPass) + + var buf bytes.Buffer + n1, err := orig.WriteTo(&buf) + if err != nil { + t.Fatalf("WriteTo failed: %v", err) + } + + var parsed socks5.HandshakeRequest + n2, err := parsed.ReadFrom(&buf) + if err != nil { + t.Fatalf("ReadFrom failed: %v", err) + } + + if n1 != n2 { + t.Errorf("expected %d bytes read, got %d", n1, n2) + } + if parsed.Version != socks5.SocksVersion { + t.Errorf("expected version %d, got %d", socks5.SocksVersion, parsed.Version) + } + if len(parsed.Methods) != len(orig.Methods) { + t.Fatalf("expected %d methods, got %d", len(orig.Methods), len(parsed.Methods)) + } + for i, m := range parsed.Methods { + if m != orig.Methods[i] { + t.Errorf("method[%d]: expected 0x%02x, got 0x%02x", i, orig.Methods[i], m) + } + } +} + +func Test_HandshakeRequest_ReadFrom_Truncated(t *testing.T) { + data := []byte{5, 2, 0x00} // NMETHODS=2 but only 1 method byte present + r := &socks5.HandshakeRequest{} + if _, err := r.ReadFrom(bytes.NewReader(data)); err == nil { + t.Errorf("expected error for truncated handshake") + } +} + +func Test_HandshakeRequest_WriteTo_ErrorPropagation(t *testing.T) { + r := &socks5.HandshakeRequest{} + r.Init(socks5.MethodNoAuth) + + failWriter := writerFunc(func(p []byte) (int, error) { + return 0, io.ErrClosedPipe + }) + + if _, err := r.WriteTo(failWriter); err == nil { + t.Errorf("expected write error") + } +} + +func Test_HandshakeRequest_String(t *testing.T) { + r := &socks5.HandshakeRequest{} + r.Init(socks5.MethodNoAuth, socks5.MethodUserPass) + + if s := r.String(); s == "" { + t.Errorf("expected non-empty String() output") + } +} + +// helper type to simulate write errors. + +type writerFunc func([]byte) (int, error) + +func (f writerFunc) Write(p []byte) (int, error) { return f(p) } diff --git a/socks5/reply.go b/socks5/reply.go new file mode 100644 index 0000000..90f0898 --- /dev/null +++ b/socks5/reply.go @@ -0,0 +1,244 @@ +package socks5 + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "net" +) + +// Common validation errors for replies. +var ( + ErrInvalidReplyVersion = errors.New("invalid SOCKS version in reply (must be 5)") + ErrInvalidReplyRSV = errors.New("invalid reserved byte in reply (must be 0x00)") + ErrInvalidReplyAddr = errors.New("invalid address or address type in reply") + ErrInvalidReplyDomain = errors.New("invalid domain in reply (empty or too long)") +) + +// Reply represents a SOCKS5 server reply. +type Reply struct { + Version byte // VER; SOCKS protocol version (always 5) + Reply byte // REP; reply code + Reserved byte // RSV; must be 0x00 + AddrType byte // ATYP; address type (IPv4, DOMAIN, IPv6) + IP net.IP // BND.ADDR; Bound IP (if IPv4/IPv6) + Domain string // BND.ADDR; Bound domain (if ATYP=DOMAIN) + Port uint16 // BND.PORT; Bound port +} + +// Init initializes a SOCKS5 reply. +func (r *Reply) Init(version, rep, reserved, addrType byte, ip net.IP, domain string, port uint16) { + r.Version = version + r.Reply = rep + r.Reserved = reserved + r.AddrType = addrType + r.IP = ip + r.Domain = domain + r.Port = port +} + +// GetHost returns the bound host (domain or IP string). +func (r *Reply) GetHost() string { + if r.AddrType == AddrTypeDomain { + return r.Domain + } + return r.IP.String() +} + +// Addr returns a combined "host:port" string. +func (r *Reply) Addr() string { + return net.JoinHostPort(r.GetHost(), fmt.Sprint(r.Port)) +} + +// ValidateHeader validates the reply header fields. +func (r *Reply) ValidateHeader() error { + if r.Version != SocksVersion { + return ErrInvalidReplyVersion + } + if r.Reserved != 0x00 { + return ErrInvalidReplyRSV + } + switch r.AddrType { + case AddrTypeIPv4, AddrTypeDomain, AddrTypeIPv6: + default: + return ErrInvalidReplyAddr + } + return nil +} + +// Validate validates the full reply. +func (r *Reply) Validate() error { + if err := r.ValidateHeader(); err != nil { + return err + } + switch r.AddrType { + case AddrTypeDomain: + if len(r.Domain) == 0 || len(r.Domain) > 255 { + return ErrInvalidReplyDomain + } + case AddrTypeIPv4, AddrTypeIPv6: + if r.IP == nil { + return ErrInvalidReplyAddr + } + } + return nil +} + +// ReadFrom reads a SOCKS5 reply from a Reader. +// Implements io.ReaderFrom. +func (r *Reply) ReadFrom(src io.Reader) (int64, error) { + var ( + hdr [4]byte + total int64 + ) + + n, err := io.ReadFull(src, hdr[:]) + total += int64(n) + if err != nil { + return total, err + } + + r.Version = hdr[0] + r.Reply = hdr[1] + r.Reserved = hdr[2] + r.AddrType = hdr[3] + + if err := r.ValidateHeader(); err != nil { + return total, err + } + + switch r.AddrType { + case AddrTypeIPv4: + var ip [4]byte + n, err = io.ReadFull(src, ip[:]) + total += int64(n) + if err != nil { + return total, err + } + r.IP = net.IP(ip[:]) + + case AddrTypeIPv6: + var ip [16]byte + n, err = io.ReadFull(src, ip[:]) + total += int64(n) + if err != nil { + return total, err + } + r.IP = net.IP(ip[:]) + + case AddrTypeDomain: + var ln [1]byte + n, err = io.ReadFull(src, ln[:]) + total += int64(n) + if err != nil { + return total, err + } + buf := make([]byte, ln[0]) + n, err = io.ReadFull(src, buf) + total += int64(n) + if err != nil { + return total, err + } + r.Domain = string(buf) + } + + var portBuf [2]byte + n, err = io.ReadFull(src, portBuf[:]) + total += int64(n) + if err != nil { + return total, err + } + r.Port = binary.BigEndian.Uint16(portBuf[:]) + + return total, r.Validate() +} + +// WriteTo writes a SOCKS5 reply to a Writer. +// Implements io.WriterTo. +// Note: returns error if domain length is invalid. +func (r *Reply) WriteTo(dst io.Writer) (int64, error) { + if r.AddrType == AddrTypeDomain { + domainLen := len(r.Domain) + if domainLen == 0 || domainLen > 255 { + return 0, ErrInvalidReplyDomain + } + } + + hdr := [4]byte{r.Version, r.Reply, r.Reserved, r.AddrType} + total := int64(0) + + n, err := dst.Write(hdr[:]) + total += int64(n) + if err != nil { + return total, err + } + + switch r.AddrType { + case AddrTypeIPv4: + n, err = dst.Write(r.IP.To4()) + case AddrTypeIPv6: + n, err = dst.Write(r.IP.To16()) + case AddrTypeDomain: + n, err = dst.Write([]byte{byte(len(r.Domain))}) + total += int64(n) + if err == nil { + n, err = io.WriteString(dst, r.Domain) + } + } + total += int64(n) + if err != nil { + return total, err + } + + var portBuf [2]byte + binary.BigEndian.PutUint16(portBuf[:], r.Port) + n, err = dst.Write(portBuf[:]) + total += int64(n) + + return total, err +} + +// String returns a human-readable representation of the reply. +func (r *Reply) String() string { + var rep string + switch r.Reply { + case RepSuccess: + rep = "SUCCESS" + case RepGeneralFailure: + rep = "GENERAL_FAILURE" + case RepConnectionNotAllowed: + rep = "CONNECTION_NOT_ALLOWED" + case RepNetworkUnreachable: + rep = "NETWORK_UNREACHABLE" + case RepHostUnreachable: + rep = "HOST_UNREACHABLE" + case RepConnectionRefused: + rep = "CONNECTION_REFUSED" + case RepTTLExpired: + rep = "TTL_EXPIRED" + case RepCommandNotSupported: + rep = "COMMAND_NOT_SUPPORTED" + case RepAddrTypeNotSupported: + rep = "ADDR_TYPE_NOT_SUPPORTED" + default: + rep = fmt.Sprintf("UNKNOWN(0x%02X)", r.Reply) + } + + var atype string + switch r.AddrType { + case AddrTypeIPv4: + atype = "IPv4" + case AddrTypeDomain: + atype = "DOMAIN" + case AddrTypeIPv6: + atype = "IPv6" + default: + atype = fmt.Sprintf("0x%02X", r.AddrType) + } + + return fmt.Sprintf( + "SOCKS5 Reply{Reply=%s, AddrType=%s, Host=%s, Port=%d, Version=%d, RSV=%#02x}", + rep, atype, r.GetHost(), r.Port, r.Version, r.Reserved, + ) +} diff --git a/socks5/reply_test.go b/socks5/reply_test.go new file mode 100644 index 0000000..2c0833b --- /dev/null +++ b/socks5/reply_test.go @@ -0,0 +1,171 @@ +package socks5_test + +import ( + "bytes" + "net" + "testing" + + "github.com/33TU/socks/socks5" +) + +func Test_Reply_Init_Validate(t *testing.T) { + tests := []struct { + name string + reply socks5.Reply + wantErr bool + }{ + { + name: "valid success IPv4", + reply: func() socks5.Reply { + var r socks5.Reply + r.Init(socks5.SocksVersion, socks5.RepSuccess, 0x00, socks5.AddrTypeIPv4, net.IPv4(127, 0, 0, 1), "", 1080) + return r + }(), + wantErr: false, + }, + { + name: "invalid version", + reply: func() socks5.Reply { + var r socks5.Reply + r.Init(4, socks5.RepSuccess, 0x00, socks5.AddrTypeIPv4, net.IPv4(127, 0, 0, 1), "", 1080) + return r + }(), + wantErr: true, + }, + { + name: "invalid RSV", + reply: func() socks5.Reply { + var r socks5.Reply + r.Init(5, socks5.RepSuccess, 0x01, socks5.AddrTypeIPv4, net.IPv4(127, 0, 0, 1), "", 1080) + return r + }(), + wantErr: true, + }, + { + name: "invalid ATYP", + reply: func() socks5.Reply { + var r socks5.Reply + r.Init(5, socks5.RepSuccess, 0x00, 0x99, net.IPv4(127, 0, 0, 1), "", 1080) + return r + }(), + wantErr: true, + }, + { + name: "invalid domain length", + reply: func() socks5.Reply { + var r socks5.Reply + r.Init(5, socks5.RepSuccess, 0x00, socks5.AddrTypeDomain, nil, "", 1080) + return r + }(), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.reply.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("Validate() error = %v, wantErr = %v", err, tt.wantErr) + } + }) + } +} + +func Test_Reply_WriteTo_ReadFrom_RoundTrip(t *testing.T) { + tests := []struct { + name string + init func() socks5.Reply + }{ + { + name: "IPv4", + init: func() socks5.Reply { + var r socks5.Reply + r.Init(socks5.SocksVersion, socks5.RepSuccess, 0x00, socks5.AddrTypeIPv4, net.IPv4(192, 168, 1, 10), "", 1080) + return r + }, + }, + { + name: "Domain", + init: func() socks5.Reply { + var r socks5.Reply + r.Init(socks5.SocksVersion, socks5.RepSuccess, 0x00, socks5.AddrTypeDomain, nil, "example.org", 443) + return r + }, + }, + { + name: "IPv6", + init: func() socks5.Reply { + var r socks5.Reply + ip := net.ParseIP("2001:db8::1") + r.Init(socks5.SocksVersion, socks5.RepSuccess, 0x00, socks5.AddrTypeIPv6, ip, "", 9050) + return r + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + orig := tt.init() + var buf bytes.Buffer + + nw, err := orig.WriteTo(&buf) + if err != nil { + t.Fatalf("WriteTo() failed: %v", err) + } + + var got socks5.Reply + nr, err := got.ReadFrom(&buf) + if err != nil { + t.Fatalf("ReadFrom() failed: %v", err) + } + + if nw != nr { + t.Errorf("expected %d bytes written == %d bytes read", nw, nr) + } + if got.Reply != orig.Reply { + t.Errorf("reply mismatch: got %d, want %d", got.Reply, orig.Reply) + } + if got.Port != orig.Port { + t.Errorf("port mismatch: got %d, want %d", got.Port, orig.Port) + } + if got.AddrType == socks5.AddrTypeDomain && got.Domain != orig.Domain { + t.Errorf("domain mismatch: got %q, want %q", got.Domain, orig.Domain) + } + if (got.AddrType == socks5.AddrTypeIPv4 || got.AddrType == socks5.AddrTypeIPv6) && !got.IP.Equal(orig.IP) { + t.Errorf("IP mismatch: got %v, want %v", got.IP, orig.IP) + } + }) + } +} + +func Test_Reply_ReadFrom_InvalidData(t *testing.T) { + // incomplete 4-byte header + data := []byte{5, socks5.RepSuccess, 0x00} + var r socks5.Reply + if _, err := r.ReadFrom(bytes.NewReader(data)); err == nil { + t.Errorf("expected error for truncated header") + } +} + +func Test_Reply_WriteTo_InvalidDomain(t *testing.T) { + var r socks5.Reply + longDomain := make([]byte, 300) + for i := range longDomain { + longDomain[i] = 'a' + } + r.Init(5, socks5.RepSuccess, 0x00, socks5.AddrTypeDomain, nil, string(longDomain), 1080) + + var buf bytes.Buffer + if _, err := r.WriteTo(&buf); err == nil { + t.Errorf("expected ErrInvalidReplyDomain, got nil") + } +} + +func Test_Reply_String(t *testing.T) { + r := &socks5.Reply{} + r.Init(5, socks5.RepHostUnreachable, 0x00, socks5.AddrTypeIPv4, net.IPv4(10, 0, 0, 2), "", 9999) + + if s := r.String(); s == "" { + t.Errorf("expected non-empty String() output") + } +} diff --git a/socks5/request.go b/socks5/request.go new file mode 100644 index 0000000..a487bb1 --- /dev/null +++ b/socks5/request.go @@ -0,0 +1,252 @@ +package socks5 + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "net" +) + +// Common validation errors. +var ( + ErrInvalidVersion = errors.New("invalid SOCKS version (must be 5)") + ErrInvalidCommand = errors.New("invalid command (must be 1=CONNECT, 2=BIND, 3=UDP ASSOCIATE, F0=RESOLVE, or F1=RESOLVE_PTR)") + ErrInvalidAddr = errors.New("invalid address or address type") + ErrInvalidDomain = errors.New("invalid domain (empty or too long)") + ErrInvalidRSV = errors.New("invalid reserved byte (must be 0x00)") +) + +// Request represents a SOCKS5 CONNECT/BIND/UDP ASSOCIATE/RESOLVE request. +type Request struct { + Version byte // VER; SOCKS protocol version (always 5) + Command byte // CMD; CONNECT, BIND, UDP ASSOCIATE, RESOLVE, etc. + Reserved byte // RSV; reserved byte (must be 0x00) + AddrType byte // ATYP; address type (IPv4, DOMAIN, IPv6) + IP net.IP // BND.ADDR; Destination IP (IPv4 or IPv6) + Domain string // BND.ADDR; Destination domain (if ATYP=DOMAIN) + Port uint16 // BND.PORT; Destination port (big-endian) +} + +// GetHost returns the destination hostname or IP string. +func (r *Request) GetHost() string { + if r.AddrType == AddrTypeDomain { + return r.Domain + } + return r.IP.String() +} + +// Addr returns the full "host:port" string form. +func (r *Request) Addr() string { + return net.JoinHostPort(r.GetHost(), fmt.Sprint(r.Port)) +} + +// Init initializes a SOCKS5 request. +func (r *Request) Init( + version byte, + command byte, + reserved byte, + addrType byte, + ip net.IP, + domain string, + port uint16, +) { + r.Version = version + r.Command = command + r.Reserved = reserved + r.AddrType = addrType + r.IP = ip + r.Domain = domain + r.Port = port +} + +// ValidateHeader validates the SOCKS5 request header. +func (r *Request) ValidateHeader() error { + if r.Version != SocksVersion { + return ErrInvalidVersion + } + if r.Reserved != 0x00 { + return ErrInvalidRSV + } + switch r.Command { + case CmdConnect, CmdBind, CmdUDPAssociate, CmdResolve, CmdResolvePTR: + default: + return ErrInvalidCommand + } + switch r.AddrType { + case AddrTypeIPv4, AddrTypeDomain, AddrTypeIPv6: + default: + return ErrInvalidAddr + } + return nil +} + +// Validate validates the full SOCKS5 request. +func (r *Request) Validate() error { + if err := r.ValidateHeader(); err != nil { + return err + } + + switch r.AddrType { + case AddrTypeDomain: + if len(r.Domain) == 0 || len(r.Domain) > 255 { + return ErrInvalidDomain + } + return nil // domain is valid, IP may be nil + case AddrTypeIPv4, AddrTypeIPv6: + if r.IP == nil { + return ErrInvalidAddr + } + } + return nil +} + +// ReadFrom reads a SOCKS5 request from a Reader. +// Implements the io.ReaderFrom interface. +func (r *Request) ReadFrom(src io.Reader) (int64, error) { + var ( + total int64 + hdr [4]byte + ) + + n, err := io.ReadFull(src, hdr[:]) + total += int64(n) + if err != nil { + return total, err + } + + r.Version = hdr[0] + r.Command = hdr[1] + r.Reserved = hdr[2] + r.AddrType = hdr[3] + + if err := r.ValidateHeader(); err != nil { + return total, err + } + + switch r.AddrType { + case AddrTypeIPv4: + var buf [4]byte + n, err = io.ReadFull(src, buf[:]) + total += int64(n) + if err != nil { + return total, err + } + r.IP = net.IP(buf[:]) + + case AddrTypeIPv6: + var buf [16]byte + n, err = io.ReadFull(src, buf[:]) + total += int64(n) + if err != nil { + return total, err + } + r.IP = net.IP(buf[:]) + + case AddrTypeDomain: + var ln [1]byte + n, err = io.ReadFull(src, ln[:]) + total += int64(n) + if err != nil { + return total, err + } + buf := make([]byte, ln[0]) + n, err = io.ReadFull(src, buf) + total += int64(n) + if err != nil { + return total, err + } + r.Domain = string(buf) + } + + var portBuf [2]byte + n, err = io.ReadFull(src, portBuf[:]) + total += int64(n) + if err != nil { + return total, err + } + r.Port = binary.BigEndian.Uint16(portBuf[:]) + + return total, r.Validate() +} + +// WriteTo writes a SOCKS5 request to a Writer. +// Implements the io.WriterTo interface. +// Note: returns error if domain is too long. +func (r *Request) WriteTo(dst io.Writer) (int64, error) { + if r.AddrType == AddrTypeDomain { + domainLen := len(r.Domain) + if domainLen == 0 || domainLen > 255 { + return 0, ErrInvalidReplyDomain + } + } + + var total int64 + hdr := [4]byte{r.Version, r.Command, r.Reserved, r.AddrType} + + n, err := dst.Write(hdr[:]) + total += int64(n) + if err != nil { + return total, err + } + + switch r.AddrType { + case AddrTypeIPv4: + n, err = dst.Write(r.IP.To4()) + case AddrTypeIPv6: + n, err = dst.Write(r.IP.To16()) + case AddrTypeDomain: + n, err = dst.Write([]byte{byte(len(r.Domain))}) + total += int64(n) + if err == nil { + n, err = io.WriteString(dst, r.Domain) + } + } + total += int64(n) + if err != nil { + return total, err + } + + var portBuf [2]byte + binary.BigEndian.PutUint16(portBuf[:], r.Port) + n, err = dst.Write(portBuf[:]) + total += int64(n) + + return total, err +} + +// String returns a string representation of the SOCKS5 Request. +func (r *Request) String() string { + var cmd string + switch r.Command { + case CmdConnect: + cmd = "CONNECT" + case CmdBind: + cmd = "BIND" + case CmdUDPAssociate: + cmd = "UDP_ASSOCIATE" + case CmdResolve: + cmd = "RESOLVE" + case CmdResolvePTR: + cmd = "RESOLVE_PTR" + default: + cmd = fmt.Sprintf("UNKNOWN(0x%02X)", r.Command) + } + + var atype string + switch r.AddrType { + case AddrTypeIPv4: + atype = "IPv4" + case AddrTypeDomain: + atype = "DOMAIN" + case AddrTypeIPv6: + atype = "IPv6" + default: + atype = fmt.Sprintf("0x%02X", r.AddrType) + } + + return fmt.Sprintf( + "SOCKS5 Request{Cmd=%s, AddrType=%s, Host=%s, Port=%d, Version=%d, RSV=%#02x}", + cmd, atype, r.GetHost(), r.Port, r.Version, r.Reserved, + ) +} diff --git a/socks5/request_test.go b/socks5/request_test.go new file mode 100644 index 0000000..9df7e31 --- /dev/null +++ b/socks5/request_test.go @@ -0,0 +1,140 @@ +package socks5_test + +import ( + "bytes" + "errors" + "net" + "testing" + + "github.com/33TU/socks/socks5" +) + +func Test_Request_Init_And_Validate(t *testing.T) { + r := &socks5.Request{} + r.Init(socks5.SocksVersion, socks5.CmdConnect, 0x00, socks5.AddrTypeIPv4, net.IPv4(127, 0, 0, 1), "", 8080) + + if err := r.Validate(); err != nil { + t.Fatalf("expected valid request, got %v", err) + } + + r.Version = 4 + if err := r.Validate(); !errors.Is(err, socks5.ErrInvalidVersion) { + t.Errorf("expected ErrInvalidVersion, got %v", err) + } +} + +func Test_Request_WriteTo_ReadFrom_RoundTrip_IPv4(t *testing.T) { + orig := &socks5.Request{} + orig.Init(socks5.SocksVersion, socks5.CmdConnect, 0x00, socks5.AddrTypeIPv4, net.IPv4(192, 168, 0, 10), "", 1080) + + var buf bytes.Buffer + n1, err := orig.WriteTo(&buf) + if err != nil { + t.Fatalf("WriteTo failed: %v", err) + } + + var parsed socks5.Request + n2, err := parsed.ReadFrom(&buf) + if err != nil { + t.Fatalf("ReadFrom failed: %v", err) + } + + if n1 != n2 { + t.Errorf("expected %d bytes read, got %d", n1, n2) + } + if !parsed.IP.Equal(orig.IP) { + t.Errorf("expected IP %v, got %v", orig.IP, parsed.IP) + } + if parsed.Port != orig.Port { + t.Errorf("expected port %d, got %d", orig.Port, parsed.Port) + } +} + +func Test_Request_WriteTo_ReadFrom_RoundTrip_Domain(t *testing.T) { + orig := &socks5.Request{} + orig.Init(socks5.SocksVersion, socks5.CmdConnect, 0x00, socks5.AddrTypeDomain, nil, "example.com", 443) + + var buf bytes.Buffer + _, err := orig.WriteTo(&buf) + if err != nil { + t.Fatalf("WriteTo failed: %v", err) + } + + var parsed socks5.Request + _, err = parsed.ReadFrom(&buf) + if err != nil { + t.Fatalf("ReadFrom failed: %v", err) + } + + if parsed.Domain != orig.Domain { + t.Errorf("expected domain %q, got %q", orig.Domain, parsed.Domain) + } + if parsed.Port != orig.Port { + t.Errorf("expected port %d, got %d", orig.Port, parsed.Port) + } +} + +func Test_Request_WriteTo_ReadFrom_RoundTrip_IPv6(t *testing.T) { + ip := net.ParseIP("2001:db8::1") + orig := &socks5.Request{} + orig.Init(socks5.SocksVersion, socks5.CmdUDPAssociate, 0x00, socks5.AddrTypeIPv6, ip, "", 9050) + + var buf bytes.Buffer + _, err := orig.WriteTo(&buf) + if err != nil { + t.Fatalf("WriteTo failed: %v", err) + } + + var parsed socks5.Request + _, err = parsed.ReadFrom(&buf) + if err != nil { + t.Fatalf("ReadFrom failed: %v", err) + } + + if !parsed.IP.Equal(ip) { + t.Errorf("expected IP %v, got %v", ip, parsed.IP) + } + if parsed.Port != 9050 { + t.Errorf("expected port 9050, got %d", parsed.Port) + } +} + +func Test_Request_Validate_Invalid(t *testing.T) { + r := &socks5.Request{} + r.Init(5, 0x99, 0x00, socks5.AddrTypeIPv4, net.IPv4(1, 1, 1, 1), "", 80) + if err := r.Validate(); !errors.Is(err, socks5.ErrInvalidCommand) { + t.Errorf("expected ErrInvalidCommand, got %v", err) + } + + r.Init(5, socks5.CmdConnect, 0x01, socks5.AddrTypeIPv4, net.IPv4(1, 1, 1, 1), "", 80) + if err := r.Validate(); !errors.Is(err, socks5.ErrInvalidRSV) { + t.Errorf("expected ErrInvalidRSV, got %v", err) + } + + r.Init(5, socks5.CmdConnect, 0x00, socks5.AddrTypeDomain, nil, "", 80) + if err := r.Validate(); !errors.Is(err, socks5.ErrInvalidDomain) { + t.Errorf("expected ErrInvalidDomain, got %v", err) + } +} + +func Test_Request_ResolveCommands(t *testing.T) { + r := &socks5.Request{} + r.Init(5, socks5.CmdResolve, 0x00, socks5.AddrTypeDomain, nil, "example.com", 0) + if err := r.Validate(); err != nil { + t.Fatalf("expected valid RESOLVE request, got %v", err) + } + + r.Init(5, socks5.CmdResolvePTR, 0x00, socks5.AddrTypeIPv4, net.IPv4(8, 8, 8, 8), "", 0) + if err := r.Validate(); err != nil { + t.Fatalf("expected valid RESOLVE_PTR request, got %v", err) + } +} + +func Test_Request_String(t *testing.T) { + r := &socks5.Request{} + r.Init(socks5.SocksVersion, socks5.CmdConnect, 0x00, socks5.AddrTypeDomain, nil, "user.example.com", 8080) + + if s := r.String(); s == "" { + t.Errorf("expected non-empty String() output") + } +} diff --git a/socks5/udp_packet.go b/socks5/udp_packet.go new file mode 100644 index 0000000..c2ea267 --- /dev/null +++ b/socks5/udp_packet.go @@ -0,0 +1,254 @@ +package socks5 + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "net" +) + +// Common validation errors for UDP packets. +var ( + ErrInvalidUDPReserved = errors.New("invalid UDP reserved bytes (must be 0x0000)") + ErrUnsupportedFrag = errors.New("unsupported UDP fragmentation (FRAG must be 0x00)") + ErrInvalidUDPAddrType = errors.New("invalid UDP address type") + ErrInvalidUDPDomain = errors.New("invalid UDP domain (empty or too long)") + ErrMissingUDPData = errors.New("missing UDP payload data") +) + +// UDPPacket represents a SOCKS5 UDP ASSOCIATE packet. +// Wire format (RFC 1928 §7): +// +// +----+------+------+----------+----------+----------+ +// |RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA | +// +----+------+------+----------+----------+----------+ +// | 2 | 1 | 1 | Variable | 2 | Variable | +type UDPPacket struct { + Reserved [2]byte // RSV; must be 0x0000 + Frag byte // FRAG; must be 0x00 (no fragmentation) + AddrType byte // ATYP; IPv4, DOMAIN, or IPv6 + IP net.IP // Destination IP (if ATYP=IPv4 or IPv6) + Domain string // Destination domain (if ATYP=DOMAIN) + Port uint16 // Destination port + Data []byte // UDP payload data +} + +// Init initializes a UDPPacket with given values. +func (p *UDPPacket) Init( + reserved [2]byte, + frag byte, + addrType byte, + ip net.IP, + domain string, + port uint16, + data []byte, +) { + p.Reserved = reserved + p.Frag = frag + p.AddrType = addrType + p.IP = ip + p.Domain = domain + p.Port = port + p.Data = data +} + +// Validate checks for protocol correctness. +func (p *UDPPacket) Validate() error { + if p.Reserved != [2]byte{0x00, 0x00} { + return ErrInvalidUDPReserved + } + if p.Frag != 0x00 { + return ErrUnsupportedFrag + } + + switch p.AddrType { + case AddrTypeIPv4, AddrTypeIPv6: + if p.IP == nil { + return ErrInvalidUDPAddrType + } + case AddrTypeDomain: + if len(p.Domain) == 0 || len(p.Domain) > 255 { + return ErrInvalidUDPDomain + } + default: + return ErrInvalidUDPAddrType + } + + if len(p.Data) == 0 { + return ErrMissingUDPData + } + + return nil +} + +// ReadFrom reads a UDP ASSOCIATE packet from a Reader. +// Implements io.ReaderFrom. +func (p *UDPPacket) ReadFrom(src io.Reader) (int64, error) { + var total int64 + + // Read RSV + FRAG + ATYP + var hdr [4]byte + n, err := io.ReadFull(src, hdr[:]) + total += int64(n) + if err != nil { + return total, err + } + + copy(p.Reserved[:], hdr[0:2]) + p.Frag = hdr[2] + p.AddrType = hdr[3] + + if err := p.ValidateHeader(); err != nil { + return total, err + } + + // Read DST.ADDR + switch p.AddrType { + case AddrTypeIPv4: + var ip [4]byte + n, err = io.ReadFull(src, ip[:]) + total += int64(n) + if err != nil { + return total, err + } + p.IP = net.IP(ip[:]) + + case AddrTypeIPv6: + var ip [16]byte + n, err = io.ReadFull(src, ip[:]) + total += int64(n) + if err != nil { + return total, err + } + p.IP = net.IP(ip[:]) + + case AddrTypeDomain: + var ln [1]byte + n, err = io.ReadFull(src, ln[:]) + total += int64(n) + if err != nil { + return total, err + } + buf := make([]byte, ln[0]) + n, err = io.ReadFull(src, buf) + total += int64(n) + if err != nil { + return total, err + } + p.Domain = string(buf) + } + + // Read DST.PORT + var portBuf [2]byte + n, err = io.ReadFull(src, portBuf[:]) + total += int64(n) + if err != nil { + return total, err + } + p.Port = binary.BigEndian.Uint16(portBuf[:]) + + // Remaining bytes are DATA + data, err := io.ReadAll(src) + total += int64(len(data)) + if err != nil { + return total, err + } + p.Data = data + + return total, p.Validate() +} + +// ValidateHeader checks RSV/FRAG/ATYP fields before full read. +func (p *UDPPacket) ValidateHeader() error { + if p.Reserved != [2]byte{0x00, 0x00} { + return ErrInvalidUDPReserved + } + if p.Frag != 0x00 { + return ErrUnsupportedFrag + } + switch p.AddrType { + case AddrTypeIPv4, AddrTypeIPv6, AddrTypeDomain: + default: + return ErrInvalidUDPAddrType + } + return nil +} + +// WriteTo writes a UDP ASSOCIATE packet to a Writer. +// Implements io.WriterTo. +func (p *UDPPacket) WriteTo(dst io.Writer) (int64, error) { + if err := p.Validate(); err != nil { + return 0, err + } + + var total int64 + + // Write RSV + FRAG + ATYP + hdr := [4]byte{p.Reserved[0], p.Reserved[1], p.Frag, p.AddrType} + n, err := dst.Write(hdr[:]) + total += int64(n) + if err != nil { + return total, err + } + + switch p.AddrType { + case AddrTypeIPv4: + n, err = dst.Write(p.IP.To4()) + case AddrTypeIPv6: + n, err = dst.Write(p.IP.To16()) + case AddrTypeDomain: + dlen := len(p.Domain) + n, err = dst.Write([]byte{byte(dlen)}) + total += int64(n) + if err == nil { + n, err = io.WriteString(dst, p.Domain) + } + } + total += int64(n) + if err != nil { + return total, err + } + + // Write DST.PORT + var portBuf [2]byte + binary.BigEndian.PutUint16(portBuf[:], p.Port) + n, err = dst.Write(portBuf[:]) + total += int64(n) + if err != nil { + return total, err + } + + // Write DATA + n, err = dst.Write(p.Data) + total += int64(n) + return total, err +} + +// String returns a human-readable representation. +func (p *UDPPacket) String() string { + var atype string + switch p.AddrType { + case AddrTypeIPv4: + atype = "IPv4" + case AddrTypeDomain: + atype = "DOMAIN" + case AddrTypeIPv6: + atype = "IPv6" + default: + atype = fmt.Sprintf("0x%02X", p.AddrType) + } + + return fmt.Sprintf( + "UDPPacket{AddrType=%s, Host=%s, Port=%d, DataLen=%d, Frag=%d, RSV=%#02x%#02x}", + atype, p.hostString(), p.Port, len(p.Data), p.Frag, p.Reserved[0], p.Reserved[1], + ) +} + +// hostString returns the effective destination host string. +func (p *UDPPacket) hostString() string { + if p.AddrType == AddrTypeDomain { + return p.Domain + } + return p.IP.String() +} diff --git a/socks5/udp_packet_test.go b/socks5/udp_packet_test.go new file mode 100644 index 0000000..6d875dc --- /dev/null +++ b/socks5/udp_packet_test.go @@ -0,0 +1,196 @@ +package socks5_test + +import ( + "bytes" + "errors" + "net" + "testing" + + "github.com/33TU/socks/socks5" +) + +func Test_UDPPacket_Init_Validate(t *testing.T) { + tests := []struct { + name string + packet socks5.UDPPacket + wantErr bool + }{ + { + name: "valid IPv4 packet", + packet: func() socks5.UDPPacket { + var p socks5.UDPPacket + p.Init([2]byte{0, 0}, 0x00, socks5.AddrTypeIPv4, net.IPv4(127, 0, 0, 1), "", 9000, []byte("data")) + return p + }(), + wantErr: false, + }, + { + name: "invalid reserved bytes", + packet: func() socks5.UDPPacket { + var p socks5.UDPPacket + p.Init([2]byte{1, 0}, 0x00, socks5.AddrTypeIPv4, net.IPv4(127, 0, 0, 1), "", 9000, []byte("data")) + return p + }(), + wantErr: true, + }, + { + name: "invalid frag byte", + packet: func() socks5.UDPPacket { + var p socks5.UDPPacket + p.Init([2]byte{0, 0}, 0x01, socks5.AddrTypeIPv4, net.IPv4(127, 0, 0, 1), "", 9000, []byte("data")) + return p + }(), + wantErr: true, + }, + { + name: "invalid address type", + packet: func() socks5.UDPPacket { + var p socks5.UDPPacket + p.Init([2]byte{0, 0}, 0x00, 0x99, net.IPv4(127, 0, 0, 1), "", 9000, []byte("data")) + return p + }(), + wantErr: true, + }, + { + name: "missing data", + packet: func() socks5.UDPPacket { + var p socks5.UDPPacket + p.Init([2]byte{0, 0}, 0x00, socks5.AddrTypeIPv4, net.IPv4(127, 0, 0, 1), "", 9000, nil) + return p + }(), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.packet.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("Validate() error = %v, wantErr = %v", err, tt.wantErr) + } + }) + } +} + +func Test_UDPPacket_WriteTo_ReadFrom_RoundTrip(t *testing.T) { + tests := []struct { + name string + init func() socks5.UDPPacket + }{ + { + name: "IPv4", + init: func() socks5.UDPPacket { + var p socks5.UDPPacket + p.Init([2]byte{0, 0}, 0, socks5.AddrTypeIPv4, net.IPv4(192, 168, 1, 100), "", 8080, []byte("hello")) + return p + }, + }, + { + name: "IPv6", + init: func() socks5.UDPPacket { + var p socks5.UDPPacket + ip := net.ParseIP("2001:db8::1") + p.Init([2]byte{0, 0}, 0, socks5.AddrTypeIPv6, ip, "", 9000, []byte("payload")) + return p + }, + }, + { + name: "Domain", + init: func() socks5.UDPPacket { + var p socks5.UDPPacket + p.Init([2]byte{0, 0}, 0, socks5.AddrTypeDomain, nil, "example.org", 53, []byte{0xaa, 0xbb}) + return p + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + orig := tt.init() + var buf bytes.Buffer + + nw, err := orig.WriteTo(&buf) + if err != nil { + t.Fatalf("WriteTo() failed: %v", err) + } + + var got socks5.UDPPacket + nr, err := got.ReadFrom(&buf) + if err != nil { + t.Fatalf("ReadFrom() failed: %v", err) + } + + if nw != nr { + t.Errorf("expected %d bytes written == %d bytes read", nw, nr) + } + if got.Port != orig.Port { + t.Errorf("port mismatch: got %d, want %d", got.Port, orig.Port) + } + if got.AddrType == socks5.AddrTypeDomain { + if got.Domain != orig.Domain { + t.Errorf("domain mismatch: got %q, want %q", got.Domain, orig.Domain) + } + } else if !got.IP.Equal(orig.IP) { + t.Errorf("IP mismatch: got %v, want %v", got.IP, orig.IP) + } + if !bytes.Equal(got.Data, orig.Data) { + t.Errorf("data mismatch: got %x, want %x", got.Data, orig.Data) + } + }) + } +} + +func Test_UDPPacket_ReadFrom_InvalidRSV(t *testing.T) { + b := []byte{ + 0x01, 0x00, // RSV (invalid) + 0x00, // FRAG + socks5.AddrTypeIPv4, + 127, 0, 0, 1, + 0x1F, 0x90, // port 8080 + 'h', 'i', + } + + var p socks5.UDPPacket + if _, err := p.ReadFrom(bytes.NewReader(b)); !errors.Is(err, socks5.ErrInvalidUDPReserved) { + t.Errorf("expected ErrInvalidUDPReserved, got %v", err) + } +} + +func Test_UDPPacket_ReadFrom_InvalidFrag(t *testing.T) { + b := []byte{ + 0x00, 0x00, // RSV + 0x01, // FRAG (invalid) + socks5.AddrTypeIPv4, + 127, 0, 0, 1, + 0x1F, 0x90, + 'd', 'a', 't', 'a', + } + + var p socks5.UDPPacket + if _, err := p.ReadFrom(bytes.NewReader(b)); !errors.Is(err, socks5.ErrUnsupportedFrag) { + t.Errorf("expected ErrUnsupportedFrag, got %v", err) + } +} + +func Test_UDPPacket_ReadFrom_InvalidAddrType(t *testing.T) { + b := []byte{ + 0x00, 0x00, + 0x00, + 0x99, // invalid ATYP + 0x1F, 0x90, + } + + var p socks5.UDPPacket + if _, err := p.ReadFrom(bytes.NewReader(b)); !errors.Is(err, socks5.ErrInvalidUDPAddrType) { + t.Errorf("expected ErrInvalidUDPAddrType, got %v", err) + } +} + +func Test_UDPPacket_String(t *testing.T) { + var p socks5.UDPPacket + p.Init([2]byte{0, 0}, 0, socks5.AddrTypeIPv4, net.IPv4(8, 8, 8, 8), "", 53, []byte{0xaa, 0xbb}) + + if s := p.String(); s == "" { + t.Errorf("expected non-empty String() output") + } +} diff --git a/socks5/user_pass_reply.go b/socks5/user_pass_reply.go new file mode 100644 index 0000000..3165808 --- /dev/null +++ b/socks5/user_pass_reply.go @@ -0,0 +1,77 @@ +package socks5 + +import ( + "errors" + "fmt" + "io" +) + +// Errors for username/password authentication replies. +var ( + ErrInvalidUserPassReplyVersion = errors.New("invalid user/password reply version (must be 1)") +) + +// UserPassReply represents a username/password authentication reply. +type UserPassReply struct { + Version byte // VER (should be AuthVersionUserPass = 0x01) + Status byte // STATUS (0x00 = success, otherwise failure) +} + +// Init initializes a user/password authentication reply with the given version and status. +func (r *UserPassReply) Init(version, status byte) { + r.Version = version + r.Status = status +} + +// Validate ensures the reply is structurally valid. +func (r *UserPassReply) Validate() error { + if r.Version != AuthVersionUserPass { + return ErrInvalidUserPassReplyVersion + } + return nil +} + +// ReadFrom reads a username/password authentication reply from an io.Reader. +// Implements io.ReaderFrom. +func (r *UserPassReply) ReadFrom(src io.Reader) (int64, error) { + var buf [2]byte + + n, err := io.ReadFull(src, buf[:]) + if err != nil { + return int64(n), err + } + + r.Version = buf[0] + r.Status = buf[1] + + return int64(n), r.Validate() +} + +// WriteTo writes the authentication reply to an io.Writer. +// Implements io.WriterTo. +// Note: assumes the struct is already valid. +func (r *UserPassReply) WriteTo(dst io.Writer) (int64, error) { + buf := [2]byte{r.Version, r.Status} + n, err := dst.Write(buf[:]) + return int64(n), err +} + +// Success returns true if STATUS == 0x00. +func (r *UserPassReply) Success() bool { + return r.Status == 0x00 +} + +// String returns a human-readable representation. +func (r *UserPassReply) String() string { + var status string + if r.Status == 0x00 { + status = "success" + } else { + status = fmt.Sprintf("failure(0x%02x)", r.Status) + } + + return fmt.Sprintf( + "UserPassReply{Version=%d, Status=%s}", + r.Version, status, + ) +} diff --git a/socks5/user_pass_reply_test.go b/socks5/user_pass_reply_test.go new file mode 100644 index 0000000..b3508e2 --- /dev/null +++ b/socks5/user_pass_reply_test.go @@ -0,0 +1,98 @@ +package socks5_test + +import ( + "bytes" + "errors" + "io" + "testing" + + "github.com/33TU/socks/socks5" +) + +func Test_UserPassReply_Init_And_Validate(t *testing.T) { + r := &socks5.UserPassReply{} + r.Init(socks5.AuthVersionUserPass, 0x00) + + if err := r.Validate(); err != nil { + t.Fatalf("expected valid reply, got %v", err) + } + + r.Version = 0x02 + if err := r.Validate(); !errors.Is(err, socks5.ErrInvalidUserPassReplyVersion) { + t.Errorf("expected ErrInvalidUserPassReplyVersion, got %v", err) + } +} + +func Test_UserPassReply_WriteTo_ReadFrom_RoundTrip(t *testing.T) { + orig := &socks5.UserPassReply{} + orig.Init(socks5.AuthVersionUserPass, 0x00) + + var buf bytes.Buffer + n1, err := orig.WriteTo(&buf) + if err != nil { + t.Fatalf("WriteTo failed: %v", err) + } + + var parsed socks5.UserPassReply + n2, err := parsed.ReadFrom(&buf) + if err != nil { + t.Fatalf("ReadFrom failed: %v", err) + } + + if n1 != n2 { + t.Errorf("expected %d bytes read, got %d", n1, n2) + } + if parsed.Version != socks5.AuthVersionUserPass { + t.Errorf("expected version %d, got %d", socks5.AuthVersionUserPass, parsed.Version) + } + if parsed.Status != 0x00 { + t.Errorf("expected status 0x00, got 0x%02x", parsed.Status) + } + if !parsed.Success() { + t.Errorf("expected Success() to be true") + } +} + +func Test_UserPassReply_FailureStatus(t *testing.T) { + r := &socks5.UserPassReply{} + r.Init(socks5.AuthVersionUserPass, 0xFF) + + if r.Success() { + t.Errorf("expected Success() to be false for failure status") + } + + str := r.String() + if want := "failure(0xff)"; !bytes.Contains([]byte(str), []byte(want)) { + t.Errorf("expected String() to contain %q, got %q", want, str) + } +} + +func Test_UserPassReply_ReadFrom_Truncated(t *testing.T) { + data := []byte{1} // incomplete (missing STATUS) + var r socks5.UserPassReply + if _, err := r.ReadFrom(bytes.NewReader(data)); err == nil { + t.Errorf("expected EOF for truncated reply") + } +} + +func Test_UserPassReply_WriteTo_ErrorPropagation(t *testing.T) { + r := &socks5.UserPassReply{} + r.Init(socks5.AuthVersionUserPass, 0x00) + + failWriter := writerFunc(func(p []byte) (int, error) { + return 0, io.ErrClosedPipe + }) + + if _, err := r.WriteTo(failWriter); err == nil { + t.Errorf("expected write error") + } +} + +func Test_UserPassReply_String(t *testing.T) { + r := &socks5.UserPassReply{} + r.Init(socks5.AuthVersionUserPass, 0x00) + + if s := r.String(); s == "" { + t.Errorf("expected non-empty String() output") + } +} diff --git a/socks5/user_pass_request.go b/socks5/user_pass_request.go new file mode 100644 index 0000000..32a4989 --- /dev/null +++ b/socks5/user_pass_request.go @@ -0,0 +1,146 @@ +package socks5 + +import ( + "errors" + "fmt" + "io" +) + +// Errors for username/password authentication requests. +var ( + ErrInvalidUserPassVersion = errors.New("invalid user/password auth version (must be 1)") + ErrEmptyUserPassUsername = errors.New("username cannot be empty") + ErrEmptyUserPassPassword = errors.New("password cannot be empty") + ErrUserPassTooLong = errors.New("username or password too long (max 255)") +) + +// UserPassRequest represents a username/password authentication request. +type UserPassRequest struct { + Version byte // VER (should always be AuthVersionUserPass = 0x01) + Username string // UNAME (1–255 bytes) + Password string // PASSWD (1–255 bytes) +} + +// Init initializes the authentication request with username and password. +func (r *UserPassRequest) Init(version byte, username, password string) { + r.Version = version + r.Username = username + r.Password = password +} + +// Validate checks for protocol correctness. +func (r *UserPassRequest) Validate() error { + if r.Version != AuthVersionUserPass { + return ErrInvalidUserPassVersion + } + if len(r.Username) == 0 { + return ErrEmptyUserPassUsername + } + if len(r.Password) == 0 { + return ErrEmptyUserPassPassword + } + if len(r.Username) > 255 || len(r.Password) > 255 { + return ErrUserPassTooLong + } + return nil +} + +// ReadFrom reads a username/password authentication request from a reader. +// Implements io.ReaderFrom. +func (r *UserPassRequest) ReadFrom(src io.Reader) (int64, error) { + var hdr [2]byte + + // Read VER and ULEN + n, err := io.ReadFull(src, hdr[:]) + if err != nil { + return int64(n), err + } + + r.Version = hdr[0] + ulen := int(hdr[1]) + if ulen == 0 { + return int64(n), ErrEmptyUserPassUsername + } + + // Read username + username := make([]byte, ulen) + n2, err := io.ReadFull(src, username) + total := int64(n + n2) + if err != nil { + return total, err + } + r.Username = string(username) + + // Read PLEN + var plen [1]byte + n3, err := io.ReadFull(src, plen[:]) + total += int64(n3) + if err != nil { + return total, err + } + + // Read password + pwlen := int(plen[0]) + if pwlen == 0 { + return total, ErrEmptyUserPassPassword + } + + password := make([]byte, pwlen) + n4, err := io.ReadFull(src, password) + total += int64(n4) + if err != nil { + return total, err + } + r.Password = string(password) + + return total, r.Validate() +} + +// WriteTo writes the username/password request to a writer. +// Implements io.WriterTo. +// Note: returns error if user or pass is too long. +func (r *UserPassRequest) WriteTo(dst io.Writer) (int64, error) { + if len(r.Username) > 255 || len(r.Password) > 255 { + return 0, ErrUserPassTooLong + } + + buf := []byte{ + r.Version, + byte(len(r.Username)), + } + total := int64(0) + + // Write version + ULEN + n, err := dst.Write(buf) + total += int64(n) + if err != nil { + return total, err + } + + // Write username + n, err = io.WriteString(dst, r.Username) + total += int64(n) + if err != nil { + return total, err + } + + // Write password header and body + pwHdr := [1]byte{byte(len(r.Password))} + n, err = dst.Write(pwHdr[:]) + total += int64(n) + if err != nil { + return total, err + } + + n, err = io.WriteString(dst, r.Password) + total += int64(n) + return total, err +} + +// String returns a human-readable representation. +func (r *UserPassRequest) String() string { + return fmt.Sprintf( + "UserPassRequest{Version=%d, Username=%q, PasswordLen=%d}", + r.Version, r.Username, len(r.Password), + ) +} diff --git a/socks5/user_pass_request_test.go b/socks5/user_pass_request_test.go new file mode 100644 index 0000000..38f239f --- /dev/null +++ b/socks5/user_pass_request_test.go @@ -0,0 +1,115 @@ +package socks5_test + +import ( + "bytes" + "errors" + "io" + "testing" + + "github.com/33TU/socks/socks5" +) + +func Test_UserPassRequest_Init_And_Validate(t *testing.T) { + r := &socks5.UserPassRequest{} + r.Init(socks5.AuthVersionUserPass, "alice", "secret") + + if err := r.Validate(); err != nil { + t.Fatalf("expected valid request, got %v", err) + } + + r.Version = 0x02 + if err := r.Validate(); !errors.Is(err, socks5.ErrInvalidUserPassVersion) { + t.Errorf("expected ErrInvalidUserPassVersion, got %v", err) + } + + r.Version = socks5.AuthVersionUserPass + r.Username = "" + if err := r.Validate(); !errors.Is(err, socks5.ErrEmptyUserPassUsername) { + t.Errorf("expected ErrEmptyUserPassUsername, got %v", err) + } + + r.Username = "bob" + r.Password = "" + if err := r.Validate(); !errors.Is(err, socks5.ErrEmptyUserPassPassword) { + t.Errorf("expected ErrEmptyUserPassPassword, got %v", err) + } +} + +func Test_UserPassRequest_WriteTo_ReadFrom_RoundTrip(t *testing.T) { + orig := &socks5.UserPassRequest{} + orig.Init(socks5.AuthVersionUserPass, "admin", "hunter2") + + var buf bytes.Buffer + n1, err := orig.WriteTo(&buf) + if err != nil { + t.Fatalf("WriteTo failed: %v", err) + } + + var parsed socks5.UserPassRequest + n2, err := parsed.ReadFrom(&buf) + if err != nil { + t.Fatalf("ReadFrom failed: %v", err) + } + + if n1 != n2 { + t.Errorf("expected %d bytes read, got %d", n1, n2) + } + if parsed.Username != orig.Username { + t.Errorf("expected username %q, got %q", orig.Username, parsed.Username) + } + if parsed.Password != orig.Password { + t.Errorf("expected password %q, got %q", orig.Password, parsed.Password) + } + if parsed.Version != socks5.AuthVersionUserPass { + t.Errorf("expected version %d, got %d", socks5.AuthVersionUserPass, parsed.Version) + } +} + +func Test_UserPassRequest_ReadFrom_Truncated(t *testing.T) { + // missing password bytes + data := []byte{ + socks5.AuthVersionUserPass, + 3, 'b', 'o', 'b', + 5, 'p', 'a', 's', + } + r := &socks5.UserPassRequest{} + if _, err := r.ReadFrom(bytes.NewReader(data)); err == nil { + t.Errorf("expected error for truncated payload") + } +} + +func Test_UserPassRequest_ReadFrom_EmptyUsernameOrPassword(t *testing.T) { + // empty username (ULEN = 0) + data := []byte{1, 0} + r := &socks5.UserPassRequest{} + if _, err := r.ReadFrom(bytes.NewReader(data)); !errors.Is(err, socks5.ErrEmptyUserPassUsername) { + t.Errorf("expected ErrEmptyUserPassUsername, got %v", err) + } + + // empty password (PLEN = 0) + data = []byte{1, 3, 'b', 'o', 'b', 0} + if _, err := r.ReadFrom(bytes.NewReader(data)); !errors.Is(err, socks5.ErrEmptyUserPassPassword) { + t.Errorf("expected ErrEmptyUserPassPassword, got %v", err) + } +} + +func Test_UserPassRequest_WriteTo_ErrorPropagation(t *testing.T) { + r := &socks5.UserPassRequest{} + r.Init(socks5.AuthVersionUserPass, "foo", "bar") + + failWriter := writerFunc(func(p []byte) (int, error) { + return 0, io.ErrClosedPipe + }) + + if _, err := r.WriteTo(failWriter); err == nil { + t.Errorf("expected write error") + } +} +func Test_UserPassRequest_String(t *testing.T) { + r := &socks5.UserPassRequest{} + r.Init(socks5.AuthVersionUserPass, "user", "pass") + + if s := r.String(); s == "" { + t.Errorf("expected non-empty String() output") + } +} From cdee77b028afc35f355c1f69c943902155a53801 Mon Sep 17 00:00:00 2001 From: 33TU Date: Sat, 18 Oct 2025 04:05:22 +0300 Subject: [PATCH 2/6] gssapi and constants for socks5 --- socks5/consts.go | 51 +++++++------ socks5/gssapi_reply.go | 117 ++++++++++++++++++++++++++++ socks5/gssapi_reply_test.go | 138 ++++++++++++++++++++++++++++++++++ socks5/gssapi_request.go | 120 +++++++++++++++++++++++++++++ socks5/gssapi_request_test.go | 136 +++++++++++++++++++++++++++++++++ 5 files changed, 540 insertions(+), 22 deletions(-) create mode 100644 socks5/gssapi_reply.go create mode 100644 socks5/gssapi_reply_test.go create mode 100644 socks5/gssapi_request.go create mode 100644 socks5/gssapi_request_test.go diff --git a/socks5/consts.go b/socks5/consts.go index 2d4d01b..17598ab 100644 --- a/socks5/consts.go +++ b/socks5/consts.go @@ -7,42 +7,49 @@ const ( // Command codes (CMD) for client requests. const ( - CmdConnect = 1 - CmdBind = 2 - CmdUDPAssociate = 3 - CmdResolve = 0xF0 - CmdResolvePTR = 0xF1 + CmdConnect = 1 // Establish a TCP/IP stream connection + CmdBind = 2 // Establish a TCP/IP port binding + CmdUDPAssociate = 3 // Associate UDP relay + CmdResolve = 0xF0 // Name resolution (non-standard) + CmdResolvePTR = 0xF1 // Reverse lookup (non-standard) ) // Address types (ATYP) used in requests and responses. const ( - AddrTypeIPv4 = 1 - AddrTypeDomain = 3 - AddrTypeIPv6 = 4 + AddrTypeIPv4 = 1 // IPv4 address + AddrTypeDomain = 3 // Domain name + AddrTypeIPv6 = 4 // IPv6 address ) // Reply codes (REP) for server responses. const ( - RepSuccess = 0 - RepGeneralFailure = 1 - RepConnectionNotAllowed = 2 - RepNetworkUnreachable = 3 - RepHostUnreachable = 4 - RepConnectionRefused = 5 - RepTTLExpired = 6 - RepCommandNotSupported = 7 - RepAddrTypeNotSupported = 8 + RepSuccess = 0 // Request granted + RepGeneralFailure = 1 // General SOCKS server failure + RepConnectionNotAllowed = 2 // Connection not allowed by ruleset + RepNetworkUnreachable = 3 // Network unreachable + RepHostUnreachable = 4 // Host unreachable + RepConnectionRefused = 5 // Connection refused + RepTTLExpired = 6 // TTL expired + RepCommandNotSupported = 7 // Command not supported + RepAddrTypeNotSupported = 8 // Address type not supported ) // Authentication methods (METHOD) for initial greeting. const ( - MethodNoAuth = 0x00 - MethodGSSAPI = 0x01 - MethodUserPass = 0x02 - MethodNoAcceptable = 0xFF + MethodNoAuth = 0x00 // No authentication required + MethodGSSAPI = 0x01 // GSS-API authentication + MethodUserPass = 0x02 // Username/password authentication + MethodNoAcceptable = 0xFF // No acceptable methods ) // Authentication sub-negotiation versions. const ( - AuthVersionUserPass = 1 + AuthVersionUserPass = 1 // Username/password sub-negotiation version +) + +// GSS-API message types (MTYP) per RFC 1961. +const ( + GSSAPITypeInit = 0x01 // Client initial token + GSSAPITypeReply = 0x02 // Server reply token + GSSAPITypeAbort = 0xFF // Abort / failure ) diff --git a/socks5/gssapi_reply.go b/socks5/gssapi_reply.go new file mode 100644 index 0000000..001b9c1 --- /dev/null +++ b/socks5/gssapi_reply.go @@ -0,0 +1,117 @@ +package socks5 + +import ( + "encoding/binary" + "errors" + "fmt" + "io" +) + +// Errors for GSSAPI authentication replies. +var ( + ErrInvalidGSSAPIReplyVersion = errors.New("invalid GSSAPI reply version (must be 1)") + ErrEmptyGSSAPIReplyToken = errors.New("GSSAPI reply token cannot be empty") + ErrGSSAPIReplyTooLong = errors.New("GSSAPI reply token too long (max 65535)") +) + +// GSSAPIReply represents a GSSAPI authentication reply message (RFC 1961 §3.7). +type GSSAPIReply struct { + Version byte // VER (should always be 0x01) + MsgType byte // MTYP (0x02 = reply token, 0xFF = failure) + Token []byte // TOKEN (optional; none if MTYP=0xFF) +} + +// Init initializes the GSSAPI reply. +func (r *GSSAPIReply) Init(version, msgType byte, token []byte) { + r.Version = version + r.MsgType = msgType + r.Token = token +} + +// Validate checks for protocol correctness. +func (r *GSSAPIReply) Validate() error { + if r.Version != 0x01 { + return ErrInvalidGSSAPIReplyVersion + } + if r.MsgType == GSSAPITypeAbort { + return nil + } + if len(r.Token) == 0 { + return ErrEmptyGSSAPIReplyToken + } + if len(r.Token) > 65535 { + return ErrGSSAPIReplyTooLong + } + return nil +} + +// ReadFrom reads a GSSAPI reply from a reader. +func (r *GSSAPIReply) ReadFrom(src io.Reader) (int64, error) { + var hdr [4]byte + n, err := io.ReadFull(src, hdr[:2]) + if err != nil { + return int64(n), err + } + + r.Version = hdr[0] + r.MsgType = hdr[1] + if r.MsgType == GSSAPITypeAbort { + return int64(n), nil + } + + n2, err := io.ReadFull(src, hdr[2:4]) + n += n2 + if err != nil { + return int64(n), err + } + length := binary.BigEndian.Uint16(hdr[2:4]) + if length == 0 { + return int64(n), ErrEmptyGSSAPIReplyToken + } + + token := make([]byte, length) + n3, err := io.ReadFull(src, token) + total := int64(n + n3) + if err != nil { + return total, err + } + r.Token = token + return total, r.Validate() +} + +// WriteTo writes the GSSAPI reply to a writer. +// NOTE: returns error if token length is too long. +func (r *GSSAPIReply) WriteTo(dst io.Writer) (int64, error) { + if len(r.Token) > 65535 { + return 0, ErrGSSAPIReplyTooLong + } + + if r.MsgType == GSSAPITypeAbort { + buf := [2]byte{r.Version, r.MsgType} + n, err := dst.Write(buf[:]) + return int64(n), err + } + + var hdr [4]byte + hdr[0] = r.Version + hdr[1] = r.MsgType + binary.BigEndian.PutUint16(hdr[2:], uint16(len(r.Token))) + + n, err := dst.Write(hdr[:]) + total := int64(n) + if err != nil { + return total, err + } + + n2, err := dst.Write(r.Token) + total += int64(n2) + return total, err +} + +// String returns a human-readable representation. +func (r *GSSAPIReply) String() string { + return fmt.Sprintf( + "GSSAPIReply{Version=%d, MsgType=0x%02x, TokenLen=%d}", + r.Version, r.MsgType, len(r.Token), + ) +} diff --git a/socks5/gssapi_reply_test.go b/socks5/gssapi_reply_test.go new file mode 100644 index 0000000..8d11753 --- /dev/null +++ b/socks5/gssapi_reply_test.go @@ -0,0 +1,138 @@ +package socks5_test + +import ( + "bytes" + "errors" + "io" + "testing" + + "github.com/33TU/socks/socks5" +) + +func Test_GSSAPIReply_Init_And_Validate(t *testing.T) { + r := &socks5.GSSAPIReply{} + r.Init(0x01, socks5.GSSAPITypeReply, []byte{0xca, 0xfe, 0xba, 0xbe}) + + if err := r.Validate(); err != nil { + t.Fatalf("expected valid reply, got %v", err) + } + + r.Version = 0x02 + if err := r.Validate(); !errors.Is(err, socks5.ErrInvalidGSSAPIReplyVersion) { + t.Errorf("expected ErrInvalidGSSAPIReplyVersion, got %v", err) + } + + // Empty token (non-abort) + r.Version = 0x01 + r.MsgType = socks5.GSSAPITypeReply + r.Token = nil + if err := r.Validate(); !errors.Is(err, socks5.ErrEmptyGSSAPIReplyToken) { + t.Errorf("expected ErrEmptyGSSAPIReplyToken, got %v", err) + } + + // Abort message (should skip token validation) + r.MsgType = socks5.GSSAPITypeAbort + r.Token = nil + if err := r.Validate(); err != nil { + t.Errorf("abort message should be valid, got %v", err) + } + + // Token too long + r.MsgType = socks5.GSSAPITypeReply + r.Token = make([]byte, 70000) + if err := r.Validate(); !errors.Is(err, socks5.ErrGSSAPIReplyTooLong) { + t.Errorf("expected ErrGSSAPIReplyTooLong, got %v", err) + } +} + +func Test_GSSAPIReply_WriteTo_ReadFrom_RoundTrip(t *testing.T) { + orig := &socks5.GSSAPIReply{} + orig.Init(0x01, socks5.GSSAPITypeReply, []byte{0xde, 0xad, 0xbe, 0xef}) + + var buf bytes.Buffer + n1, err := orig.WriteTo(&buf) + if err != nil { + t.Fatalf("WriteTo failed: %v", err) + } + + var parsed socks5.GSSAPIReply + n2, err := parsed.ReadFrom(&buf) + if err != nil { + t.Fatalf("ReadFrom failed: %v", err) + } + + if n1 != n2 { + t.Errorf("expected %d bytes read, got %d", n1, n2) + } + if parsed.Version != 0x01 { + t.Errorf("expected version 1, got %d", parsed.Version) + } + if parsed.MsgType != socks5.GSSAPITypeReply { + t.Errorf("expected msgType 0x02, got %#02x", parsed.MsgType) + } + if !bytes.Equal(parsed.Token, orig.Token) { + t.Errorf("token mismatch: got %x, want %x", parsed.Token, orig.Token) + } +} + +func Test_GSSAPIReply_ReadFrom_Truncated(t *testing.T) { + data := []byte{ + 0x01, socks5.GSSAPITypeReply, 0x00, 0x04, // header: ver, mtyp, len=4 + 0xde, 0xad, // incomplete token + } + r := &socks5.GSSAPIReply{} + if _, err := r.ReadFrom(bytes.NewReader(data)); err == nil { + t.Errorf("expected error for truncated payload") + } +} + +func Test_GSSAPIReply_ReadFrom_Abort(t *testing.T) { + data := []byte{0x01, socks5.GSSAPITypeAbort} + r := &socks5.GSSAPIReply{} + n, err := r.ReadFrom(bytes.NewReader(data)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if n != 2 { + t.Errorf("expected 2 bytes read, got %d", n) + } + if r.MsgType != socks5.GSSAPITypeAbort { + t.Errorf("expected abort msgType 0xFF, got %#02x", r.MsgType) + } +} + +func Test_GSSAPIReply_ReadFrom_EmptyOrTooLong(t *testing.T) { + // empty token (len=0) + data := []byte{0x01, socks5.GSSAPITypeReply, 0x00, 0x00} + r := &socks5.GSSAPIReply{} + if _, err := r.ReadFrom(bytes.NewReader(data)); !errors.Is(err, socks5.ErrEmptyGSSAPIReplyToken) { + t.Errorf("expected ErrEmptyGSSAPIReplyToken, got %v", err) + } + + // invalid version + data = []byte{0x05, socks5.GSSAPITypeReply, 0x00, 0x01, 0xff} + if _, err := r.ReadFrom(bytes.NewReader(data)); !errors.Is(err, socks5.ErrInvalidGSSAPIReplyVersion) && err != nil { + t.Errorf("expected ErrInvalidGSSAPIReplyVersion, got %v", err) + } +} + +func Test_GSSAPIReply_WriteTo_ErrorPropagation(t *testing.T) { + r := &socks5.GSSAPIReply{} + r.Init(0x01, socks5.GSSAPITypeReply, []byte{0xaa, 0xbb}) + + failWriter := writerFunc(func(p []byte) (int, error) { + return 0, io.ErrClosedPipe + }) + + if _, err := r.WriteTo(failWriter); err == nil { + t.Errorf("expected write error") + } +} + +func Test_GSSAPIReply_String(t *testing.T) { + r := &socks5.GSSAPIReply{} + r.Init(0x01, socks5.GSSAPITypeReply, []byte{0xde, 0xad}) + if s := r.String(); s == "" { + t.Errorf("expected non-empty String() output") + } +} diff --git a/socks5/gssapi_request.go b/socks5/gssapi_request.go new file mode 100644 index 0000000..4ead7a7 --- /dev/null +++ b/socks5/gssapi_request.go @@ -0,0 +1,120 @@ +package socks5 + +import ( + "encoding/binary" + "errors" + "fmt" + "io" +) + +// Errors for GSSAPI authentication requests. +var ( + ErrInvalidGSSAPIVersion = errors.New("invalid GSSAPI version (must be 1)") + ErrEmptyGSSAPIToken = errors.New("GSSAPI token cannot be empty") + ErrGSSAPITokenTooLong = errors.New("GSSAPI token too long (max 65535)") +) + +// GSSAPIRequest represents a GSSAPI authentication request (RFC 1961 §3.4). +type GSSAPIRequest struct { + Version byte // VER (should always be 0x01) + MsgType byte // MTYP (0x01 = initial token) + Token []byte // TOKEN (opaque GSSAPI token) +} + +// Init initializes a GSSAPI authentication request. +func (r *GSSAPIRequest) Init(version, msgType byte, token []byte) { + r.Version = version + r.MsgType = msgType + r.Token = token +} + +// Validate checks for protocol correctness. +func (r *GSSAPIRequest) Validate() error { + if r.Version != 0x01 { + return ErrInvalidGSSAPIVersion + } + if r.MsgType == GSSAPITypeAbort { + // Abort messages have no token + return nil + } + if len(r.Token) == 0 { + return ErrEmptyGSSAPIToken + } + if len(r.Token) > 65535 { + return ErrGSSAPITokenTooLong + } + return nil +} + +// ReadFrom reads a GSSAPI authentication request from a reader. +func (r *GSSAPIRequest) ReadFrom(src io.Reader) (int64, error) { + var hdr [4]byte + n, err := io.ReadFull(src, hdr[:2]) + if err != nil { + return int64(n), err + } + + r.Version = hdr[0] + r.MsgType = hdr[1] + if r.MsgType == GSSAPITypeAbort { + return int64(n), nil + } + + // Read length + n2, err := io.ReadFull(src, hdr[2:4]) + n += n2 + if err != nil { + return int64(n), err + } + length := binary.BigEndian.Uint16(hdr[2:4]) + if length == 0 { + return int64(n), ErrEmptyGSSAPIToken + } + + token := make([]byte, length) + n3, err := io.ReadFull(src, token) + total := int64(n + n3) + if err != nil { + return total, err + } + r.Token = token + return total, r.Validate() +} + +// WriteTo writes the GSSAPI authentication request to a writer. +// NOTE: returns error if token length is too long. +func (r *GSSAPIRequest) WriteTo(dst io.Writer) (int64, error) { + if len(r.Token) > 65535 { + return 0, ErrGSSAPITokenTooLong + } + + if r.MsgType == GSSAPITypeAbort { + // Only version + abort message type + buf := [2]byte{r.Version, r.MsgType} + n, err := dst.Write(buf[:]) + return int64(n), err + } + + var hdr [4]byte + hdr[0] = r.Version + hdr[1] = r.MsgType + binary.BigEndian.PutUint16(hdr[2:], uint16(len(r.Token))) + + n, err := dst.Write(hdr[:]) + total := int64(n) + if err != nil { + return total, err + } + + n2, err := dst.Write(r.Token) + total += int64(n2) + return total, err +} + +// String returns a human-readable representation. +func (r *GSSAPIRequest) String() string { + return fmt.Sprintf( + "GSSAPIRequest{Version=%d, MsgType=0x%02x, TokenLen=%d}", + r.Version, r.MsgType, len(r.Token), + ) +} diff --git a/socks5/gssapi_request_test.go b/socks5/gssapi_request_test.go new file mode 100644 index 0000000..c104601 --- /dev/null +++ b/socks5/gssapi_request_test.go @@ -0,0 +1,136 @@ +package socks5_test + +import ( + "bytes" + "errors" + "io" + "testing" + + "github.com/33TU/socks/socks5" +) + +func Test_GSSAPIRequest_Init_And_Validate(t *testing.T) { + r := &socks5.GSSAPIRequest{} + r.Init(0x01, socks5.GSSAPITypeInit, []byte{0xde, 0xad, 0xbe, 0xef}) + + if err := r.Validate(); err != nil { + t.Fatalf("expected valid request, got %v", err) + } + + r.Version = 0x05 + if err := r.Validate(); !errors.Is(err, socks5.ErrInvalidGSSAPIVersion) { + t.Errorf("expected ErrInvalidGSSAPIVersion, got %v", err) + } + + r.Version = 0x01 + r.MsgType = socks5.GSSAPITypeInit + r.Token = nil + if err := r.Validate(); !errors.Is(err, socks5.ErrEmptyGSSAPIToken) { + t.Errorf("expected ErrEmptyGSSAPIToken, got %v", err) + } + + // Abort messages skip token validation + r.MsgType = socks5.GSSAPITypeAbort + r.Token = nil + if err := r.Validate(); err != nil { + t.Errorf("abort message should be valid, got %v", err) + } + + r.MsgType = socks5.GSSAPITypeInit + r.Token = make([]byte, 70000) + if err := r.Validate(); !errors.Is(err, socks5.ErrGSSAPITokenTooLong) { + t.Errorf("expected ErrGSSAPITokenTooLong, got %v", err) + } +} + +func Test_GSSAPIRequest_WriteTo_ReadFrom_RoundTrip(t *testing.T) { + orig := &socks5.GSSAPIRequest{} + orig.Init(0x01, socks5.GSSAPITypeInit, []byte{0x11, 0x22, 0x33, 0x44}) + + var buf bytes.Buffer + n1, err := orig.WriteTo(&buf) + if err != nil { + t.Fatalf("WriteTo failed: %v", err) + } + + var parsed socks5.GSSAPIRequest + n2, err := parsed.ReadFrom(&buf) + if err != nil { + t.Fatalf("ReadFrom failed: %v", err) + } + + if n1 != n2 { + t.Errorf("expected %d bytes read, got %d", n1, n2) + } + if parsed.Version != 0x01 { + t.Errorf("expected version 1, got %d", parsed.Version) + } + if parsed.MsgType != socks5.GSSAPITypeInit { + t.Errorf("expected msgType 0x01, got %#02x", parsed.MsgType) + } + if !bytes.Equal(parsed.Token, orig.Token) { + t.Errorf("token mismatch: got %x, want %x", parsed.Token, orig.Token) + } +} + +func Test_GSSAPIRequest_ReadFrom_Abort(t *testing.T) { + data := []byte{0x01, socks5.GSSAPITypeAbort} + r := &socks5.GSSAPIRequest{} + n, err := r.ReadFrom(bytes.NewReader(data)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if n != 2 { + t.Errorf("expected 2 bytes read, got %d", n) + } + if r.MsgType != socks5.GSSAPITypeAbort { + t.Errorf("expected abort msgType 0xFF, got %#02x", r.MsgType) + } +} + +func Test_GSSAPIRequest_ReadFrom_Truncated(t *testing.T) { + data := []byte{ + 0x01, socks5.GSSAPITypeInit, 0x00, 0x04, // header + 0xde, 0xad, // only 2 of 4 bytes + } + r := &socks5.GSSAPIRequest{} + if _, err := r.ReadFrom(bytes.NewReader(data)); err == nil { + t.Errorf("expected error for truncated payload") + } +} + +func Test_GSSAPIRequest_ReadFrom_EmptyOrTooLong(t *testing.T) { + // empty token (len=0) + data := []byte{0x01, socks5.GSSAPITypeInit, 0x00, 0x00} + r := &socks5.GSSAPIRequest{} + if _, err := r.ReadFrom(bytes.NewReader(data)); !errors.Is(err, socks5.ErrEmptyGSSAPIToken) { + t.Errorf("expected ErrEmptyGSSAPIToken, got %v", err) + } + + // invalid version + data = []byte{0x05, socks5.GSSAPITypeInit, 0x00, 0x01, 0xff} + if _, err := r.ReadFrom(bytes.NewReader(data)); !errors.Is(err, socks5.ErrInvalidGSSAPIVersion) && err != nil { + t.Errorf("expected ErrInvalidGSSAPIVersion, got %v", err) + } +} + +func Test_GSSAPIRequest_WriteTo_ErrorPropagation(t *testing.T) { + r := &socks5.GSSAPIRequest{} + r.Init(0x01, socks5.GSSAPITypeInit, []byte{0xaa, 0xbb}) + + failWriter := writerFunc(func(p []byte) (int, error) { + return 0, io.ErrClosedPipe + }) + + if _, err := r.WriteTo(failWriter); err == nil { + t.Errorf("expected write error") + } +} + +func Test_GSSAPIRequest_String(t *testing.T) { + r := &socks5.GSSAPIRequest{} + r.Init(0x01, socks5.GSSAPITypeInit, []byte{0xde, 0xad}) + if s := r.String(); s == "" { + t.Errorf("expected non-empty String() output") + } +} From 446da82514d687cc45dd32a5e280176008b67d16 Mon Sep 17 00:00:00 2001 From: 33TU Date: Sat, 18 Oct 2025 04:08:51 +0300 Subject: [PATCH 3/6] socks5 comments cleanup --- socks5/consts.go | 50 ++++++++++++++++++++++---------------------- socks5/udp_packet.go | 6 ------ 2 files changed, 25 insertions(+), 31 deletions(-) diff --git a/socks5/consts.go b/socks5/consts.go index 17598ab..9dd20b9 100644 --- a/socks5/consts.go +++ b/socks5/consts.go @@ -7,49 +7,49 @@ const ( // Command codes (CMD) for client requests. const ( - CmdConnect = 1 // Establish a TCP/IP stream connection - CmdBind = 2 // Establish a TCP/IP port binding - CmdUDPAssociate = 3 // Associate UDP relay - CmdResolve = 0xF0 // Name resolution (non-standard) - CmdResolvePTR = 0xF1 // Reverse lookup (non-standard) + CmdConnect = 1 + CmdBind = 2 + CmdUDPAssociate = 3 + CmdResolve = 0xF0 + CmdResolvePTR = 0xF1 ) // Address types (ATYP) used in requests and responses. const ( - AddrTypeIPv4 = 1 // IPv4 address - AddrTypeDomain = 3 // Domain name - AddrTypeIPv6 = 4 // IPv6 address + AddrTypeIPv4 = 1 + AddrTypeDomain = 3 + AddrTypeIPv6 = 4 ) // Reply codes (REP) for server responses. const ( - RepSuccess = 0 // Request granted - RepGeneralFailure = 1 // General SOCKS server failure - RepConnectionNotAllowed = 2 // Connection not allowed by ruleset - RepNetworkUnreachable = 3 // Network unreachable - RepHostUnreachable = 4 // Host unreachable - RepConnectionRefused = 5 // Connection refused - RepTTLExpired = 6 // TTL expired - RepCommandNotSupported = 7 // Command not supported - RepAddrTypeNotSupported = 8 // Address type not supported + RepSuccess = 0 + RepGeneralFailure = 1 + RepConnectionNotAllowed = 2 + RepNetworkUnreachable = 3 + RepHostUnreachable = 4 + RepConnectionRefused = 5 + RepTTLExpired = 6 + RepCommandNotSupported = 7 + RepAddrTypeNotSupported = 8 ) // Authentication methods (METHOD) for initial greeting. const ( - MethodNoAuth = 0x00 // No authentication required - MethodGSSAPI = 0x01 // GSS-API authentication - MethodUserPass = 0x02 // Username/password authentication - MethodNoAcceptable = 0xFF // No acceptable methods + MethodNoAuth = 0x00 + MethodGSSAPI = 0x01 + MethodUserPass = 0x02 + MethodNoAcceptable = 0xFF ) // Authentication sub-negotiation versions. const ( - AuthVersionUserPass = 1 // Username/password sub-negotiation version + AuthVersionUserPass = 1 ) // GSS-API message types (MTYP) per RFC 1961. const ( - GSSAPITypeInit = 0x01 // Client initial token - GSSAPITypeReply = 0x02 // Server reply token - GSSAPITypeAbort = 0xFF // Abort / failure + GSSAPITypeInit = 0x01 + GSSAPITypeReply = 0x02 + GSSAPITypeAbort = 0xFF ) diff --git a/socks5/udp_packet.go b/socks5/udp_packet.go index c2ea267..260be64 100644 --- a/socks5/udp_packet.go +++ b/socks5/udp_packet.go @@ -18,12 +18,6 @@ var ( ) // UDPPacket represents a SOCKS5 UDP ASSOCIATE packet. -// Wire format (RFC 1928 §7): -// -// +----+------+------+----------+----------+----------+ -// |RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA | -// +----+------+------+----------+----------+----------+ -// | 2 | 1 | 1 | Variable | 2 | Variable | type UDPPacket struct { Reserved [2]byte // RSV; must be 0x0000 Frag byte // FRAG; must be 0x00 (no fragmentation) From 48306be7e6cb2d947983c7a723f69e7555062a62 Mon Sep 17 00:00:00 2001 From: 33TU Date: Sat, 18 Oct 2025 13:32:16 +0300 Subject: [PATCH 4/6] gssapi version --- socks5/consts.go | 7 ++++++- socks5/gssapi_reply_test.go | 18 +++++++++--------- socks5/gssapi_request_test.go | 18 +++++++++--------- 3 files changed, 24 insertions(+), 19 deletions(-) diff --git a/socks5/consts.go b/socks5/consts.go index 9dd20b9..905f86d 100644 --- a/socks5/consts.go +++ b/socks5/consts.go @@ -47,9 +47,14 @@ const ( AuthVersionUserPass = 1 ) -// GSS-API message types (MTYP) per RFC 1961. +// GSS-API message types (MTYP) const ( GSSAPITypeInit = 0x01 GSSAPITypeReply = 0x02 GSSAPITypeAbort = 0xFF ) + +// GSS-API protocol version. (VER) +const ( + GSSAPIVersion = 1 +) diff --git a/socks5/gssapi_reply_test.go b/socks5/gssapi_reply_test.go index 8d11753..7f34a95 100644 --- a/socks5/gssapi_reply_test.go +++ b/socks5/gssapi_reply_test.go @@ -11,7 +11,7 @@ import ( func Test_GSSAPIReply_Init_And_Validate(t *testing.T) { r := &socks5.GSSAPIReply{} - r.Init(0x01, socks5.GSSAPITypeReply, []byte{0xca, 0xfe, 0xba, 0xbe}) + r.Init(socks5.GSSAPIVersion, socks5.GSSAPITypeReply, []byte{0xca, 0xfe, 0xba, 0xbe}) if err := r.Validate(); err != nil { t.Fatalf("expected valid reply, got %v", err) @@ -23,7 +23,7 @@ func Test_GSSAPIReply_Init_And_Validate(t *testing.T) { } // Empty token (non-abort) - r.Version = 0x01 + r.Version = socks5.GSSAPIVersion r.MsgType = socks5.GSSAPITypeReply r.Token = nil if err := r.Validate(); !errors.Is(err, socks5.ErrEmptyGSSAPIReplyToken) { @@ -47,7 +47,7 @@ func Test_GSSAPIReply_Init_And_Validate(t *testing.T) { func Test_GSSAPIReply_WriteTo_ReadFrom_RoundTrip(t *testing.T) { orig := &socks5.GSSAPIReply{} - orig.Init(0x01, socks5.GSSAPITypeReply, []byte{0xde, 0xad, 0xbe, 0xef}) + orig.Init(socks5.GSSAPIVersion, socks5.GSSAPITypeReply, []byte{0xde, 0xad, 0xbe, 0xef}) var buf bytes.Buffer n1, err := orig.WriteTo(&buf) @@ -64,7 +64,7 @@ func Test_GSSAPIReply_WriteTo_ReadFrom_RoundTrip(t *testing.T) { if n1 != n2 { t.Errorf("expected %d bytes read, got %d", n1, n2) } - if parsed.Version != 0x01 { + if parsed.Version != socks5.GSSAPIVersion { t.Errorf("expected version 1, got %d", parsed.Version) } if parsed.MsgType != socks5.GSSAPITypeReply { @@ -77,7 +77,7 @@ func Test_GSSAPIReply_WriteTo_ReadFrom_RoundTrip(t *testing.T) { func Test_GSSAPIReply_ReadFrom_Truncated(t *testing.T) { data := []byte{ - 0x01, socks5.GSSAPITypeReply, 0x00, 0x04, // header: ver, mtyp, len=4 + socks5.GSSAPIVersion, socks5.GSSAPITypeReply, 0x00, 0x04, // header: ver, mtyp, len=4 0xde, 0xad, // incomplete token } r := &socks5.GSSAPIReply{} @@ -87,7 +87,7 @@ func Test_GSSAPIReply_ReadFrom_Truncated(t *testing.T) { } func Test_GSSAPIReply_ReadFrom_Abort(t *testing.T) { - data := []byte{0x01, socks5.GSSAPITypeAbort} + data := []byte{socks5.GSSAPIVersion, socks5.GSSAPITypeAbort} r := &socks5.GSSAPIReply{} n, err := r.ReadFrom(bytes.NewReader(data)) if err != nil { @@ -103,7 +103,7 @@ func Test_GSSAPIReply_ReadFrom_Abort(t *testing.T) { func Test_GSSAPIReply_ReadFrom_EmptyOrTooLong(t *testing.T) { // empty token (len=0) - data := []byte{0x01, socks5.GSSAPITypeReply, 0x00, 0x00} + data := []byte{socks5.GSSAPIVersion, socks5.GSSAPITypeReply, 0x00, 0x00} r := &socks5.GSSAPIReply{} if _, err := r.ReadFrom(bytes.NewReader(data)); !errors.Is(err, socks5.ErrEmptyGSSAPIReplyToken) { t.Errorf("expected ErrEmptyGSSAPIReplyToken, got %v", err) @@ -118,7 +118,7 @@ func Test_GSSAPIReply_ReadFrom_EmptyOrTooLong(t *testing.T) { func Test_GSSAPIReply_WriteTo_ErrorPropagation(t *testing.T) { r := &socks5.GSSAPIReply{} - r.Init(0x01, socks5.GSSAPITypeReply, []byte{0xaa, 0xbb}) + r.Init(socks5.GSSAPIVersion, socks5.GSSAPITypeReply, []byte{0xaa, 0xbb}) failWriter := writerFunc(func(p []byte) (int, error) { return 0, io.ErrClosedPipe @@ -131,7 +131,7 @@ func Test_GSSAPIReply_WriteTo_ErrorPropagation(t *testing.T) { func Test_GSSAPIReply_String(t *testing.T) { r := &socks5.GSSAPIReply{} - r.Init(0x01, socks5.GSSAPITypeReply, []byte{0xde, 0xad}) + r.Init(socks5.GSSAPIVersion, socks5.GSSAPITypeReply, []byte{0xde, 0xad}) if s := r.String(); s == "" { t.Errorf("expected non-empty String() output") } diff --git a/socks5/gssapi_request_test.go b/socks5/gssapi_request_test.go index c104601..04a1797 100644 --- a/socks5/gssapi_request_test.go +++ b/socks5/gssapi_request_test.go @@ -11,7 +11,7 @@ import ( func Test_GSSAPIRequest_Init_And_Validate(t *testing.T) { r := &socks5.GSSAPIRequest{} - r.Init(0x01, socks5.GSSAPITypeInit, []byte{0xde, 0xad, 0xbe, 0xef}) + r.Init(socks5.GSSAPIVersion, socks5.GSSAPITypeInit, []byte{0xde, 0xad, 0xbe, 0xef}) if err := r.Validate(); err != nil { t.Fatalf("expected valid request, got %v", err) @@ -22,7 +22,7 @@ func Test_GSSAPIRequest_Init_And_Validate(t *testing.T) { t.Errorf("expected ErrInvalidGSSAPIVersion, got %v", err) } - r.Version = 0x01 + r.Version = socks5.GSSAPIVersion r.MsgType = socks5.GSSAPITypeInit r.Token = nil if err := r.Validate(); !errors.Is(err, socks5.ErrEmptyGSSAPIToken) { @@ -45,7 +45,7 @@ func Test_GSSAPIRequest_Init_And_Validate(t *testing.T) { func Test_GSSAPIRequest_WriteTo_ReadFrom_RoundTrip(t *testing.T) { orig := &socks5.GSSAPIRequest{} - orig.Init(0x01, socks5.GSSAPITypeInit, []byte{0x11, 0x22, 0x33, 0x44}) + orig.Init(socks5.GSSAPIVersion, socks5.GSSAPITypeInit, []byte{0x11, 0x22, 0x33, 0x44}) var buf bytes.Buffer n1, err := orig.WriteTo(&buf) @@ -62,7 +62,7 @@ func Test_GSSAPIRequest_WriteTo_ReadFrom_RoundTrip(t *testing.T) { if n1 != n2 { t.Errorf("expected %d bytes read, got %d", n1, n2) } - if parsed.Version != 0x01 { + if parsed.Version != socks5.GSSAPIVersion { t.Errorf("expected version 1, got %d", parsed.Version) } if parsed.MsgType != socks5.GSSAPITypeInit { @@ -74,7 +74,7 @@ func Test_GSSAPIRequest_WriteTo_ReadFrom_RoundTrip(t *testing.T) { } func Test_GSSAPIRequest_ReadFrom_Abort(t *testing.T) { - data := []byte{0x01, socks5.GSSAPITypeAbort} + data := []byte{socks5.GSSAPIVersion, socks5.GSSAPITypeAbort} r := &socks5.GSSAPIRequest{} n, err := r.ReadFrom(bytes.NewReader(data)) if err != nil { @@ -90,7 +90,7 @@ func Test_GSSAPIRequest_ReadFrom_Abort(t *testing.T) { func Test_GSSAPIRequest_ReadFrom_Truncated(t *testing.T) { data := []byte{ - 0x01, socks5.GSSAPITypeInit, 0x00, 0x04, // header + socks5.GSSAPIVersion, socks5.GSSAPITypeInit, 0x00, 0x04, // header 0xde, 0xad, // only 2 of 4 bytes } r := &socks5.GSSAPIRequest{} @@ -101,7 +101,7 @@ func Test_GSSAPIRequest_ReadFrom_Truncated(t *testing.T) { func Test_GSSAPIRequest_ReadFrom_EmptyOrTooLong(t *testing.T) { // empty token (len=0) - data := []byte{0x01, socks5.GSSAPITypeInit, 0x00, 0x00} + data := []byte{socks5.GSSAPIVersion, socks5.GSSAPITypeInit, 0x00, 0x00} r := &socks5.GSSAPIRequest{} if _, err := r.ReadFrom(bytes.NewReader(data)); !errors.Is(err, socks5.ErrEmptyGSSAPIToken) { t.Errorf("expected ErrEmptyGSSAPIToken, got %v", err) @@ -116,7 +116,7 @@ func Test_GSSAPIRequest_ReadFrom_EmptyOrTooLong(t *testing.T) { func Test_GSSAPIRequest_WriteTo_ErrorPropagation(t *testing.T) { r := &socks5.GSSAPIRequest{} - r.Init(0x01, socks5.GSSAPITypeInit, []byte{0xaa, 0xbb}) + r.Init(socks5.GSSAPIVersion, socks5.GSSAPITypeInit, []byte{0xaa, 0xbb}) failWriter := writerFunc(func(p []byte) (int, error) { return 0, io.ErrClosedPipe @@ -129,7 +129,7 @@ func Test_GSSAPIRequest_WriteTo_ErrorPropagation(t *testing.T) { func Test_GSSAPIRequest_String(t *testing.T) { r := &socks5.GSSAPIRequest{} - r.Init(0x01, socks5.GSSAPITypeInit, []byte{0xde, 0xad}) + r.Init(socks5.GSSAPIVersion, socks5.GSSAPITypeInit, []byte{0xde, 0xad}) if s := r.String(); s == "" { t.Errorf("expected non-empty String() output") } From cccf3f0662064fc2e1ccec3855d5f2f01b750fe5 Mon Sep 17 00:00:00 2001 From: 33TU Date: Sat, 18 Oct 2025 13:35:06 +0300 Subject: [PATCH 5/6] handshake version --- socks5/handshake_reply.go | 4 ++-- socks5/handshake_reply_test.go | 8 ++++---- socks5/handshake_request.go | 4 ++-- socks5/handshake_request_test.go | 8 ++++---- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/socks5/handshake_reply.go b/socks5/handshake_reply.go index df3bb80..9fded1e 100644 --- a/socks5/handshake_reply.go +++ b/socks5/handshake_reply.go @@ -18,8 +18,8 @@ type HandshakeReply struct { } // Init initializes a handshake reply with the given method. -func (h *HandshakeReply) Init(method byte) { - h.Version = SocksVersion +func (h *HandshakeReply) Init(version byte, method byte) { + h.Version = version h.Method = method } diff --git a/socks5/handshake_reply_test.go b/socks5/handshake_reply_test.go index 96e7b22..2cdd667 100644 --- a/socks5/handshake_reply_test.go +++ b/socks5/handshake_reply_test.go @@ -11,7 +11,7 @@ import ( func Test_HandshakeReply_Init_And_Validate(t *testing.T) { h := &socks5.HandshakeReply{} - h.Init(socks5.MethodUserPass) + h.Init(socks5.SocksVersion, socks5.MethodUserPass) if err := h.Validate(); err != nil { t.Fatalf("expected valid reply, got %v", err) @@ -25,7 +25,7 @@ func Test_HandshakeReply_Init_And_Validate(t *testing.T) { func Test_HandshakeReply_WriteTo_ReadFrom_RoundTrip(t *testing.T) { orig := &socks5.HandshakeReply{} - orig.Init(socks5.MethodNoAuth) + orig.Init(socks5.SocksVersion, socks5.MethodNoAuth) var buf bytes.Buffer n1, err := orig.WriteTo(&buf) @@ -57,7 +57,7 @@ func Test_HandshakeReply_ReadFrom_Truncated(t *testing.T) { func Test_HandshakeReply_WriteTo_ErrorPropagation(t *testing.T) { h := &socks5.HandshakeReply{} - h.Init(socks5.MethodUserPass) + h.Init(socks5.SocksVersion, socks5.MethodUserPass) failWriter := writerFunc(func(p []byte) (int, error) { return 0, io.ErrClosedPipe @@ -70,7 +70,7 @@ func Test_HandshakeReply_WriteTo_ErrorPropagation(t *testing.T) { func Test_HandshakeReply_String(t *testing.T) { h := &socks5.HandshakeReply{} - h.Init(socks5.MethodNoAuth) + h.Init(socks5.SocksVersion, socks5.MethodNoAuth) if s := h.String(); s == "" { t.Errorf("expected non-empty String() output") diff --git a/socks5/handshake_request.go b/socks5/handshake_request.go index bef2d5d..52d1000 100644 --- a/socks5/handshake_request.go +++ b/socks5/handshake_request.go @@ -21,8 +21,8 @@ type HandshakeRequest struct { } // Init initializes a handshake request with the given methods. -func (h *HandshakeRequest) Init(methods ...byte) { - h.Version = SocksVersion +func (h *HandshakeRequest) Init(version byte, methods ...byte) { + h.Version = version h.NMethods = byte(len(methods)) h.Methods = append([]byte(nil), methods...) // copy } diff --git a/socks5/handshake_request_test.go b/socks5/handshake_request_test.go index 7e2f629..a666697 100644 --- a/socks5/handshake_request_test.go +++ b/socks5/handshake_request_test.go @@ -11,7 +11,7 @@ import ( func Test_HandshakeRequest_Init_And_Validate(t *testing.T) { r := &socks5.HandshakeRequest{} - r.Init(socks5.MethodNoAuth, socks5.MethodUserPass) + r.Init(socks5.SocksVersion, socks5.MethodNoAuth, socks5.MethodUserPass) if err := r.Validate(); err != nil { t.Fatalf("expected valid request, got %v", err) @@ -31,7 +31,7 @@ func Test_HandshakeRequest_Init_And_Validate(t *testing.T) { func Test_HandshakeRequest_WriteTo_ReadFrom_RoundTrip(t *testing.T) { orig := &socks5.HandshakeRequest{} - orig.Init(socks5.MethodNoAuth, socks5.MethodUserPass) + orig.Init(socks5.SocksVersion, socks5.MethodNoAuth, socks5.MethodUserPass) var buf bytes.Buffer n1, err := orig.WriteTo(&buf) @@ -71,7 +71,7 @@ func Test_HandshakeRequest_ReadFrom_Truncated(t *testing.T) { func Test_HandshakeRequest_WriteTo_ErrorPropagation(t *testing.T) { r := &socks5.HandshakeRequest{} - r.Init(socks5.MethodNoAuth) + r.Init(socks5.SocksVersion, socks5.MethodNoAuth) failWriter := writerFunc(func(p []byte) (int, error) { return 0, io.ErrClosedPipe @@ -84,7 +84,7 @@ func Test_HandshakeRequest_WriteTo_ErrorPropagation(t *testing.T) { func Test_HandshakeRequest_String(t *testing.T) { r := &socks5.HandshakeRequest{} - r.Init(socks5.MethodNoAuth, socks5.MethodUserPass) + r.Init(socks5.SocksVersion, socks5.MethodNoAuth, socks5.MethodUserPass) if s := r.String(); s == "" { t.Errorf("expected non-empty String() output") From 2e54d855f88bd292067a4dc7ed1e5324b1a36c2f Mon Sep 17 00:00:00 2001 From: 33TU Date: Sat, 18 Oct 2025 13:39:29 +0300 Subject: [PATCH 6/6] add validate for variable len WriteTo funcs --- socks5/gssapi_reply.go | 5 ++--- socks5/gssapi_request.go | 8 +++----- socks5/reply.go | 8 ++------ socks5/request.go | 8 ++------ socks5/user_pass_request.go | 5 ++--- 5 files changed, 11 insertions(+), 23 deletions(-) diff --git a/socks5/gssapi_reply.go b/socks5/gssapi_reply.go index 001b9c1..9e2c975 100644 --- a/socks5/gssapi_reply.go +++ b/socks5/gssapi_reply.go @@ -80,10 +80,9 @@ func (r *GSSAPIReply) ReadFrom(src io.Reader) (int64, error) { } // WriteTo writes the GSSAPI reply to a writer. -// NOTE: returns error if token length is too long. func (r *GSSAPIReply) WriteTo(dst io.Writer) (int64, error) { - if len(r.Token) > 65535 { - return 0, ErrGSSAPIReplyTooLong + if err := r.Validate(); err != nil { + return 0, err } if r.MsgType == GSSAPITypeAbort { diff --git a/socks5/gssapi_request.go b/socks5/gssapi_request.go index 4ead7a7..ef6f2ea 100644 --- a/socks5/gssapi_request.go +++ b/socks5/gssapi_request.go @@ -34,8 +34,7 @@ func (r *GSSAPIRequest) Validate() error { return ErrInvalidGSSAPIVersion } if r.MsgType == GSSAPITypeAbort { - // Abort messages have no token - return nil + return nil // Abort messages have no token } if len(r.Token) == 0 { return ErrEmptyGSSAPIToken @@ -82,10 +81,9 @@ func (r *GSSAPIRequest) ReadFrom(src io.Reader) (int64, error) { } // WriteTo writes the GSSAPI authentication request to a writer. -// NOTE: returns error if token length is too long. func (r *GSSAPIRequest) WriteTo(dst io.Writer) (int64, error) { - if len(r.Token) > 65535 { - return 0, ErrGSSAPITokenTooLong + if err := r.Validate(); err != nil { + return 0, err } if r.MsgType == GSSAPITypeAbort { diff --git a/socks5/reply.go b/socks5/reply.go index 90f0898..9f837f3 100644 --- a/socks5/reply.go +++ b/socks5/reply.go @@ -156,13 +156,9 @@ func (r *Reply) ReadFrom(src io.Reader) (int64, error) { // WriteTo writes a SOCKS5 reply to a Writer. // Implements io.WriterTo. -// Note: returns error if domain length is invalid. func (r *Reply) WriteTo(dst io.Writer) (int64, error) { - if r.AddrType == AddrTypeDomain { - domainLen := len(r.Domain) - if domainLen == 0 || domainLen > 255 { - return 0, ErrInvalidReplyDomain - } + if err := r.Validate(); err != nil { + return 0, err } hdr := [4]byte{r.Version, r.Reply, r.Reserved, r.AddrType} diff --git a/socks5/request.go b/socks5/request.go index a487bb1..2eaedc7 100644 --- a/socks5/request.go +++ b/socks5/request.go @@ -172,13 +172,9 @@ func (r *Request) ReadFrom(src io.Reader) (int64, error) { // WriteTo writes a SOCKS5 request to a Writer. // Implements the io.WriterTo interface. -// Note: returns error if domain is too long. func (r *Request) WriteTo(dst io.Writer) (int64, error) { - if r.AddrType == AddrTypeDomain { - domainLen := len(r.Domain) - if domainLen == 0 || domainLen > 255 { - return 0, ErrInvalidReplyDomain - } + if err := r.Validate(); err != nil { + return 0, err } var total int64 diff --git a/socks5/user_pass_request.go b/socks5/user_pass_request.go index 32a4989..88e7db9 100644 --- a/socks5/user_pass_request.go +++ b/socks5/user_pass_request.go @@ -98,10 +98,9 @@ func (r *UserPassRequest) ReadFrom(src io.Reader) (int64, error) { // WriteTo writes the username/password request to a writer. // Implements io.WriterTo. -// Note: returns error if user or pass is too long. func (r *UserPassRequest) WriteTo(dst io.Writer) (int64, error) { - if len(r.Username) > 255 || len(r.Password) > 255 { - return 0, ErrUserPassTooLong + if err := r.Validate(); err != nil { + return 0, err } buf := []byte{