Skip to content

Commit 8529e15

Browse files
committed
Various improvements to reconnecting websockets + More tests
1 parent 6694117 commit 8529e15

File tree

4 files changed

+242
-171
lines changed

4 files changed

+242
-171
lines changed

src/api/coderApi.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,11 @@ export class CoderApi extends Api {
240240
socketFactory,
241241
this.output,
242242
configs.apiRoute,
243+
undefined,
244+
() =>
245+
this.reconnectingSockets.delete(
246+
reconnectingSocket as ReconnectingWebSocket<unknown>,
247+
),
243248
);
244249

245250
this.reconnectingSockets.add(

src/websocket/reconnectingWebSocket.ts

Lines changed: 80 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import type { WebSocketEventType } from "coder/site/src/utils/OneWayWebSocket";
2+
import type { CloseEvent } from "ws";
23

34
import type { Logger } from "../logging/logger";
45

@@ -18,9 +19,6 @@ export type ReconnectingWebSocketOptions = {
1819
// 403 Forbidden, 410 Gone, 426 Upgrade Required, 1002/1003 Protocol errors
1920
const UNRECOVERABLE_CLOSE_CODES = new Set([403, 410, 426, 1002, 1003]);
2021

21-
// Custom close code for intentional reconnection (4000-4999 range is for private use)
22-
const CLOSE_CODE_RECONNECTING = 4000;
23-
2422
export class ReconnectingWebSocket<TData = unknown>
2523
implements UnidirectionalStream<TData>
2624
{
@@ -40,12 +38,15 @@ export class ReconnectingWebSocket<TData = unknown>
4038
#reconnectTimeoutId: NodeJS.Timeout | null = null;
4139
#isDisposed = false;
4240
#isConnecting = false;
41+
#pendingReconnect = false;
42+
readonly #onDispose?: () => void;
4343

4444
private constructor(
4545
socketFactory: SocketFactory<TData>,
4646
logger: Logger,
4747
apiRoute: string,
4848
options: ReconnectingWebSocketOptions = {},
49+
onDispose?: () => void,
4950
) {
5051
this.#socketFactory = socketFactory;
5152
this.#logger = logger;
@@ -56,19 +57,22 @@ export class ReconnectingWebSocket<TData = unknown>
5657
jitterFactor: options.jitterFactor ?? 0.1,
5758
};
5859
this.#backoffMs = this.#options.initialBackoffMs;
60+
this.#onDispose = onDispose;
5961
}
6062

6163
static async create<TData>(
6264
socketFactory: SocketFactory<TData>,
6365
logger: Logger,
6466
apiRoute: string,
6567
options: ReconnectingWebSocketOptions = {},
68+
onDispose?: () => void,
6669
): Promise<ReconnectingWebSocket<TData>> {
6770
const instance = new ReconnectingWebSocket<TData>(
6871
socketFactory,
6972
logger,
7073
apiRoute,
7174
options,
75+
onDispose,
7276
);
7377
await instance.#connect();
7478
return instance;
@@ -85,10 +89,6 @@ export class ReconnectingWebSocket<TData = unknown>
8589
(this.#eventHandlers[event] as Set<EventHandler<TData, TEvent>>).add(
8690
callback,
8791
);
88-
89-
if (this.#currentSocket) {
90-
this.#currentSocket.addEventListener(event, callback);
91-
}
9292
}
9393

9494
removeEventListener<TEvent extends WebSocketEventType>(
@@ -98,9 +98,23 @@ export class ReconnectingWebSocket<TData = unknown>
9898
(this.#eventHandlers[event] as Set<EventHandler<TData, TEvent>>).delete(
9999
callback,
100100
);
101+
}
101102

102-
if (this.#currentSocket) {
103-
this.#currentSocket.removeEventListener(event, callback);
103+
#executeHandlers<TEvent extends WebSocketEventType>(
104+
event: TEvent,
105+
eventData: Parameters<EventHandler<TData, TEvent>>[0],
106+
): void {
107+
const handlers = this.#eventHandlers[event] as Set<
108+
EventHandler<TData, TEvent>
109+
>;
110+
for (const handler of handlers) {
111+
try {
112+
handler(eventData);
113+
} catch (error) {
114+
this.#logger.error(
115+
`Error in ${event} handler for ${this.#apiRoute}: ${error instanceof Error ? error.message : String(error)}`,
116+
);
117+
}
104118
}
105119
}
106120

@@ -109,6 +123,25 @@ export class ReconnectingWebSocket<TData = unknown>
109123
return;
110124
}
111125

126+
// Fire close handlers synchronously before disposing
127+
if (this.#currentSocket) {
128+
this.#executeHandlers("close", {
129+
code: code ?? 1000,
130+
reason: reason ?? "",
131+
wasClean: true,
132+
type: "close",
133+
target: this.#currentSocket,
134+
} as CloseEvent);
135+
}
136+
137+
this.#dispose(code, reason);
138+
}
139+
140+
#dispose(code?: number, reason?: string): void {
141+
if (this.#isDisposed) {
142+
return;
143+
}
144+
112145
this.#isDisposed = true;
113146

114147
if (this.#reconnectTimeoutId !== null) {
@@ -124,6 +157,8 @@ export class ReconnectingWebSocket<TData = unknown>
124157
for (const set of Object.values(this.#eventHandlers)) {
125158
set.clear();
126159
}
160+
161+
this.#onDispose?.();
127162
}
128163

129164
reconnect(): void {
@@ -136,9 +171,21 @@ export class ReconnectingWebSocket<TData = unknown>
136171
this.#reconnectTimeoutId = null;
137172
}
138173

139-
if (this.#currentSocket) {
140-
this.#currentSocket.close(CLOSE_CODE_RECONNECTING, "Reconnecting");
174+
// If already connecting, schedule reconnect after current attempt
175+
if (this.#isConnecting) {
176+
this.#pendingReconnect = true;
177+
return;
141178
}
179+
180+
// #connect() will close any existing socket
181+
this.#connect().catch((error) => {
182+
if (!this.#isDisposed) {
183+
this.#logger.warn(
184+
`Manual reconnection failed for ${this.#apiRoute}: ${error instanceof Error ? error.message : String(error)}`,
185+
);
186+
this.#scheduleReconnect();
187+
}
188+
});
142189
}
143190

144191
async #connect(): Promise<void> {
@@ -148,45 +195,40 @@ export class ReconnectingWebSocket<TData = unknown>
148195

149196
this.#isConnecting = true;
150197
try {
198+
// Close any existing socket before creating a new one
199+
if (this.#currentSocket) {
200+
this.#currentSocket.close(1000, "Replacing connection");
201+
this.#currentSocket = null;
202+
}
203+
151204
const socket = await this.#socketFactory();
152205
this.#currentSocket = socket;
153206

154-
socket.addEventListener("open", () => {
207+
socket.addEventListener("open", (event) => {
155208
this.#backoffMs = this.#options.initialBackoffMs;
209+
this.#executeHandlers("open", event);
156210
});
157211

158-
for (const handler of this.#eventHandlers.open) {
159-
socket.addEventListener("open", handler);
160-
}
161-
162-
for (const handler of this.#eventHandlers.message) {
163-
socket.addEventListener("message", handler);
164-
}
212+
socket.addEventListener("message", (event) => {
213+
this.#executeHandlers("message", event);
214+
});
165215

166-
for (const handler of this.#eventHandlers.error) {
167-
socket.addEventListener("error", handler);
168-
}
216+
socket.addEventListener("error", (event) => {
217+
this.#executeHandlers("error", event);
218+
});
169219

170220
socket.addEventListener("close", (event) => {
171-
for (const handler of this.#eventHandlers.close) {
172-
handler(event);
173-
}
174-
175221
if (this.#isDisposed) {
176222
return;
177223
}
178224

225+
this.#executeHandlers("close", event);
226+
179227
if (UNRECOVERABLE_CLOSE_CODES.has(event.code)) {
180228
this.#logger.error(
181229
`WebSocket connection closed with unrecoverable error code ${event.code}`,
182230
);
183-
this.#isDisposed = true;
184-
return;
185-
}
186-
187-
// Reconnect if this was an intentional close for reconnection
188-
if (event.code === CLOSE_CODE_RECONNECTING) {
189-
this.#scheduleReconnect();
231+
this.#dispose();
190232
return;
191233
}
192234

@@ -200,6 +242,11 @@ export class ReconnectingWebSocket<TData = unknown>
200242
});
201243
} finally {
202244
this.#isConnecting = false;
245+
246+
if (this.#pendingReconnect) {
247+
this.#pendingReconnect = false;
248+
this.reconnect();
249+
}
203250
}
204251
}
205252

@@ -218,7 +265,6 @@ export class ReconnectingWebSocket<TData = unknown>
218265

219266
this.#reconnectTimeoutId = setTimeout(() => {
220267
this.#reconnectTimeoutId = null;
221-
// Errors already handled in #connect
222268
this.#connect().catch((error) => {
223269
if (!this.#isDisposed) {
224270
this.#logger.warn(
@@ -231,8 +277,4 @@ export class ReconnectingWebSocket<TData = unknown>
231277

232278
this.#backoffMs = Math.min(this.#backoffMs * 2, this.#options.maxBackoffMs);
233279
}
234-
235-
isDisposed(): boolean {
236-
return this.#isDisposed;
237-
}
238280
}

test/unit/api/coderApi.test.ts

Lines changed: 37 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -374,96 +374,75 @@ describe("CoderApi", () => {
374374
});
375375

376376
describe("Reconnection on Host/Token Changes", () => {
377-
it("triggers reconnection when session token changes", async () => {
378-
const mockWs = createMockWebSocket(
379-
`wss://${CODER_URL.replace("https://", "")}/api/v2/workspaceagents/${AGENT_ID}/watch-metadata-ws`,
380-
{
377+
const setupAutoOpeningWebSocket = () => {
378+
const sockets: Array<Partial<Ws>> = [];
379+
vi.mocked(Ws).mockImplementation((url: string | URL) => {
380+
const mockWs = createMockWebSocket(String(url), {
381381
on: vi.fn((event, handler) => {
382382
if (event === "open") {
383383
setImmediate(() => handler());
384384
}
385385
return mockWs as Ws;
386386
}),
387-
},
388-
);
389-
setupWebSocketMock(mockWs);
387+
});
388+
sockets.push(mockWs);
389+
return mockWs as Ws;
390+
});
391+
return sockets;
392+
};
390393

394+
it("triggers reconnection when session token changes", async () => {
395+
const sockets = setupAutoOpeningWebSocket();
391396
api = createApi(CODER_URL, AXIOS_TOKEN);
392-
const _ws = await api.watchAgentMetadata(AGENT_ID);
397+
await api.watchAgentMetadata(AGENT_ID);
393398

394-
// Change token - should trigger reconnection
395399
api.setSessionToken("new-token");
400+
await new Promise((resolve) => setImmediate(resolve));
396401

397-
expect(mockWs.close).toHaveBeenCalledWith(4000, "Reconnecting");
402+
expect(sockets[0].close).toHaveBeenCalledWith(
403+
1000,
404+
"Replacing connection",
405+
);
406+
expect(sockets).toHaveLength(2);
398407
});
399408

400409
it("triggers reconnection when host changes", async () => {
401-
const mockWs = createMockWebSocket(
402-
`wss://${CODER_URL.replace("https://", "")}/api/v2/workspaceagents/${AGENT_ID}/watch-metadata-ws`,
403-
{
404-
on: vi.fn((event, handler) => {
405-
if (event === "open") {
406-
setImmediate(() => handler());
407-
}
408-
return mockWs as Ws;
409-
}),
410-
},
411-
);
412-
setupWebSocketMock(mockWs);
413-
410+
const sockets = setupAutoOpeningWebSocket();
414411
api = createApi(CODER_URL, AXIOS_TOKEN);
415-
const _ws = await api.watchAgentMetadata(AGENT_ID);
412+
const wsWrap = await api.watchAgentMetadata(AGENT_ID);
413+
expect(wsWrap.url).toContain(CODER_URL.replace("http", "ws"));
416414

417-
// Change host - should trigger reconnection
418415
api.setHost("https://new-coder.example.com");
416+
await new Promise((resolve) => setImmediate(resolve));
419417

420-
expect(mockWs.close).toHaveBeenCalledWith(4000, "Reconnecting");
418+
expect(sockets[0].close).toHaveBeenCalledWith(
419+
1000,
420+
"Replacing connection",
421+
);
422+
expect(sockets).toHaveLength(2);
423+
expect(wsWrap.url).toContain("wss://new-coder.example.com");
421424
});
422425

423426
it("does not reconnect when token is set to same value", async () => {
424-
const mockWs = createMockWebSocket(
425-
`wss://${CODER_URL.replace("https://", "")}/api/v2/workspaceagents/${AGENT_ID}/watch-metadata-ws`,
426-
{
427-
on: vi.fn((event, handler) => {
428-
if (event === "open") {
429-
setImmediate(() => handler());
430-
}
431-
return mockWs as Ws;
432-
}),
433-
},
434-
);
435-
setupWebSocketMock(mockWs);
436-
427+
const sockets = setupAutoOpeningWebSocket();
437428
api = createApi(CODER_URL, AXIOS_TOKEN);
438-
const _ws = await api.watchAgentMetadata(AGENT_ID);
429+
await api.watchAgentMetadata(AGENT_ID);
439430

440-
// Set same token - should NOT trigger reconnection
441431
api.setSessionToken(AXIOS_TOKEN);
442432

443-
expect(mockWs.close).not.toHaveBeenCalled();
433+
expect(sockets[0].close).not.toHaveBeenCalled();
434+
expect(sockets).toHaveLength(1);
444435
});
445436

446437
it("does not reconnect when host is set to same value", async () => {
447-
const mockWs = createMockWebSocket(
448-
`wss://${CODER_URL.replace("https://", "")}/api/v2/workspaceagents/${AGENT_ID}/watch-metadata-ws`,
449-
{
450-
on: vi.fn((event, handler) => {
451-
if (event === "open") {
452-
setImmediate(() => handler());
453-
}
454-
return mockWs as Ws;
455-
}),
456-
},
457-
);
458-
setupWebSocketMock(mockWs);
459-
438+
const sockets = setupAutoOpeningWebSocket();
460439
api = createApi(CODER_URL, AXIOS_TOKEN);
461-
const _ws = await api.watchAgentMetadata(AGENT_ID);
440+
await api.watchAgentMetadata(AGENT_ID);
462441

463-
// Set same host - should NOT trigger reconnection
464442
api.setHost(CODER_URL);
465443

466-
expect(mockWs.close).not.toHaveBeenCalled();
444+
expect(sockets[0].close).not.toHaveBeenCalled();
445+
expect(sockets).toHaveLength(1);
467446
});
468447
});
469448

0 commit comments

Comments
 (0)