Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 83 additions & 20 deletions pkg/vmcp/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,15 @@ func defaultClientFactory(ctx context.Context, target *vmcp.BackendTarget) (*cli
return nil, fmt.Errorf("failed to start client connection: %w", err)
}

// Initialize the MCP connection
if err := initializeClient(ctx, c); err != nil {
_ = c.Close()
return nil, fmt.Errorf("failed to initialize MCP connection: %w", err)
}

// Note: Initialization is deferred to the caller (e.g., ListCapabilities)
// so that ServerCapabilities can be captured and used for conditional querying
return c, nil
}

// initializeClient performs MCP protocol initialization handshake.
func initializeClient(ctx context.Context, c *client.Client) error {
_, err := c.Initialize(ctx, mcp.InitializeRequest{
// initializeClient performs MCP protocol initialization handshake and returns server capabilities.
// This allows the caller to determine which optional features the server supports.
func initializeClient(ctx context.Context, c *client.Client) (*mcp.ServerCapabilities, error) {
result, err := c.Initialize(ctx, mcp.InitializeRequest{
Params: mcp.InitializeParams{
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
ClientInfo: mcp.Implementation{
Expand All @@ -146,37 +143,88 @@ func initializeClient(ctx context.Context, c *client.Client) error {
},
},
})
return err
if err != nil {
return nil, err
}
return &result.Capabilities, nil
}

// queryTools queries tools from a backend if the server advertises tool support.
func queryTools(ctx context.Context, c *client.Client, supported bool, backendID string) (*mcp.ListToolsResult, error) {
if supported {
result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
if err != nil {
return nil, fmt.Errorf("failed to list tools from backend %s: %w", backendID, err)
}
return result, nil
}
logger.Debugf("Backend %s does not advertise tools capability, skipping tools query", backendID)
return &mcp.ListToolsResult{Tools: []mcp.Tool{}}, nil
}

// queryResources queries resources from a backend if the server advertises resource support.
func queryResources(ctx context.Context, c *client.Client, supported bool, backendID string) (*mcp.ListResourcesResult, error) {
if supported {
result, err := c.ListResources(ctx, mcp.ListResourcesRequest{})
if err != nil {
return nil, fmt.Errorf("failed to list resources from backend %s: %w", backendID, err)
}
return result, nil
}
logger.Debugf("Backend %s does not advertise resources capability, skipping resources query", backendID)
return &mcp.ListResourcesResult{Resources: []mcp.Resource{}}, nil
}

// queryPrompts queries prompts from a backend if the server advertises prompt support.
func queryPrompts(ctx context.Context, c *client.Client, supported bool, backendID string) (*mcp.ListPromptsResult, error) {
if supported {
result, err := c.ListPrompts(ctx, mcp.ListPromptsRequest{})
if err != nil {
return nil, fmt.Errorf("failed to list prompts from backend %s: %w", backendID, err)
}
return result, nil
}
logger.Debugf("Backend %s does not advertise prompts capability, skipping prompts query", backendID)
return &mcp.ListPromptsResult{Prompts: []mcp.Prompt{}}, nil
}

// ListCapabilities queries a backend for its MCP capabilities.
// Returns tools, resources, and prompts exposed by the backend.
// Only queries capabilities that the server advertises during initialization.
func (h *httpBackendClient) ListCapabilities(ctx context.Context, target *vmcp.BackendTarget) (*vmcp.CapabilityList, error) {
logger.Debugf("Querying capabilities from backend %s (%s)", target.WorkloadName, target.BaseURL)

// Create a client for this backend
// Create a client for this backend (not yet initialized)
c, err := h.clientFactory(ctx, target)
if err != nil {
return nil, fmt.Errorf("failed to create client for backend %s: %w", target.WorkloadID, err)
}
defer c.Close()

// Query tools
toolsResp, err := c.ListTools(ctx, mcp.ListToolsRequest{})
// Initialize the client and get server capabilities
serverCaps, err := initializeClient(ctx, c)
if err != nil {
return nil, fmt.Errorf("failed to list tools from backend %s: %w", target.WorkloadID, err)
return nil, fmt.Errorf("failed to initialize client for backend %s: %w", target.WorkloadID, err)
}

// Query resources
resourcesResp, err := c.ListResources(ctx, mcp.ListResourcesRequest{})
logger.Debugf("Backend %s capabilities: tools=%v, resources=%v, prompts=%v",
target.WorkloadID, serverCaps.Tools != nil, serverCaps.Resources != nil, serverCaps.Prompts != nil)

// Query each capability type based on server advertisement
// Check for nil BEFORE passing to functions to avoid interface{} nil pointer issues
toolsResp, err := queryTools(ctx, c, serverCaps.Tools != nil, target.WorkloadID)
if err != nil {
return nil, fmt.Errorf("failed to list resources from backend %s: %w", target.WorkloadID, err)
return nil, err
}

// Query prompts
promptsResp, err := c.ListPrompts(ctx, mcp.ListPromptsRequest{})
resourcesResp, err := queryResources(ctx, c, serverCaps.Resources != nil, target.WorkloadID)
if err != nil {
return nil, fmt.Errorf("failed to list prompts from backend %s: %w", target.WorkloadID, err)
return nil, err
}

promptsResp, err := queryPrompts(ctx, c, serverCaps.Prompts != nil, target.WorkloadID)
if err != nil {
return nil, err
}

// Convert MCP types to vmcp types
Expand Down Expand Up @@ -266,6 +314,11 @@ func (h *httpBackendClient) CallTool(
}
defer c.Close()

// Initialize the client
if _, err := initializeClient(ctx, c); err != nil {
return nil, fmt.Errorf("failed to initialize client for backend %s: %w", target.WorkloadID, err)
}

// Call the tool
result, err := c.CallTool(ctx, mcp.CallToolRequest{
Params: mcp.CallToolParams{
Expand Down Expand Up @@ -337,6 +390,11 @@ func (h *httpBackendClient) ReadResource(ctx context.Context, target *vmcp.Backe
}
defer c.Close()

// Initialize the client
if _, err := initializeClient(ctx, c); err != nil {
return nil, fmt.Errorf("failed to initialize client for backend %s: %w", target.WorkloadID, err)
}

// Read the resource
result, err := c.ReadResource(ctx, mcp.ReadResourceRequest{
Params: mcp.ReadResourceParams{
Expand Down Expand Up @@ -387,6 +445,11 @@ func (h *httpBackendClient) GetPrompt(
}
defer c.Close()

// Initialize the client
if _, err := initializeClient(ctx, c); err != nil {
return "", fmt.Errorf("failed to initialize client for backend %s: %w", target.WorkloadID, err)
}

// Get the prompt
// Convert map[string]any to map[string]string
stringArgs := make(map[string]string)
Expand Down
31 changes: 31 additions & 0 deletions pkg/vmcp/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,37 @@ func TestHTTPBackendClient_ListCapabilities_WithMockFactory(t *testing.T) {
})
}

func TestQueryHelpers_PartialCapabilities(t *testing.T) {
t.Parallel()

t.Run("queryTools with unsupported capability returns empty slice", func(t *testing.T) {
t.Parallel()

result, err := queryTools(context.Background(), nil, false, "test-backend")
require.NoError(t, err)
assert.NotNil(t, result)
assert.Empty(t, result.Tools)
})

t.Run("queryResources with unsupported capability returns empty slice", func(t *testing.T) {
t.Parallel()

result, err := queryResources(context.Background(), nil, false, "test-backend")
require.NoError(t, err)
assert.NotNil(t, result)
assert.Empty(t, result.Resources)
})

t.Run("queryPrompts with unsupported capability returns empty slice", func(t *testing.T) {
t.Parallel()

result, err := queryPrompts(context.Background(), nil, false, "test-backend")
require.NoError(t, err)
assert.NotNil(t, result)
assert.Empty(t, result.Prompts)
})
}

func TestDefaultClientFactory_UnsupportedTransport(t *testing.T) {
t.Parallel()

Expand Down
Loading