11import type { WebSocketEventType } from "coder/site/src/utils/OneWayWebSocket" ;
2+ import type { CloseEvent } from "ws" ;
23
34import type { Logger } from "../logging/logger" ;
45
@@ -18,9 +19,6 @@ export type ReconnectingWebSocketOptions = {
1819// 403 Forbidden, 410 Gone, 426 Upgrade Required, 1002/1003 Protocol errors
1920const UNRECOVERABLE_CLOSE_CODES = new Set ( [ 403 , 410 , 426 , 1002 , 1003 ] ) ;
2021
21- // Custom close code for intentional reconnection (4000-4999 range is for private use)
22- const CLOSE_CODE_RECONNECTING = 4000 ;
23-
2422export class ReconnectingWebSocket < TData = unknown >
2523 implements UnidirectionalStream < TData >
2624{
@@ -40,12 +38,15 @@ export class ReconnectingWebSocket<TData = unknown>
4038 #reconnectTimeoutId: NodeJS . Timeout | null = null ;
4139 #isDisposed = false ;
4240 #isConnecting = false ;
41+ #pendingReconnect = false ;
42+ readonly #onDispose?: ( ) => void ;
4343
4444 private constructor (
4545 socketFactory : SocketFactory < TData > ,
4646 logger : Logger ,
4747 apiRoute : string ,
4848 options : ReconnectingWebSocketOptions = { } ,
49+ onDispose ?: ( ) => void ,
4950 ) {
5051 this . #socketFactory = socketFactory ;
5152 this . #logger = logger ;
@@ -56,19 +57,22 @@ export class ReconnectingWebSocket<TData = unknown>
5657 jitterFactor : options . jitterFactor ?? 0.1 ,
5758 } ;
5859 this . #backoffMs = this . #options. initialBackoffMs ;
60+ this . #onDispose = onDispose ;
5961 }
6062
6163 static async create < TData > (
6264 socketFactory : SocketFactory < TData > ,
6365 logger : Logger ,
6466 apiRoute : string ,
6567 options : ReconnectingWebSocketOptions = { } ,
68+ onDispose ?: ( ) => void ,
6669 ) : Promise < ReconnectingWebSocket < TData > > {
6770 const instance = new ReconnectingWebSocket < TData > (
6871 socketFactory ,
6972 logger ,
7073 apiRoute ,
7174 options ,
75+ onDispose ,
7276 ) ;
7377 await instance . #connect( ) ;
7478 return instance ;
@@ -85,10 +89,6 @@ export class ReconnectingWebSocket<TData = unknown>
8589 ( this . #eventHandlers[ event ] as Set < EventHandler < TData , TEvent > > ) . add (
8690 callback ,
8791 ) ;
88-
89- if ( this . #currentSocket) {
90- this . #currentSocket. addEventListener ( event , callback ) ;
91- }
9292 }
9393
9494 removeEventListener < TEvent extends WebSocketEventType > (
@@ -98,9 +98,23 @@ export class ReconnectingWebSocket<TData = unknown>
9898 ( this . #eventHandlers[ event ] as Set < EventHandler < TData , TEvent > > ) . delete (
9999 callback ,
100100 ) ;
101+ }
101102
102- if ( this . #currentSocket) {
103- this . #currentSocket. removeEventListener ( event , callback ) ;
103+ #executeHandlers< TEvent extends WebSocketEventType > (
104+ event : TEvent ,
105+ eventData : Parameters < EventHandler < TData , TEvent > > [ 0 ] ,
106+ ) : void {
107+ const handlers = this . #eventHandlers[ event ] as Set <
108+ EventHandler < TData , TEvent >
109+ > ;
110+ for ( const handler of handlers ) {
111+ try {
112+ handler ( eventData ) ;
113+ } catch ( error ) {
114+ this . #logger. error (
115+ `Error in ${ event } handler for ${ this . #apiRoute} : ${ error instanceof Error ? error . message : String ( error ) } ` ,
116+ ) ;
117+ }
104118 }
105119 }
106120
@@ -109,6 +123,25 @@ export class ReconnectingWebSocket<TData = unknown>
109123 return ;
110124 }
111125
126+ // Fire close handlers synchronously before disposing
127+ if ( this . #currentSocket) {
128+ this . #executeHandlers( "close" , {
129+ code : code ?? 1000 ,
130+ reason : reason ?? "" ,
131+ wasClean : true ,
132+ type : "close" ,
133+ target : this . #currentSocket,
134+ } as CloseEvent ) ;
135+ }
136+
137+ this . #dispose( code , reason ) ;
138+ }
139+
140+ #dispose( code ?: number , reason ?: string ) : void {
141+ if ( this . #isDisposed) {
142+ return ;
143+ }
144+
112145 this . #isDisposed = true ;
113146
114147 if ( this . #reconnectTimeoutId !== null ) {
@@ -124,6 +157,8 @@ export class ReconnectingWebSocket<TData = unknown>
124157 for ( const set of Object . values ( this . #eventHandlers) ) {
125158 set . clear ( ) ;
126159 }
160+
161+ this . #onDispose?.( ) ;
127162 }
128163
129164 reconnect ( ) : void {
@@ -136,9 +171,21 @@ export class ReconnectingWebSocket<TData = unknown>
136171 this . #reconnectTimeoutId = null ;
137172 }
138173
139- if ( this . #currentSocket) {
140- this . #currentSocket. close ( CLOSE_CODE_RECONNECTING , "Reconnecting" ) ;
174+ // If already connecting, schedule reconnect after current attempt
175+ if ( this . #isConnecting) {
176+ this . #pendingReconnect = true ;
177+ return ;
141178 }
179+
180+ // #connect() will close any existing socket
181+ this . #connect( ) . catch ( ( error ) => {
182+ if ( ! this . #isDisposed) {
183+ this . #logger. warn (
184+ `Manual reconnection failed for ${ this . #apiRoute} : ${ error instanceof Error ? error . message : String ( error ) } ` ,
185+ ) ;
186+ this . #scheduleReconnect( ) ;
187+ }
188+ } ) ;
142189 }
143190
144191 async #connect( ) : Promise < void > {
@@ -148,45 +195,40 @@ export class ReconnectingWebSocket<TData = unknown>
148195
149196 this . #isConnecting = true ;
150197 try {
198+ // Close any existing socket before creating a new one
199+ if ( this . #currentSocket) {
200+ this . #currentSocket. close ( 1000 , "Replacing connection" ) ;
201+ this . #currentSocket = null ;
202+ }
203+
151204 const socket = await this . #socketFactory( ) ;
152205 this . #currentSocket = socket ;
153206
154- socket . addEventListener ( "open" , ( ) => {
207+ socket . addEventListener ( "open" , ( event ) => {
155208 this . #backoffMs = this . #options. initialBackoffMs ;
209+ this . #executeHandlers( "open" , event ) ;
156210 } ) ;
157211
158- for ( const handler of this . #eventHandlers. open ) {
159- socket . addEventListener ( "open" , handler ) ;
160- }
161-
162- for ( const handler of this . #eventHandlers. message ) {
163- socket . addEventListener ( "message" , handler ) ;
164- }
212+ socket . addEventListener ( "message" , ( event ) => {
213+ this . #executeHandlers( "message" , event ) ;
214+ } ) ;
165215
166- for ( const handler of this . #eventHandlers . error ) {
167- socket . addEventListener ( "error" , handler ) ;
168- }
216+ socket . addEventListener ( " error" , ( event ) => {
217+ this . #executeHandlers ( "error" , event ) ;
218+ } ) ;
169219
170220 socket . addEventListener ( "close" , ( event ) => {
171- for ( const handler of this . #eventHandlers. close ) {
172- handler ( event ) ;
173- }
174-
175221 if ( this . #isDisposed) {
176222 return ;
177223 }
178224
225+ this . #executeHandlers( "close" , event ) ;
226+
179227 if ( UNRECOVERABLE_CLOSE_CODES . has ( event . code ) ) {
180228 this . #logger. error (
181229 `WebSocket connection closed with unrecoverable error code ${ event . code } ` ,
182230 ) ;
183- this . #isDisposed = true ;
184- return ;
185- }
186-
187- // Reconnect if this was an intentional close for reconnection
188- if ( event . code === CLOSE_CODE_RECONNECTING ) {
189- this . #scheduleReconnect( ) ;
231+ this . #dispose( ) ;
190232 return ;
191233 }
192234
@@ -200,6 +242,11 @@ export class ReconnectingWebSocket<TData = unknown>
200242 } ) ;
201243 } finally {
202244 this . #isConnecting = false ;
245+
246+ if ( this . #pendingReconnect) {
247+ this . #pendingReconnect = false ;
248+ this . reconnect ( ) ;
249+ }
203250 }
204251 }
205252
@@ -218,7 +265,6 @@ export class ReconnectingWebSocket<TData = unknown>
218265
219266 this . #reconnectTimeoutId = setTimeout ( ( ) => {
220267 this . #reconnectTimeoutId = null ;
221- // Errors already handled in #connect
222268 this . #connect( ) . catch ( ( error ) => {
223269 if ( ! this . #isDisposed) {
224270 this . #logger. warn (
@@ -231,8 +277,4 @@ export class ReconnectingWebSocket<TData = unknown>
231277
232278 this . #backoffMs = Math . min ( this . #backoffMs * 2 , this . #options. maxBackoffMs ) ;
233279 }
234-
235- isDisposed ( ) : boolean {
236- return this . #isDisposed;
237- }
238280}
0 commit comments