diff --git a/.codespellrc b/.codespellrc index e83793750..26b1e785d 100644 --- a/.codespellrc +++ b/.codespellrc @@ -1,3 +1,3 @@ [codespell] -ignore-words-list = NotIn,notin,AfterAll,ND,aks,deriver +ignore-words-list = NotIn,notin,AfterAll,ND,aks,deriver,te skip = *.svg,*.mod,*.sum diff --git a/pkg/vmcp/composer/security_test.go b/pkg/vmcp/composer/security_test.go new file mode 100644 index 000000000..69c550b76 --- /dev/null +++ b/pkg/vmcp/composer/security_test.go @@ -0,0 +1,226 @@ +package composer + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/vmcp" +) + +// TestTemplateExpander_DepthLimit tests protection against deeply nested structures. +func TestTemplateExpander_DepthLimit(t *testing.T) { + t.Parallel() + + // Create deeply nested structure exceeding maxTemplateDepth + deeplyNested := make(map[string]any) + current := deeplyNested + for i := 0; i < 150; i++ { + nested := make(map[string]any) + current["nested"] = nested + current = nested + } + current["value"] = "{{.params.test}}" + + expander := NewTemplateExpander() + _, err := expander.Expand(context.Background(), map[string]any{"deep": deeplyNested}, newWorkflowContext(map[string]any{"test": "value"})) + + require.Error(t, err) + assert.Contains(t, err.Error(), "depth limit exceeded") +} + +// TestTemplateExpander_OutputSizeLimit tests protection against large outputs. +func TestTemplateExpander_OutputSizeLimit(t *testing.T) { + t.Parallel() + + largeString := strings.Repeat("A", 11*1024*1024) // 11 MB (exceeds 10 MB limit) + expander := NewTemplateExpander() + + _, err := expander.Expand(context.Background(), + map[string]any{"output": "{{.params.large}}"}, + newWorkflowContext(map[string]any{"large": largeString})) + + require.Error(t, err) + assert.Contains(t, err.Error(), "template output too large") +} + +// TestWorkflowEngine_MaxStepsValidation tests protection against excessive steps. +func TestWorkflowEngine_MaxStepsValidation(t *testing.T) { + t.Parallel() + te := newTestEngine(t) + + steps := make([]WorkflowStep, 150) // Exceeds maxWorkflowSteps (100) + for i := range steps { + steps[i] = toolStep(fmt.Sprintf("s%d", i), "test.tool", nil) + } + + err := te.Engine.ValidateWorkflow(context.Background(), &WorkflowDefinition{Name: "test", Steps: steps}) + + require.Error(t, err) + assert.Contains(t, err.Error(), "too many steps") +} + +// TestWorkflowEngine_RetryCountCapping tests that retries are capped at maximum. +func TestWorkflowEngine_RetryCountCapping(t *testing.T) { + t.Parallel() + te := newTestEngine(t) + + def := &WorkflowDefinition{ + Name: "retry-test", + Steps: []WorkflowStep{{ + ID: "flaky", + Type: StepTypeTool, + Tool: "test.tool", + OnError: &ErrorHandler{ + Action: "retry", + RetryCount: 1000, // Should be capped at maxRetryCount (10) + RetryDelay: 1 * time.Millisecond, + }, + }}, + Timeout: 5 * time.Second, + } + + target := &vmcp.BackendTarget{WorkloadID: "test", BaseURL: "http://test:8080"} + te.Router.EXPECT().RouteTool(gomock.Any(), "test.tool").Return(target, nil) + + callCount := 0 + te.Backend.EXPECT().CallTool(gomock.Any(), target, "test.tool", gomock.Any()). + DoAndReturn(func(context.Context, *vmcp.BackendTarget, string, map[string]any) (map[string]any, error) { + callCount++ + return nil, fmt.Errorf("fail") + }).MaxTimes(12) // 1 initial + 10 retries max + + result, err := execute(t, te.Engine, def, nil) + + require.Error(t, err) + assert.Equal(t, maxRetryCount, callCount-1) + assert.LessOrEqual(t, result.Steps["flaky"].RetryCount, maxRetryCount) +} + +// TestTemplateExpander_NoCodeExecution tests that templates cannot execute code. +func TestTemplateExpander_NoCodeExecution(t *testing.T) { + t.Parallel() + + malicious := []string{ + "{{exec \"rm -rf /\"}}", + "{{system \"whoami\"}}", + "{{eval \"code\"}}", + "{{import \"os\"}}", + "{{.Execute \"danger\"}}", + } + + expander := NewTemplateExpander() + ctx := newWorkflowContext(map[string]any{"test": "value"}) + + for _, tmpl := range malicious { + t.Run(tmpl, func(t *testing.T) { + t.Parallel() + _, err := expander.Expand(context.Background(), map[string]any{"attempt": tmpl}, ctx) + require.Error(t, err, "malicious template should fail safely") + }) + } +} + +// TestWorkflowEngine_CircularDependencyDetection verifies cycle detection. +func TestWorkflowEngine_CircularDependencyDetection(t *testing.T) { + t.Parallel() + + cycles := []struct { + name string + steps []WorkflowStep + }{ + {"A->B->A", []WorkflowStep{ + toolStepWithDeps("A", "t1", nil, []string{"B"}), + toolStepWithDeps("B", "t2", nil, []string{"A"})}}, + {"A->B->C->A", []WorkflowStep{ + toolStepWithDeps("A", "t1", nil, []string{"C"}), + toolStepWithDeps("B", "t2", nil, []string{"A"}), + toolStepWithDeps("C", "t3", nil, []string{"B"})}}, + {"A->A", []WorkflowStep{toolStepWithDeps("A", "t1", nil, []string{"A"})}}, + } + + te := newTestEngine(t) + + for _, tc := range cycles { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + err := te.Engine.ValidateWorkflow(context.Background(), simpleWorkflow("test", tc.steps...)) + require.Error(t, err) + assert.Contains(t, err.Error(), "circular dependency") + }) + } +} + +// TestWorkflowContext_ConcurrentAccess tests thread-safety. +func TestWorkflowContext_ConcurrentAccess(t *testing.T) { + t.Parallel() + + mgr := newWorkflowContextManager() + done := make(chan bool, 10) + + for i := 0; i < 10; i++ { + go func(id int) { + ctx := mgr.CreateContext(map[string]any{"id": id}) + time.Sleep(time.Millisecond) + retrieved, err := mgr.GetContext(ctx.WorkflowID) + assert.NoError(t, err) + assert.Equal(t, ctx.WorkflowID, retrieved.WorkflowID) + mgr.DeleteContext(ctx.WorkflowID) + done <- true + }(i) + } + + for i := 0; i < 10; i++ { + <-done + } +} + +// TestTemplateExpander_SafeFunctions verifies only safe functions are available. +func TestTemplateExpander_SafeFunctions(t *testing.T) { + t.Parallel() + + safe := map[string]string{ + "json": `{{json .params.obj}}`, + "quote": `{{quote .params.str}}`, + } + + expander := NewTemplateExpander() + ctx := newWorkflowContext(map[string]any{"obj": map[string]any{"k": "v"}, "str": "test"}) + + for name, tmpl := range safe { + t.Run(name, func(t *testing.T) { + t.Parallel() + result, err := expander.Expand(context.Background(), map[string]any{"data": tmpl}, ctx) + require.NoError(t, err) + assert.NotNil(t, result) + }) + } +} + +// TestWorkflowEngine_NoSensitiveDataInErrors tests error sanitization. +func TestWorkflowEngine_NoSensitiveDataInErrors(t *testing.T) { + t.Parallel() + te := newTestEngine(t) + + def := simpleWorkflow("auth", toolStep("login", "auth.login", map[string]any{ + "username": "{{.params.username}}", + "password": "{{.params.password}}", + })) + + te.expectToolCallWithAnyArgsAndError("auth.login", fmt.Errorf("auth failed")) + + _, err := execute(t, te.Engine, def, map[string]any{ + "username": "admin", + "password": "supersecret123", + }) + + require.Error(t, err) + assert.NotContains(t, err.Error(), "supersecret123") + assert.NotContains(t, err.Error(), "password") +} diff --git a/pkg/vmcp/composer/template_expander.go b/pkg/vmcp/composer/template_expander.go new file mode 100644 index 000000000..aaefe2f61 --- /dev/null +++ b/pkg/vmcp/composer/template_expander.go @@ -0,0 +1,221 @@ +// Package composer provides composite tool workflow execution for Virtual MCP Server. +package composer + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "text/template" +) + +const ( + // maxTemplateDepth is the maximum recursion depth for template expansion. + // This prevents stack overflow from deeply nested objects. + maxTemplateDepth = 100 + + // maxTemplateOutputSize is the maximum size in bytes for template expansion output. + // This prevents memory exhaustion from maliciously large template outputs. + maxTemplateOutputSize = 10 * 1024 * 1024 // 10 MB +) + +// defaultTemplateExpander implements TemplateExpander using Go's text/template. +type defaultTemplateExpander struct { + // funcMap provides custom template functions. + funcMap template.FuncMap +} + +// NewTemplateExpander creates a new template expander. +func NewTemplateExpander() TemplateExpander { + return &defaultTemplateExpander{ + funcMap: template.FuncMap{ + "json": jsonEncode, + "quote": func(s string) string { + return fmt.Sprintf("%q", s) + }, + }, + } +} + +// Expand evaluates templates in the given data using the workflow context. +// It recursively processes all string values and expands templates. +func (e *defaultTemplateExpander) Expand( + ctx context.Context, + data map[string]any, + workflowCtx *WorkflowContext, +) (map[string]any, error) { + if data == nil { + return nil, nil + } + + result := make(map[string]any, len(data)) + for key, value := range data { + expanded, err := e.expandValue(ctx, value, workflowCtx) + if err != nil { + return nil, fmt.Errorf("failed to expand value for key %q: %w", key, err) + } + result[key] = expanded + } + + return result, nil +} + +// expandValue recursively expands templates in a value. +func (e *defaultTemplateExpander) expandValue( + ctx context.Context, + value any, + workflowCtx *WorkflowContext, +) (any, error) { + return e.expandValueWithDepth(ctx, value, workflowCtx, 0) +} + +// expandValueWithDepth recursively expands templates with depth tracking. +func (e *defaultTemplateExpander) expandValueWithDepth( + ctx context.Context, + value any, + workflowCtx *WorkflowContext, + depth int, +) (any, error) { + // Check context cancellation before proceeding + if err := ctx.Err(); err != nil { + return nil, fmt.Errorf("context cancelled during template expansion: %w", err) + } + + // Prevent stack overflow from deeply nested templates + if depth > maxTemplateDepth { + return nil, fmt.Errorf("template expansion depth limit exceeded: %d", maxTemplateDepth) + } + switch v := value.(type) { + case string: + // Expand template string + return e.expandString(ctx, v, workflowCtx) + + case map[string]any: + // Recursively expand nested maps + expanded := make(map[string]any, len(v)) + for key, val := range v { + expandedVal, err := e.expandValueWithDepth(ctx, val, workflowCtx, depth+1) + if err != nil { + return nil, fmt.Errorf("failed to expand nested key %q: %w", key, err) + } + expanded[key] = expandedVal + } + return expanded, nil + + case []any: + // Recursively expand arrays + expanded := make([]any, len(v)) + for i, val := range v { + expandedVal, err := e.expandValueWithDepth(ctx, val, workflowCtx, depth+1) + if err != nil { + return nil, fmt.Errorf("failed to expand array element %d: %w", i, err) + } + expanded[i] = expandedVal + } + return expanded, nil + + default: + // Return other types unchanged (numbers, booleans, nil) + return value, nil + } +} + +// expandString expands a single template string. +func (e *defaultTemplateExpander) expandString( + ctx context.Context, + tmplStr string, + workflowCtx *WorkflowContext, +) (string, error) { + // Check context cancellation before expensive template operations + if err := ctx.Err(); err != nil { + return "", fmt.Errorf("context cancelled before template expansion: %w", err) + } + + // Create template context with params and steps + tmplCtx := map[string]any{ + "params": workflowCtx.Params, + "steps": e.buildStepsContext(workflowCtx), + "vars": workflowCtx.Variables, + } + + // Parse and execute template + tmpl, err := template.New("expand").Funcs(e.funcMap).Parse(tmplStr) + if err != nil { + return "", fmt.Errorf("failed to parse template: %w", err) + } + + var buf bytes.Buffer + // Pre-allocate reasonable buffer size to reduce allocations + buf.Grow(1024) + + if err := tmpl.Execute(&buf, tmplCtx); err != nil { + return "", fmt.Errorf("failed to execute template: %w", err) + } + + // Enforce output size limit to prevent memory exhaustion + if buf.Len() > maxTemplateOutputSize { + return "", fmt.Errorf("template output too large: %d bytes (max %d)", + buf.Len(), maxTemplateOutputSize) + } + + return buf.String(), nil +} + +// buildStepsContext converts StepResult map to a template-friendly structure. +// This provides access to step outputs via {{.steps.stepid.output.field}}. +func (*defaultTemplateExpander) buildStepsContext(workflowCtx *WorkflowContext) map[string]any { + stepsCtx := make(map[string]any, len(workflowCtx.Steps)) + + for stepID, result := range workflowCtx.Steps { + stepData := map[string]any{ + "status": string(result.Status), + "output": result.Output, + } + + // Add error information if step failed + if result.Error != nil { + stepData["error"] = result.Error.Error() + } + + stepsCtx[stepID] = stepData + } + + return stepsCtx +} + +// EvaluateCondition evaluates a condition template to a boolean. +// The condition string must evaluate to "true" or "false". +func (e *defaultTemplateExpander) EvaluateCondition( + ctx context.Context, + condition string, + workflowCtx *WorkflowContext, +) (bool, error) { + if condition == "" { + return true, nil + } + + // Expand the condition as a template + result, err := e.expandString(ctx, condition, workflowCtx) + if err != nil { + return false, fmt.Errorf("failed to evaluate condition: %w", err) + } + + // Parse as boolean + switch result { + case "true", "True", "TRUE": + return true, nil + case "false", "False", "FALSE": + return false, nil + default: + return false, fmt.Errorf("condition must evaluate to 'true' or 'false', got: %q", result) + } +} + +// jsonEncode is a template function that encodes a value as JSON. +func jsonEncode(v any) (string, error) { + b, err := json.Marshal(v) + if err != nil { + return "", fmt.Errorf("failed to encode JSON: %w", err) + } + return string(b), nil +} diff --git a/pkg/vmcp/composer/template_expander_test.go b/pkg/vmcp/composer/template_expander_test.go new file mode 100644 index 000000000..ebbd055fb --- /dev/null +++ b/pkg/vmcp/composer/template_expander_test.go @@ -0,0 +1,220 @@ +package composer + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTemplateExpander_Expand(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + data map[string]any + params map[string]any + steps map[string]*StepResult + expected map[string]any + wantErr bool + }{ + { + name: "basic param substitution", + data: map[string]any{"title": "Issue: {{.params.title}}"}, + params: map[string]any{"title": "Test"}, + expected: map[string]any{"title": "Issue: Test"}, + }, + { + name: "step output substitution", + data: map[string]any{"msg": "Created: {{.steps.create.output.url}}"}, + params: map[string]any{}, + steps: map[string]*StepResult{ + "create": {Status: StepStatusCompleted, Output: map[string]any{"url": "http://example.com"}}, + }, + expected: map[string]any{"msg": "Created: http://example.com"}, + }, + { + name: "nested objects", + data: map[string]any{"cfg": map[string]any{"repo": "{{.params.repo}}"}}, + params: map[string]any{"repo": "myrepo"}, + expected: map[string]any{"cfg": map[string]any{"repo": "myrepo"}}, + }, + { + name: "arrays", + data: map[string]any{"files": []any{"{{.params.f1}}", "{{.params.f2}}"}}, + params: map[string]any{"f1": "a.go", "f2": "b.go"}, + expected: map[string]any{"files": []any{"a.go", "b.go"}}, + }, + { + name: "mixed types", + data: map[string]any{"title": "{{.params.title}}", "num": 42, "flag": true}, + params: map[string]any{"title": "Test"}, + expected: map[string]any{"title": "Test", "num": 42, "flag": true}, + }, + { + name: "json function", + data: map[string]any{"payload": `{"data": {{json .params.obj}}}`}, + params: map[string]any{"obj": map[string]any{"key": "value"}}, + expected: map[string]any{"payload": `{"data": {"key":"value"}}`}, + }, + { + name: "invalid template", + data: map[string]any{"bad": "{{.params.missing"}, + params: map[string]any{}, + wantErr: true, + }, + { + name: "missing param uses zero value", + data: map[string]any{"val": "{{.params.nonexistent}}"}, + params: map[string]any{}, + expected: map[string]any{"val": ""}, + }, + } + + expander := NewTemplateExpander() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := newWorkflowContext(tt.params) + if tt.steps != nil { + ctx.Steps = tt.steps + } + + result, err := expander.Expand(context.Background(), tt.data, ctx) + if tt.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestTemplateExpander_EvaluateCondition(t *testing.T) { + t.Parallel() + + tests := []struct { + condition string + params map[string]any + steps map[string]*StepResult + expected bool + wantErr bool + }{ + {"", nil, nil, true, false}, // empty = true + {"true", nil, nil, true, false}, + {"false", nil, nil, false, false}, + {"True", nil, nil, true, false}, // case insensitive + {"{{if eq .params.enabled true}}true{{else}}false{{end}}", map[string]any{"enabled": true}, nil, true, false}, + {"{{if eq .params.enabled true}}true{{else}}false{{end}}", map[string]any{"enabled": false}, nil, false, false}, + {"{{if eq .steps.s1.status \"completed\"}}true{{else}}false{{end}}", nil, + map[string]*StepResult{"s1": {Status: StepStatusCompleted}}, true, false}, + {"not_boolean", nil, nil, false, true}, + {"{{.params.missing", nil, nil, false, true}, + } + + expander := NewTemplateExpander() + + for _, tt := range tests { + t.Run(tt.condition, func(t *testing.T) { + t.Parallel() + + ctx := newWorkflowContext(tt.params) + if tt.steps != nil { + ctx.Steps = tt.steps + } + + result, err := expander.EvaluateCondition(context.Background(), tt.condition, ctx) + if tt.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestWorkflowContext_Lifecycle(t *testing.T) { + t.Parallel() + + ctx := newWorkflowContext(map[string]any{"key": "value"}) + + // Start -> Success + ctx.RecordStepStart("s1") + assert.Equal(t, StepStatusRunning, ctx.Steps["s1"].Status) + + time.Sleep(10 * time.Millisecond) + ctx.RecordStepSuccess("s1", map[string]any{"result": "ok"}) + assert.Equal(t, StepStatusCompleted, ctx.Steps["s1"].Status) + assert.Greater(t, ctx.Steps["s1"].Duration, time.Duration(0)) + + // Start -> Failure + ctx.RecordStepStart("s2") + ctx.RecordStepFailure("s2", assert.AnError) + assert.Equal(t, StepStatusFailed, ctx.Steps["s2"].Status) + assert.True(t, ctx.HasStepFailed("s2")) + + // Skipped + ctx.RecordStepSkipped("s3") + assert.Equal(t, StepStatusSkipped, ctx.Steps["s3"].Status) + + // Check completion status + assert.True(t, ctx.HasStepCompleted("s1")) + assert.False(t, ctx.HasStepCompleted("s2")) + assert.False(t, ctx.HasStepCompleted("s3")) +} + +func TestWorkflowContext_GetLastStepOutput(t *testing.T) { + t.Parallel() + + ctx := newWorkflowContext(nil) + + // No completed steps + assert.Nil(t, ctx.GetLastStepOutput()) + + // Add steps with different completion times + ctx.RecordStepStart("s1") + time.Sleep(5 * time.Millisecond) + ctx.RecordStepSuccess("s1", map[string]any{"order": 1}) + + time.Sleep(5 * time.Millisecond) + ctx.RecordStepStart("s2") + time.Sleep(5 * time.Millisecond) + ctx.RecordStepSuccess("s2", map[string]any{"order": 2}) + + // Should return latest (s2) + output := ctx.GetLastStepOutput() + require.NotNil(t, output) + assert.Equal(t, 2, output["order"]) +} + +func TestWorkflowContext_Clone(t *testing.T) { + t.Parallel() + + original := &WorkflowContext{ + WorkflowID: "test", + Params: map[string]any{"key": "value"}, + Steps: map[string]*StepResult{"s1": {StepID: "s1", Status: StepStatusCompleted}}, + Variables: map[string]any{"var": "val"}, + } + + clone := original.Clone() + + // Verify deep copy + assert.Equal(t, original.WorkflowID, clone.WorkflowID) + assert.Equal(t, original.Params, clone.Params) + + // Modify clone - shouldn't affect original + clone.Params["new"] = "val" + clone.Steps["s2"] = &StepResult{StepID: "s2"} + + assert.NotEqual(t, original.Params, clone.Params) + assert.NotEqual(t, len(original.Steps), len(clone.Steps)) +} diff --git a/pkg/vmcp/composer/testhelpers_test.go b/pkg/vmcp/composer/testhelpers_test.go new file mode 100644 index 000000000..eb16daf76 --- /dev/null +++ b/pkg/vmcp/composer/testhelpers_test.go @@ -0,0 +1,120 @@ +package composer + +import ( + "context" + "testing" + + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/mocks" + routermocks "github.com/stacklok/toolhive/pkg/vmcp/router/mocks" +) + +// testEngine is a test helper that sets up a workflow engine with mocks. +type testEngine struct { + Engine Composer + Router *routermocks.MockRouter + Backend *mocks.MockBackendClient + Ctrl *gomock.Controller +} + +// newTestEngine creates a test engine with mocks. +func newTestEngine(t *testing.T) *testEngine { + t.Helper() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockRouter := routermocks.NewMockRouter(ctrl) + mockBackend := mocks.NewMockBackendClient(ctrl) + engine := NewWorkflowEngine(mockRouter, mockBackend) + + return &testEngine{ + Engine: engine, + Router: mockRouter, + Backend: mockBackend, + Ctrl: ctrl, + } +} + +// expectToolCall is a helper to set up tool call expectations. +func (te *testEngine) expectToolCall(toolName string, args, output map[string]any) { + target := &vmcp.BackendTarget{ + WorkloadID: "test-backend", + WorkloadName: "test", + BaseURL: "http://test:8080", + } + te.Router.EXPECT().RouteTool(gomock.Any(), toolName).Return(target, nil) + te.Backend.EXPECT().CallTool(gomock.Any(), target, toolName, args).Return(output, nil) +} + +// expectToolCallWithError is a helper to set up failing tool call expectations. +func (te *testEngine) expectToolCallWithError(toolName string, args map[string]any, err error) { + target := &vmcp.BackendTarget{ + WorkloadID: "test-backend", + BaseURL: "http://test:8080", + } + te.Router.EXPECT().RouteTool(gomock.Any(), toolName).Return(target, nil) + te.Backend.EXPECT().CallTool(gomock.Any(), target, toolName, args).Return(nil, err) +} + +// expectToolCallWithAnyArgsAndError is a helper for failing calls with any args. +func (te *testEngine) expectToolCallWithAnyArgsAndError(toolName string, err error) { + target := &vmcp.BackendTarget{ + WorkloadID: "test-backend", + BaseURL: "http://test:8080", + } + te.Router.EXPECT().RouteTool(gomock.Any(), toolName).Return(target, nil) + te.Backend.EXPECT().CallTool(gomock.Any(), target, toolName, gomock.Any()).Return(nil, err) +} + +// expectToolCallWithAnyArgs is a helper for calls where args are dynamically generated. +func (te *testEngine) expectToolCallWithAnyArgs(toolName string, output map[string]any) { + target := &vmcp.BackendTarget{ + WorkloadID: "test-backend", + BaseURL: "http://test:8080", + } + te.Router.EXPECT().RouteTool(gomock.Any(), toolName).Return(target, nil) + te.Backend.EXPECT().CallTool(gomock.Any(), target, toolName, gomock.Any()).Return(output, nil) +} + +// newWorkflowContext creates a test workflow context. +func newWorkflowContext(params map[string]any) *WorkflowContext { + return &WorkflowContext{ + WorkflowID: "test-workflow", + Params: params, + Steps: make(map[string]*StepResult), + Variables: make(map[string]any), + } +} + +// toolStep creates a simple tool step for testing. +func toolStep(id, tool string, args map[string]any) WorkflowStep { + return WorkflowStep{ + ID: id, + Type: StepTypeTool, + Tool: tool, + Arguments: args, + } +} + +// toolStepWithDeps creates a tool step with dependencies. +func toolStepWithDeps(id, tool string, args map[string]any, deps []string) WorkflowStep { + step := toolStep(id, tool, args) + step.DependsOn = deps + return step +} + +// simpleWorkflow creates a simple workflow for testing. +func simpleWorkflow(name string, steps ...WorkflowStep) *WorkflowDefinition { + return &WorkflowDefinition{ + Name: name, + Steps: steps, + } +} + +// execute is a helper to execute a workflow. +func execute(t *testing.T, engine Composer, def *WorkflowDefinition, params map[string]any) (*WorkflowResult, error) { + t.Helper() + return engine.ExecuteWorkflow(context.Background(), def, params) +} diff --git a/pkg/vmcp/composer/workflow_context.go b/pkg/vmcp/composer/workflow_context.go new file mode 100644 index 000000000..4ac072d6c --- /dev/null +++ b/pkg/vmcp/composer/workflow_context.go @@ -0,0 +1,174 @@ +// Package composer provides composite tool workflow execution for Virtual MCP Server. +package composer + +import ( + "fmt" + "sync" + "time" + + "github.com/google/uuid" +) + +// workflowContextManager manages workflow execution contexts. +type workflowContextManager struct { + mu sync.RWMutex + contexts map[string]*WorkflowContext +} + +// newWorkflowContextManager creates a new context manager. +func newWorkflowContextManager() *workflowContextManager { + return &workflowContextManager{ + contexts: make(map[string]*WorkflowContext), + } +} + +// CreateContext creates a new workflow context with a unique ID. +func (m *workflowContextManager) CreateContext(params map[string]any) *WorkflowContext { + m.mu.Lock() + defer m.mu.Unlock() + + ctx := &WorkflowContext{ + WorkflowID: uuid.New().String(), + Params: params, + Steps: make(map[string]*StepResult), + Variables: make(map[string]any), + } + + m.contexts[ctx.WorkflowID] = ctx + return ctx +} + +// GetContext retrieves a workflow context by ID. +func (m *workflowContextManager) GetContext(workflowID string) (*WorkflowContext, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + ctx, exists := m.contexts[workflowID] + if !exists { + return nil, fmt.Errorf("workflow context not found: %s", workflowID) + } + + return ctx, nil +} + +// DeleteContext removes a workflow context. +func (m *workflowContextManager) DeleteContext(workflowID string) { + m.mu.Lock() + defer m.mu.Unlock() + + delete(m.contexts, workflowID) +} + +// RecordStepStart records that a step has started execution. +func (ctx *WorkflowContext) RecordStepStart(stepID string) { + ctx.Steps[stepID] = &StepResult{ + StepID: stepID, + Status: StepStatusRunning, + StartTime: time.Now(), + } +} + +// RecordStepSuccess records a successful step completion. +func (ctx *WorkflowContext) RecordStepSuccess(stepID string, output map[string]any) { + if result, exists := ctx.Steps[stepID]; exists { + result.Status = StepStatusCompleted + result.Output = output + result.EndTime = time.Now() + result.Duration = result.EndTime.Sub(result.StartTime) + } +} + +// RecordStepFailure records a step failure. +func (ctx *WorkflowContext) RecordStepFailure(stepID string, err error) { + if result, exists := ctx.Steps[stepID]; exists { + result.Status = StepStatusFailed + result.Error = err + result.EndTime = time.Now() + result.Duration = result.EndTime.Sub(result.StartTime) + } +} + +// RecordStepSkipped records that a step was skipped (condition was false). +func (ctx *WorkflowContext) RecordStepSkipped(stepID string) { + ctx.Steps[stepID] = &StepResult{ + StepID: stepID, + Status: StepStatusSkipped, + StartTime: time.Now(), + EndTime: time.Now(), + } +} + +// GetStepResult retrieves a step result by ID. +func (ctx *WorkflowContext) GetStepResult(stepID string) (*StepResult, bool) { + result, exists := ctx.Steps[stepID] + return result, exists +} + +// HasStepCompleted checks if a step has completed successfully. +func (ctx *WorkflowContext) HasStepCompleted(stepID string) bool { + result, exists := ctx.Steps[stepID] + return exists && result.Status == StepStatusCompleted +} + +// HasStepFailed checks if a step has failed. +func (ctx *WorkflowContext) HasStepFailed(stepID string) bool { + result, exists := ctx.Steps[stepID] + return exists && result.Status == StepStatusFailed +} + +// GetLastStepOutput retrieves the output of the most recently completed step. +// This is useful for getting the final workflow output. +func (ctx *WorkflowContext) GetLastStepOutput() map[string]any { + var lastTime time.Time + var lastOutput map[string]any + + for _, result := range ctx.Steps { + if result.Status == StepStatusCompleted && result.EndTime.After(lastTime) { + lastTime = result.EndTime + lastOutput = result.Output + } + } + + return lastOutput +} + +// Clone creates a shallow copy of the workflow context. +// Maps and step results are cloned, but nested values within maps are shared. +// This is useful for testing and validation. +func (ctx *WorkflowContext) Clone() *WorkflowContext { + clone := &WorkflowContext{ + WorkflowID: ctx.WorkflowID, + Params: cloneMap(ctx.Params), + Steps: make(map[string]*StepResult, len(ctx.Steps)), + Variables: cloneMap(ctx.Variables), + } + + // Clone step results + for stepID, result := range ctx.Steps { + clone.Steps[stepID] = &StepResult{ + StepID: result.StepID, + Status: result.Status, + Output: cloneMap(result.Output), + Error: result.Error, + StartTime: result.StartTime, + EndTime: result.EndTime, + Duration: result.Duration, + RetryCount: result.RetryCount, + } + } + + return clone +} + +// cloneMap creates a shallow copy of a map. +func cloneMap(m map[string]any) map[string]any { + if m == nil { + return nil + } + + clone := make(map[string]any, len(m)) + for k, v := range m { + clone[k] = v + } + return clone +} diff --git a/pkg/vmcp/composer/workflow_engine.go b/pkg/vmcp/composer/workflow_engine.go new file mode 100644 index 000000000..7364121e2 --- /dev/null +++ b/pkg/vmcp/composer/workflow_engine.go @@ -0,0 +1,517 @@ +// Package composer provides composite tool workflow execution for Virtual MCP Server. +package composer + +import ( + "context" + "fmt" + "time" + + "github.com/cenkalti/backoff/v5" + + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/router" +) + +const ( + // defaultWorkflowTimeout is the default maximum execution time for workflows. + defaultWorkflowTimeout = 30 * time.Minute + + // defaultStepTimeout is the default maximum execution time for individual steps. + defaultStepTimeout = 5 * time.Minute + + // maxWorkflowSteps is the maximum number of steps allowed in a workflow. + // This prevents resource exhaustion from maliciously large workflows. + maxWorkflowSteps = 100 + + // maxRetryCount is the maximum number of retries allowed per step. + // This prevents infinite retry loops from malicious configurations. + maxRetryCount = 10 +) + +// workflowEngine implements Composer interface. +type workflowEngine struct { + // router routes tool calls to backend servers. + router router.Router + + // backendClient makes calls to backend MCP servers. + backendClient vmcp.BackendClient + + // templateExpander handles template expansion. + templateExpander TemplateExpander + + // contextManager manages workflow execution contexts. + contextManager *workflowContextManager +} + +// NewWorkflowEngine creates a new workflow execution engine. +func NewWorkflowEngine( + rtr router.Router, + backendClient vmcp.BackendClient, +) Composer { + return &workflowEngine{ + router: rtr, + backendClient: backendClient, + templateExpander: NewTemplateExpander(), + contextManager: newWorkflowContextManager(), + } +} + +// ExecuteWorkflow executes a composite tool workflow. +func (e *workflowEngine) ExecuteWorkflow( + ctx context.Context, + def *WorkflowDefinition, + params map[string]any, +) (*WorkflowResult, error) { + logger.Infof("Starting workflow execution: %s", def.Name) + + // Create workflow context + workflowCtx := e.contextManager.CreateContext(params) + defer e.contextManager.DeleteContext(workflowCtx.WorkflowID) + + // Apply workflow timeout + timeout := def.Timeout + if timeout == 0 { + timeout = defaultWorkflowTimeout + } + execCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + // Create result + result := &WorkflowResult{ + WorkflowID: workflowCtx.WorkflowID, + Status: WorkflowStatusRunning, + Steps: make(map[string]*StepResult), + StartTime: time.Now(), + Metadata: make(map[string]string), + } + + // Execute workflow steps sequentially + for _, step := range def.Steps { + // Check if context was cancelled or timed out + select { + case <-execCtx.Done(): + result.Status = WorkflowStatusTimedOut + result.Error = ErrWorkflowTimeout + result.EndTime = time.Now() + result.Duration = result.EndTime.Sub(result.StartTime) + logger.Warnf("Workflow %s timed out after %v", def.Name, result.Duration) + return result, ErrWorkflowTimeout + default: + } + + // Execute step + stepErr := e.executeStep(execCtx, &step, workflowCtx, def.FailureMode) + + // Copy step result to workflow result + if stepResult, exists := workflowCtx.GetStepResult(step.ID); exists { + result.Steps[step.ID] = stepResult + } + + // Handle step failure + if stepErr != nil { + logger.Errorf("Step %s failed in workflow %s: %v", step.ID, def.Name, stepErr) + + // Check failure mode + if def.FailureMode == "" || def.FailureMode == "abort" { + result.Status = WorkflowStatusFailed + result.Error = NewWorkflowError(workflowCtx.WorkflowID, step.ID, "step failed", stepErr) + result.EndTime = time.Now() + result.Duration = result.EndTime.Sub(result.StartTime) + return result, result.Error + } + + // For "continue" or "best_effort" modes, log and continue + logger.Warnf("Continuing workflow %s despite step %s failure (mode: %s)", + def.Name, step.ID, def.FailureMode) + } + } + + // Workflow completed successfully + result.Status = WorkflowStatusCompleted + result.Output = workflowCtx.GetLastStepOutput() + result.EndTime = time.Now() + result.Duration = result.EndTime.Sub(result.StartTime) + + logger.Infof("Workflow %s completed successfully in %v", def.Name, result.Duration) + return result, nil +} + +// executeStep executes a single workflow step. +func (e *workflowEngine) executeStep( + ctx context.Context, + step *WorkflowStep, + workflowCtx *WorkflowContext, + _ string, // failureMode is handled at workflow level +) error { + logger.Debugf("Executing step: %s (type: %s)", step.ID, step.Type) + + // Record step start + workflowCtx.RecordStepStart(step.ID) + + // Apply step timeout + timeout := step.Timeout + if timeout == 0 { + timeout = defaultStepTimeout + } + stepCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + // Check dependencies + for _, depID := range step.DependsOn { + if !workflowCtx.HasStepCompleted(depID) { + err := fmt.Errorf("%w: step %s depends on %s which hasn't completed", + ErrDependencyNotMet, step.ID, depID) + workflowCtx.RecordStepFailure(step.ID, err) + return err + } + } + + // Evaluate condition + if step.Condition != "" { + shouldExecute, err := e.templateExpander.EvaluateCondition(ctx, step.Condition, workflowCtx) + if err != nil { + condErr := fmt.Errorf("%w: failed to evaluate condition for step %s: %v", + ErrTemplateExpansion, step.ID, err) + workflowCtx.RecordStepFailure(step.ID, condErr) + return condErr + } + if !shouldExecute { + logger.Debugf("Step %s skipped due to condition", step.ID) + workflowCtx.RecordStepSkipped(step.ID) + return nil + } + } + + // Execute based on step type + switch step.Type { + case StepTypeTool: + return e.executeToolStep(stepCtx, step, workflowCtx) + case StepTypeElicitation: + // Elicitation is not implemented in Phase 2 (basic workflow engine) + err := fmt.Errorf("elicitation steps are not yet supported") + workflowCtx.RecordStepFailure(step.ID, err) + return err + case StepTypeConditional: + // Conditional steps are not implemented in Phase 2 + err := fmt.Errorf("conditional steps are not yet supported") + workflowCtx.RecordStepFailure(step.ID, err) + return err + default: + err := fmt.Errorf("unsupported step type: %s", step.Type) + workflowCtx.RecordStepFailure(step.ID, err) + return err + } +} + +// executeToolStep executes a tool step. +func (e *workflowEngine) executeToolStep( + ctx context.Context, + step *WorkflowStep, + workflowCtx *WorkflowContext, +) error { + logger.Debugf("Executing tool step: %s, tool: %s", step.ID, step.Tool) + + // Expand template arguments + expandedArgs, err := e.templateExpander.Expand(ctx, step.Arguments, workflowCtx) + if err != nil { + expandErr := fmt.Errorf("%w: failed to expand arguments for step %s: %v", + ErrTemplateExpansion, step.ID, err) + workflowCtx.RecordStepFailure(step.ID, expandErr) + return expandErr + } + + // Route tool to backend + target, err := e.router.RouteTool(ctx, step.Tool) + if err != nil { + routeErr := fmt.Errorf("failed to route tool %s in step %s: %w", + step.Tool, step.ID, err) + workflowCtx.RecordStepFailure(step.ID, routeErr) + return routeErr + } + + // Call tool with retry logic + output, retryCount, err := e.callToolWithRetry(ctx, target, step, expandedArgs, workflowCtx) + + // Handle result + if err != nil { + return e.handleToolStepFailure(step, workflowCtx, retryCount, err) + } + + return e.handleToolStepSuccess(step, workflowCtx, output, retryCount) +} + +// callToolWithRetry calls a tool with retry logic using exponential backoff. +func (e *workflowEngine) callToolWithRetry( + ctx context.Context, + target *vmcp.BackendTarget, + step *WorkflowStep, + args map[string]any, + _ *WorkflowContext, +) (map[string]any, int, error) { + maxRetries, initialDelay := e.getRetryConfig(step) + + // Configure exponential backoff + expBackoff := backoff.NewExponentialBackOff() + expBackoff.InitialInterval = initialDelay + expBackoff.MaxInterval = 60 * initialDelay // Cap at 60x the initial delay + expBackoff.Reset() + + attemptCount := 0 + operation := func() (map[string]any, error) { + attemptCount++ + output, err := e.backendClient.CallTool(ctx, target, step.Tool, args) + if err != nil { + logger.Warnf("Tool call failed for step %s (attempt %d/%d): %v", + step.ID, attemptCount, maxRetries+1, err) + return nil, err + } + return output, nil + } + + // Execute with retry + // Safe conversion: maxRetries is capped by maxRetryCount constant (10) + output, err := backoff.Retry(ctx, operation, + backoff.WithBackOff(expBackoff), + backoff.WithMaxTries(uint(maxRetries+1)), // #nosec G115 -- +1 because it includes the initial attempt + backoff.WithNotify(func(_ error, duration time.Duration) { + logger.Debugf("Retrying step %s after %v", step.ID, duration) + }), + ) + + return output, attemptCount - 1, err // Return retry count (attempts - 1) +} + +// getRetryConfig extracts retry configuration from step. +func (*workflowEngine) getRetryConfig(step *WorkflowStep) (int, time.Duration) { + retries := 0 + retryDelay := time.Second + + if step.OnError != nil && step.OnError.Action == "retry" { + retries = step.OnError.RetryCount + + // Cap retry count to prevent infinite retry loops + if retries > maxRetryCount { + logger.Warnf("Step %s retry count %d exceeds maximum %d, capping to %d", + step.ID, retries, maxRetryCount, maxRetryCount) + retries = maxRetryCount + } + + if step.OnError.RetryDelay > 0 { + retryDelay = step.OnError.RetryDelay + } + } + + return retries, retryDelay +} + +// handleToolStepFailure handles a failed tool step. +func (*workflowEngine) handleToolStepFailure( + step *WorkflowStep, + workflowCtx *WorkflowContext, + retryCount int, + err error, +) error { + finalErr := fmt.Errorf("%w: tool %s in step %s: %v", + ErrToolCallFailed, step.Tool, step.ID, err) + workflowCtx.RecordStepFailure(step.ID, finalErr) + + // Update retry count + if result, exists := workflowCtx.GetStepResult(step.ID); exists { + result.RetryCount = retryCount + } + + // Check if we should continue on error + if step.OnError != nil && step.OnError.ContinueOnError { + logger.Warnf("Continuing workflow despite step %s failure (continue_on_error=true)", step.ID) + return nil + } + + return finalErr +} + +// handleToolStepSuccess handles a successful tool step. +func (*workflowEngine) handleToolStepSuccess( + step *WorkflowStep, + workflowCtx *WorkflowContext, + output map[string]any, + retryCount int, +) error { + workflowCtx.RecordStepSuccess(step.ID, output) + + // Update retry count + if result, exists := workflowCtx.GetStepResult(step.ID); exists { + result.RetryCount = retryCount + } + + logger.Debugf("Step %s completed successfully", step.ID) + return nil +} + +// ValidateWorkflow checks if a workflow definition is valid. +func (e *workflowEngine) ValidateWorkflow(_ context.Context, def *WorkflowDefinition) error { + if def == nil { + return NewValidationError("workflow", "workflow definition is nil", nil) + } + + // Validate name + if def.Name == "" { + return NewValidationError("name", "workflow name is required", nil) + } + + // Validate steps + if len(def.Steps) == 0 { + return NewValidationError("steps", "workflow must have at least one step", nil) + } + + // Enforce maximum steps limit to prevent resource exhaustion + if len(def.Steps) > maxWorkflowSteps { + return NewValidationError("steps", + fmt.Sprintf("too many steps: %d (max %d)", len(def.Steps), maxWorkflowSteps), + nil) + } + + // Check for duplicate step IDs + stepIDs := make(map[string]bool) + for _, step := range def.Steps { + if step.ID == "" { + return NewValidationError("step.id", "step ID is required", nil) + } + if stepIDs[step.ID] { + return NewValidationError("step.id", + fmt.Sprintf("duplicate step ID: %s", step.ID), nil) + } + stepIDs[step.ID] = true + } + + // Validate dependencies and detect cycles + if err := e.validateDependencies(def.Steps); err != nil { + return err + } + + // Validate step types and configurations + for _, step := range def.Steps { + if err := e.validateStep(&step, stepIDs); err != nil { + return err + } + } + + return nil +} + +// validateDependencies checks for circular dependencies using DFS. +func (*workflowEngine) validateDependencies(steps []WorkflowStep) error { + // Build adjacency list + graph := make(map[string][]string) + for i := range steps { + graph[steps[i].ID] = steps[i].DependsOn + } + + // Track visited and recursion stack + visited := make(map[string]bool) + recStack := make(map[string]bool) + + // DFS to detect cycles + var hasCycle func(string) bool + hasCycle = func(nodeID string) bool { + visited[nodeID] = true + recStack[nodeID] = true + + for _, depID := range graph[nodeID] { + if !visited[depID] { + if hasCycle(depID) { + return true + } + } else if recStack[depID] { + return true + } + } + + recStack[nodeID] = false + return false + } + + // Check each step + for i := range steps { + if !visited[steps[i].ID] { + if hasCycle(steps[i].ID) { + return NewValidationError("dependencies", + fmt.Sprintf("circular dependency detected involving step %s", steps[i].ID), + ErrCircularDependency) + } + } + } + + // Validate dependency references + for i := range steps { + for _, depID := range steps[i].DependsOn { + if !visited[depID] { + return NewValidationError("dependencies", + fmt.Sprintf("step %s depends on non-existent step %s", steps[i].ID, depID), + nil) + } + } + } + + return nil +} + +// validateStep validates a single step configuration. +func (*workflowEngine) validateStep(step *WorkflowStep, validStepIDs map[string]bool) error { + // Validate step type + switch step.Type { + case StepTypeTool: + if step.Tool == "" { + return NewValidationError("step.tool", + fmt.Sprintf("tool name is required for tool step %s", step.ID), + nil) + } + case StepTypeElicitation: + if step.Elicitation == nil { + return NewValidationError("step.elicitation", + fmt.Sprintf("elicitation config is required for elicitation step %s", step.ID), + nil) + } + if step.Elicitation.Message == "" { + return NewValidationError("step.elicitation.message", + fmt.Sprintf("elicitation message is required for step %s", step.ID), + nil) + } + case StepTypeConditional: + // Future: validate conditional step + return NewValidationError("step.type", + fmt.Sprintf("conditional steps are not yet supported (step %s)", step.ID), + nil) + default: + return NewValidationError("step.type", + fmt.Sprintf("invalid step type %q for step %s", step.Type, step.ID), + nil) + } + + // Validate dependencies exist + for _, depID := range step.DependsOn { + if !validStepIDs[depID] { + return NewValidationError("step.depends_on", + fmt.Sprintf("step %s depends on non-existent step %s", step.ID, depID), + nil) + } + } + + return nil +} + +// GetWorkflowStatus returns the current status of a running workflow. +// For Phase 2 (basic workflow engine), this is a placeholder. +func (*workflowEngine) GetWorkflowStatus(_ context.Context, _ string) (*WorkflowStatus, error) { + // In Phase 2, we don't track long-running workflows + // This will be implemented in Phase 3 with persistent state + return nil, fmt.Errorf("workflow status tracking not yet implemented") +} + +// CancelWorkflow cancels a running workflow. +// For Phase 2 (basic workflow engine), this is a placeholder. +func (*workflowEngine) CancelWorkflow(_ context.Context, _ string) error { + // In Phase 2, workflows run synchronously and blocking + // Cancellation will be implemented in Phase 3 + return fmt.Errorf("workflow cancellation not yet implemented") +} diff --git a/pkg/vmcp/composer/workflow_engine_test.go b/pkg/vmcp/composer/workflow_engine_test.go new file mode 100644 index 000000000..6a2f390a8 --- /dev/null +++ b/pkg/vmcp/composer/workflow_engine_test.go @@ -0,0 +1,190 @@ +package composer + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/vmcp" +) + +func TestWorkflowEngine_ExecuteWorkflow_Success(t *testing.T) { + t.Parallel() + te := newTestEngine(t) + + // Two-step workflow: create issue -> add label + def := simpleWorkflow("test-workflow", + toolStep("create_issue", "github.create_issue", map[string]any{ + "title": "{{.params.title}}", + "body": "Test body", + }), + toolStepWithDeps("add_label", "github.add_label", map[string]any{ + "issue": "{{.steps.create_issue.output.number}}", + "label": "bug", + }, []string{"create_issue"}), + ) + + // Expectations + te.expectToolCall("github.create_issue", + map[string]any{"title": "Test Issue", "body": "Test body"}, + map[string]any{"number": 123, "url": "https://github.com/org/repo/issues/123"}) + + te.expectToolCallWithAnyArgs("github.add_label", map[string]any{"success": true}) + + // Execute + result, err := execute(t, te.Engine, def, map[string]any{"title": "Test Issue"}) + + // Verify + require.NoError(t, err) + assert.Equal(t, WorkflowStatusCompleted, result.Status) + assert.Len(t, result.Steps, 2) + assert.Equal(t, StepStatusCompleted, result.Steps["create_issue"].Status) + assert.Equal(t, StepStatusCompleted, result.Steps["add_label"].Status) +} + +func TestWorkflowEngine_ExecuteWorkflow_StepFailure(t *testing.T) { + t.Parallel() + te := newTestEngine(t) + + def := simpleWorkflow("test", toolStep("fail", "test.tool", map[string]any{"p": "v"})) + + te.expectToolCallWithError("test.tool", map[string]any{"p": "v"}, errors.New("tool failed")) + + result, err := execute(t, te.Engine, def, nil) + + require.Error(t, err) + assert.Equal(t, WorkflowStatusFailed, result.Status) + assert.Equal(t, StepStatusFailed, result.Steps["fail"].Status) +} + +func TestWorkflowEngine_ExecuteWorkflow_WithRetry(t *testing.T) { + t.Parallel() + te := newTestEngine(t) + + def := &WorkflowDefinition{ + Name: "retry-test", + Steps: []WorkflowStep{{ + ID: "flaky", + Type: StepTypeTool, + Tool: "test.tool", + OnError: &ErrorHandler{ + Action: "retry", + RetryCount: 2, + RetryDelay: 10 * time.Millisecond, + }, + }}, + } + + target := &vmcp.BackendTarget{WorkloadID: "test", BaseURL: "http://test:8080"} + te.Router.EXPECT().RouteTool(gomock.Any(), "test.tool").Return(target, nil) + + // Fail once, then succeed + gomock.InOrder( + te.Backend.EXPECT().CallTool(gomock.Any(), target, "test.tool", gomock.Any()). + Return(nil, errors.New("temp fail")), + te.Backend.EXPECT().CallTool(gomock.Any(), target, "test.tool", gomock.Any()). + Return(map[string]any{"ok": true}, nil), + ) + + result, err := execute(t, te.Engine, def, nil) + + require.NoError(t, err) + assert.Equal(t, WorkflowStatusCompleted, result.Status) + assert.Equal(t, 1, result.Steps["flaky"].RetryCount) +} + +func TestWorkflowEngine_ExecuteWorkflow_ConditionalSkip(t *testing.T) { + t.Parallel() + te := newTestEngine(t) + + def := &WorkflowDefinition{ + Name: "conditional", + Steps: []WorkflowStep{ + toolStep("always", "test.tool1", nil), + { + ID: "conditional", + Type: StepTypeTool, + Tool: "test.tool2", + Condition: "{{if eq .params.enabled true}}true{{else}}false{{end}}", + }, + }, + } + + te.expectToolCall("test.tool1", nil, map[string]any{"ok": true}) + // tool2 should NOT be called (condition is false) + + result, err := execute(t, te.Engine, def, map[string]any{"enabled": false}) + + require.NoError(t, err) + assert.Equal(t, StepStatusCompleted, result.Steps["always"].Status) + assert.Equal(t, StepStatusSkipped, result.Steps["conditional"].Status) +} + +func TestWorkflowEngine_ValidateWorkflow(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + def *WorkflowDefinition + errMsg string + }{ + {"valid", simpleWorkflow("test", toolStep("s1", "t1", nil)), ""}, + {"nil workflow", nil, "workflow definition is nil"}, + {"missing name", &WorkflowDefinition{Steps: []WorkflowStep{toolStep("s1", "t1", nil)}}, "name is required"}, + {"no steps", &WorkflowDefinition{Name: "test"}, "at least one step"}, + {"duplicate IDs", simpleWorkflow("test", toolStep("s1", "t1", nil), toolStep("s1", "t2", nil)), "duplicate step ID"}, + {"circular deps", simpleWorkflow("test", + toolStepWithDeps("s1", "t1", nil, []string{"s2"}), + toolStepWithDeps("s2", "t2", nil, []string{"s1"})), "circular dependency"}, + {"invalid dep", simpleWorkflow("test", toolStepWithDeps("s1", "t1", nil, []string{"unknown"})), "non-existent"}, + {"too many steps", &WorkflowDefinition{Name: "test", Steps: make([]WorkflowStep, 101)}, "too many steps"}, + } + + te := newTestEngine(t) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := te.Engine.ValidateWorkflow(context.Background(), tt.def) + if tt.errMsg == "" { + require.NoError(t, err) + } else { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMsg) + } + }) + } +} + +func TestWorkflowEngine_ExecuteWorkflow_Timeout(t *testing.T) { + t.Parallel() + te := newTestEngine(t) + + def := &WorkflowDefinition{ + Name: "timeout-test", + Timeout: 50 * time.Millisecond, + Steps: []WorkflowStep{ + toolStep("s1", "test.tool", nil), + toolStep("s2", "test.tool", nil), + }, + } + + target := &vmcp.BackendTarget{WorkloadID: "test", BaseURL: "http://test:8080"} + te.Router.EXPECT().RouteTool(gomock.Any(), "test.tool").Return(target, nil) + te.Backend.EXPECT().CallTool(gomock.Any(), target, "test.tool", gomock.Any()). + DoAndReturn(func(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any) (map[string]any, error) { + time.Sleep(60 * time.Millisecond) // Exceed workflow timeout + return map[string]any{"ok": true}, nil + }) + + result, err := execute(t, te.Engine, def, nil) + + require.Error(t, err) + assert.ErrorIs(t, err, ErrWorkflowTimeout) + assert.Equal(t, WorkflowStatusTimedOut, result.Status) +} diff --git a/pkg/vmcp/composer/workflow_errors.go b/pkg/vmcp/composer/workflow_errors.go new file mode 100644 index 000000000..35e805c63 --- /dev/null +++ b/pkg/vmcp/composer/workflow_errors.go @@ -0,0 +1,109 @@ +// Package composer provides composite tool workflow execution for Virtual MCP Server. +package composer + +import ( + "errors" + "fmt" +) + +// Common workflow execution errors. +var ( + // ErrWorkflowNotFound indicates the workflow doesn't exist. + ErrWorkflowNotFound = errors.New("workflow not found") + + // ErrWorkflowTimeout indicates the workflow exceeded its timeout. + ErrWorkflowTimeout = errors.New("workflow timed out") + + // ErrWorkflowCancelled indicates the workflow was cancelled. + ErrWorkflowCancelled = errors.New("workflow cancelled") + + // ErrInvalidWorkflowDefinition indicates the workflow definition is invalid. + ErrInvalidWorkflowDefinition = errors.New("invalid workflow definition") + + // ErrStepFailed indicates a workflow step failed. + ErrStepFailed = errors.New("step failed") + + // ErrTemplateExpansion indicates template expansion failed. + ErrTemplateExpansion = errors.New("template expansion failed") + + // ErrCircularDependency indicates a circular dependency in step dependencies. + ErrCircularDependency = errors.New("circular dependency detected") + + // ErrDependencyNotMet indicates a step dependency hasn't completed. + ErrDependencyNotMet = errors.New("dependency not met") + + // ErrToolCallFailed indicates a tool call failed. + ErrToolCallFailed = errors.New("tool call failed") +) + +// WorkflowError wraps workflow execution errors with context. +type WorkflowError struct { + // WorkflowID is the workflow execution ID. + WorkflowID string + + // StepID is the step that caused the error (if applicable). + StepID string + + // Message is the error message. + Message string + + // Cause is the underlying error. + Cause error +} + +// Error implements the error interface. +func (e *WorkflowError) Error() string { + if e.StepID != "" { + return fmt.Sprintf("workflow %s, step %s: %s: %v", e.WorkflowID, e.StepID, e.Message, e.Cause) + } + return fmt.Sprintf("workflow %s: %s: %v", e.WorkflowID, e.Message, e.Cause) +} + +// Unwrap returns the underlying error for errors.Is and errors.As. +func (e *WorkflowError) Unwrap() error { + return e.Cause +} + +// NewWorkflowError creates a new workflow error. +func NewWorkflowError(workflowID, stepID, message string, cause error) *WorkflowError { + return &WorkflowError{ + WorkflowID: workflowID, + StepID: stepID, + Message: message, + Cause: cause, + } +} + +// ValidationError wraps workflow validation errors. +type ValidationError struct { + // Field is the field that failed validation. + Field string + + // Message is the error message. + Message string + + // Cause is the underlying error. + Cause error +} + +// Error implements the error interface. +func (e *ValidationError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("validation error for %s: %s: %v", e.Field, e.Message, e.Cause) + } + return fmt.Sprintf("validation error for %s: %s", e.Field, e.Message) +} + +// Unwrap returns the underlying error. +func (e *ValidationError) Unwrap() error { + return e.Cause +} + +// NewValidationError creates a new validation error. +func NewValidationError(field, message string, cause error) *ValidationError { + return &ValidationError{ + Field: field, + Message: message, + Cause: cause, + } +}