diff --git a/go.mod b/go.mod index 3285372..3eaca58 100644 --- a/go.mod +++ b/go.mod @@ -11,4 +11,15 @@ require ( golang.org/x/text v0.14.0 ) -go 1.13 +require ( + github.com/sethvargo/go-limiter v1.1.0 // indirect + github.com/yuin/goldmark v1.4.13 // indirect + golang.org/x/mod v0.8.0 // indirect + golang.org/x/net v0.10.0 // indirect + golang.org/x/tools v0.6.0 // indirect + golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 // indirect +) + +go 1.22 + +toolchain go1.24.3 diff --git a/go.sum b/go.sum index 83d4fe5..28ad030 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/alexcesaro/log v0.0.0-20150915221235-61e686294e58 h1:MkpmYfld/S8kXqTY github.com/alexcesaro/log v0.0.0-20150915221235-61e686294e58/go.mod h1:YNfsMyWSs+h+PaYkxGeMVmVCX75Zj/pqdjbu12ciCYE= github.com/jessevdk/go-flags v1.5.0 h1:1jKYvbxEjfUl0fmqTCOfonvskHHXMjBySTLW4y9LFvc= github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c2EWnYs/m4= +github.com/sethvargo/go-limiter v1.1.0 h1:eLeZVQ2zqJOiEs03GguqmBVG6/T6lsZB+6PP1t7J6fA= +github.com/sethvargo/go-limiter v1.1.0/go.mod h1:01b6tW25Ap+MeLYBuD4aHunMrJoNO5PVUFdS9rac3II= github.com/shazow/rateio v0.0.0-20200113175441-4461efc8bdc4 h1:zwQ1HBo5FYwn1ksMd19qBCKO8JAWE9wmHivEpkw/DvE= github.com/shazow/rateio v0.0.0-20200113175441-4461efc8bdc4/go.mod h1:vt2jWY/3Qw1bIzle5thrJWucsLuuX9iUNnp20CqCciI= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= diff --git a/sshd/net.go b/sshd/net.go index 678454b..6723f77 100644 --- a/sshd/net.go +++ b/sshd/net.go @@ -1,9 +1,12 @@ package sshd import ( + "context" "net" "time" + "github.com/sethvargo/go-limiter" + "github.com/sethvargo/go-limiter/memorystore" "github.com/shazow/rateio" "golang.org/x/crypto/ssh" ) @@ -15,6 +18,12 @@ type SSHListener struct { RateLimit func() rateio.Limiter HandlerFunc func(term *Terminal) + + // handshakeLimit is a semaphore to limit concurrent handshakes globally + handshakeLimit chan struct{} + + // limiter is the per-IP rate limiter + limiter limiter.Store } // ListenSSH makes an SSH listener socket @@ -23,7 +32,25 @@ func ListenSSH(laddr string, config *ssh.ServerConfig) (*SSHListener, error) { if err != nil { return nil, err } - l := SSHListener{Listener: socket, config: config} + + // Create a rate limiter: 3 attempts per second per IP? + // The user wanted to "throttle many connections". + // 3 per second is generous for a chat server. + // If an IP connects >3 times in a second, it's likely a bot or flood. + store, err := memorystore.New(&memorystore.Config{ + Tokens: 3, + Interval: time.Second, + }) + if err != nil { + return nil, err + } + + l := SSHListener{ + Listener: socket, + config: config, + handshakeLimit: make(chan struct{}, 20), + limiter: store, + } return &l, nil } @@ -35,7 +62,7 @@ func (l *SSHListener) handleConn(conn net.Conn) (*Terminal, error) { // If the connection doesn't write anything back for too long before we get // a valid session, it should be dropped. - var handleTimeout = 20 * time.Second + var handleTimeout = 10 * time.Second conn.SetReadDeadline(time.Now().Add(handleTimeout)) defer conn.SetReadDeadline(time.Time{}) @@ -61,9 +88,39 @@ func (l *SSHListener) Serve() { break } + // Check per-IP limit using go-limiter + host, _, err := net.SplitHostPort(conn.RemoteAddr().String()) + if err != nil { + logger.Printf("Failed to split remote addr: %v", err) + } + + if err == nil { + // Context with timeout is not strictly needed for memory store, but good practice + // Although Take is non-blocking for memory store usually. + _, _, _, ok, err := l.limiter.Take(context.Background(), host) + if err != nil { + // Store error (shouldn't happen with memory store unless closed) + logger.Printf("Rate limiter error: %v", err) + } else if !ok { + logger.Printf("[%s] Rejected connection: rate limit exceeded", conn.RemoteAddr()) + conn.Close() + continue + } + } + + // Acquire global semaphore + l.handshakeLimit <- struct{}{} + // Goroutineify to resume accepting sockets early go func() { term, err := l.handleConn(conn) + + // Handshake is done (success or failure). Release limits. + // Explicit release is required because l.HandlerFunc below + // runs for the duration of the session. We only want to limit + // concurrent handshakes, not concurrent sessions. + <-l.handshakeLimit + if err != nil { logger.Printf("[%s] Failed to handshake: %s", conn.RemoteAddr(), err) conn.Close() // Must be closed to avoid a leak