diff --git a/extensions/cli/src/__mocks__/auth/workos.ts b/extensions/cli/src/__mocks__/auth/workos.ts index f77d874629a..1c9be83f948 100644 --- a/extensions/cli/src/__mocks__/auth/workos.ts +++ b/extensions/cli/src/__mocks__/auth/workos.ts @@ -1,7 +1,6 @@ import { vi } from "vitest"; -import type { AuthConfig } from "../../auth/workos.js"; -export const isAuthenticated = vi.fn(() => false); +export const isAuthenticated = vi.fn(() => Promise.resolve(false)); export const isAuthenticatedConfig = vi.fn(() => false); export const isEnvironmentAuthConfig = vi.fn(() => false); export const loadAuthConfig = vi.fn(() => null); diff --git a/extensions/cli/src/auth/ensureAuth.ts b/extensions/cli/src/auth/ensureAuth.ts index f3f6636bc92..03865fc586e 100644 --- a/extensions/cli/src/auth/ensureAuth.ts +++ b/extensions/cli/src/auth/ensureAuth.ts @@ -9,7 +9,7 @@ import { isAuthenticated, login } from "./workos.js"; export async function ensureAuthenticated( requireAuth: boolean = true, ): Promise { - if (isAuthenticated()) { + if (await isAuthenticated()) { return true; } diff --git a/extensions/cli/src/auth/orgSelection.test.ts b/extensions/cli/src/auth/orgSelection.test.ts index 838a208990a..a07f28ef087 100644 --- a/extensions/cli/src/auth/orgSelection.test.ts +++ b/extensions/cli/src/auth/orgSelection.test.ts @@ -6,7 +6,7 @@ import { autoSelectOrganizationAndConfig, createUpdatedAuthConfig, } from "./orgSelection.js"; -import type { AuthenticatedConfig } from "./workos.js"; +import { AuthenticatedConfig } from "./workos-types.js"; // Mock dependencies vi.mock("fs"); diff --git a/extensions/cli/src/auth/orgSelection.ts b/extensions/cli/src/auth/orgSelection.ts index e82bea8ae98..63e7140304f 100644 --- a/extensions/cli/src/auth/orgSelection.ts +++ b/extensions/cli/src/auth/orgSelection.ts @@ -3,11 +3,13 @@ import * as path from "path"; import chalk from "chalk"; -import type { AuthConfig, AuthenticatedConfig } from "../auth/workos.js"; +import type { AuthConfig } from "../auth/workos.js"; import { saveAuthConfig } from "../auth/workos.js"; import { getApiClient } from "../config.js"; import { env } from "../env.js"; +import { AuthenticatedConfig } from "./workos-types.js"; + /** * Creates an updated AuthenticatedConfig with a new organization ID and optional config URI */ diff --git a/extensions/cli/src/auth/workos-org.test.ts b/extensions/cli/src/auth/workos-org.test.ts index ba0de72b28d..215f1dbf459 100644 --- a/extensions/cli/src/auth/workos-org.test.ts +++ b/extensions/cli/src/auth/workos-org.test.ts @@ -1,10 +1,10 @@ import chalk from "chalk"; -import { describe, expect, test, beforeEach, vi } from "vitest"; +import { beforeEach, describe, expect, test, vi } from "vitest"; import { getApiClient } from "../config.js"; +import { AuthenticatedConfig, EnvironmentAuthConfig } from "./workos-types.js"; import { ensureOrganization } from "./workos.js"; -import type { AuthenticatedConfig, EnvironmentAuthConfig } from "./workos.js"; // Mock dependencies vi.mock("../config.js", () => ({ diff --git a/extensions/cli/src/auth/workos-types.ts b/extensions/cli/src/auth/workos-types.ts new file mode 100644 index 00000000000..bfab63a14bb --- /dev/null +++ b/extensions/cli/src/auth/workos-types.ts @@ -0,0 +1,38 @@ +/** + * Device authorization response from WorkOS + */ +export interface DeviceAuthorizationResponse { + device_code: string; + user_code: string; + verification_uri: string; + verification_uri_complete: string; + expires_in: number; + interval: number; +} + +// Represents an authenticated user's configuration +export interface AuthenticatedConfig { + userId: string; + userEmail: string; + accessToken: string; + refreshToken: string; + expiresAt: number; + organizationId: string | null | undefined; // null means personal organization, undefined triggers auto-selection + configUri?: string; // Optional config URI (file:// or slug://owner/slug) + modelName?: string; // Name of the selected model +} + +// Represents configuration when using environment variable auth +export interface EnvironmentAuthConfig { + /** + * This userId?: undefined; field a trick to help TypeScript differentiate between + * AuthenticatedConfig and EnvironmentAuthConfig. Otherwise AuthenticatedConfig is + * a possible subtype of EnvironmentAuthConfig and TypeScript gets confused where + * type guards are involved. + */ + userId?: undefined; + accessToken: string; + organizationId: string | null; // Can be set via --org flag in headless mode + configUri?: string; // Optional config URI (file:// or slug://owner/slug) + modelName?: string; // Name of the selected model +} diff --git a/extensions/cli/src/auth/workos.helpers.ts b/extensions/cli/src/auth/workos.helpers.ts index cc4690cdfaf..1a0b3630446 100644 --- a/extensions/cli/src/auth/workos.helpers.ts +++ b/extensions/cli/src/auth/workos.helpers.ts @@ -4,11 +4,8 @@ import { getApiClient } from "../config.js"; import { safeStderr } from "../init.js"; import { gracefulExit } from "../util/exit.js"; -import type { - AuthConfig, - AuthenticatedConfig, - EnvironmentAuthConfig, -} from "./workos.js"; +import { AuthenticatedConfig, EnvironmentAuthConfig } from "./workos-types.js"; +import type { AuthConfig } from "./workos.js"; import { saveAuthConfig } from "./workos.js"; /** diff --git a/extensions/cli/src/auth/workos.test.ts b/extensions/cli/src/auth/workos.test.ts index 22e13642bf0..dfccdcaa58e 100644 --- a/extensions/cli/src/auth/workos.test.ts +++ b/extensions/cli/src/auth/workos.test.ts @@ -1,7 +1,6 @@ import { slugToUri } from "./uriUtils.js"; +import { AuthenticatedConfig, EnvironmentAuthConfig } from "./workos-types.js"; import { - AuthenticatedConfig, - EnvironmentAuthConfig, getAccessToken, getAssistantSlug, getOrganizationId, diff --git a/extensions/cli/src/auth/workos.ts b/extensions/cli/src/auth/workos.ts index a005eb56aa0..b31d3893490 100644 --- a/extensions/cli/src/auth/workos.ts +++ b/extensions/cli/src/auth/workos.ts @@ -3,13 +3,15 @@ import * as os from "os"; import * as path from "path"; import chalk from "chalk"; -// Polyfill fetch for Node < 18 import nodeFetch from "node-fetch"; import open from "open"; +import { logger } from "src/util/logger.js"; + import { getApiClient } from "../config.js"; // eslint-disable-next-line import/order import { env } from "../env.js"; + if (!globalThis.fetch) { globalThis.fetch = nodeFetch as unknown as typeof globalThis.fetch; } @@ -21,33 +23,6 @@ function getAuthConfigPath() { return path.join(continueHome, "auth.json"); } -// Represents an authenticated user's configuration -export interface AuthenticatedConfig { - userId: string; - userEmail: string; - accessToken: string; - refreshToken: string; - expiresAt: number; - organizationId: string | null | undefined; // null means personal organization, undefined triggers auto-selection - configUri?: string; // Optional config URI (file:// or slug://owner/slug) - modelName?: string; // Name of the selected model -} - -// Represents configuration when using environment variable auth -export interface EnvironmentAuthConfig { - /** - * This userId?: undefined; field a trick to help TypeScript differentiate between - * AuthenticatedConfig and EnvironmentAuthConfig. Otherwise AuthenticatedConfig is - * a possible subtype of EnvironmentAuthConfig and TypeScript gets confused where - * type guards are involved. - */ - userId?: undefined; - accessToken: string; - organizationId: string | null; // Can be set via --org flag in headless mode - configUri?: string; // Optional config URI (file:// or slug://owner/slug) - modelName?: string; // Name of the selected model -} - // Union type representing the possible authentication states export type AuthConfig = AuthenticatedConfig | EnvironmentAuthConfig | null; @@ -117,6 +92,11 @@ import { import { autoSelectOrganizationAndConfig } from "./orgSelection.js"; import { pathToUri, slugToUri, uriToPath, uriToSlug } from "./uriUtils.js"; +import { + AuthenticatedConfig, + DeviceAuthorizationResponse, + EnvironmentAuthConfig, +} from "./workos-types.js"; import { handleCliOrgForAuthenticatedConfig, handleCliOrgForEnvironmentAuth, @@ -266,46 +246,30 @@ export function updateLocalConfigPath(localConfigPath: string | null): void { /** * Checks if the user is authenticated and the token is valid */ -export function isAuthenticated(): boolean { +export async function isAuthenticated(): Promise { const config = loadAuthConfig(); if (config === null) { return false; } - // Environment auth is always valid if (isEnvironmentAuthConfig(config)) { return true; } - /** - * THIS CODE DOESN'T WORK. - * .catch() will never return in a non-async function. - * It's a hallucination. - **/ if (Date.now() > config.expiresAt) { - // Try refreshing the token - refreshToken(config.refreshToken).catch(() => { - // If refresh fails, we're not authenticated + try { + const refreshed = await refreshToken(config.refreshToken); + return isAuthenticatedConfig(refreshed); + } catch (e) { + logger.error("Failed to refresh auto token", e); return false; - }); + } } return true; } -/** - * Device authorization response from WorkOS - */ -interface DeviceAuthorizationResponse { - device_code: string; - user_code: string; - verification_uri: string; - verification_uri_complete: string; - expires_in: number; - interval: number; -} - /** * Request device authorization from WorkOS */ diff --git a/extensions/cli/src/infoScreen.ts b/extensions/cli/src/infoScreen.ts index fc495c94cef..4a62bfb8a42 100644 --- a/extensions/cli/src/infoScreen.ts +++ b/extensions/cli/src/infoScreen.ts @@ -26,7 +26,7 @@ export async function handleInfoSlashCommand() { ); // Auth info - if (isAuthenticated()) { + if (await isAuthenticated()) { const config = loadAuthConfig(); if (config && isAuthenticatedConfig(config)) { const email = config.userEmail || config.userId; diff --git a/extensions/cli/src/integration/model-persistence-e2e.test.ts b/extensions/cli/src/integration/model-persistence-e2e.test.ts index 22354c98987..4bff856ccbd 100644 --- a/extensions/cli/src/integration/model-persistence-e2e.test.ts +++ b/extensions/cli/src/integration/model-persistence-e2e.test.ts @@ -5,8 +5,9 @@ import * as path from "path"; import { AssistantUnrolled, ModelConfig } from "@continuedev/config-yaml"; import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { AuthenticatedConfig } from "src/auth/workos-types.js"; + import { - AuthenticatedConfig, getModelName, loadAuthConfig, saveAuthConfig, diff --git a/extensions/cli/src/integration/model-persistence-user-flow.test.ts b/extensions/cli/src/integration/model-persistence-user-flow.test.ts index 3aef05071ac..811234e0ab4 100644 --- a/extensions/cli/src/integration/model-persistence-user-flow.test.ts +++ b/extensions/cli/src/integration/model-persistence-user-flow.test.ts @@ -5,10 +5,10 @@ import * as path from "path"; import { AssistantUnrolled, ModelConfig } from "@continuedev/config-yaml"; import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { AuthenticatedConfig } from "src/auth/workos-types.js"; import { AuthService } from "src/services/AuthService.js"; import { - AuthenticatedConfig, getModelName, loadAuthConfig, saveAuthConfig, diff --git a/extensions/cli/src/integration/model-persistence.test.ts b/extensions/cli/src/integration/model-persistence.test.ts index 9e352143d17..f2f86f0a72d 100644 --- a/extensions/cli/src/integration/model-persistence.test.ts +++ b/extensions/cli/src/integration/model-persistence.test.ts @@ -4,8 +4,9 @@ import * as path from "path"; import { afterEach, beforeEach, describe, expect, test } from "vitest"; +import { AuthenticatedConfig } from "src/auth/workos-types.js"; + import { - AuthenticatedConfig, getModelName, loadAuthConfig, saveAuthConfig, diff --git a/extensions/cli/src/services/AuthService.test.ts b/extensions/cli/src/services/AuthService.test.ts index e72c17e74ea..c119a5a1414 100644 --- a/extensions/cli/src/services/AuthService.test.ts +++ b/extensions/cli/src/services/AuthService.test.ts @@ -1,17 +1,17 @@ -import { describe, expect, test, beforeEach, vi } from "vitest"; +import { beforeEach, describe, expect, test, vi } from "vitest"; // Mock the workos module vi.mock("../auth/workos.js"); // Import the workos functions we need to mock import { - loadAuthConfig, + ensureOrganization, isAuthenticated, + listUserOrganizations, + loadAuthConfig, login, logout, - ensureOrganization, saveAuthConfig, - listUserOrganizations, } from "../auth/workos.js"; import { AuthService } from "./AuthService.js"; @@ -35,7 +35,7 @@ describe("AuthService", () => { describe("State Management", () => { test("should initialize with unauthenticated state", async () => { vi.mocked(loadAuthConfig).mockReturnValue(null); - vi.mocked(isAuthenticated).mockReturnValue(false); + vi.mocked(isAuthenticated).mockResolvedValue(false); const state = await service.initialize(); @@ -48,7 +48,7 @@ describe("AuthService", () => { test("should initialize with authenticated state", async () => { vi.mocked(loadAuthConfig).mockReturnValue(mockAuthConfig); - vi.mocked(isAuthenticated).mockReturnValue(true); + vi.mocked(isAuthenticated).mockResolvedValue(true); const state = await service.initialize(); @@ -64,7 +64,7 @@ describe("AuthService", () => { test("should update state after successful login", async () => { // Initialize first vi.mocked(loadAuthConfig).mockReturnValue(null); - vi.mocked(isAuthenticated).mockReturnValue(false); + vi.mocked(isAuthenticated).mockResolvedValue(false); await service.initialize(); // Mock successful login @@ -93,7 +93,7 @@ describe("AuthService", () => { test("should clear state after logout", async () => { // Initialize with authenticated state vi.mocked(loadAuthConfig).mockReturnValue(mockAuthConfig); - vi.mocked(isAuthenticated).mockReturnValue(true); + vi.mocked(isAuthenticated).mockResolvedValue(true); await service.initialize(); const state = await service.logout(); @@ -112,7 +112,7 @@ describe("AuthService", () => { // Initialize with auth but no org const authWithoutOrg = { ...mockAuthConfig, organizationId: null }; vi.mocked(loadAuthConfig).mockReturnValue(authWithoutOrg); - vi.mocked(isAuthenticated).mockReturnValue(true); + vi.mocked(isAuthenticated).mockResolvedValue(true); await service.initialize(); // Mock ensure organization @@ -129,7 +129,7 @@ describe("AuthService", () => { test("should throw error if not authenticated", async () => { vi.mocked(loadAuthConfig).mockReturnValue(null); - vi.mocked(isAuthenticated).mockReturnValue(false); + vi.mocked(isAuthenticated).mockResolvedValue(false); await service.initialize(); await expect(service.ensureOrganization()).rejects.toThrow( @@ -139,7 +139,7 @@ describe("AuthService", () => { test("should pass organization slug to ensureOrganization in headless mode", async () => { vi.mocked(loadAuthConfig).mockReturnValue(mockAuthConfig); - vi.mocked(isAuthenticated).mockReturnValue(true); + vi.mocked(isAuthenticated).mockResolvedValue(true); await service.initialize(); const updatedConfig = { ...mockAuthConfig, organizationId: "org-456" }; @@ -161,7 +161,7 @@ describe("AuthService", () => { test("should handle personal organization slug", async () => { vi.mocked(loadAuthConfig).mockReturnValue(mockAuthConfig); - vi.mocked(isAuthenticated).mockReturnValue(true); + vi.mocked(isAuthenticated).mockResolvedValue(true); await service.initialize(); const personalConfig = { ...mockAuthConfig, organizationId: null }; @@ -183,7 +183,7 @@ describe("AuthService", () => { test("should work without organization slug parameter", async () => { vi.mocked(loadAuthConfig).mockReturnValue(mockAuthConfig); - vi.mocked(isAuthenticated).mockReturnValue(true); + vi.mocked(isAuthenticated).mockResolvedValue(true); await service.initialize(); vi.mocked(ensureOrganization).mockResolvedValue(mockAuthConfig); @@ -206,7 +206,7 @@ describe("AuthService", () => { describe("switchOrganization()", () => { test("should update state with new organization", async () => { vi.mocked(loadAuthConfig).mockReturnValue(mockAuthConfig); - vi.mocked(isAuthenticated).mockReturnValue(true); + vi.mocked(isAuthenticated).mockResolvedValue(true); await service.initialize(); const state = await service.switchOrganization("org-456"); @@ -227,7 +227,7 @@ describe("AuthService", () => { test("should handle null organization", async () => { vi.mocked(loadAuthConfig).mockReturnValue(mockAuthConfig); - vi.mocked(isAuthenticated).mockReturnValue(true); + vi.mocked(isAuthenticated).mockResolvedValue(true); await service.initialize(); const state = await service.switchOrganization(null); @@ -238,7 +238,7 @@ describe("AuthService", () => { test("should throw error if not file-based auth", async () => { const tokenOnlyAuth = { accessToken: "token" } as any; vi.mocked(loadAuthConfig).mockReturnValue(tokenOnlyAuth); - vi.mocked(isAuthenticated).mockReturnValue(true); + vi.mocked(isAuthenticated).mockResolvedValue(true); await service.initialize(); await expect(service.switchOrganization("org-456")).rejects.toThrow( @@ -250,7 +250,7 @@ describe("AuthService", () => { describe("getAvailableOrganizations()", () => { test("should return organizations when authenticated", async () => { vi.mocked(loadAuthConfig).mockReturnValue(mockAuthConfig); - vi.mocked(isAuthenticated).mockReturnValue(true); + vi.mocked(isAuthenticated).mockResolvedValue(true); await service.initialize(); const mockOrgs = [ @@ -265,7 +265,7 @@ describe("AuthService", () => { test("should return null when not authenticated", async () => { vi.mocked(loadAuthConfig).mockReturnValue(null); - vi.mocked(isAuthenticated).mockReturnValue(false); + vi.mocked(isAuthenticated).mockResolvedValue(false); await service.initialize(); const orgs = await service.getAvailableOrganizations(); @@ -274,7 +274,7 @@ describe("AuthService", () => { test("should handle errors gracefully", async () => { vi.mocked(loadAuthConfig).mockReturnValue(mockAuthConfig); - vi.mocked(isAuthenticated).mockReturnValue(true); + vi.mocked(isAuthenticated).mockResolvedValue(true); await service.initialize(); vi.mocked(listUserOrganizations).mockRejectedValue( @@ -297,7 +297,7 @@ describe("AuthService", () => { describe("hasMultipleOrganizations()", () => { test("should return true when multiple orgs available", async () => { vi.mocked(loadAuthConfig).mockReturnValue(mockAuthConfig); - vi.mocked(isAuthenticated).mockReturnValue(true); + vi.mocked(isAuthenticated).mockResolvedValue(true); await service.initialize(); vi.mocked(listUserOrganizations).mockResolvedValue([ @@ -311,7 +311,7 @@ describe("AuthService", () => { test("should return false when no orgs available", async () => { vi.mocked(loadAuthConfig).mockReturnValue(mockAuthConfig); - vi.mocked(isAuthenticated).mockReturnValue(true); + vi.mocked(isAuthenticated).mockResolvedValue(true); await service.initialize(); vi.mocked(listUserOrganizations).mockResolvedValue([]); @@ -324,12 +324,12 @@ describe("AuthService", () => { describe("refresh()", () => { test("should reload auth state from disk", async () => { vi.mocked(loadAuthConfig).mockReturnValue(null); - vi.mocked(isAuthenticated).mockReturnValue(false); + vi.mocked(isAuthenticated).mockResolvedValue(false); await service.initialize(); // Update mock to return authenticated state vi.mocked(loadAuthConfig).mockReturnValue(mockAuthConfig); - vi.mocked(isAuthenticated).mockReturnValue(true); + vi.mocked(isAuthenticated).mockResolvedValue(true); const state = await service.refresh(); @@ -345,7 +345,7 @@ describe("AuthService", () => { test("should emit stateChanged on login", async () => { // Initialize with no auth config to simulate logged out state vi.mocked(loadAuthConfig).mockReturnValue(null); - vi.mocked(isAuthenticated).mockReturnValue(false); + vi.mocked(isAuthenticated).mockResolvedValue(false); await service.initialize(); const listener = vi.fn(); diff --git a/extensions/cli/src/services/AuthService.ts b/extensions/cli/src/services/AuthService.ts index 90a2f734791..f2b9cd47686 100644 --- a/extensions/cli/src/services/AuthService.ts +++ b/extensions/cli/src/services/AuthService.ts @@ -1,5 +1,6 @@ +import { AuthenticatedConfig } from "src/auth/workos-types.js"; + import { - AuthenticatedConfig, login as doLogin, logout as doLogout, ensureOrganization, @@ -30,7 +31,7 @@ export class AuthService extends BaseService { */ async doInitialize(): Promise { const authConfig = loadAuthConfig(); - const authenticated = isAuthenticated(); + const authenticated = await isAuthenticated(); const state: AuthServiceState = { authConfig, diff --git a/extensions/cli/src/services/MCPService.ts b/extensions/cli/src/services/MCPService.ts index 231312e1ccb..0f8a4b5bc89 100644 --- a/extensions/cli/src/services/MCPService.ts +++ b/extensions/cli/src/services/MCPService.ts @@ -1,9 +1,6 @@ import { type AssistantConfig } from "@continuedev/sdk"; import { Client } from "@modelcontextprotocol/sdk/client/index.js"; -import { - SSEClientTransport, - SseError, -} from "@modelcontextprotocol/sdk/client/sse.js"; +import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js"; import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js"; import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; import { @@ -12,10 +9,14 @@ import { StdioMcpServer, } from "node_modules/@continuedev/config-yaml/dist/schemas/mcp/index.js"; +import { isAuthenticated, loadAuthConfig } from "src/auth/workos.js"; + +import { get } from "../util/apiClient.js"; import { getErrorString } from "../util/error.js"; import { logger } from "../util/logger.js"; import { BaseService, ServiceWithDependencies } from "./BaseService.js"; +import { isAuthError } from "./mcpUtils.js"; import { serviceContainer } from "./ServiceContainer.js"; import { MCPConnectionInfo, @@ -24,14 +25,6 @@ import { SERVICE_NAMES, } from "./types.js"; -function is401Error(error: unknown) { - return ( - (error instanceof SseError && error.code === 401) || - (error instanceof Error && error.message.includes("401")) || - (error instanceof Error && error.message.includes("Unauthorized")) - ); -} - interface ServerConnection extends MCPConnectionInfo { client: Client | null; } @@ -50,9 +43,11 @@ export class MCPService private connections: Map = new Map(); private assistant: AssistantConfig | null = null; private isShuttingDown = false; + private isHeadless: boolean | undefined; + private apiKeyCache: Map = new Map(); getDependencies(): string[] { - return [SERVICE_NAMES.CONFIG]; + return [SERVICE_NAMES.CONFIG, SERVICE_NAMES.AUTH]; } constructor() { super("MCPService", { @@ -73,34 +68,20 @@ export class MCPService hasAgentFile: boolean, isHeadless: boolean | undefined, ): Promise { - logger.debug("Initializing MCPService", { - configName: assistant.name, - serverCount: assistant.mcpServers?.length || 0, - }); + this.isHeadless = isHeadless; await this.shutdownConnections(); this.assistant = assistant; this.connections.clear(); - if (assistant.mcpServers?.length) { - logger.debug("Starting MCP server connections", { - serverCount: assistant.mcpServers.length, - }); - } const connectionPromises = assistant.mcpServers?.map(async (config) => { if (config) { return await this.connectServer(config); } }); - const connectionInit = Promise.all(connectionPromises ?? []).then( - (connections) => { - logger.debug("MCP connections established", { - connectionCount: connections.length, - }); - }, - ); + const connectionInit = Promise.all(connectionPromises ?? []); if (isHeadless || hasAgentFile) { await connectionInit; @@ -178,6 +159,86 @@ export class MCPService return { status: "idle", hasWarnings }; } + /** + * Generic wrapper for client operations that handles 401 errors with token refresh + * Only applies to SSE/HTTP connections, not stdio + */ + private async withTokenRefresh( + serverName: string, + operation: () => Promise, + ): Promise { + const connection = this.connections.get(serverName); + if (!connection) { + throw new Error(`Connection ${serverName} not found`); + } + + const serverConfig = connection.config; + if (!serverConfig || "command" in serverConfig) { + // For stdio connections, just execute normally (no token refresh possible) + return await operation(); + } + + try { + // Try the operation first + return await operation(); + } catch (error: unknown) { + // If not a 401 error, rethrow + if (!isAuthError(error)) { + throw error; + } + + // Check if user is signed in + const isAuthed = await isAuthenticated(); + if (!isAuthed) { + throw error; + } + + const authConfig = loadAuthConfig(); + + // Clear cached token since it's invalid + this.apiKeyCache.delete(serverName); + + // Fetch OAuth token from backend + const organizationSlug = authConfig?.organizationId; + + let token: string | null = null; + try { + const params = new URLSearchParams({ + url: serverConfig.url, + }); + if (organizationSlug) { + params.set("organizationSlug", organizationSlug); + } + if (serverConfig.sourceSlug) { + params.set("slug", serverConfig.sourceSlug); + } + + const response = await get<{ + configured: boolean; + hasCredentials: boolean; + accessToken?: string; + expiresAt?: string; + expired?: boolean; + }>(`/ide/mcp-auth?${params.toString()}`); + + if (response.data.hasCredentials && response.data.accessToken) { + token = response.data.accessToken; + this.apiKeyCache.set(serverName, token); + } + } catch { + logger.debug("Failed to fetch mcp oauth credentials"); + } + + if (!token) { + throw error; + } + + this.apiKeyCache.set(serverConfig.name, token); + + return await operation(); + } + } + /** * Run a tool by name */ @@ -186,9 +247,16 @@ export class MCPService if (connection.status === "connected" && connection.client) { const tool = connection.tools.find((t) => t.name === name); if (tool) { - return await connection.client.callTool({ - name, - arguments: args, + const serverName = connection.config!.name; + return await this.withTokenRefresh(serverName, async () => { + const conn = this.connections.get(serverName); + if (!conn?.client) { + throw new Error(`Client for ${serverName} not available`); + } + return await conn.client.callTool({ + name, + arguments: args, + }); }); } } @@ -202,7 +270,6 @@ export class MCPService public async restartAllServers(): Promise { if (!this.assistant) return; - logger.debug("Restarting all MCP servers"); await this.shutdownConnections(); await this.initialize(this.assistant); } @@ -221,8 +288,6 @@ export class MCPService throw new Error(`Server ${serverName} not found in configuration`); } - logger.debug("Restarting MCP server", { name: serverName }); - const existingConnection = this.connections.get(serverName); if (existingConnection) { if (existingConnection.status === "connected") { @@ -254,19 +319,19 @@ export class MCPService this.updateState(); const capabilities = client.getServerCapabilities(); - logger.debug("MCP server capabilities", { - name: serverName, - hasPrompts: !!capabilities?.prompts, - hasTools: !!capabilities?.tools, - }); if (capabilities?.prompts) { try { - connection.prompts = (await client.listPrompts()).prompts; - logger.debug("Loaded MCP prompts", { - name: serverName, - count: connection.prompts.length, - }); + connection.prompts = await this.withTokenRefresh( + serverName, + async () => { + const conn = this.connections.get(serverName); + if (!conn?.client) { + throw new Error(`Client for ${serverName} not available`); + } + return (await conn.client.listPrompts()).prompts; + }, + ); } catch (error) { const errorMessage = getErrorString(error); connection.warnings.push(`Failed to load prompts: ${errorMessage}`); @@ -279,11 +344,16 @@ export class MCPService if (capabilities?.tools) { try { - connection.tools = (await client.listTools()).tools; - logger.debug("Loaded MCP tools", { - name: serverName, - count: connection.tools.length, - }); + connection.tools = await this.withTokenRefresh( + serverName, + async () => { + const conn = this.connections.get(serverName); + if (!conn?.client) { + throw new Error(`Client for ${serverName} not available`); + } + return (await conn.client.listTools()).tools; + }, + ); } catch (error) { const errorMessage = getErrorString(error); connection.warnings.push(`Failed to load tools: ${errorMessage}`); @@ -293,8 +363,6 @@ export class MCPService }); } } - - logger.debug("MCP server connected successfully", { name: serverName }); } catch (error) { const errorMessage = getErrorString(error); connection.status = "error"; @@ -321,8 +389,6 @@ export class MCPService * Stop a specific server */ public async stopServer(serverName: string): Promise { - logger.debug("Stopping MCP server", { name: serverName }); - const connection = this.connections.get(serverName); if (connection) { await this.shutdownConnection(connection); @@ -371,7 +437,6 @@ export class MCPService this.isShuttingDown = true; this.removeAllListeners(); - logger.debug("Shutting down MCPService"); await this.shutdownConnections(); } @@ -389,30 +454,37 @@ export class MCPService if ("command" in serverConfig) { // STDIO: no need to check type, just if command is present - logger.debug("Connecting to MCP server", { - name: serverConfig.name, - command: serverConfig.command, - }); const transport = this.constructStdioTransport(serverConfig, connection); await client.connect(transport, {}); } else { // SSE/HTTP: if type isn't explicit: try http and fall back to sse - logger.debug("Connecting to MCP server", { - name: serverConfig.name, - url: serverConfig.url, - }); - try { - if (serverConfig.type === "sse") { - const transport = this.constructSseTransport(serverConfig); - await client.connect(transport, {}); - } else if (serverConfig.type === "streamable-http") { - const transport = this.constructHttpTransport(serverConfig); - await client.connect(transport, {}); - } + await this.withTokenRefresh(serverConfig.name, async () => { + if (serverConfig.apiKey && !this.apiKeyCache.get(serverConfig.name)) { + this.apiKeyCache.set(serverConfig.name, serverConfig.apiKey); + } + if (serverConfig.type === "sse") { + const transport = this.constructSseTransport(serverConfig); + await client.connect(transport, {}); + } else if (serverConfig.type === "streamable-http") { + const transport = this.constructHttpTransport(serverConfig); + await client.connect(transport, {}); + } else { + try { + const transport = this.constructHttpTransport(serverConfig); + await client.connect(transport, {}); + } catch (e) { + if (isAuthError(e)) { + throw e; + } + const transport = this.constructSseTransport(serverConfig); + await client.connect(transport, {}); + } + } + }); } catch (error: unknown) { - // on authorization error, use "mcp-remote" with stdio transport to connect - if (is401Error(error)) { + // If token refresh didn't work and it's a 401, fall back to mcp-remote + if (isAuthError(error) && !this.isHeadless) { const transport = this.constructStdioTransport( { name: serverConfig.name, @@ -426,32 +498,6 @@ export class MCPService throw error; } } - - if (typeof serverConfig.type === "undefined") { - try { - const transport = this.constructHttpTransport(serverConfig); - await client.connect(transport, {}); - } catch { - logger.debug( - "MCP Connection: http connection failed, falling back to sse connection", - { - name: serverConfig.name, - }, - ); - try { - const transport = this.constructSseTransport(serverConfig); - await client.connect(transport, {}); - } catch (e) { - throw new Error( - `MCP config with URL and no type specified failed both SSE and HTTP connection: ${e instanceof Error ? e.message : String(e)}`, - ); - } - } - } else if ( - !["streamable-http", "sse", "stdio"].includes(serverConfig.type) - ) { - throw new Error(`Unsupported transport type: ${serverConfig.type}`); - } } return client; @@ -460,11 +506,12 @@ export class MCPService private constructSseTransport( serverConfig: SseMcpServer, ): SSEClientTransport { + const apiKey = this.apiKeyCache.get(serverConfig.name); // Merge apiKey into headers if provided const headers = { ...serverConfig.requestOptions?.headers, - ...(serverConfig.apiKey && { - Authorization: `Bearer ${serverConfig.apiKey}`, + ...(apiKey && { + Authorization: `Bearer ${apiKey}`, }), }; @@ -486,10 +533,12 @@ export class MCPService serverConfig: HttpMcpServer, ): StreamableHTTPClientTransport { // Merge apiKey into headers if provided + const apiKey = this.apiKeyCache.get(serverConfig.name); + const headers = { ...serverConfig.requestOptions?.headers, - ...(serverConfig.apiKey && { - Authorization: `Bearer ${serverConfig.apiKey}`, + ...(apiKey && { + Authorization: `Bearer ${apiKey}`, }), }; diff --git a/extensions/cli/src/services/mcpUtils.ts b/extensions/cli/src/services/mcpUtils.ts new file mode 100644 index 00000000000..c5058df1323 --- /dev/null +++ b/extensions/cli/src/services/mcpUtils.ts @@ -0,0 +1,19 @@ +import { SseError } from "@modelcontextprotocol/sdk/client/sse.js"; +import { StreamableHTTPError } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; + +/** + * Check if an error is an authentication error that should trigger token refresh + * 405 is technically "not allowed" but some servers like Sanity don't return the correct error + * Since we're using mcp library there's a good chance 405 means auth issue and doesn't hurt to retry with auth + */ +export function isAuthError(error: unknown): boolean { + return ( + (error instanceof SseError && error.code === 401) || + (error instanceof SseError && error.code === 405) || + (error instanceof StreamableHTTPError && error.code === 401) || + (error instanceof StreamableHTTPError && error.code === 405) || + (error instanceof Error && error.message.includes("401")) || + (error instanceof Error && error.message.includes("405")) || + (error instanceof Error && error.message.includes("Unauthorized")) + ); +} diff --git a/extensions/cli/src/slashCommands.info.test.ts b/extensions/cli/src/slashCommands.info.test.ts index e9878803808..ad4cedcafaa 100644 --- a/extensions/cli/src/slashCommands.info.test.ts +++ b/extensions/cli/src/slashCommands.info.test.ts @@ -1,4 +1,4 @@ -import { describe, it, expect, vi, beforeEach } from "vitest"; +import { beforeEach, describe, expect, it, vi } from "vitest"; import * as workosModule from "./auth/workos.js"; import { services } from "./services/index.js"; @@ -57,7 +57,7 @@ describe("handleSlashCommands - /info", () => { it("should include version and working directory in output", async () => { // Mock auth as not authenticated - vi.mocked(workosModule.isAuthenticated).mockReturnValue(false); + vi.mocked(workosModule.isAuthenticated).mockResolvedValue(false); // Mock config service const mockConfigState = { @@ -85,7 +85,7 @@ describe("handleSlashCommands - /info", () => { it("should handle authenticated user info", async () => { // Mock auth as authenticated - vi.mocked(workosModule.isAuthenticated).mockReturnValue(true); + vi.mocked(workosModule.isAuthenticated).mockResolvedValue(true); vi.mocked(workosModule.loadAuthConfig).mockReturnValue({ userEmail: "test@example.com", userId: "test-user", @@ -112,7 +112,7 @@ describe("handleSlashCommands - /info", () => { it("should handle missing model info gracefully", async () => { // Mock auth as not authenticated - vi.mocked(workosModule.isAuthenticated).mockReturnValue(false); + vi.mocked(workosModule.isAuthenticated).mockResolvedValue(false); // Mock config service with no model info const mockConfigState = { @@ -129,7 +129,7 @@ describe("handleSlashCommands - /info", () => { it("should handle config service error", async () => { // Mock auth as not authenticated - vi.mocked(workosModule.isAuthenticated).mockReturnValue(false); + vi.mocked(workosModule.isAuthenticated).mockResolvedValue(false); // Mock config service throwing error vi.mocked(services.config.getState).mockImplementation(() => { diff --git a/extensions/cli/src/slashCommands.test.ts b/extensions/cli/src/slashCommands.test.ts index 30df8e832ba..fee0939dce3 100644 --- a/extensions/cli/src/slashCommands.test.ts +++ b/extensions/cli/src/slashCommands.test.ts @@ -1,14 +1,15 @@ import type { AssistantUnrolled } from "@continuedev/config-yaml"; import { + beforeEach, describe, - it, expect, + it, vi, - beforeEach, type MockedFunction, } from "vitest"; -import type { AuthConfig, AuthenticatedConfig } from "./auth/workos.js"; +import { AuthenticatedConfig } from "./auth/workos-types.js"; +import type { AuthConfig } from "./auth/workos.js"; import type { ConfigServiceState } from "./services/types.js"; import { handleSlashCommands } from "./slashCommands.js"; @@ -128,7 +129,7 @@ describe("slashCommands", () => { ( isAuthenticated as MockedFunction - ).mockReturnValue(false); + ).mockResolvedValue(false); ( services.config.getState as MockedFunction< typeof services.config.getState @@ -157,7 +158,7 @@ describe("slashCommands", () => { ( isAuthenticated as MockedFunction - ).mockReturnValue(true); + ).mockResolvedValue(true); (loadAuthConfig as MockedFunction).mockReturnValue( {} as AuthConfig, ); @@ -199,7 +200,7 @@ describe("slashCommands", () => { ( isAuthenticated as MockedFunction - ).mockReturnValue(true); + ).mockResolvedValue(true); (loadAuthConfig as MockedFunction).mockReturnValue( mockAuthConfig, ); @@ -230,7 +231,7 @@ describe("slashCommands", () => { ( isAuthenticated as MockedFunction - ).mockReturnValue(false); + ).mockResolvedValue(false); ( services.config.getState as MockedFunction< typeof services.config.getState @@ -263,7 +264,7 @@ describe("slashCommands", () => { ( isAuthenticated as MockedFunction - ).mockReturnValue(false); + ).mockResolvedValue(false); ( services.config.getState as MockedFunction< typeof services.config.getState diff --git a/extensions/cli/src/slashCommands.ts b/extensions/cli/src/slashCommands.ts index 503d139014c..c28390d0cec 100644 --- a/extensions/cli/src/slashCommands.ts +++ b/extensions/cli/src/slashCommands.ts @@ -93,9 +93,10 @@ async function handleLogout() { } } -function handleWhoami() { - if (isAuthenticated()) { - const config = loadAuthConfig(); +async function handleWhoami() { + const authed = await isAuthenticated(); + if (authed) { + const config = loadAuthConfig(); // TODO duplicate auth config loading if (config && isAuthenticatedConfig(config)) { return { exit: false, diff --git a/packages/config-yaml/src/markdown/agentFiles.test.ts b/packages/config-yaml/src/markdown/agentFiles.test.ts index 5f1a951af1c..20763ff0e49 100644 --- a/packages/config-yaml/src/markdown/agentFiles.test.ts +++ b/packages/config-yaml/src/markdown/agentFiles.test.ts @@ -486,6 +486,258 @@ describe("parseAgentFileTools", () => { }); }); + describe("URL-based MCP references", () => { + it("should parse HTTPS URL without tool name", () => { + const result = parseAgentFileTools("https://mcp.example.com"); + expect(result).toEqual({ + tools: [{ mcpServer: "https://mcp.example.com" }], + mcpServers: ["https://mcp.example.com"], + allBuiltIn: false, + }); + }); + + it("should parse HTTP URL without tool name", () => { + const result = parseAgentFileTools("http://mcp.example.com"); + expect(result).toEqual({ + tools: [{ mcpServer: "http://mcp.example.com" }], + mcpServers: ["http://mcp.example.com"], + allBuiltIn: false, + }); + }); + + it("should parse HTTPS URL with port", () => { + const result = parseAgentFileTools("https://mcp.example.com:8080"); + expect(result).toEqual({ + tools: [{ mcpServer: "https://mcp.example.com:8080" }], + mcpServers: ["https://mcp.example.com:8080"], + allBuiltIn: false, + }); + }); + + it("should parse HTTPS URL with tool name", () => { + const result = parseAgentFileTools("https://mcp.example.com:tool_name"); + expect(result).toEqual({ + tools: [ + { mcpServer: "https://mcp.example.com", toolName: "tool_name" }, + ], + mcpServers: ["https://mcp.example.com"], + allBuiltIn: false, + }); + }); + + it("should parse HTTP URL with tool name", () => { + const result = parseAgentFileTools("http://mcp.example.com:my_tool"); + expect(result).toEqual({ + tools: [{ mcpServer: "http://mcp.example.com", toolName: "my_tool" }], + mcpServers: ["http://mcp.example.com"], + allBuiltIn: false, + }); + }); + + it("should parse HTTPS URL with port and tool name", () => { + const result = parseAgentFileTools( + "https://mcp.example.com:8080:tool_name", + ); + expect(result).toEqual({ + tools: [ + { mcpServer: "https://mcp.example.com:8080", toolName: "tool_name" }, + ], + mcpServers: ["https://mcp.example.com:8080"], + allBuiltIn: false, + }); + }); + + it("should parse URL with path", () => { + const result = parseAgentFileTools("https://api.example.com/mcp"); + expect(result).toEqual({ + tools: [{ mcpServer: "https://api.example.com/mcp" }], + mcpServers: ["https://api.example.com/mcp"], + allBuiltIn: false, + }); + }); + + it("should parse URL with path and tool name", () => { + const result = parseAgentFileTools( + "https://api.example.com/mcp:tool_name", + ); + expect(result).toEqual({ + tools: [ + { mcpServer: "https://api.example.com/mcp", toolName: "tool_name" }, + ], + mcpServers: ["https://api.example.com/mcp"], + allBuiltIn: false, + }); + }); + + it("should parse URL with query parameters", () => { + const result = parseAgentFileTools( + "https://api.example.com/mcp?key=value", + ); + expect(result).toEqual({ + tools: [{ mcpServer: "https://api.example.com/mcp?key=value" }], + mcpServers: ["https://api.example.com/mcp?key=value"], + allBuiltIn: false, + }); + }); + + it("should parse URL with query parameters and tool name", () => { + const result = parseAgentFileTools( + "https://api.example.com/mcp?key=value:tool_name", + ); + expect(result).toEqual({ + tools: [ + { + mcpServer: "https://api.example.com/mcp?key=value", + toolName: "tool_name", + }, + ], + mcpServers: ["https://api.example.com/mcp?key=value"], + allBuiltIn: false, + }); + }); + + it("should parse tool names with hyphens", () => { + const result = parseAgentFileTools("https://mcp.example.com:read-alerts"); + expect(result).toEqual({ + tools: [ + { mcpServer: "https://mcp.example.com", toolName: "read-alerts" }, + ], + mcpServers: ["https://mcp.example.com"], + allBuiltIn: false, + }); + }); + + it("should parse tool names with mixed alphanumeric and special chars", () => { + const result = parseAgentFileTools( + "https://mcp.example.com:tool_name-123", + ); + expect(result).toEqual({ + tools: [ + { mcpServer: "https://mcp.example.com", toolName: "tool_name-123" }, + ], + mcpServers: ["https://mcp.example.com"], + allBuiltIn: false, + }); + }); + + it("should parse multiple URL-based MCP references", () => { + const result = parseAgentFileTools( + "https://mcp1.example.com:tool1, https://mcp2.example.com:tool2", + ); + expect(result).toEqual({ + tools: [ + { mcpServer: "https://mcp1.example.com", toolName: "tool1" }, + { mcpServer: "https://mcp2.example.com", toolName: "tool2" }, + ], + mcpServers: ["https://mcp1.example.com", "https://mcp2.example.com"], + allBuiltIn: false, + }); + }); + + it("should parse mixed URL and slug-based MCP references", () => { + const result = parseAgentFileTools( + "https://mcp.example.com:tool1, owner/package:tool2, bash", + ); + expect(result).toEqual({ + tools: [ + { mcpServer: "https://mcp.example.com", toolName: "tool1" }, + { mcpServer: "owner/package", toolName: "tool2" }, + { toolName: "bash" }, + ], + mcpServers: ["https://mcp.example.com", "owner/package"], + allBuiltIn: false, + }); + }); + + it("should deduplicate URL-based MCP servers", () => { + const result = parseAgentFileTools( + "https://mcp.example.com:tool1, https://mcp.example.com:tool2, https://mcp.example.com", + ); + expect(result).toEqual({ + tools: [ + { mcpServer: "https://mcp.example.com", toolName: "tool1" }, + { mcpServer: "https://mcp.example.com", toolName: "tool2" }, + { mcpServer: "https://mcp.example.com" }, + ], + mcpServers: ["https://mcp.example.com"], + allBuiltIn: false, + }); + }); + + it("should reject URL with whitespace in tool name", () => { + expect(() => + parseAgentFileTools("https://mcp.example.com:tool name"), + ).toThrow( + 'Invalid URL-based MCP tool reference "https://mcp.example.com:tool name": the part after the last colon must be either a port number or a valid tool name', + ); + }); + + it("should reject URL with invalid characters after colon", () => { + expect(() => + parseAgentFileTools("https://mcp.example.com:invalid@tool"), + ).toThrow( + 'Invalid URL-based MCP tool reference "https://mcp.example.com:invalid@tool": the part after the last colon must be either a port number or a valid tool name', + ); + }); + + it("should reject URL with whitespace before colon", () => { + expect(() => + parseAgentFileTools("https://mcp.example.com :tool"), + ).toThrow( + 'Invalid MCP tool reference "https://mcp.example.com :tool": colon-separated tool references cannot contain whitespace', + ); + }); + + it("should handle URL-based MCP with built_in keyword", () => { + const result = parseAgentFileTools( + "built_in, https://mcp.example.com:tool1", + ); + expect(result).toEqual({ + tools: [{ mcpServer: "https://mcp.example.com", toolName: "tool1" }], + mcpServers: ["https://mcp.example.com"], + allBuiltIn: true, + }); + }); + + it("should parse localhost URLs", () => { + const result = parseAgentFileTools("http://localhost:3000"); + expect(result).toEqual({ + tools: [{ mcpServer: "http://localhost:3000" }], + mcpServers: ["http://localhost:3000"], + allBuiltIn: false, + }); + }); + + it("should parse localhost URLs with tool name", () => { + const result = parseAgentFileTools("http://localhost:3000:my_tool"); + expect(result).toEqual({ + tools: [{ mcpServer: "http://localhost:3000", toolName: "my_tool" }], + mcpServers: ["http://localhost:3000"], + allBuiltIn: false, + }); + }); + + it("should parse IP address URLs", () => { + const result = parseAgentFileTools("https://192.168.1.1:8080"); + expect(result).toEqual({ + tools: [{ mcpServer: "https://192.168.1.1:8080" }], + mcpServers: ["https://192.168.1.1:8080"], + allBuiltIn: false, + }); + }); + + it("should parse IP address URLs with tool name", () => { + const result = parseAgentFileTools("https://192.168.1.1:8080:tool_name"); + expect(result).toEqual({ + tools: [ + { mcpServer: "https://192.168.1.1:8080", toolName: "tool_name" }, + ], + mcpServers: ["https://192.168.1.1:8080"], + allBuiltIn: false, + }); + }); + }); + describe("whitespace validation in colon-separated MCP tool references", () => { it("should reject MCP tool reference with space after colon", () => { expect(() => parseAgentFileTools("owner/slug: tool")).toThrow( @@ -601,4 +853,53 @@ describe("parseAgentFileTools", () => { expect(() => parseAgentFileTools("my tool")).not.toThrow(); }); }); + + describe("edge cases", () => { + it("should handle empty tool names correctly", () => { + const result = parseAgentFileTools("owner/package:,https://example.com:"); + expect(result.tools).toEqual([ + { mcpServer: "owner/package", toolName: "" }, + { mcpServer: "https://example.com", toolName: "" }, + ]); + }); + + it("should handle trailing commas", () => { + const result = parseAgentFileTools("bash, owner/package,"); + expect(result.tools).toEqual([ + { toolName: "bash" }, + { mcpServer: "owner/package" }, + ]); + }); + + it("should handle multiple consecutive commas", () => { + const result = parseAgentFileTools("bash,,owner/package"); + expect(result.tools).toEqual([ + { toolName: "bash" }, + { mcpServer: "owner/package" }, + ]); + }); + + it("should parse complex real-world example", () => { + const result = parseAgentFileTools( + "built_in, linear/mcp, sentry/mcp:read-alerts, https://custom-mcp.internal.com:8443:custom_tool, bash", + ); + expect(result).toEqual({ + tools: [ + { mcpServer: "linear/mcp" }, + { mcpServer: "sentry/mcp", toolName: "read-alerts" }, + { + mcpServer: "https://custom-mcp.internal.com:8443", + toolName: "custom_tool", + }, + { toolName: "bash" }, + ], + mcpServers: [ + "linear/mcp", + "sentry/mcp", + "https://custom-mcp.internal.com:8443", + ], + allBuiltIn: true, + }); + }); + }); }); diff --git a/packages/config-yaml/src/markdown/agentFiles.ts b/packages/config-yaml/src/markdown/agentFiles.ts index 4450cda0c7b..e5e0cf5041f 100644 --- a/packages/config-yaml/src/markdown/agentFiles.ts +++ b/packages/config-yaml/src/markdown/agentFiles.ts @@ -23,7 +23,7 @@ export type AgentFile = z.infer; * Parsed agent tool reference */ export interface AgentToolReference { - /** MCP server slug (owner/package) if this is an MCP tool */ + /** MCP server slug (owner/package) or URL (https://...) if this is an MCP tool */ mcpServer?: string; /** Specific tool name - either MCP tool name or built-in tool name */ toolName?: string; @@ -91,6 +91,8 @@ export function serializeAgentFile(agentFile: AgentFile): string { * Supports formats: * - owner/package - all tools from MCP server * - owner/package:tool_name - specific tool from MCP server + * - https://mcp.url.com or http://mcp.url.com - all tools from URL-based MCP server + * - https://mcp.url.com:tool_name - specific tool from URL-based MCP server * - ToolName or tool_name - built-in tool * - built_in - all built-in tools * @@ -115,6 +117,53 @@ export function parseAgentFileTools(toolsString?: string): ParsedAgentTools { if (toolRef === "built_in") { // Special keyword for all built-in tools allBuiltIn = true; + } else if ( + toolRef.startsWith("http://") || + toolRef.startsWith("https://") + ) { + // URL-based MCP tool reference: "https://mcp.url.com" or "https://mcp.url.com:tool_name" + const protocolEndIndex = toolRef.indexOf("://") + 3; + const lastColonIndex = toolRef.lastIndexOf(":"); + + // Check if there's a colon after the protocol + if (lastColonIndex > protocolEndIndex) { + const afterLastColon = toolRef.substring(lastColonIndex + 1); + + // Check if it's a port number (only digits), empty string, or a tool name + if (/^\d+(?:$|[/?#])/.test(afterLastColon)) { + // It's a port number, treat the whole thing as the server + const mcpServer = toolRef; + tools.push({ mcpServer }); + mcpServerSet.add(mcpServer); + } else if ( + afterLastColon === "" || + /^[a-zA-Z0-9_-]+$/.test(afterLastColon) + ) { + // It's a tool name (or empty string) + // Reject references with whitespace to prevent silent misconfigurations + if (/\s/.test(toolRef)) { + throw new Error( + `Invalid MCP tool reference "${toolRef}": colon-separated tool references cannot contain whitespace. ` + + `Use format "https://server:tool_name" without spaces.`, + ); + } + + const mcpServer = toolRef.substring(0, lastColonIndex); + const toolName = afterLastColon; + + tools.push({ mcpServer, toolName }); + mcpServerSet.add(mcpServer); + } else { + throw new Error( + `Invalid URL-based MCP tool reference "${toolRef}": the part after the last colon must be either a port number or a valid tool name (alphanumeric, underscores, hyphens).`, + ); + } + } else { + // No colon after the protocol, treat the whole thing as the server + const mcpServer = toolRef; + tools.push({ mcpServer }); + mcpServerSet.add(mcpServer); + } } else if (toolRef.includes("/")) { // MCP tool reference: "owner/package" or "owner/package:tool_name" const colonIndex = toolRef.indexOf(":");