@@ -14,7 +14,7 @@ import {
1414 type WorkspaceAgentLog ,
1515} from "coder/site/src/api/typesGenerated" ;
1616import * as vscode from "vscode" ;
17- import { type ClientOptions , type CloseEvent , type ErrorEvent } from "ws" ;
17+ import { type ClientOptions } from "ws" ;
1818
1919import { CertificateError } from "../error" ;
2020import { getHeaderCommand , getHeaders } from "../headers" ;
@@ -31,11 +31,20 @@ import {
3131 HttpClientLogLevel ,
3232} from "../logging/types" ;
3333import { sizeOf } from "../logging/utils" ;
34- import { type UnidirectionalStream } from "../websocket/eventStreamConnection" ;
34+ import { HttpStatusCode } from "../websocket/codes" ;
35+ import {
36+ type UnidirectionalStream ,
37+ type CloseEvent ,
38+ type ErrorEvent ,
39+ } from "../websocket/eventStreamConnection" ;
3540import {
3641 OneWayWebSocket ,
3742 type OneWayWebSocketInit ,
3843} from "../websocket/oneWayWebSocket" ;
44+ import {
45+ ReconnectingWebSocket ,
46+ type SocketFactory ,
47+ } from "../websocket/reconnectingWebSocket" ;
3948import { SseConnection } from "../websocket/sseConnection" ;
4049
4150import { createHttpAgent } from "./utils" ;
@@ -47,6 +56,10 @@ const coderSessionTokenHeader = "Coder-Session-Token";
4756 * and WebSocket methods for real-time functionality.
4857 */
4958export class CoderApi extends Api {
59+ private readonly reconnectingSockets = new Set <
60+ ReconnectingWebSocket < unknown >
61+ > ( ) ;
62+
5063 private constructor ( private readonly output : Logger ) {
5164 super ( ) ;
5265 }
@@ -66,10 +79,34 @@ export class CoderApi extends Api {
6679 client . setSessionToken ( token ) ;
6780 }
6881
69- setupInterceptors ( client , baseUrl , output ) ;
82+ setupInterceptors ( client , output ) ;
7083 return client ;
7184 }
7285
86+ setSessionToken = ( token : string ) : void => {
87+ const defaultHeaders = this . getAxiosInstance ( ) . defaults . headers . common ;
88+ const currentToken = defaultHeaders [ coderSessionTokenHeader ] ;
89+ defaultHeaders [ coderSessionTokenHeader ] = token ;
90+
91+ if ( currentToken !== token ) {
92+ for ( const socket of this . reconnectingSockets ) {
93+ socket . reconnect ( ) ;
94+ }
95+ }
96+ } ;
97+
98+ setHost = ( host : string | undefined ) : void => {
99+ const defaults = this . getAxiosInstance ( ) . defaults ;
100+ const currentHost = defaults . baseURL ;
101+ defaults . baseURL = host ;
102+
103+ if ( currentHost !== host ) {
104+ for ( const socket of this . reconnectingSockets ) {
105+ socket . reconnect ( ) ;
106+ }
107+ }
108+ } ;
109+
73110 watchInboxNotifications = async (
74111 watchTemplates : string [ ] ,
75112 watchTargets : string [ ] ,
@@ -83,6 +120,7 @@ export class CoderApi extends Api {
83120 targets : watchTargets . join ( "," ) ,
84121 } ,
85122 options,
123+ enableRetry : true ,
86124 } ) ;
87125 } ;
88126
@@ -91,6 +129,7 @@ export class CoderApi extends Api {
91129 apiRoute : `/api/v2/workspaces/${ workspace . id } /watch-ws` ,
92130 fallbackApiRoute : `/api/v2/workspaces/${ workspace . id } /watch` ,
93131 options,
132+ enableRetry : true ,
94133 } ) ;
95134 } ;
96135
@@ -102,6 +141,7 @@ export class CoderApi extends Api {
102141 apiRoute : `/api/v2/workspaceagents/${ agentId } /watch-metadata-ws` ,
103142 fallbackApiRoute : `/api/v2/workspaceagents/${ agentId } /watch-metadata` ,
104143 options,
144+ enableRetry : true ,
105145 } ) ;
106146 } ;
107147
@@ -148,53 +188,78 @@ export class CoderApi extends Api {
148188 }
149189
150190 private async createWebSocket < TData = unknown > (
151- configs : Omit < OneWayWebSocketInit , "location" > ,
152- ) {
153- const baseUrlRaw = this . getAxiosInstance ( ) . defaults . baseURL ;
154- if ( ! baseUrlRaw ) {
155- throw new Error ( "No base URL set on REST client" ) ;
156- }
191+ configs : Omit < OneWayWebSocketInit , "location" > & { enableRetry ?: boolean } ,
192+ ) : Promise < UnidirectionalStream < TData > > {
193+ const { enableRetry, ...socketConfigs } = configs ;
194+
195+ const socketFactory : SocketFactory < TData > = async ( ) => {
196+ const baseUrlRaw = this . getAxiosInstance ( ) . defaults . baseURL ;
197+ if ( ! baseUrlRaw ) {
198+ throw new Error ( "No base URL set on REST client" ) ;
199+ }
200+
201+ const baseUrl = new URL ( baseUrlRaw ) ;
202+ const token = this . getAxiosInstance ( ) . defaults . headers . common [
203+ coderSessionTokenHeader
204+ ] as string | undefined ;
205+
206+ const headersFromCommand = await getHeaders (
207+ baseUrlRaw ,
208+ getHeaderCommand ( vscode . workspace . getConfiguration ( ) ) ,
209+ this . output ,
210+ ) ;
157211
158- const baseUrl = new URL ( baseUrlRaw ) ;
159- const token = this . getAxiosInstance ( ) . defaults . headers . common [
160- coderSessionTokenHeader
161- ] as string | undefined ;
212+ const httpAgent = await createHttpAgent (
213+ vscode . workspace . getConfiguration ( ) ,
214+ ) ;
162215
163- const headersFromCommand = await getHeaders (
164- baseUrlRaw ,
165- getHeaderCommand ( vscode . workspace . getConfiguration ( ) ) ,
166- this . output ,
167- ) ;
216+ /**
217+ * Similar to the REST client, we want to prioritize headers in this order (highest to lowest):
218+ * 1. Headers from the header command
219+ * 2. Any headers passed directly to this function
220+ * 3. Coder session token from the Api client (if set)
221+ */
222+ const headers = {
223+ ...( token ? { [ coderSessionTokenHeader ] : token } : { } ) ,
224+ ...configs . options ?. headers ,
225+ ...headersFromCommand ,
226+ } ;
168227
169- const httpAgent = await createHttpAgent (
170- vscode . workspace . getConfiguration ( ) ,
171- ) ;
228+ const webSocket = new OneWayWebSocket < TData > ( {
229+ location : baseUrl ,
230+ ...socketConfigs ,
231+ options : {
232+ ...configs . options ,
233+ agent : httpAgent ,
234+ followRedirects : true ,
235+ headers,
236+ } ,
237+ } ) ;
172238
173- /**
174- * Similar to the REST client, we want to prioritize headers in this order (highest to lowest):
175- * 1. Headers from the header command
176- * 2. Any headers passed directly to this function
177- * 3. Coder session token from the Api client (if set)
178- */
179- const headers = {
180- ...( token ? { [ coderSessionTokenHeader ] : token } : { } ) ,
181- ...configs . options ?. headers ,
182- ...headersFromCommand ,
239+ this . attachStreamLogger ( webSocket ) ;
240+ return webSocket ;
183241 } ;
184242
185- const webSocket = new OneWayWebSocket < TData > ( {
186- location : baseUrl ,
187- ...configs ,
188- options : {
189- ...configs . options ,
190- agent : httpAgent ,
191- followRedirects : true ,
192- headers,
193- } ,
194- } ) ;
243+ if ( enableRetry ) {
244+ const reconnectingSocket = await ReconnectingWebSocket . create < TData > (
245+ socketFactory ,
246+ this . output ,
247+ configs . apiRoute ,
248+ undefined ,
249+ ( ) =>
250+ this . reconnectingSockets . delete (
251+ reconnectingSocket as ReconnectingWebSocket < unknown > ,
252+ ) ,
253+ ) ;
254+
255+ this . reconnectingSockets . add (
256+ reconnectingSocket as ReconnectingWebSocket < unknown > ,
257+ ) ;
195258
196- this . attachStreamLogger ( webSocket ) ;
197- return webSocket ;
259+ return reconnectingSocket ;
260+ } else {
261+ return socketFactory ( ) ;
262+ }
198263 }
199264
200265 private attachStreamLogger < TData > (
@@ -230,13 +295,15 @@ export class CoderApi extends Api {
230295 fallbackApiRoute : string ;
231296 searchParams ?: Record < string , string > | URLSearchParams ;
232297 options ?: ClientOptions ;
298+ enableRetry ?: boolean ;
233299 } ) : Promise < UnidirectionalStream < TData > > {
234- let webSocket : OneWayWebSocket < TData > ;
300+ let webSocket : UnidirectionalStream < TData > ;
235301 try {
236302 webSocket = await this . createWebSocket < TData > ( {
237303 apiRoute : configs . apiRoute ,
238304 searchParams : configs . searchParams ,
239305 options : configs . options ,
306+ enableRetry : configs . enableRetry ,
240307 } ) ;
241308 } catch {
242309 // Failed to create WebSocket, use SSE fallback
@@ -274,8 +341,8 @@ export class CoderApi extends Api {
274341 const handleError = ( event : ErrorEvent ) => {
275342 cleanup ( ) ;
276343 const is404 =
277- event . message ?. includes ( "404" ) ||
278- event . error ?. message ?. includes ( "404" ) ;
344+ event . message ?. includes ( String ( HttpStatusCode . NOT_FOUND ) ) ||
345+ event . error ?. message ?. includes ( String ( HttpStatusCode . NOT_FOUND ) ) ;
279346
280347 if ( is404 && onNotFound ) {
281348 connection . close ( ) ;
@@ -323,14 +390,11 @@ export class CoderApi extends Api {
323390/**
324391 * Set up logging and request interceptors for the CoderApi instance.
325392 */
326- function setupInterceptors (
327- client : CoderApi ,
328- baseUrl : string ,
329- output : Logger ,
330- ) : void {
393+ function setupInterceptors ( client : CoderApi , output : Logger ) : void {
331394 addLoggingInterceptors ( client . getAxiosInstance ( ) , output ) ;
332395
333396 client . getAxiosInstance ( ) . interceptors . request . use ( async ( config ) => {
397+ const baseUrl = client . getAxiosInstance ( ) . defaults . baseURL ;
334398 const headers = await getHeaders (
335399 baseUrl ,
336400 getHeaderCommand ( vscode . workspace . getConfiguration ( ) ) ,
@@ -356,7 +420,12 @@ function setupInterceptors(
356420 client . getAxiosInstance ( ) . interceptors . response . use (
357421 ( r ) => r ,
358422 async ( err ) => {
359- throw await CertificateError . maybeWrap ( err , baseUrl , output ) ;
423+ const baseUrl = client . getAxiosInstance ( ) . defaults . baseURL ;
424+ if ( baseUrl ) {
425+ throw await CertificateError . maybeWrap ( err , baseUrl , output ) ;
426+ } else {
427+ throw err ;
428+ }
360429 } ,
361430 ) ;
362431}
0 commit comments