Skip to content

Commit e75c19e

Browse files
committed
feat: continue oauth for cli mcps
1 parent 9bb396b commit e75c19e

File tree

6 files changed

+188
-38
lines changed

6 files changed

+188
-38
lines changed

extensions/cli/src/auth/ensureAuth.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import { isAuthenticated, login } from "./workos.js";
99
export async function ensureAuthenticated(
1010
requireAuth: boolean = true,
1111
): Promise<boolean> {
12-
if (isAuthenticated()) {
12+
if (await isAuthenticated()) {
1313
return true;
1414
}
1515

extensions/cli/src/auth/workos.ts

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ import {
115115
persistModelName,
116116
} from "../util/modelPersistence.js";
117117

118+
import { logger } from "src/util/logger.js";
118119
import { autoSelectOrganizationAndConfig } from "./orgSelection.js";
119120
import { pathToUri, slugToUri, uriToPath, uriToSlug } from "./uriUtils.js";
120121
import {
@@ -266,7 +267,7 @@ export function updateLocalConfigPath(localConfigPath: string | null): void {
266267
/**
267268
* Checks if the user is authenticated and the token is valid
268269
*/
269-
export function isAuthenticated(): boolean {
270+
export async function isAuthenticated(): Promise<boolean> {
270271
const config = loadAuthConfig();
271272

272273
if (config === null) {
@@ -278,17 +279,14 @@ export function isAuthenticated(): boolean {
278279
return true;
279280
}
280281

281-
/**
282-
* THIS CODE DOESN'T WORK.
283-
* .catch() will never return in a non-async function.
284-
* It's a hallucination.
285-
**/
286282
if (Date.now() > config.expiresAt) {
287-
// Try refreshing the token
288-
refreshToken(config.refreshToken).catch(() => {
289-
// If refresh fails, we're not authenticated
283+
try {
284+
const refreshed = await refreshToken(config.refreshToken);
285+
return isAuthenticatedConfig(refreshed);
286+
} catch (e) {
287+
logger.error("Failed to refresh auto token");
290288
return false;
291-
});
289+
}
292290
}
293291

294292
return true;

extensions/cli/src/infoScreen.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ export async function handleInfoSlashCommand() {
2626
);
2727

2828
// Auth info
29-
if (isAuthenticated()) {
29+
if (await isAuthenticated()) {
3030
const config = loadAuthConfig();
3131
if (config && isAuthenticatedConfig(config)) {
3232
const email = config.userEmail || config.userId;

extensions/cli/src/services/AuthService.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ export class AuthService extends BaseService<AuthServiceState> {
3030
*/
3131
async doInitialize(): Promise<AuthServiceState> {
3232
const authConfig = loadAuthConfig();
33-
const authenticated = isAuthenticated();
33+
const authenticated = await isAuthenticated();
3434

3535
const state: AuthServiceState = {
3636
authConfig,

extensions/cli/src/services/MCPService.ts

Lines changed: 173 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@ import {
1212
StdioMcpServer,
1313
} from "node_modules/@continuedev/config-yaml/dist/schemas/mcp/index.js";
1414

15+
import { get } from "../util/apiClient.js";
1516
import { getErrorString } from "../util/error.js";
1617
import { logger } from "../util/logger.js";
1718

19+
import { isAuthenticated, loadAuthConfig } from "src/auth/workos.js";
1820
import { BaseService, ServiceWithDependencies } from "./BaseService.js";
1921
import { serviceContainer } from "./ServiceContainer.js";
2022
import {
@@ -50,9 +52,11 @@ export class MCPService
5052
private connections: Map<string, ServerConnection> = new Map();
5153
private assistant: AssistantConfig | null = null;
5254
private isShuttingDown = false;
55+
private isHeadless: boolean | undefined;
56+
private mcpTokenCache: Map<string, string> = new Map();
5357

5458
getDependencies(): string[] {
55-
return [SERVICE_NAMES.CONFIG];
59+
return [SERVICE_NAMES.CONFIG, SERVICE_NAMES.AUTH];
5660
}
5761
constructor() {
5862
super("MCPService", {
@@ -73,6 +77,8 @@ export class MCPService
7377
hasAgentFile: boolean,
7478
isHeadless: boolean | undefined,
7579
): Promise<MCPServiceState> {
80+
this.isHeadless = isHeadless;
81+
7682
logger.debug("Initializing MCPService", {
7783
configName: assistant.name,
7884
serverCount: assistant.mcpServers?.length || 0,
@@ -178,6 +184,115 @@ export class MCPService
178184
return { status: "idle", hasWarnings };
179185
}
180186

187+
/**
188+
* Generic wrapper for client operations that handles 401 errors with token refresh
189+
* Only applies to SSE/HTTP connections, not stdio
190+
*/
191+
private async withTokenRefresh<T>(
192+
serverName: string,
193+
operation: () => Promise<T>,
194+
): Promise<T> {
195+
const connection = this.connections.get(serverName);
196+
if (!connection) {
197+
throw new Error(`Connection ${serverName} not found`);
198+
}
199+
200+
const serverConfig = connection.config;
201+
if (!serverConfig || "command" in serverConfig) {
202+
// For stdio connections, just execute normally (no token refresh possible)
203+
return await operation();
204+
}
205+
206+
try {
207+
// Try the operation first
208+
return await operation();
209+
} catch (error: unknown) {
210+
// If not a 401 error, rethrow
211+
if (!is401Error(error)) {
212+
throw error;
213+
}
214+
215+
logger.debug("Got 401 error on MCP operation, attempting token refresh", {
216+
name: serverName,
217+
});
218+
219+
// Check if user is signed in
220+
const isAuthed = await isAuthenticated();
221+
if (!isAuthed) {
222+
logger.debug("User not signed in, cannot refresh OAuth token", {
223+
name: serverName,
224+
});
225+
throw error;
226+
}
227+
228+
const authConfig = loadAuthConfig();
229+
230+
// Clear cached token since it's invalid
231+
this.mcpTokenCache.delete(serverName);
232+
233+
// Fetch OAuth token from backend
234+
const identifier = serverConfig.name;
235+
const organizationSlug = authConfig?.organizationId;
236+
237+
let token: string | null = null;
238+
try {
239+
const params = new URLSearchParams({ identifier });
240+
if (organizationSlug) {
241+
params.set("organizationSlug", organizationSlug);
242+
}
243+
244+
logger.debug("Fetching OAuth token for MCP server on 401", {
245+
name: serverName,
246+
identifier,
247+
organizationSlug,
248+
});
249+
250+
const response = await get<{
251+
configured: boolean;
252+
hasCredentials: boolean;
253+
accessToken?: string;
254+
expiresAt?: string;
255+
expired?: boolean;
256+
}>(`/ide/mcp-auth?${params.toString()}`);
257+
258+
if (response.data.hasCredentials && response.data.accessToken) {
259+
token = response.data.accessToken;
260+
this.mcpTokenCache.set(serverName, token);
261+
logger.debug("Successfully retrieved OAuth token for MCP server", {
262+
name: serverName,
263+
});
264+
} else {
265+
logger.debug("No OAuth token available for MCP server", {
266+
name: serverName,
267+
configured: response.data.configured,
268+
hasCredentials: response.data.hasCredentials,
269+
expired: response.data.expired,
270+
});
271+
}
272+
} catch (fetchError) {
273+
logger.debug("Error fetching OAuth token for MCP server", {
274+
name: serverName,
275+
error: getErrorString(fetchError),
276+
});
277+
}
278+
279+
if (!token) {
280+
logger.debug("No OAuth token available for refresh", {
281+
name: serverName,
282+
});
283+
throw error;
284+
}
285+
286+
// Update the server config with new token and retry
287+
serverConfig.apiKey = token;
288+
logger.debug("Retrying operation with refreshed OAuth token", {
289+
name: serverName,
290+
});
291+
292+
return await operation();
293+
}
294+
}
295+
181296
/**
182297
* Run a tool by name
183298
*/
@@ -186,9 +301,16 @@ export class MCPService
186301
if (connection.status === "connected" && connection.client) {
187302
const tool = connection.tools.find((t) => t.name === name);
188303
if (tool) {
189-
return await connection.client.callTool({
190-
name,
191-
arguments: args,
304+
const serverName = connection.config!.name;
305+
return await this.withTokenRefresh(serverName, async () => {
306+
const conn = this.connections.get(serverName);
307+
if (!conn?.client) {
308+
throw new Error(`Client for ${serverName} not available`);
309+
}
310+
return await conn.client.callTool({
311+
name,
312+
arguments: args,
313+
});
192314
});
193315
}
194316
}
@@ -262,7 +384,16 @@ export class MCPService
262384

263385
if (capabilities?.prompts) {
264386
try {
265-
connection.prompts = (await client.listPrompts()).prompts;
387+
connection.prompts = await this.withTokenRefresh(
388+
serverName,
389+
async () => {
390+
const conn = this.connections.get(serverName);
391+
if (!conn?.client) {
392+
throw new Error(`Client for ${serverName} not available`);
393+
}
394+
return (await conn.client.listPrompts()).prompts;
395+
},
396+
);
266397
logger.debug("Loaded MCP prompts", {
267398
name: serverName,
268399
count: connection.prompts.length,
@@ -279,7 +410,16 @@ export class MCPService
279410

280411
if (capabilities?.tools) {
281412
try {
282-
connection.tools = (await client.listTools()).tools;
413+
connection.tools = await this.withTokenRefresh(
414+
serverName,
415+
async () => {
416+
const conn = this.connections.get(serverName);
417+
if (!conn?.client) {
418+
throw new Error(`Client for ${serverName} not available`);
419+
}
420+
return (await conn.client.listTools()).tools;
421+
},
422+
);
283423
logger.debug("Loaded MCP tools", {
284424
name: serverName,
285425
count: connection.tools.length,
@@ -403,16 +543,21 @@ export class MCPService
403543
});
404544

405545
try {
406-
if (serverConfig.type === "sse") {
407-
const transport = this.constructSseTransport(serverConfig);
408-
await client.connect(transport, {});
409-
} else if (serverConfig.type === "streamable-http") {
410-
const transport = this.constructHttpTransport(serverConfig);
411-
await client.connect(transport, {});
412-
}
546+
await this.withTokenRefresh(serverConfig.name, async () => {
547+
if (serverConfig.type === "sse") {
548+
const transport = this.constructSseTransport(serverConfig);
549+
await client.connect(transport, {});
550+
} else if (serverConfig.type === "streamable-http") {
551+
const transport = this.constructHttpTransport(serverConfig);
552+
await client.connect(transport, {});
553+
}
554+
});
413555
} catch (error: unknown) {
414-
// on authorization error, use "mcp-remote" with stdio transport to connect
415-
if (is401Error(error)) {
556+
// If token refresh didn't work and it's a 401, fall back to mcp-remote
557+
if (is401Error(error) && !this.isHeadless) {
558+
logger.debug("Falling back to mcp-remote after 401 error", {
559+
name: serverConfig.name,
560+
});
416561
const transport = this.constructStdioTransport(
417562
{
418563
name: serverConfig.name,
@@ -428,22 +573,28 @@ export class MCPService
428573
}
429574

430575
if (typeof serverConfig.type === "undefined") {
576+
// Try HTTP first, then SSE
431577
try {
432-
const transport = this.constructHttpTransport(serverConfig);
433-
await client.connect(transport, {});
434-
} catch {
578+
await this.withTokenRefresh(serverConfig.name, async () => {
579+
const transport = this.constructHttpTransport(serverConfig);
580+
await client.connect(transport, {});
581+
});
582+
} catch (httpError) {
435583
logger.debug(
436584
"MCP Connection: http connection failed, falling back to sse connection",
437585
{
438586
name: serverConfig.name,
587+
error: getErrorString(httpError),
439588
},
440589
);
441590
try {
442-
const transport = this.constructSseTransport(serverConfig);
443-
await client.connect(transport, {});
444-
} catch (e) {
591+
await this.withTokenRefresh(serverConfig.name, async () => {
592+
const transport = this.constructSseTransport(serverConfig);
593+
await client.connect(transport, {});
594+
});
595+
} catch (sseError) {
445596
throw new Error(
446-
`MCP config with URL and no type specified failed both SSE and HTTP connection: ${e instanceof Error ? e.message : String(e)}`,
597+
`MCP config with URL and no type specified failed both SSE and HTTP connection: ${sseError instanceof Error ? sseError.message : String(sseError)}`,
447598
);
448599
}
449600
}

extensions/cli/src/slashCommands.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,10 @@ async function handleLogout() {
9393
}
9494
}
9595

96-
function handleWhoami() {
97-
if (isAuthenticated()) {
98-
const config = loadAuthConfig();
96+
async function handleWhoami() {
97+
const authed = await isAuthenticated();
98+
if (authed) {
99+
const config = loadAuthConfig(); // TODO duplicate auth config loading
99100
if (config && isAuthenticatedConfig(config)) {
100101
return {
101102
exit: false,

0 commit comments

Comments
 (0)