From ed2073cddea6b84cb3db56dfa34555cce2b37505 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 1 Apr 2026 03:01:02 +0000 Subject: [PATCH 1/2] fix: replace deprecated golang.org/x/net/websocket with gorilla/websocket Agent-Logs-Url: https://github.com/Automattic/cron-control-runner/sessions/c1de8e80-9bd6-4c70-8b70-8f15cae35c63 Co-authored-by: sjinks <7810770+sjinks@users.noreply.github.com> --- go.mod | 3 +- go.sum | 2 ++ remote/remote.go | 75 ++++++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 73 insertions(+), 7 deletions(-) diff --git a/go.mod b/go.mod index f32f7f4..880c9a5 100644 --- a/go.mod +++ b/go.mod @@ -5,12 +5,12 @@ go 1.23.0 require ( github.com/bradfitz/gomemcache v0.0.0-20190913173617-a41fca850d0b github.com/creack/pty v1.1.18 + github.com/gorilla/websocket v1.5.3 github.com/hashicorp/go-retryablehttp v0.7.1 github.com/howeyc/fsnotify v0.9.0 github.com/lthibault/jitterbug/v2 v2.2.2 github.com/prometheus/client_golang v1.11.0 github.com/yookoala/gofast v0.6.0 - golang.org/x/net v0.38.0 golang.org/x/sys v0.31.0 golang.org/x/term v0.30.0 gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 @@ -25,6 +25,7 @@ require ( github.com/prometheus/client_model v0.2.0 // indirect github.com/prometheus/common v0.26.0 // indirect github.com/prometheus/procfs v0.6.0 // indirect + golang.org/x/net v0.38.0 // indirect golang.org/x/tools v0.0.0-20200908211811-12e1bf57a112 // indirect google.golang.org/protobuf v1.26.0-rc.1 // indirect ) diff --git a/go.sum b/go.sum index 3ee3d2e..245e379 100644 --- a/go.sum +++ b/go.sum @@ -45,6 +45,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/go-cleanhttp v0.5.1 h1:dH3aiDG9Jvb5r5+bYHsikaOUIpcM0xvgMXVoDkXMzJM= github.com/hashicorp/go-cleanhttp v0.5.1/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= github.com/hashicorp/go-hclog v0.9.2 h1:CG6TE5H9/JXsFWJCfoIVpKFIkFe6ysEuHirp4DxCsHI= diff --git a/remote/remote.go b/remote/remote.go index 5fdbbbe..f19a6bf 100644 --- a/remote/remote.go +++ b/remote/remote.go @@ -27,10 +27,10 @@ import ( "time" "unicode" + gorilla "github.com/gorilla/websocket" "github.com/creack/pty" "github.com/hashicorp/go-retryablehttp" "github.com/howeyc/fsnotify" - "golang.org/x/net/websocket" "golang.org/x/sys/unix" "golang.org/x/term" ) @@ -39,6 +39,59 @@ const ( shutdownErrorCode = 4001 // WebSocket close code when a shutdown signal is detected ) +// wsNetConn wraps a *gorilla.Conn and implements the net.Conn interface so that +// the WebSocket connection can be used wherever a plain net.Conn is expected. +// gorilla/websocket is message-oriented, so Read buffers an entire message and +// returns chunks of it on successive calls. +type wsNetConn struct { + conn *gorilla.Conn + mu sync.Mutex // serialises concurrent writes (gorilla requires one writer at a time) + readBuf []byte +} + +func newWSNetConn(conn *gorilla.Conn) *wsNetConn { + return &wsNetConn{conn: conn} +} + +func (c *wsNetConn) Read(b []byte) (int, error) { + for len(c.readBuf) == 0 { + _, msg, err := c.conn.ReadMessage() + if err != nil { + return 0, err + } + c.readBuf = msg + } + n := copy(b, c.readBuf) + c.readBuf = c.readBuf[n:] + return n, nil +} + +func (c *wsNetConn) Write(b []byte) (int, error) { + c.mu.Lock() + defer c.mu.Unlock() + if err := c.conn.WriteMessage(gorilla.BinaryMessage, b); err != nil { + return 0, err + } + return len(b), nil +} + +func (c *wsNetConn) writeClose(code int) error { + c.mu.Lock() + defer c.mu.Unlock() + return c.conn.WriteControl( + gorilla.CloseMessage, + gorilla.FormatCloseMessage(code, ""), + time.Now().Add(time.Second), + ) +} + +func (c *wsNetConn) Close() error { return c.conn.Close() } +func (c *wsNetConn) LocalAddr() net.Addr { return c.conn.LocalAddr() } +func (c *wsNetConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } +func (c *wsNetConn) SetDeadline(t time.Time) error { return c.conn.UnderlyingConn().SetDeadline(t) } +func (c *wsNetConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) } +func (c *wsNetConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) } + var nonUTF8Replacement = []byte(string(unicode.ReplacementChar)) // Holds info related to a specific remote CLI that is running. @@ -115,6 +168,10 @@ func ListenForConnections() { listenAddr := "0.0.0.0:22122" if remoteConfig.useWebsockets { + upgrader := &gorilla.Upgrader{ + // Allow all origins since this is an internal service, not a browser-facing endpoint. + CheckOrigin: func(r *http.Request) bool { return true }, + } s := &http.Server{ Addr: listenAddr, ConnContext: func(ctx context.Context, c net.Conn) context.Context { @@ -125,9 +182,15 @@ func ListenForConnections() { } return ctx }, - Handler: websocket.Handler(func(wsConn *websocket.Conn) { - log.Printf("websocket connection from %s\n", wsConn.RemoteAddr().String()) - authConn(wsConn) + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsConn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Printf("websocket upgrade error: %v\n", err) + return + } + netConn := newWSNetConn(wsConn) + log.Printf("websocket connection from %s\n", netConn.RemoteAddr().String()) + authConn(netConn) }), } log.Printf("Listening for websocket protocol on %q...", listenAddr) @@ -408,9 +471,9 @@ func processShutdown(conn net.Conn, wpcli *wpCLIProcess) { wpcli.padlock.Lock() - wsConn, ok := conn.(*websocket.Conn) + wsConn, ok := conn.(*wsNetConn) if ok { - wsConn.WriteClose(shutdownErrorCode) + wsConn.writeClose(shutdownErrorCode) } conn.Close() From e06889a475597e0e64124b890966ae30532f0cf9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 1 Apr 2026 03:11:16 +0000 Subject: [PATCH 2/2] test(remote): add wsNetConn test coverage Agent-Logs-Url: https://github.com/Automattic/cron-control-runner/sessions/92142be9-6f84-4c2f-8bbe-4b29a1ca3652 Co-authored-by: sjinks <7810770+sjinks@users.noreply.github.com> --- remote/wsnetconn_test.go | 277 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 277 insertions(+) create mode 100644 remote/wsnetconn_test.go diff --git a/remote/wsnetconn_test.go b/remote/wsnetconn_test.go new file mode 100644 index 0000000..5b9a39a --- /dev/null +++ b/remote/wsnetconn_test.go @@ -0,0 +1,277 @@ +package remote + +import ( + "bytes" + "net" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + gorilla "github.com/gorilla/websocket" +) + +// newTestWSPair creates a matched server-side *wsNetConn and a raw client-side +// *gorilla.Conn backed by an httptest server. The caller must invoke cleanup() +// when done to release all resources. +func newTestWSPair(t *testing.T) (server *wsNetConn, client *gorilla.Conn, cleanup func()) { + t.Helper() + + upgrader := gorilla.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + } + + serverConnCh := make(chan *gorilla.Conn, 1) + handlerDone := make(chan struct{}) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ws, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade: %v", err) + return + } + serverConnCh <- ws + <-handlerDone // keep the handler alive until cleanup + })) + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/" + clientWS, _, err := gorilla.DefaultDialer.Dial(wsURL, nil) + if err != nil { + close(handlerDone) + srv.Close() + t.Fatalf("dial: %v", err) + } + + serverWS := <-serverConnCh + + return newWSNetConn(serverWS), clientWS, func() { + close(handlerDone) + clientWS.Close() + serverWS.Close() + srv.Close() + } +} + +// TestWSNetConnImplementsNetConn verifies at compile time that *wsNetConn +// satisfies the net.Conn interface. +func TestWSNetConnImplementsNetConn(t *testing.T) { + t.Parallel() + var _ net.Conn = (*wsNetConn)(nil) +} + +// TestWSNetConnRead_SingleMessage verifies that a single message sent by the +// client is fully returned by one (or more) Read calls on the server side. +func TestWSNetConnRead_SingleMessage(t *testing.T) { + t.Parallel() + server, client, cleanup := newTestWSPair(t) + defer cleanup() + + want := []byte("hello world") + if err := client.WriteMessage(gorilla.BinaryMessage, want); err != nil { + t.Fatalf("client write: %v", err) + } + + got := make([]byte, len(want)) + n, err := server.Read(got) + if err != nil { + t.Fatalf("server Read: %v", err) + } + if n != len(want) { + t.Fatalf("Read returned %d bytes, want %d", n, len(want)) + } + if !bytes.Equal(got[:n], want) { + t.Fatalf("Read returned %q, want %q", got[:n], want) + } +} + +// TestWSNetConnRead_Chunked verifies that a large message is returned +// correctly when the caller provides a buffer smaller than the message. +func TestWSNetConnRead_Chunked(t *testing.T) { + t.Parallel() + server, client, cleanup := newTestWSPair(t) + defer cleanup() + + want := bytes.Repeat([]byte("abcde"), 100) // 500 bytes + if err := client.WriteMessage(gorilla.BinaryMessage, want); err != nil { + t.Fatalf("client write: %v", err) + } + + // Read in small chunks of 64 bytes. + var got []byte + chunk := make([]byte, 64) + for len(got) < len(want) { + n, err := server.Read(chunk) + if err != nil { + t.Fatalf("Read error after %d bytes: %v", len(got), err) + } + got = append(got, chunk[:n]...) + } + + if !bytes.Equal(got, want) { + t.Fatalf("chunked Read returned unexpected data (len %d, want %d)", len(got), len(want)) + } +} + +// TestWSNetConnRead_MultipleMessages verifies that successive Read calls +// drain messages in order. +func TestWSNetConnRead_MultipleMessages(t *testing.T) { + t.Parallel() + server, client, cleanup := newTestWSPair(t) + defer cleanup() + + messages := [][]byte{ + []byte("first"), + []byte("second"), + []byte("third"), + } + for _, m := range messages { + if err := client.WriteMessage(gorilla.BinaryMessage, m); err != nil { + t.Fatalf("client write: %v", err) + } + } + + for _, want := range messages { + buf := make([]byte, 256) + n, err := server.Read(buf) + if err != nil { + t.Fatalf("Read: %v", err) + } + if !bytes.Equal(buf[:n], want) { + t.Fatalf("got %q, want %q", buf[:n], want) + } + } +} + +// TestWSNetConnWrite verifies that data written via wsNetConn.Write is +// received intact by the client. +func TestWSNetConnWrite(t *testing.T) { + t.Parallel() + server, client, cleanup := newTestWSPair(t) + defer cleanup() + + want := []byte("server says hello") + n, err := server.Write(want) + if err != nil { + t.Fatalf("Write: %v", err) + } + if n != len(want) { + t.Fatalf("Write returned n=%d, want %d", n, len(want)) + } + + _, got, err := client.ReadMessage() + if err != nil { + t.Fatalf("client ReadMessage: %v", err) + } + if !bytes.Equal(got, want) { + t.Fatalf("client received %q, want %q", got, want) + } +} + +// TestWSNetConnWrite_ConcurrentSafe verifies that concurrent Write calls do +// not panic or return errors (gorilla requires serialised writes; the mutex +// inside wsNetConn should guarantee this). +func TestWSNetConnWrite_ConcurrentSafe(t *testing.T) { + t.Parallel() + server, client, cleanup := newTestWSPair(t) + defer cleanup() + + const goroutines = 20 + var wg sync.WaitGroup + + // Drain client messages concurrently so the server is never blocked. + stop := make(chan struct{}) + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-stop: + return + default: + client.SetReadDeadline(time.Now().Add(50 * time.Millisecond)) + client.ReadMessage() //nolint:errcheck // best-effort drain + } + } + }() + + var writeWg sync.WaitGroup + for i := 0; i < goroutines; i++ { + writeWg.Add(1) + go func() { + defer writeWg.Done() + if _, err := server.Write([]byte("ping")); err != nil { + t.Errorf("concurrent Write: %v", err) + } + }() + } + writeWg.Wait() + close(stop) + wg.Wait() +} + +// TestWSNetConnWriteClose verifies that writeClose sends a well-formed +// WebSocket close frame carrying the expected status code. +func TestWSNetConnWriteClose(t *testing.T) { + t.Parallel() + server, client, cleanup := newTestWSPair(t) + defer cleanup() + + const code = 4001 + if err := server.writeClose(code); err != nil { + t.Fatalf("writeClose: %v", err) + } + + // The client's next read should surface the close message. + client.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, _, err := client.ReadMessage() + if err == nil { + t.Fatal("expected close error from client ReadMessage, got nil") + } + + closeErr, ok := err.(*gorilla.CloseError) + if !ok { + t.Fatalf("expected *gorilla.CloseError, got %T: %v", err, err) + } + if closeErr.Code != code { + t.Fatalf("close code = %d, want %d", closeErr.Code, code) + } +} + +// TestWSNetConnAddrs verifies that LocalAddr and RemoteAddr return non-nil +// addresses (the exact values depend on the OS). +func TestWSNetConnAddrs(t *testing.T) { + t.Parallel() + server, _, cleanup := newTestWSPair(t) + defer cleanup() + + if server.LocalAddr() == nil { + t.Error("LocalAddr() returned nil") + } + if server.RemoteAddr() == nil { + t.Error("RemoteAddr() returned nil") + } +} + +// TestWSNetConnSetReadDeadline verifies that SetReadDeadline causes a +// subsequent Read to time out rather than block forever. +func TestWSNetConnSetReadDeadline(t *testing.T) { + t.Parallel() + server, _, cleanup := newTestWSPair(t) + defer cleanup() + + if err := server.SetReadDeadline(time.Now().Add(50 * time.Millisecond)); err != nil { + t.Fatalf("SetReadDeadline: %v", err) + } + + buf := make([]byte, 16) + _, err := server.Read(buf) + if err == nil { + t.Fatal("expected timeout error, got nil") + } + // The error should indicate a timeout. + if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { + t.Fatalf("expected net.Error with Timeout()=true, got %T: %v", err, err) + } +}