diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 9b7a19c..22b0a48 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -48,6 +48,9 @@ //forward the following ports "forwardPorts": [8084], + //network + "network": "host", + //mount docker directly on the host "mounts": ["source=/var/run/docker.sock,target=/var/run/docker.sock,type=bind"], diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index a857e6e..4d1c4d6 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -53,4 +53,21 @@ jobs: - name: Run cmd/main.go tests working-directory: . run: | - go test -v ./... + make test + + go-e2e-tests: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: "1.24" + cache: true + + - name: Run cmd/main.go tests + working-directory: . + run: | + make e2e diff --git a/.gitignore b/.gitignore index d146fbf..2496f99 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,8 @@ bin/ .env.local .env.development.local .env.test.local -.env.production.local \ No newline at end of file +.env.production.local +/logs/ +/kagent-tools +/*.out +*.html diff --git a/Makefile b/Makefile index 90924db..1d9fdd1 100644 --- a/Makefile +++ b/Makefile @@ -13,6 +13,10 @@ LDFLAGS := -X github.com/kagent-dev/tools/internal/version.Version=$(VERSION) -X ## Location to install dependencies to LOCALBIN ?= $(shell pwd)/bin +.PHONY: clean +clean: + rm -rf ./bin/kagent-tools-* + .PHONY: fmt fmt: ## Run go fmt against code. go fmt ./... @@ -23,11 +27,11 @@ vet: ## Run go vet against code. .PHONY: lint lint: golangci-lint ## Run golangci-lint linter - $(GOLANGCI_LINT) run + $(GOLANGCI_LINT) run --build-tags=test .PHONY: lint-fix lint-fix: golangci-lint ## Run golangci-lint linter and perform fixes - $(GOLANGCI_LINT) run --fix + $(GOLANGCI_LINT) run --build-tags=test --fix .PHONY: lint-config lint-config: golangci-lint ## Verify golangci-lint linter configuration @@ -43,8 +47,16 @@ tidy: ## Run go mod tidy to ensure dependencies are up to date. go mod tidy .PHONY: test -test: - go test -v -cover ./... +test: build lint ## Run all tests with build, lint, and coverage + go test -tags=test -v -cover ./pkg/... ./internal/... + +.PHONY: test-only +test-only: ## Run tests only (without build/lint for faster iteration) + go test -tags=test -v -cover ./pkg/... ./internal/... + +.PHONY: e2e +e2e: test docker-build + go test -tags=test -v -cover ./e2e/... bin/kagent-tools-linux-amd64: CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags "$(LDFLAGS)" -o bin/kagent-tools-linux-amd64 ./cmd @@ -143,6 +155,12 @@ docker-build-all: DOCKER_BUILD_ARGS = --progress=plain --builder $(BUILDX_BUILDE docker-build-all: $(DOCKER_BUILDER) build $(DOCKER_BUILD_ARGS) $(TOOLS_IMAGE_BUILD_ARGS) -f Dockerfile ./ +.PHONY: kind-update-kagent +kind-update-kagent: docker-build + kind get clusters | grep -q $(KIND_CLUSTER_NAME) || kind create cluster --name $(KIND_CLUSTER_NAME) + kind load docker-image --name $(KIND_CLUSTER_NAME) $(TOOLS_IMG) + kubectl patch --namespace kagent deployment/kagent --type='json' -p='[{"op": "replace", "path": "/spec/template/spec/containers/3/image", "value": "$(TOOLS_IMG)"}]' + ## Tool Binaries ## Location to install dependencies t diff --git a/cmd/main.go b/cmd/main.go index 5c44309..6ea40ae 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -7,24 +7,29 @@ import ( "net/http" "os" "os/signal" + "runtime" "strings" "sync" "syscall" "time" "github.com/joho/godotenv" + "github.com/kagent-dev/tools/internal/logger" + "github.com/kagent-dev/tools/internal/telemetry" "github.com/kagent-dev/tools/internal/version" - "github.com/kagent-dev/tools/pkg/logger" - "github.com/kagent-dev/tools/pkg/utils" - "github.com/kagent-dev/tools/pkg/argo" "github.com/kagent-dev/tools/pkg/cilium" "github.com/kagent-dev/tools/pkg/helm" "github.com/kagent-dev/tools/pkg/istio" "github.com/kagent-dev/tools/pkg/k8s" "github.com/kagent-dev/tools/pkg/prometheus" - "github.com/mark3labs/mcp-go/server" + "github.com/kagent-dev/tools/pkg/utils" "github.com/spf13/cobra" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + + "github.com/mark3labs/mcp-go/server" ) var ( @@ -69,12 +74,36 @@ func run(cmd *cobra.Command, args []string) { logger.Init() defer logger.Sync() - logger.Get().Info("Starting "+Name, "version", Version, "git_commit", GitCommit, "build_date", BuildDate) - // Setup context with cancellation for graceful shutdown ctx, cancel := context.WithCancel(context.Background()) defer cancel() + // Initialize OpenTelemetry tracing + cfg := telemetry.LoadOtelCfg() + + err := telemetry.SetupOTelSDK(ctx) + if err != nil { + logger.Get().Error("Failed to setup OpenTelemetry SDK", "error", err) + os.Exit(1) + } + + // Start root span for server lifecycle + tracer := otel.Tracer("kagent-tools/server") + ctx, rootSpan := tracer.Start(ctx, "server.lifecycle") + defer rootSpan.End() + + rootSpan.SetAttributes( + attribute.String("server.name", Name), + attribute.String("server.version", cfg.Telemetry.ServiceVersion), + attribute.String("server.git_commit", GitCommit), + attribute.String("server.build_date", BuildDate), + attribute.Bool("server.stdio_mode", stdio), + attribute.Int("server.port", port), + attribute.StringSlice("server.tools", tools), + ) + + logger.Get().Info("Starting "+Name, "version", Version, "git_commit", GitCommit, "build_date", BuildDate) + mcp := server.NewMCPServer( Name, Version, @@ -91,7 +120,7 @@ func run(cmd *cobra.Command, args []string) { signal.Notify(signalChan, os.Interrupt, syscall.SIGTERM) // HTTP server reference (only used when not in stdio mode) - var sseServer *server.StreamableHTTPServer + var httpServer *http.Server // Start server based on chosen mode wg.Add(1) @@ -101,16 +130,49 @@ func run(cmd *cobra.Command, args []string) { runStdioServer(ctx, mcp) }() } else { - sseServer = server.NewStreamableHTTPServer(mcp) + sseServer := server.NewStreamableHTTPServer(mcp) + + // Create a mux to handle different routes + mux := http.NewServeMux() + + // Add health endpoint + mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + if err := writeResponse(w, []byte("OK")); err != nil { + logger.Get().Error("Failed to write health response", "error", err) + } + }) + + // Add metrics endpoint (basic implementation for e2e tests) + mux.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + + // Generate real runtime metrics instead of hardcoded values + metrics := generateRuntimeMetrics() + if err := writeResponse(w, []byte(metrics)); err != nil { + logger.Get().Error("Failed to write metrics response", "error", err) + } + }) + + // Handle all other routes with the MCP server wrapped in telemetry middleware + mux.Handle("/", telemetry.HTTPMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sseServer.ServeHTTP(w, r) + }))) + + httpServer = &http.Server{ + Addr: fmt.Sprintf(":%d", port), + Handler: mux, + } + go func() { defer wg.Done() - addr := fmt.Sprintf(":%d", port) - logger.Get().Info("Running KAgent Tools Server", "port", addr, "tools", strings.Join(tools, ",")) - if err := sseServer.Start(addr); err != nil { + logger.Get().Info("Running KAgent Tools Server", "port", fmt.Sprintf(":%d", port), "tools", strings.Join(tools, ",")) + if err := httpServer.ListenAndServe(); err != nil { if !errors.Is(err, http.ErrServerClosed) { - logger.Get().Error(err, "Failed to start SSE server") + logger.Get().Error("Failed to start HTTP server", "error", err) } else { - logger.Get().Info("SSE server closed gracefully.") + logger.Get().Info("HTTP server closed gracefully.") } } }() @@ -121,16 +183,23 @@ func run(cmd *cobra.Command, args []string) { <-signalChan logger.Get().Info("Received termination signal, shutting down server...") + // Mark root span as shutting down + rootSpan.AddEvent("server.shutdown.initiated") + // Cancel context to notify any context-aware operations cancel() // Gracefully shutdown HTTP server if running - if !stdio && sseServer != nil { + if !stdio && httpServer != nil { shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) defer shutdownCancel() - if err := sseServer.Shutdown(shutdownCtx); err != nil { - logger.Get().Error(err, "Failed to shutdown server gracefully") + if err := httpServer.Shutdown(shutdownCtx); err != nil { + logger.Get().Error("Failed to shutdown server gracefully", "error", err) + rootSpan.RecordError(err) + rootSpan.SetStatus(codes.Error, "Server shutdown failed") + } else { + rootSpan.AddEvent("server.shutdown.completed") } } }() @@ -140,6 +209,53 @@ func run(cmd *cobra.Command, args []string) { logger.Get().Info("Server shutdown complete") } +// writeResponse writes data to an HTTP response writer with proper error handling +func writeResponse(w http.ResponseWriter, data []byte) error { + _, err := w.Write(data) + return err +} + +// generateRuntimeMetrics generates real runtime metrics for the /metrics endpoint +func generateRuntimeMetrics() string { + var m runtime.MemStats + runtime.ReadMemStats(&m) + + now := time.Now().Unix() + + // Build metrics in Prometheus format + metrics := strings.Builder{} + + // Go runtime info + metrics.WriteString("# HELP go_info Information about the Go environment.\n") + metrics.WriteString("# TYPE go_info gauge\n") + metrics.WriteString(fmt.Sprintf("go_info{version=\"%s\"} 1\n", runtime.Version())) + + // Process start time + metrics.WriteString("# HELP process_start_time_seconds Start time of the process since unix epoch in seconds.\n") + metrics.WriteString("# TYPE process_start_time_seconds gauge\n") + metrics.WriteString(fmt.Sprintf("process_start_time_seconds %d\n", now)) + + // Memory metrics + metrics.WriteString("# HELP go_memstats_alloc_bytes Number of bytes allocated and still in use.\n") + metrics.WriteString("# TYPE go_memstats_alloc_bytes gauge\n") + metrics.WriteString(fmt.Sprintf("go_memstats_alloc_bytes %d\n", m.Alloc)) + + metrics.WriteString("# HELP go_memstats_total_alloc_bytes Total number of bytes allocated, even if freed.\n") + metrics.WriteString("# TYPE go_memstats_total_alloc_bytes counter\n") + metrics.WriteString(fmt.Sprintf("go_memstats_total_alloc_bytes %d\n", m.TotalAlloc)) + + metrics.WriteString("# HELP go_memstats_sys_bytes Number of bytes obtained from system.\n") + metrics.WriteString("# TYPE go_memstats_sys_bytes gauge\n") + metrics.WriteString(fmt.Sprintf("go_memstats_sys_bytes %d\n", m.Sys)) + + // Goroutine count + metrics.WriteString("# HELP go_goroutines Number of goroutines that currently exist.\n") + metrics.WriteString("# TYPE go_goroutines gauge\n") + metrics.WriteString(fmt.Sprintf("go_goroutines %d\n", runtime.NumGoroutine())) + + return metrics.String() +} + func runStdioServer(ctx context.Context, mcp *server.MCPServer) { logger.Get().Info("Running KAgent Tools Server STDIO:", "tools", strings.Join(tools, ",")) stdioServer := server.NewStdioServer(mcp) @@ -149,39 +265,28 @@ func runStdioServer(ctx context.Context, mcp *server.MCPServer) { } func registerMCP(mcp *server.MCPServer, enabledToolProviders []string, kubeconfig string) { - - var toolProviderMap = map[string]func(*server.MCPServer, string){ - "utils": utils.RegisterDateTimeTools, - "k8s": k8s.RegisterK8sTools, - "prometheus": prometheus.RegisterPrometheusTools, - "helm": helm.RegisterHelmTools, - "istio": istio.RegisterIstioTools, - "argo": argo.RegisterArgoTools, - "cilium": cilium.RegisterCiliumTools, + // A map to hold tool providers and their registration functions + toolProviderMap := map[string]func(*server.MCPServer){ + "argo": argo.RegisterTools, + "cilium": cilium.RegisterTools, + "helm": helm.RegisterTools, + "istio": istio.RegisterTools, + "k8s": func(s *server.MCPServer) { k8s.RegisterTools(s, nil, kubeconfig) }, + "prometheus": prometheus.RegisterTools, + "utils": utils.RegisterTools, } - if len(kubeconfig) > 0 { - logger.Get().Info("Using kubeconfig file", "path", kubeconfig) - } - - // If no tools specified, register all tools + // If no specific tools are specified, register all available tools. if len(enabledToolProviders) == 0 { - logger.Get().Info("No specific tools provided, registering all tools") - for toolProvider, registerFunc := range toolProviderMap { - logger.Get().Info("Registering tools", "provider", toolProvider) - registerFunc(mcp, kubeconfig) + for name := range toolProviderMap { + enabledToolProviders = append(enabledToolProviders, name) } - return } - - // Register only the specified tools - logger.Get().Info("provider list", "tools", enabledToolProviders) for _, toolProviderName := range enabledToolProviders { - if registerFunc, ok := toolProviderMap[strings.ToLower(toolProviderName)]; ok { - logger.Get().Info("Registering tool", "provider", toolProviderName) - registerFunc(mcp, kubeconfig) + if registerFunc, ok := toolProviderMap[toolProviderName]; ok { + registerFunc(mcp) } else { - logger.Get().Error(nil, "Unknown tool specified", "provider", toolProviderName) + logger.Get().Error("Unknown tool specified", "provider", toolProviderName) } } } diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go new file mode 100644 index 0000000..ec04751 --- /dev/null +++ b/e2e/e2e_test.go @@ -0,0 +1,1005 @@ +package e2e + +import ( + "bufio" + "context" + "fmt" + "io" + "net/http" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// getBinaryName returns the platform-specific binary name +func getBinaryName() string { + osName := runtime.GOOS + archName := runtime.GOARCH + return fmt.Sprintf("kagent-tools-%s-%s", osName, archName) +} + +// TestServerConfig holds configuration for server tests +type TestServerConfig struct { + Port int + Tools []string + Kubeconfig string + Stdio bool + Timeout time.Duration +} + +// ServerTestResult holds the result of a server test +type ServerTestResult struct { + Output string + Error error + Duration time.Duration +} + +// TestServer represents a test server instance +type TestServer struct { + cmd *exec.Cmd + port int + stdio bool + cancel context.CancelFunc + done chan struct{} + output strings.Builder + mu sync.RWMutex +} + +// NewTestServer creates a new test server instance +func NewTestServer(config TestServerConfig) *TestServer { + return &TestServer{ + port: config.Port, + stdio: config.Stdio, + done: make(chan struct{}), + } +} + +// Start starts the test server +func (ts *TestServer) Start(ctx context.Context, config TestServerConfig) error { + ts.mu.Lock() + defer ts.mu.Unlock() + + // Build command arguments + args := []string{} + if config.Stdio { + args = append(args, "--stdio") + } else { + args = append(args, "--port", fmt.Sprintf("%d", config.Port)) + } + + if len(config.Tools) > 0 { + args = append(args, "--tools", strings.Join(config.Tools, ",")) + } + + if config.Kubeconfig != "" { + args = append(args, "--kubeconfig", config.Kubeconfig) + } + + // Create context with cancellation + ctx, cancel := context.WithCancel(ctx) + ts.cancel = cancel + + // Start server process + binaryName := getBinaryName() + ts.cmd = exec.CommandContext(ctx, fmt.Sprintf("../bin/%s", binaryName), args...) + ts.cmd.Env = append(os.Environ(), "LOG_LEVEL=debug") + + // Set up output capture + stdout, err := ts.cmd.StdoutPipe() + if err != nil { + return fmt.Errorf("failed to create stdout pipe: %w", err) + } + + stderr, err := ts.cmd.StderrPipe() + if err != nil { + return fmt.Errorf("failed to create stderr pipe: %w", err) + } + + // Start the command + if err := ts.cmd.Start(); err != nil { + return fmt.Errorf("failed to start server: %w", err) + } + + // Start goroutines to capture output + go ts.captureOutput(stdout, "STDOUT") + go ts.captureOutput(stderr, "STDERR") + + // Wait for server to start + if !config.Stdio { + return ts.waitForHTTPServer(ctx, config.Timeout) + } + + return nil +} + +// Stop stops the test server +func (ts *TestServer) Stop() error { + ts.mu.Lock() + defer ts.mu.Unlock() + + if ts.cancel != nil { + ts.cancel() + } + + if ts.cmd != nil && ts.cmd.Process != nil { + // Send interrupt signal for graceful shutdown + if err := ts.cmd.Process.Signal(os.Interrupt); err != nil { + // If interrupt fails, kill the process + _ = ts.cmd.Process.Kill() + } + + // Wait for process to exit with timeout + done := make(chan error, 1) + go func() { + done <- ts.cmd.Wait() + }() + + select { + case <-done: + // Process exited + case <-time.After(5 * time.Second): + // Timeout, force kill + _ = ts.cmd.Process.Kill() + } + } + + close(ts.done) + return nil +} + +// GetOutput returns the captured output +func (ts *TestServer) GetOutput() string { + ts.mu.RLock() + defer ts.mu.RUnlock() + return ts.output.String() +} + +// captureOutput captures output from the server +func (ts *TestServer) captureOutput(reader io.Reader, prefix string) { + scanner := bufio.NewScanner(reader) + for scanner.Scan() { + line := scanner.Text() + ts.mu.Lock() + ts.output.WriteString(fmt.Sprintf("[%s] %s\n", prefix, line)) + ts.mu.Unlock() + } +} + +// waitForHTTPServer waits for the HTTP server to become available +func (ts *TestServer) waitForHTTPServer(ctx context.Context, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + url := fmt.Sprintf("http://localhost:%d/health", ts.port) + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return fmt.Errorf("timeout waiting for server to start") + case <-ticker.C: + resp, err := http.Get(url) + if err == nil { + resp.Body.Close() + if resp.StatusCode == http.StatusOK { + return nil + } + } + } + } +} + +// TestHTTPServerStartup tests basic HTTP server startup and shutdown +func TestHTTPServerStartup(t *testing.T) { + ctx := context.Background() + + config := TestServerConfig{ + Port: 8085, + Stdio: false, + Timeout: 30 * time.Second, + } + + server := NewTestServer(config) + + // Start server + err := server.Start(ctx, config) + require.NoError(t, err, "Server should start successfully") + + // Wait a bit for server to be fully ready + time.Sleep(3 * time.Second) + + // Test health endpoint + resp, err := http.Get(fmt.Sprintf("http://localhost:%d/health", config.Port)) + require.NoError(t, err, "Health endpoint should be accessible") + assert.Equal(t, http.StatusOK, resp.StatusCode) + resp.Body.Close() + + // Check server output + output := server.GetOutput() + assert.Contains(t, output, "Running KAgent Tools Server") + assert.Contains(t, output, fmt.Sprintf(":%d", config.Port)) + + // Stop server + err = server.Stop() + require.NoError(t, err, "Server should stop gracefully") + + // Verify server is stopped + time.Sleep(1 * time.Second) + _, err = http.Get(fmt.Sprintf("http://localhost:%d/health", config.Port)) + assert.Error(t, err, "Server should not be accessible after stop") +} + +// TestHTTPServerWithSpecificTools tests server with specific tools enabled +func TestHTTPServerWithSpecificTools(t *testing.T) { + ctx := context.Background() + + config := TestServerConfig{ + Port: 8086, + Tools: []string{"utils", "k8s"}, + Stdio: false, + Timeout: 30 * time.Second, + } + + server := NewTestServer(config) + + // Start server + err := server.Start(ctx, config) + require.NoError(t, err, "Server should start successfully") + + // Wait for server to be ready + time.Sleep(3 * time.Second) + + // Check server output for tool registration + output := server.GetOutput() + assert.Contains(t, output, "RegisterTools initialized", "Should register specified tools") + assert.Contains(t, output, "utils", "Should register utils tools") + assert.Contains(t, output, "k8s", "Should register k8s tools") + + // Stop server + err = server.Stop() + require.NoError(t, err, "Server should stop gracefully") +} + +// TestHTTPServerWithAllTools tests server with all tools enabled (default) +func TestHTTPServerWithAllTools(t *testing.T) { + ctx := context.Background() + + config := TestServerConfig{ + Port: 8087, + Stdio: false, + Timeout: 30 * time.Second, + } + + server := NewTestServer(config) + + // Start server + err := server.Start(ctx, config) + require.NoError(t, err, "Server should start successfully") + + // Wait for server to be ready + time.Sleep(3 * time.Second) + + // Check server output for all tools registration + output := server.GetOutput() + assert.Contains(t, output, "RegisterTools initialized", "Should initialize RegisterTools") + + // Verify server is running (tools are implicitly registered when no specific tools are provided) + assert.Contains(t, output, "Running KAgent Tools Server", "Should be running with all tools") + + // Stop server + err = server.Stop() + require.NoError(t, err, "Server should stop gracefully") +} + +// TestHTTPServerWithKubeconfig tests server with kubeconfig parameter +func TestHTTPServerWithKubeconfig(t *testing.T) { + ctx := context.Background() + + // Create a temporary kubeconfig file + tempDir := t.TempDir() + kubeconfigPath := filepath.Join(tempDir, "kubeconfig") + + kubeconfigContent := `apiVersion: v1 +kind: Config +clusters: +- cluster: + server: https://test-cluster + name: test-cluster +contexts: +- context: + cluster: test-cluster + user: test-user + name: test-context +current-context: test-context +users: +- name: test-user + user: + token: test-token +` + + err := os.WriteFile(kubeconfigPath, []byte(kubeconfigContent), 0644) + require.NoError(t, err, "Should create temporary kubeconfig file") + + config := TestServerConfig{ + Port: 8088, + Kubeconfig: kubeconfigPath, + Stdio: false, + Timeout: 30 * time.Second, + } + + server := NewTestServer(config) + + // Start server + err = server.Start(ctx, config) + require.NoError(t, err, "Server should start successfully") + + // Wait for server to be ready + time.Sleep(3 * time.Second) + + // Check server output for kubeconfig setting + output := server.GetOutput() + assert.Contains(t, output, "RegisterTools initialized", "Should initialize RegisterTools") + assert.Contains(t, output, "Running KAgent Tools Server", "Should be running with kubeconfig") + + // Stop server + err = server.Stop() + require.NoError(t, err, "Server should stop gracefully") +} + +// TestStdioServer tests STDIO server mode +func TestStdioServer(t *testing.T) { + ctx := context.Background() + + config := TestServerConfig{ + Stdio: true, + Timeout: 30 * time.Second, + } + + server := NewTestServer(config) + + // Start server + err := server.Start(ctx, config) + require.NoError(t, err, "Server should start successfully") + + // Wait for server to be ready + time.Sleep(3 * time.Second) + + // Check server output for STDIO mode + output := server.GetOutput() + assert.Contains(t, output, "Running KAgent Tools Server STDIO") + + // Stop server + err = server.Stop() + require.NoError(t, err, "Server should stop gracefully") +} + +// TestServerGracefulShutdown tests graceful shutdown behavior +func TestServerGracefulShutdown(t *testing.T) { + ctx := context.Background() + + config := TestServerConfig{ + Port: 8100, + Stdio: false, + Timeout: 30 * time.Second, + } + + server := NewTestServer(config) + + // Start server + err := server.Start(ctx, config) + require.NoError(t, err, "Server should start successfully") + + // Wait for server to be ready + time.Sleep(3 * time.Second) + + // Stop server and measure shutdown time + start := time.Now() + err = server.Stop() + duration := time.Since(start) + + require.NoError(t, err, "Server should stop gracefully") + assert.Less(t, duration, 10*time.Second, "Shutdown should complete within reasonable time") + + // Wait a bit for shutdown logs to be captured + time.Sleep(3 * time.Second) + + // Check server output for graceful shutdown + output := server.GetOutput() + // The main test is that the server started successfully and stopped without error + assert.Contains(t, output, "Running KAgent Tools Server", "Server should have started successfully") + + // Try to verify the server is actually stopped by attempting to connect + _, err = http.Get(fmt.Sprintf("http://localhost:%d/health", config.Port)) + assert.Error(t, err, "Server should not be accessible after stop") +} + +// TestServerWithInvalidTool tests server behavior with invalid tool names +func TestServerWithInvalidTool(t *testing.T) { + ctx := context.Background() + + config := TestServerConfig{ + Port: 8090, + Tools: []string{"invalid-tool", "utils"}, + Stdio: false, + Timeout: 30 * time.Second, + } + + server := NewTestServer(config) + + // Start server + err := server.Start(ctx, config) + require.NoError(t, err, "Server should start even with invalid tools") + + // Wait for server to be ready + time.Sleep(3 * time.Second) + + // Check server output for error about invalid tool + output := server.GetOutput() + assert.Contains(t, output, "Unknown tool specified") + assert.Contains(t, output, "invalid-tool") + + // Valid tools should still be registered + assert.Contains(t, output, "RegisterTools initialized") + assert.Contains(t, output, "utils") + + // Stop server + err = server.Stop() + require.NoError(t, err, "Server should stop gracefully") +} + +// TestServerVersionAndBuildInfo tests server version and build information +func TestServerVersionAndBuildInfo(t *testing.T) { + ctx := context.Background() + + config := TestServerConfig{ + Port: 8091, + Stdio: false, + Timeout: 30 * time.Second, + } + + server := NewTestServer(config) + + // Start server + err := server.Start(ctx, config) + require.NoError(t, err, "Server should start successfully") + + // Wait for server to be ready + time.Sleep(3 * time.Second) + + // Check server output for version information + output := server.GetOutput() + assert.Contains(t, output, "Starting kagent-tools-server") + assert.Contains(t, output, "version") + + // Stop server + err = server.Stop() + require.NoError(t, err, "Server should stop gracefully") +} + +// TestConcurrentServerInstances tests running multiple server instances +func TestConcurrentServerInstances(t *testing.T) { + ctx := context.Background() + + var wg sync.WaitGroup + numServers := 3 + servers := make([]*TestServer, numServers) + + // Start multiple servers on different ports + for i := 0; i < numServers; i++ { + wg.Add(1) + go func(index int) { + defer wg.Done() + + config := TestServerConfig{ + Port: 8092 + index, + Tools: []string{"utils"}, + Stdio: false, + Timeout: 30 * time.Second, + } + + server := NewTestServer(config) + servers[index] = server + + err := server.Start(ctx, config) + assert.NoError(t, err, fmt.Sprintf("Server %d should start successfully", index)) + + // Wait for server to be ready + time.Sleep(3 * time.Second) + + // Test health endpoint + resp, err := http.Get(fmt.Sprintf("http://localhost:%d/health", config.Port)) + assert.NoError(t, err, fmt.Sprintf("Health endpoint should be accessible for server %d", index)) + if resp != nil { + resp.Body.Close() + } + }(i) + } + + wg.Wait() + + // Stop all servers + for i, server := range servers { + if server != nil { + err := server.Stop() + assert.NoError(t, err, fmt.Sprintf("Server %d should stop gracefully", i)) + } + } +} + +// TestServerEnvironmentVariables tests server with environment variables +func TestServerEnvironmentVariables(t *testing.T) { + ctx := context.Background() + + // Set environment variables + originalEnv := os.Environ() + defer func() { + os.Clearenv() + for _, env := range originalEnv { + parts := strings.SplitN(env, "=", 2) + if len(parts) == 2 { + os.Setenv(parts[0], parts[1]) + } + } + }() + + os.Setenv("LOG_LEVEL", "info") + os.Setenv("OTEL_SERVICE_NAME", "test-kagent-tools") + + config := TestServerConfig{ + Port: 8095, + Stdio: false, + Timeout: 30 * time.Second, + } + + server := NewTestServer(config) + + // Start server + err := server.Start(ctx, config) + require.NoError(t, err, "Server should start successfully") + + // Wait for server to be ready + time.Sleep(3 * time.Second) + + // Check server output + output := server.GetOutput() + assert.Contains(t, output, "Starting kagent-tools-server") + + // Stop server + err = server.Stop() + require.NoError(t, err, "Server should stop gracefully") +} + +// TestServerBuildAndExecution tests that the server binary exists and is executable +func TestServerBuildAndExecution(t *testing.T) { + // Check if server binary exists + binaryName := getBinaryName() + binaryPath := fmt.Sprintf("../bin/%s", binaryName) + _, err := os.Stat(binaryPath) + if os.IsNotExist(err) { + t.Skip("Server binary not found, skipping test. Run 'make build' first.") + } + require.NoError(t, err, "Server binary should exist") + + // Test --help flag + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, binaryPath, "--help") + output, err := cmd.CombinedOutput() + require.NoError(t, err, "Server should respond to --help flag") + + outputStr := string(output) + assert.Contains(t, outputStr, "KAgent tool server") + assert.Contains(t, outputStr, "--port") + assert.Contains(t, outputStr, "--stdio") + assert.Contains(t, outputStr, "--tools") + assert.Contains(t, outputStr, "--kubeconfig") +} + +// Benchmark tests +func BenchmarkServerStartup(b *testing.B) { + ctx := context.Background() + + for i := 0; i < b.N; i++ { + config := TestServerConfig{ + Port: 8096 + i, + Stdio: false, + Timeout: 30 * time.Second, + } + + server := NewTestServer(config) + + start := time.Now() + err := server.Start(ctx, config) + if err != nil { + b.Fatalf("Server startup failed: %v", err) + } + + // Wait for server to be ready + time.Sleep(1 * time.Second) + + duration := time.Since(start) + b.ReportMetric(float64(duration.Nanoseconds()), "startup_time_ns") + + // Stop server + _ = server.Stop() + } +} + +// Helper functions for test setup +func init() { + // Ensure the binary exists before running tests + binaryName := getBinaryName() + binaryPath := fmt.Sprintf("../bin/%s", binaryName) + if _, err := os.Stat(binaryPath); os.IsNotExist(err) { + // Try to build the binary + cmd := exec.Command("make", "build") + cmd.Dir = ".." + if err := cmd.Run(); err != nil { + panic(fmt.Sprintf("Failed to build server binary: %v", err)) + } + } +} + +// TestToolRegistrationValidation tests that tool registration works correctly +func TestToolRegistrationValidation(t *testing.T) { + ctx := context.Background() + + testCases := []struct { + name string + config TestServerConfig + expectedTools []string + shouldFail bool + }{ + { + name: "Register single tool", + config: TestServerConfig{ + Port: 8087, + Tools: []string{"k8s"}, + Timeout: 30 * time.Second, + }, + expectedTools: []string{"k8s"}, + shouldFail: false, + }, + { + name: "Register multiple tools", + config: TestServerConfig{ + Port: 8088, + Tools: []string{"k8s", "prometheus", "utils"}, + Timeout: 30 * time.Second, + }, + expectedTools: []string{"k8s", "prometheus", "utils"}, + shouldFail: false, + }, + { + name: "Register invalid tool", + config: TestServerConfig{ + Port: 8089, + Tools: []string{"invalid-tool"}, + Timeout: 30 * time.Second, + }, + shouldFail: false, + }, + { + name: "Register all tools implicitly", + config: TestServerConfig{ + Port: 8090, + Tools: []string{}, + Timeout: 30 * time.Second, + }, + expectedTools: []string{"utils", "k8s", "prometheus", "helm", "istio", "argo", "cilium"}, + shouldFail: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + server := NewTestServer(tc.config) + err := server.Start(ctx, tc.config) + + if tc.shouldFail { + require.Error(t, err, "Server should fail to start with invalid configuration") + return + } + + require.NoError(t, err, "Server should start successfully") + defer func() { + if err := server.Stop(); err != nil { + t.Errorf("Failed to stop server: %v", err) + } + }() + + // Wait for server to be ready + time.Sleep(3 * time.Second) + + // Verify registered tools + output := server.GetOutput() + + // Special handling for invalid tool test case + if tc.name == "Register invalid tool" { + assert.Contains(t, output, "Unknown tool specified", "Should warn about invalid tool") + assert.Contains(t, output, "invalid-tool", "Should mention the invalid tool name") + } else { + if tc.name == "Register all tools implicitly" { + // For implicit all tools registration, check for RegisterTools initialized + assert.Contains(t, output, "RegisterTools initialized", "Should initialize RegisterTools") + // Don't check for individual tool names as they're not logged individually + assert.Contains(t, output, "Running KAgent Tools Server", "Should be running with all tools") + } else { + // For specific tools, check for Running server message and tool names + assert.Contains(t, output, "Running KAgent Tools Server", "Should be running server") + for _, tool := range tc.expectedTools { + assert.Contains(t, output, tool, fmt.Sprintf("Should register %s tool", tool)) + } + } + } + + // Test health endpoint + resp, err := http.Get(fmt.Sprintf("http://localhost:%d/health", tc.config.Port)) + require.NoError(t, err, "Health endpoint should be accessible") + assert.Equal(t, http.StatusOK, resp.StatusCode) + resp.Body.Close() + }) + } +} + +// TestToolExecutionFlow tests the complete flow of tool execution +func TestToolExecutionFlow(t *testing.T) { + ctx := context.Background() + + config := TestServerConfig{ + Port: 8091, + Tools: []string{"utils"}, + Timeout: 30 * time.Second, + } + + server := NewTestServer(config) + err := server.Start(ctx, config) + require.NoError(t, err, "Server should start successfully") + defer func() { + if err := server.Stop(); err != nil { + t.Errorf("Failed to stop server: %v", err) + } + }() + + // Wait for server to be ready + time.Sleep(3 * time.Second) + + // Test health endpoint (MCP server doesn't have REST endpoints for tool execution) + resp, err := http.Get(fmt.Sprintf("http://localhost:%d/health", config.Port)) + require.NoError(t, err, "Should execute request successfully") + defer resp.Body.Close() + + // Check response + assert.Equal(t, http.StatusOK, resp.StatusCode, "Should return OK status") + + // Read response body + body, err := io.ReadAll(resp.Body) + require.NoError(t, err, "Should read response body") + + // Response should contain "OK" + assert.Equal(t, "OK", string(body), "Should return OK response") +} + +// TestServerTelemetry tests that telemetry is properly initialized and working +func TestServerTelemetry(t *testing.T) { + ctx := context.Background() + + config := TestServerConfig{ + Port: 8092, + Tools: []string{"utils"}, + Timeout: 30 * time.Second, + } + + // Set test environment variables for telemetry + os.Setenv("OTEL_SERVICE_NAME", "kagent-tools-test") + os.Setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "localhost:4317") + defer os.Unsetenv("OTEL_SERVICE_NAME") + defer os.Unsetenv("OTEL_EXPORTER_OTLP_ENDPOINT") + + server := NewTestServer(config) + err := server.Start(ctx, config) + require.NoError(t, err, "Server should start successfully") + defer func() { + if err := server.Stop(); err != nil { + t.Errorf("Failed to stop server: %v", err) + } + }() + + // Wait for server to be ready + time.Sleep(3 * time.Second) + + // Check server output for telemetry initialization + output := server.GetOutput() + assert.Contains(t, output, "Starting kagent-tools-server", "Server should start with telemetry") + + // Make a request to generate telemetry + resp, err := http.Get(fmt.Sprintf("http://localhost:%d/health", config.Port)) + require.NoError(t, err, "Health endpoint should be accessible") + assert.Equal(t, http.StatusOK, resp.StatusCode) + resp.Body.Close() + + // Check server output for successful startup (telemetry is initialized internally) + output = server.GetOutput() + assert.Contains(t, output, "Running KAgent Tools Server", "Server should be running with telemetry enabled") +} + +// TestToolRegistrationWithInvalidNames tests server behavior with invalid tool names +func TestToolRegistrationWithInvalidNames(t *testing.T) { + ctx := context.Background() + + config := TestServerConfig{ + Port: 8087, + Tools: []string{"invalid-tool", "not-exists", "k8s"}, + Stdio: false, + Timeout: 30 * time.Second, + } + + server := NewTestServer(config) + err := server.Start(ctx, config) + require.NoError(t, err, "Server should start successfully despite invalid tools") + + // Wait for server to be ready + time.Sleep(3 * time.Second) + + // Check server output for warning messages about invalid tools + output := server.GetOutput() + assert.Contains(t, output, "Unknown tool specified") + assert.Contains(t, output, "invalid-tool") + assert.Contains(t, output, "not-exists") + + // Verify that valid tools were still registered + assert.Contains(t, output, "Running KAgent Tools Server") + assert.Contains(t, output, "k8s") + + err = server.Stop() + require.NoError(t, err, "Server should stop gracefully") +} + +// TestConcurrentToolExecution tests concurrent tool execution +func TestConcurrentToolExecution(t *testing.T) { + ctx := context.Background() + + config := TestServerConfig{ + Port: 8088, + Tools: []string{"utils", "k8s"}, + Stdio: false, + Timeout: 30 * time.Second, + } + + server := NewTestServer(config) + err := server.Start(ctx, config) + require.NoError(t, err, "Server should start successfully") + + // Wait for server to be ready + time.Sleep(3 * time.Second) + + // Create multiple concurrent requests + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + resp, err := http.Get(fmt.Sprintf("http://localhost:%d/health", config.Port)) + require.NoError(t, err, "Concurrent request %d should succeed", id) + assert.Equal(t, http.StatusOK, resp.StatusCode) + resp.Body.Close() + }(i) + } + + wg.Wait() + err = server.Stop() + require.NoError(t, err, "Server should stop gracefully") +} + +// TestServerErrorHandling tests server's error handling capabilities +func TestServerErrorHandling(t *testing.T) { + ctx := context.Background() + + config := TestServerConfig{ + Port: 8089, + Tools: []string{"utils"}, + Stdio: false, + Timeout: 30 * time.Second, + } + + server := NewTestServer(config) + err := server.Start(ctx, config) + require.NoError(t, err, "Server should start successfully") + + // Wait for server to be ready + time.Sleep(3 * time.Second) + + // Test malformed request + req, err := http.NewRequest("POST", fmt.Sprintf("http://localhost:%d/nonexistent", config.Port), strings.NewReader("invalid json")) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + resp.Body.Close() + + err = server.Stop() + require.NoError(t, err, "Server should stop gracefully") +} + +// TestServerMetricsEndpoint tests the metrics endpoint functionality +func TestServerMetricsEndpoint(t *testing.T) { + ctx := context.Background() + + config := TestServerConfig{ + Port: 8090, + Tools: []string{"utils"}, + Stdio: false, + Timeout: 30 * time.Second, + } + + server := NewTestServer(config) + err := server.Start(ctx, config) + require.NoError(t, err, "Server should start successfully") + + // Wait for server to be ready + time.Sleep(3 * time.Second) + + // Test metrics endpoint + resp, err := http.Get(fmt.Sprintf("http://localhost:%d/metrics", config.Port)) + require.NoError(t, err, "Metrics endpoint should be accessible") + assert.Equal(t, http.StatusOK, resp.StatusCode) + + // Read and verify metrics content + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + resp.Body.Close() + + metricsContent := string(body) + assert.Contains(t, metricsContent, "go_") + assert.Contains(t, metricsContent, "process_") + + err = server.Stop() + require.NoError(t, err, "Server should stop gracefully") +} + +// TestToolSpecificFunctionality tests specific functionality of registered tools +func TestToolSpecificFunctionality(t *testing.T) { + ctx := context.Background() + + config := TestServerConfig{ + Port: 8091, + Tools: []string{"utils", "k8s"}, + Stdio: false, + Timeout: 30 * time.Second, + } + + server := NewTestServer(config) + err := server.Start(ctx, config) + require.NoError(t, err, "Server should start successfully") + + // Wait for server to be ready + time.Sleep(3 * time.Second) + + // Test utils tool endpoint + resp, err := http.Get(fmt.Sprintf("http://localhost:%d/health", config.Port)) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + resp.Body.Close() + + // Verify response format matches expected OK response + assert.Equal(t, "OK", string(body), "Should return OK response") + + err = server.Stop() + require.NoError(t, err, "Server should stop gracefully") +} diff --git a/go.mod b/go.mod index 08a9e04..220ea7f 100644 --- a/go.mod +++ b/go.mod @@ -1,67 +1,47 @@ module github.com/kagent-dev/tools -go 1.24.4 +go 1.24.5 require ( github.com/go-logr/logr v1.4.3 github.com/go-logr/stdr v1.2.2 github.com/joho/godotenv v1.5.1 - github.com/kagent-dev/kagent/go v0.0.0-20250707014726-aa7651a0e4e3 github.com/mark3labs/mcp-go v0.32.0 github.com/spf13/cobra v1.9.1 github.com/stretchr/testify v1.10.0 github.com/tmc/langchaingo v0.1.13 go.opentelemetry.io/otel v1.37.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.34.0 + go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.34.0 go.opentelemetry.io/otel/metric v1.37.0 + go.opentelemetry.io/otel/sdk v1.37.0 + go.opentelemetry.io/otel/trace v1.37.0 ) require ( + github.com/cenkalti/backoff/v4 v4.3.0 // indirect + github.com/cenkalti/backoff/v5 v5.0.2 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dlclark/regexp2 v1.10.0 // indirect - github.com/emicklei/go-restful/v3 v3.12.2 // indirect - github.com/fxamacker/cbor/v2 v2.8.0 // indirect - github.com/go-openapi/jsonpointer v0.21.1 // indirect - github.com/go-openapi/jsonreference v0.21.0 // indirect - github.com/go-openapi/swag v0.23.1 // indirect - github.com/gogo/protobuf v1.3.2 // indirect - github.com/google/gnostic-models v0.6.9 // indirect - github.com/google/go-cmp v0.7.0 // indirect github.com/google/uuid v1.6.0 // indirect + github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect - github.com/josharian/intern v1.0.0 // indirect - github.com/json-iterator/go v1.1.12 // indirect - github.com/mailru/easyjson v0.9.0 // indirect - github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect - github.com/modern-go/reflect2 v1.0.2 // indirect - github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect - github.com/pkg/errors v0.9.1 // indirect github.com/pkoukk/tiktoken-go v0.1.6 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/spf13/cast v1.9.2 // indirect github.com/spf13/pflag v1.0.6 // indirect - github.com/x448/float16 v0.8.4 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect - go.opentelemetry.io/otel/trace v1.37.0 // indirect - go.uber.org/automaxprocs v1.6.0 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.37.0 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.37.0 // indirect + go.opentelemetry.io/proto/otlp v1.7.0 // indirect golang.org/x/net v0.41.0 // indirect - golang.org/x/oauth2 v0.30.0 // indirect golang.org/x/sys v0.33.0 // indirect - golang.org/x/term v0.32.0 // indirect golang.org/x/text v0.26.0 // indirect - golang.org/x/time v0.12.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect + google.golang.org/grpc v1.73.0 // indirect google.golang.org/protobuf v1.36.6 // indirect - gopkg.in/evanphx/json-patch.v4 v4.12.0 // indirect - gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect - k8s.io/api v0.33.2 // indirect - k8s.io/apimachinery v0.33.2 // indirect - k8s.io/client-go v0.33.2 // indirect - k8s.io/klog/v2 v2.130.1 // indirect - k8s.io/kube-openapi v0.0.0-20250610211856-8b98d1ed966a // indirect - k8s.io/utils v0.0.0-20250604170112-4c0f3b243397 // indirect - sigs.k8s.io/json v0.0.0-20241014173422-cfa47c3a1cc8 // indirect - sigs.k8s.io/randfill v1.0.0 // indirect - sigs.k8s.io/structured-merge-diff/v4 v4.6.0 // indirect sigs.k8s.io/yaml v1.4.0 // indirect ) diff --git a/go.sum b/go.sum index 15455ac..3a1cf49 100644 --- a/go.sum +++ b/go.sum @@ -1,77 +1,40 @@ +github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= +github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= +github.com/cenkalti/backoff/v5 v5.0.2 h1:rIfFVxEf1QsI7E1ZHfp/B4DF/6QBAUhmgkxc0H7Zss8= +github.com/cenkalti/backoff/v5 v5.0.2/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= -github.com/emicklei/go-restful/v3 v3.12.2 h1:DhwDP0vY3k8ZzE0RunuJy8GhNpPL6zqLkDf9B/a0/xU= -github.com/emicklei/go-restful/v3 v3.12.2/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= -github.com/fxamacker/cbor/v2 v2.8.0 h1:fFtUGXUzXPHTIUdne5+zzMPTfffl3RD5qYnkY40vtxU= -github.com/fxamacker/cbor/v2 v2.8.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= -github.com/go-openapi/jsonpointer v0.21.1 h1:whnzv/pNXtK2FbX/W9yJfRmE2gsmkfahjMKB0fZvcic= -github.com/go-openapi/jsonpointer v0.21.1/go.mod h1:50I1STOfbY1ycR8jGz8DaMeLCdXiI6aDteEdRNNzpdk= -github.com/go-openapi/jsonreference v0.21.0 h1:Rs+Y7hSXT83Jacb7kFyjn4ijOuVGSvOdF2+tg1TRrwQ= -github.com/go-openapi/jsonreference v0.21.0/go.mod h1:LmZmgsrTkVg9LG4EaHeY8cBDslNPMo06cago5JNLkm4= -github.com/go-openapi/swag v0.23.1 h1:lpsStH0n2ittzTnbaSloVZLuB5+fvSY/+hnagBjSNZU= -github.com/go-openapi/swag v0.23.1/go.mod h1:STZs8TbRvEQQKUA+JZNAm3EWlgaOBGpyFDqQnDHMef0= -github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= -github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= -github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= -github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= -github.com/google/gnostic-models v0.6.9 h1:MU/8wDLif2qCXZmzncUQ/BOfxWfthHi63KqpoNbWqVw= -github.com/google/gnostic-models v0.6.9/go.mod h1:CiWsm0s6BSQd1hRn8/QmxqB6BesYcbSZxsz9b0KuDBw= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/pprof v0.0.0-20250607225305-033d6d78b36a h1://KbezygeMJZCSHH+HgUZiTeSoiuFspbMg1ge+eFj18= -github.com/google/pprof v0.0.0-20250607225305-033d6d78b36a/go.mod h1:5hDyRhoBCxViHszMt12TnOpEI4VVi+U8Gm9iphldiMA= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.1 h1:X5VWvz21y3gzm9Nw/kaUeku/1+uBhcekkmy4IkffJww= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.1/go.mod h1:Zanoh4+gvIgluNqcfMVTJueD4wSS5hT7zTt4Mrutd90= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= -github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= -github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= -github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= -github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= -github.com/kagent-dev/kagent/go v0.0.0-20250707014726-aa7651a0e4e3 h1:B5EkhSmYMG6bgn7DTsOfhal8sl1MmhjixSXP1PP/jNw= -github.com/kagent-dev/kagent/go v0.0.0-20250707014726-aa7651a0e4e3/go.mod h1:hwTH7K+UkePRxA6DhXOXavNyXRK3nPmvipA07DSRUxI= -github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= -github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= -github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/mark3labs/mcp-go v0.32.0 h1:fgwmbfL2gbd67obg57OfV2Dnrhs1HtSdlY/i5fn7MU8= github.com/mark3labs/mcp-go v0.32.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= -github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= -github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= -github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= -github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= -github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= -github.com/onsi/ginkgo/v2 v2.23.4 h1:ktYTpKJAVZnDT4VjxSbiBenUjmlL/5QkBEocaWXiQus= -github.com/onsi/ginkgo/v2 v2.23.4/go.mod h1:Bt66ApGPBFzHyR+JO10Zbt0Gsp4uWxu5mIOTusL46e8= -github.com/onsi/gomega v1.37.0 h1:CdEG8g0S133B4OswTDC/5XPSzE1OeP29QOioj2PID2Y= -github.com/onsi/gomega v1.37.0/go.mod h1:8D9+Txp43QWKhM24yyOBEdpkzN8FvJyAwecBgsU4KU0= -github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw= github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= @@ -83,98 +46,54 @@ github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo= github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0= github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o= github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= -github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tmc/langchaingo v0.1.13 h1:rcpMWBIi2y3B90XxfE4Ao8dhCQPVDMaNPnN5cGB1CaA= github.com/tmc/langchaingo v0.1.13/go.mod h1:vpQ5NOIhpzxDfTZK9B6tf2GM/MoaHewPWM5KXXGh7hg= -github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= -github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= -github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.37.0 h1:Ahq7pZmv87yiyn3jeFz/LekZmPLLdKejuO3NcK9MssM= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.37.0/go.mod h1:MJTqhM0im3mRLw1i8uGHnCvUEeS7VwRyxlLC78PA18M= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.37.0 h1:EtFWSnwW9hGObjkIdmlnWSydO+Qs8OwzfzXLUPg4xOc= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.37.0/go.mod h1:QjUEoiGCPkvFZ/MjK6ZZfNOS6mfVEVKYE99dFhuN2LI= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.34.0 h1:BEj3SPM81McUZHYjRS5pEgNgnmzGJ5tRpU5krWnV8Bs= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.34.0/go.mod h1:9cKLGBDzI/F3NoHLQGm4ZrYdIHsvGt6ej6hUowxY0J4= +go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.34.0 h1:jBpDk4HAUsrnVO1FsfCfCOTEc/MkInJmvfCHYLFiT80= +go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.34.0/go.mod h1:H9LUIM1daaeZaz91vZcfeM0fejXPmgCYE8ZhzqfJuiU= go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE= go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= +go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI= +go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg= +go.opentelemetry.io/otel/sdk/metric v1.35.0 h1:1RriWBmCKgkeHEhM7a2uMjMUfP7MsOF5JpUCaEqEI9o= +go.opentelemetry.io/otel/sdk/metric v1.35.0/go.mod h1:is6XYCUMpcKi+ZsOvfluY5YstFnhW0BidkR+gL+qN+w= go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= -go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs= -go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +go.opentelemetry.io/proto/otlp v1.7.0 h1:jX1VolD6nHuFzOYso2E73H85i92Mv8JQYk0K9vz09os= +go.opentelemetry.io/proto/otlp v1.7.0/go.mod h1:fSKjH6YJ7HDlwzltzyMj036AJ3ejJLCgCSHGj4efDDo= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= -golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= -golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg= -golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= -golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= -golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc= -golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 h1:oWVWY3NzT7KJppx2UKhKmzPq4SRe0LdCijVRwvGeikY= +google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822/go.mod h1:h3c4v36UTKzUiuaOKQ6gr3S+0hovBtUrXzTG/i3+XEc= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 h1:fc6jSaCT0vBduLYZHYrBBNY4dsWuvgyff9noRNDdBeE= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= +google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok= +google.golang.org/grpc v1.73.0/go.mod h1:50sbHOUqWoCQGI8V2HQLJM0B+LMlIUjNSZmow7EVBQc= google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/evanphx/json-patch.v4 v4.12.0 h1:n6jtcsulIzXPJaxegRbvFNNrZDjbij7ny3gmSPG+6V4= -gopkg.in/evanphx/json-patch.v4 v4.12.0/go.mod h1:p8EYWUEYMpynmqDbY58zCKCFZw8pRWMG4EsWvDvM72M= -gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= -gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -k8s.io/api v0.33.2 h1:YgwIS5jKfA+BZg//OQhkJNIfie/kmRsO0BmNaVSimvY= -k8s.io/api v0.33.2/go.mod h1:fhrbphQJSM2cXzCWgqU29xLDuks4mu7ti9vveEnpSXs= -k8s.io/apimachinery v0.33.2 h1:IHFVhqg59mb8PJWTLi8m1mAoepkUNYmptHsV+Z1m5jY= -k8s.io/apimachinery v0.33.2/go.mod h1:BHW0YOu7n22fFv/JkYOEfkUYNRN0fj0BlvMFWA7b+SM= -k8s.io/client-go v0.33.2 h1:z8CIcc0P581x/J1ZYf4CNzRKxRvQAwoAolYPbtQes+E= -k8s.io/client-go v0.33.2/go.mod h1:9mCgT4wROvL948w6f6ArJNb7yQd7QsvqavDeZHvNmHo= -k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk= -k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE= -k8s.io/kube-openapi v0.0.0-20250610211856-8b98d1ed966a h1:ZV3Zr+/7s7aVbjNGICQt+ppKWsF1tehxggNfbM7XnG8= -k8s.io/kube-openapi v0.0.0-20250610211856-8b98d1ed966a/go.mod h1:5jIi+8yX4RIb8wk3XwBo5Pq2ccx4FP10ohkbSKCZoK8= -k8s.io/utils v0.0.0-20250604170112-4c0f3b243397 h1:hwvWFiBzdWw1FhfY1FooPn3kzWuJ8tmbZBHi4zVsl1Y= -k8s.io/utils v0.0.0-20250604170112-4c0f3b243397/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= -sigs.k8s.io/json v0.0.0-20241014173422-cfa47c3a1cc8 h1:gBQPwqORJ8d8/YNZWEjoZs7npUVDpVXUUOFfW6CgAqE= -sigs.k8s.io/json v0.0.0-20241014173422-cfa47c3a1cc8/go.mod h1:mdzfpAEoE6DHQEN0uh9ZbOCuHbLK5wOm7dK4ctXE9Tg= -sigs.k8s.io/randfill v0.0.0-20250304075658-069ef1bbf016/go.mod h1:XeLlZ/jmk4i1HRopwe7/aU3H5n1zNUcX6TM94b3QxOY= -sigs.k8s.io/randfill v1.0.0 h1:JfjMILfT8A6RbawdsK2JXGBR5AQVfd+9TbzrlneTyrU= -sigs.k8s.io/randfill v1.0.0/go.mod h1:XeLlZ/jmk4i1HRopwe7/aU3H5n1zNUcX6TM94b3QxOY= -sigs.k8s.io/structured-merge-diff/v4 v4.6.0 h1:IUA9nvMmnKWcj5jl84xn+T5MnlZKThmUW1TdblaLVAc= -sigs.k8s.io/structured-merge-diff/v4 v4.6.0/go.mod h1:dDy58f92j70zLsuZVuUX5Wp9vtxXpaZnkPGWeqDfCps= sigs.k8s.io/yaml v1.4.0 h1:Mk1wCc2gy/F0THH0TAp1QYyJNzRm2KCLy3o5ASXVI5E= sigs.k8s.io/yaml v1.4.0/go.mod h1:Ejl7/uTz7PSA4eKMyQCUTnhZYNmLIl+5c2lQPGR2BPY= diff --git a/internal/cache/cache.go b/internal/cache/cache.go new file mode 100644 index 0000000..4c2c105 --- /dev/null +++ b/internal/cache/cache.go @@ -0,0 +1,545 @@ +package cache + +import ( + "context" + "fmt" + "sync" + "time" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + + "github.com/kagent-dev/tools/internal/logger" + "github.com/kagent-dev/tools/internal/telemetry" +) + +// CacheType represents the type of cache using enum pattern +type CacheType int + +const ( + CacheTypeKubernetes CacheType = iota + CacheTypeCommand + CacheTypeHelm + CacheTypeIstio +) + +// String returns the string representation of CacheType +func (ct CacheType) String() string { + switch ct { + case CacheTypeKubernetes: + return "kubernetes" + case CacheTypeCommand: + return "command" + case CacheTypeHelm: + return "helm" + case CacheTypeIstio: + return "istio" + default: + return "unknown" + } +} + +// Command to cache type mapping +var commandToCacheType = map[string]CacheType{ + "kubectl": CacheTypeKubernetes, + "helm": CacheTypeHelm, + "istioctl": CacheTypeIstio, + "cilium": CacheTypeCommand, // Use command cache for cilium + "argo": CacheTypeCommand, // Use command cache for argo +} + +// CacheEntry represents a cached item with TTL +type CacheEntry[T any] struct { + Value T + CreatedAt time.Time + ExpiresAt time.Time + AccessedAt time.Time + AccessCount int64 +} + +// IsExpired checks if the cache entry has expired +func (e *CacheEntry[T]) IsExpired() bool { + return time.Now().After(e.ExpiresAt) +} + +// Cache is a thread-safe cache with TTL support +type Cache[T any] struct { + mu sync.RWMutex + data map[string]*CacheEntry[T] + name string + defaultTTL time.Duration + maxSize int + cleanupInterval time.Duration + stopCleanup chan struct{} + + // Metrics + hits metric.Int64Counter + misses metric.Int64Counter + evictions metric.Int64Counter + size metric.Int64UpDownCounter +} + +// NewCache creates a new cache with specified configuration and name +func NewCache[T any](name string, defaultTTL time.Duration, maxSize int, cleanupInterval time.Duration) *Cache[T] { + meter := otel.Meter(fmt.Sprintf("kagent-tools/cache/%s", name)) + + // Create metrics with cache name as a label + hits, _ := meter.Int64Counter( + "cache_hits_total", + metric.WithDescription("Total number of cache hits"), + ) + + misses, _ := meter.Int64Counter( + "cache_misses_total", + metric.WithDescription("Total number of cache misses"), + ) + + evictions, _ := meter.Int64Counter( + "cache_evictions_total", + metric.WithDescription("Total number of cache evictions"), + ) + + size, _ := meter.Int64UpDownCounter( + "cache_size", + metric.WithDescription("Current number of items in cache"), + ) + + cache := &Cache[T]{ + data: make(map[string]*CacheEntry[T]), + name: name, + defaultTTL: defaultTTL, + maxSize: maxSize, + cleanupInterval: cleanupInterval, + stopCleanup: make(chan struct{}), + hits: hits, + misses: misses, + evictions: evictions, + size: size, + } + + // Start background cleanup + go cache.cleanupExpired() + + return cache +} + +// Get retrieves a value from the cache +func (c *Cache[T]) Get(key string) (T, bool) { + ctx := context.Background() + _, span := telemetry.StartSpan(ctx, "cache.get", + attribute.String("cache.name", c.name), + attribute.String("cache.key", key), + ) + defer span.End() + + c.mu.RLock() + defer c.mu.RUnlock() + + entry, exists := c.data[key] + if !exists { + var zero T + c.recordMiss(key) + telemetry.AddEvent(span, "cache.miss", + attribute.String("cache.result", "miss"), + ) + span.SetAttributes(attribute.String("cache.result", "miss")) + return zero, false + } + + if entry.IsExpired() { + var zero T + c.recordMiss(key) + telemetry.AddEvent(span, "cache.miss", + attribute.String("cache.result", "miss"), + attribute.String("cache.miss_reason", "expired"), + ) + span.SetAttributes( + attribute.String("cache.result", "miss"), + attribute.String("cache.miss_reason", "expired"), + ) + return zero, false + } + + // Update access time and count + entry.AccessedAt = time.Now() + entry.AccessCount++ + + c.recordHit(key) + telemetry.AddEvent(span, "cache.hit", + attribute.String("cache.result", "hit"), + attribute.Int64("cache.access_count", entry.AccessCount), + ) + span.SetAttributes( + attribute.String("cache.result", "hit"), + attribute.Int64("cache.access_count", entry.AccessCount), + ) + + logger.Get().Debug("Cache hit", "key", key, "access_count", entry.AccessCount) + return entry.Value, true +} + +// Set stores a value in the cache with default TTL +func (c *Cache[T]) Set(key string, value T) { + c.SetWithTTL(key, value, c.defaultTTL) +} + +// SetWithTTL stores a value in the cache with specified TTL +func (c *Cache[T]) SetWithTTL(key string, value T, ttl time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + + now := time.Now() + + // Check if we need to evict items to make room + if len(c.data) >= c.maxSize { + c.evictLRU() + } + + entry := &CacheEntry[T]{ + Value: value, + CreatedAt: now, + ExpiresAt: now.Add(ttl), + AccessedAt: now, + AccessCount: 1, + } + + // Check if key already exists + if _, exists := c.data[key]; !exists { + c.size.Add(context.Background(), 1) + } + + c.data[key] = entry + + logger.Get().Debug("Cache set", "key", key, "ttl", ttl) +} + +// Delete removes a value from the cache +func (c *Cache[T]) Delete(key string) { + c.mu.Lock() + defer c.mu.Unlock() + + if _, exists := c.data[key]; exists { + delete(c.data, key) + c.size.Add(context.Background(), -1) + logger.Get().Debug("Cache delete", "key", key) + } +} + +// Clear removes all items from the cache +func (c *Cache[T]) Clear() { + c.mu.Lock() + defer c.mu.Unlock() + + count := len(c.data) + c.data = make(map[string]*CacheEntry[T]) + c.size.Add(context.Background(), -int64(count)) + + logger.Get().Info("Cache cleared", "items_removed", count) +} + +// Size returns the current number of items in the cache +func (c *Cache[T]) Size() int { + c.mu.RLock() + defer c.mu.RUnlock() + return len(c.data) +} + +// Name returns the name of the cache +func (c *Cache[T]) Name() string { + return c.name +} + +// Stats returns cache statistics +func (c *Cache[T]) Stats() CacheStats { + c.mu.RLock() + defer c.mu.RUnlock() + + stats := CacheStats{ + Size: len(c.data), + MaxSize: c.maxSize, + Expired: 0, + Oldest: time.Now(), + Newest: time.Time{}, + } + + for _, entry := range c.data { + if entry.IsExpired() { + stats.Expired++ + } + + if entry.CreatedAt.Before(stats.Oldest) { + stats.Oldest = entry.CreatedAt + } + + if entry.CreatedAt.After(stats.Newest) { + stats.Newest = entry.CreatedAt + } + } + + return stats +} + +// CacheStats represents cache statistics +type CacheStats struct { + Size int `json:"size"` + MaxSize int `json:"max_size"` + Expired int `json:"expired"` + Oldest time.Time `json:"oldest"` + Newest time.Time `json:"newest"` +} + +// cleanupExpired removes expired entries from the cache +func (c *Cache[T]) cleanupExpired() { + ticker := time.NewTicker(c.cleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + c.performCleanup() + case <-c.stopCleanup: + return + } + } +} + +// performCleanup removes expired entries +func (c *Cache[T]) performCleanup() { + c.mu.Lock() + defer c.mu.Unlock() + + keysToDelete := make([]string, 0) + + for key, entry := range c.data { + if entry.IsExpired() { + keysToDelete = append(keysToDelete, key) + } + } + + if len(keysToDelete) > 0 { + for _, key := range keysToDelete { + delete(c.data, key) + c.evictions.Add(context.Background(), 1) + } + + c.size.Add(context.Background(), -int64(len(keysToDelete))) + logger.Get().Debug("Cache cleanup", "expired_items", len(keysToDelete)) + } +} + +// evictLRU removes the least recently used item +func (c *Cache[T]) evictLRU() { + var oldestKey string + var oldestTime time.Time = time.Now() + + for key, entry := range c.data { + if entry.AccessedAt.Before(oldestTime) { + oldestTime = entry.AccessedAt + oldestKey = key + } + } + + if oldestKey != "" { + delete(c.data, oldestKey) + c.evictions.Add(context.Background(), 1) + c.size.Add(context.Background(), -1) + logger.Get().Debug("Cache LRU eviction", "key", oldestKey) + } +} + +// recordHit records a cache hit +func (c *Cache[T]) recordHit(key string) { + c.hits.Add(context.Background(), 1, metric.WithAttributes( + attribute.String("cache.key", key), + attribute.String("cache.result", "hit"), + attribute.String("cache.name", c.name), + )) +} + +// recordMiss records a cache miss +func (c *Cache[T]) recordMiss(key string) { + c.misses.Add(context.Background(), 1, metric.WithAttributes( + attribute.String("cache.key", key), + attribute.String("cache.result", "miss"), + attribute.String("cache.name", c.name), + )) +} + +// Close stops the cache cleanup goroutine +func (c *Cache[T]) Close() { + close(c.stopCleanup) +} + +// InvalidateByType clears the entire cache for a specific cache type +func InvalidateByType(cacheType CacheType) { + ctx := context.Background() + _, span := telemetry.StartSpan(ctx, "cache.invalidate", + attribute.String("cache.type", cacheType.String()), + attribute.String("cache.operation", "invalidate"), + ) + defer span.End() + + InitCaches() + if cache, exists := cacheRegistry[cacheType]; exists { + oldSize := cache.Size() + cache.Clear() + + telemetry.AddEvent(span, "cache.invalidated", + attribute.String("cache.name", cache.name), + attribute.Int("cache.items_cleared", oldSize), + ) + span.SetAttributes( + attribute.String("cache.name", cache.name), + attribute.Int("cache.items_cleared", oldSize), + ) + telemetry.RecordSuccess(span, "Cache invalidated successfully") + + logger.Get().Info("Cache invalidated", "cache_type", cacheType.String(), "reason", "modification_command", "items_cleared", oldSize) + } else { + telemetry.RecordError(span, fmt.Errorf("cache type not found: %s", cacheType.String()), "Cache type not found") + } +} + +// InvalidateKubernetesCache clears the Kubernetes cache +func InvalidateKubernetesCache() { + InvalidateByType(CacheTypeKubernetes) +} + +// InvalidateHelmCache clears the Helm cache +func InvalidateHelmCache() { + InvalidateByType(CacheTypeHelm) +} + +// InvalidateIstioCache clears the Istio cache +func InvalidateIstioCache() { + InvalidateByType(CacheTypeIstio) +} + +// InvalidateCommandCache clears the Command cache +func InvalidateCommandCache() { + InvalidateByType(CacheTypeCommand) +} + +// InvalidateCacheForCommand invalidates the appropriate cache based on command type +func InvalidateCacheForCommand(command string) { + if cacheType, exists := commandToCacheType[command]; exists { + InvalidateByType(cacheType) + } else { + // Default to command cache for unknown commands + InvalidateCommandCache() + } +} + +// Global cache instances for different use cases +var ( + // cacheRegistry holds all cache instances by type + cacheRegistry = make(map[CacheType]*Cache[string]) + once sync.Once +) + +// InitCaches initializes all global cache instances +func InitCaches() { + once.Do(func() { + // Initialize caches with optimized TTL values based on use case + // Kubernetes: 45s - K8s resources change frequently, users expect fresh data + cacheRegistry[CacheTypeKubernetes] = NewCache[string](CacheTypeKubernetes.String(), 45*time.Second, 1000, 1*time.Minute) + + // Istio: 1m - Service mesh config more stable than pods, but proxy status can change + cacheRegistry[CacheTypeIstio] = NewCache[string](CacheTypeIstio.String(), 1*time.Minute, 500, 1*time.Minute) + + // Helm: 2m - Releases change less frequently, chart info is stable + cacheRegistry[CacheTypeHelm] = NewCache[string](CacheTypeHelm.String(), 2*time.Minute, 300, 2*time.Minute) + + // Command: 3m - General CLI commands have stable output, status commands don't change rapidly + cacheRegistry[CacheTypeCommand] = NewCache[string](CacheTypeCommand.String(), 3*time.Minute, 200, 1*time.Minute) + + logger.Get().Info("Caches initialized") + }) +} + +// GetCacheByType returns a cache instance by cache type +func GetCacheByType(cacheType CacheType) *Cache[string] { + InitCaches() + if cache, exists := cacheRegistry[cacheType]; exists { + return cache + } + // Fallback to command cache if type not found + return cacheRegistry[CacheTypeCommand] +} + +// GetCacheByCommand returns a cache instance based on the command name +func GetCacheByCommand(command string) *Cache[string] { + InitCaches() + if cacheType, exists := commandToCacheType[command]; exists { + return GetCacheByType(cacheType) + } + // Default to command cache for unknown commands + return GetCacheByType(CacheTypeCommand) +} + +// CacheKey generates a consistent cache key from components +func CacheKey(components ...string) string { + result := "" + for i, component := range components { + if i > 0 { + result += ":" + } + result += component + } + return result +} + +// CacheResult is a helper function to cache the result of a function +func CacheResult[T any](cache *Cache[T], key string, ttl time.Duration, fn func() (T, error)) (T, error) { + ctx := context.Background() + _, span := telemetry.StartSpan(ctx, "cache.result", + attribute.String("cache.name", cache.name), + attribute.String("cache.key", key), + attribute.String("cache.ttl", ttl.String()), + ) + defer span.End() + + var zero T + + // Try to get from cache first + if cachedResult, found := cache.Get(key); found { + telemetry.AddEvent(span, "cache.result.hit", + attribute.String("cache.operation", "get"), + attribute.String("cache.result", "hit"), + ) + span.SetAttributes( + attribute.String("cache.operation", "get"), + attribute.String("cache.result", "hit"), + ) + telemetry.RecordSuccess(span, "Cache hit - returning cached result") + return cachedResult, nil + } + + // Not in cache, execute function + telemetry.AddEvent(span, "cache.result.miss", + attribute.String("cache.operation", "compute"), + attribute.String("cache.result", "miss"), + ) + span.SetAttributes( + attribute.String("cache.operation", "compute"), + attribute.String("cache.result", "miss"), + ) + + result, err := fn() + if err != nil { + telemetry.RecordError(span, err, "Function execution failed") + return zero, err + } + + // Store in cache + cache.SetWithTTL(key, result, ttl) + + telemetry.AddEvent(span, "cache.result.stored", + attribute.String("cache.operation", "set"), + ) + span.SetAttributes(attribute.String("cache.operation", "set")) + telemetry.RecordSuccess(span, "Function executed and result cached") + + return result, nil +} diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go new file mode 100644 index 0000000..cc7cf64 --- /dev/null +++ b/internal/cache/cache_test.go @@ -0,0 +1,488 @@ +package cache + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewCache(t *testing.T) { + cache := NewCache[string]("test-cache", 1*time.Minute, 100, 10*time.Second) + + if cache.defaultTTL != 1*time.Minute { + t.Errorf("Expected default TTL of 1 minute, got %v", cache.defaultTTL) + } + + if cache.maxSize != 100 { + t.Errorf("Expected max size of 100, got %d", cache.maxSize) + } + + if cache.cleanupInterval != 10*time.Second { + t.Errorf("Expected cleanup interval of 10 seconds, got %v", cache.cleanupInterval) + } + + if cache.name != "test-cache" { + t.Errorf("Expected cache name 'test-cache', got %s", cache.name) + } + + cache.Close() +} + +func TestCacheName(t *testing.T) { + cache := NewCache[string]("my-test-cache", 1*time.Minute, 100, 10*time.Second) + defer cache.Close() + + if cache.Name() != "my-test-cache" { + t.Errorf("Expected cache name 'my-test-cache', got %s", cache.Name()) + } +} + +func TestCacheSetAndGet(t *testing.T) { + cache := NewCache[string]("test-cache", 1*time.Minute, 100, 10*time.Second) + defer cache.Close() + + // Test set and get + cache.Set("key1", "value1") + value, found := cache.Get("key1") + + if !found { + t.Error("Expected to find key1") + } + + if value != "value1" { + t.Errorf("Expected value1, got %v", value) + } +} + +func TestCacheSetWithTTL(t *testing.T) { + cache := NewCache[string]("test-cache", 1*time.Minute, 100, 10*time.Second) + defer cache.Close() + + // Test set with custom TTL + cache.SetWithTTL("key1", "value1", 100*time.Millisecond) + + // Should be found immediately + value, found := cache.Get("key1") + if !found { + t.Error("Expected to find key1") + } + if value != "value1" { + t.Errorf("Expected value1, got %v", value) + } + + // Wait for expiration + time.Sleep(150 * time.Millisecond) + + // Should not be found after expiration + _, found = cache.Get("key1") + if found { + t.Error("Expected key1 to be expired") + } +} + +func TestCacheDelete(t *testing.T) { + cache := NewCache[string]("test-cache", 1*time.Minute, 100, 10*time.Second) + defer cache.Close() + + cache.Set("key1", "value1") + cache.Delete("key1") + + _, found := cache.Get("key1") + if found { + t.Error("Expected key1 to be deleted") + } +} + +func TestCacheClear(t *testing.T) { + cache := NewCache[string]("test-cache", 1*time.Minute, 100, 10*time.Second) + defer cache.Close() + + cache.Set("key1", "value1") + cache.Set("key2", "value2") + + if cache.Size() != 2 { + t.Errorf("Expected size 2, got %d", cache.Size()) + } + + cache.Clear() + + if cache.Size() != 0 { + t.Errorf("Expected size 0 after clear, got %d", cache.Size()) + } +} + +func TestCacheEviction(t *testing.T) { + cache := NewCache[string]("test-cache", 1*time.Minute, 2, 10*time.Second) // Small cache + defer cache.Close() + + // Fill cache to capacity + cache.Set("key1", "value1") + cache.Set("key2", "value2") + + // Add one more item - should evict LRU + cache.Set("key3", "value3") + + // key1 should be evicted (oldest) + _, found := cache.Get("key1") + if found { + t.Error("Expected key1 to be evicted") + } + + // key2 and key3 should still be there + _, found = cache.Get("key2") + if !found { + t.Error("Expected key2 to be present") + } + + _, found = cache.Get("key3") + if !found { + t.Error("Expected key3 to be present") + } +} + +func TestCacheExpiration(t *testing.T) { + cache := NewCache[string]("test-cache", 1*time.Minute, 100, 50*time.Millisecond) // Fast cleanup + defer cache.Close() + + // Set item with short TTL + cache.SetWithTTL("key1", "value1", 100*time.Millisecond) + + // Wait for cleanup to run + time.Sleep(200 * time.Millisecond) + + // Item should be cleaned up + _, found := cache.Get("key1") + if found { + t.Error("Expected key1 to be cleaned up") + } +} + +func TestCacheStats(t *testing.T) { + cache := NewCache[string]("test-cache", 1*time.Minute, 100, 10*time.Second) + defer cache.Close() + + cache.Set("key1", "value1") + cache.Set("key2", "value2") + + stats := cache.Stats() + + if stats.Size != 2 { + t.Errorf("Expected stats size 2, got %d", stats.Size) + } + + if stats.MaxSize != 100 { + t.Errorf("Expected stats max size 100, got %d", stats.MaxSize) + } + + if stats.Expired != 0 { + t.Errorf("Expected 0 expired items, got %d", stats.Expired) + } +} + +func TestCacheKey(t *testing.T) { + tests := []struct { + name string + components []string + expected string + }{ + {"single component", []string{"key1"}, "key1"}, + {"multiple components", []string{"key1", "key2", "key3"}, "key1:key2:key3"}, + {"empty components", []string{}, ""}, + {"empty string component", []string{"key1", "", "key3"}, "key1::key3"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := CacheKey(tt.components...) + if result != tt.expected { + t.Errorf("Expected %q, got %q", tt.expected, result) + } + }) + } +} + +func TestCacheResult(t *testing.T) { + cache := NewCache[string]("test-cache", 1*time.Minute, 100, 10*time.Second) + defer cache.Close() + + callCount := 0 + testFunction := func() (string, error) { + callCount++ + return "result", nil + } + + // First call should execute function + result, err := CacheResult(cache, "test-key", 1*time.Minute, testFunction) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if result != "result" { + t.Errorf("Expected 'result', got %q", result) + } + if callCount != 1 { + t.Errorf("Expected function to be called once, got %d", callCount) + } + + // Second call should use cache + result, err = CacheResult(cache, "test-key", 1*time.Minute, testFunction) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if result != "result" { + t.Errorf("Expected 'result', got %q", result) + } + if callCount != 1 { + t.Errorf("Expected function to be called once (cached), got %d", callCount) + } +} + +func TestCacheResultWithError(t *testing.T) { + cache := NewCache[string]("test-cache", 1*time.Minute, 100, 10*time.Second) + defer cache.Close() + + testFunction := func() (string, error) { + return "", &testError{message: "test error"} + } + + result, err := CacheResult(cache, "test-key", 1*time.Minute, testFunction) + if err == nil { + t.Error("Expected error") + } + if result != "" { + t.Errorf("Expected empty result, got %q", result) + } + + // Check that error result is not cached + _, found := cache.Get("test-key") + if found { + t.Error("Expected error result not to be cached") + } +} + +func TestCacheInitialization(t *testing.T) { + // Test that all cache types are properly initialized + types := []CacheType{ + CacheTypeKubernetes, + CacheTypeCommand, + CacheTypeHelm, + CacheTypeIstio, + } + + for _, cacheType := range types { + t.Run(cacheType.String(), func(t *testing.T) { + cache := GetCacheByType(cacheType) + if cache == nil { + t.Errorf("Expected cache for type %s to be initialized", cacheType.String()) + } + if cache.Name() != cacheType.String() { + t.Errorf("Expected cache name %s, got %s", cacheType.String(), cache.Name()) + } + }) + } +} + +func TestCacheEntry(t *testing.T) { + now := time.Now() + entry := &CacheEntry[string]{ + Value: "test", + CreatedAt: now, + ExpiresAt: now.Add(1 * time.Minute), + AccessedAt: now, + AccessCount: 1, + } + + // Should not be expired + if entry.IsExpired() { + t.Error("Expected entry not to be expired") + } + + // Make it expired + entry.ExpiresAt = now.Add(-1 * time.Minute) + + // Should be expired + if !entry.IsExpired() { + t.Error("Expected entry to be expired") + } +} + +func TestCachePerformCleanup(t *testing.T) { + cache := NewCache[string]("test-cache", 1*time.Minute, 100, 10*time.Second) + defer cache.Close() + + // Add expired item + cache.SetWithTTL("expired", "value", -1*time.Minute) + + // Add valid item + cache.Set("valid", "value") + + // Perform cleanup + cache.performCleanup() + + // Expired item should be removed + _, found := cache.Get("expired") + if found { + t.Error("Expected expired item to be removed") + } + + // Valid item should remain + _, found = cache.Get("valid") + if !found { + t.Error("Expected valid item to remain") + } +} + +func TestCacheConcurrency(t *testing.T) { + cache := NewCache[string]("test-cache", 1*time.Minute, 1000, 10*time.Second) + defer cache.Close() + + // Test concurrent operations + done := make(chan bool) + + // Writer goroutine + go func() { + for i := 0; i < 100; i++ { + cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i)) + } + done <- true + }() + + // Reader goroutine + go func() { + for i := 0; i < 100; i++ { + cache.Get(fmt.Sprintf("key%d", i)) + } + done <- true + }() + + // Wait for both goroutines + <-done + <-done + + // Cache should have items + if cache.Size() == 0 { + t.Error("Expected cache to have items") + } +} + +// Helper types for testing +type testError struct { + message string +} + +func (e *testError) Error() string { + return e.message +} + +func TestCacheTypeString(t *testing.T) { + tests := []struct { + cacheType CacheType + expected string + }{ + {CacheTypeKubernetes, "kubernetes"}, + {CacheTypeCommand, "command"}, + {CacheTypeHelm, "helm"}, + {CacheTypeIstio, "istio"}, + {CacheType(999), "unknown"}, // Test unknown type + } + + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + result := tt.cacheType.String() + if result != tt.expected { + t.Errorf("Expected %q, got %q", tt.expected, result) + } + }) + } +} + +func TestGetCacheByType(t *testing.T) { + // Test all valid cache types + types := []CacheType{ + CacheTypeKubernetes, + CacheTypeCommand, + CacheTypeHelm, + CacheTypeIstio, + } + + for _, cacheType := range types { + t.Run(cacheType.String(), func(t *testing.T) { + cache := GetCacheByType(cacheType) + if cache == nil { + t.Errorf("Expected cache for type %s, got nil", cacheType.String()) + } + if cache.Name() != cacheType.String() { + t.Errorf("Expected cache name %s, got %s", cacheType.String(), cache.Name()) + } + }) + } +} + +func TestGetCacheByCommand(t *testing.T) { + tests := []struct { + command string + expectedType CacheType + }{ + {"kubectl", CacheTypeKubernetes}, + {"helm", CacheTypeHelm}, + {"istioctl", CacheTypeIstio}, + {"cilium", CacheTypeCommand}, + {"argo", CacheTypeCommand}, + {"unknown-command", CacheTypeCommand}, // Should default to command cache + } + + for _, tt := range tests { + t.Run(tt.command, func(t *testing.T) { + cache := GetCacheByCommand(tt.command) + if cache == nil { + t.Errorf("Expected cache for command %s, got nil", tt.command) + } + if cache.Name() != tt.expectedType.String() { + t.Errorf("Expected cache name %s for command %s, got %s", + tt.expectedType.String(), tt.command, cache.Name()) + } + }) + } +} + +func TestCacheOTelTracing(t *testing.T) { + // This test verifies that OTEL tracing calls don't panic + // The actual tracing verification would require setting up an OTEL test environment + cache := NewCache[string]("test-tracing", 1*time.Minute, 10, 5*time.Minute) + defer cache.Close() + + // Test cache miss with tracing + _, found := cache.Get("missing-key") + assert.False(t, found) + + // Test cache hit with tracing + cache.Set("test-key", "test-value") + value, found := cache.Get("test-key") + assert.True(t, found) + assert.Equal(t, "test-value", value) + + // Test CacheResult with tracing + callCount := 0 + result, err := CacheResult(cache, "result-key", 1*time.Minute, func() (string, error) { + callCount++ + return "computed-value", nil + }) + assert.NoError(t, err) + assert.Equal(t, "computed-value", result) + assert.Equal(t, 1, callCount) + + // Test cache hit on second call + result2, err := CacheResult(cache, "result-key", 1*time.Minute, func() (string, error) { + callCount++ + return "computed-value", nil + }) + assert.NoError(t, err) + assert.Equal(t, "computed-value", result2) + assert.Equal(t, 1, callCount) // Should not increment due to cache hit + + // Test cache invalidation with tracing + oldSize := cache.Size() + InvalidateByType(CacheTypeCommand) + assert.True(t, oldSize > 0) // Verify we had items to clear +} diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go new file mode 100644 index 0000000..3061006 --- /dev/null +++ b/internal/cmd/cmd.go @@ -0,0 +1,69 @@ +package cmd + +import ( + "context" + "os/exec" + "time" + + "github.com/kagent-dev/tools/internal/logger" +) + +// ShellExecutor defines the interface for executing shell commands +type ShellExecutor interface { + Exec(ctx context.Context, command string, args ...string) (output []byte, err error) +} + +// DefaultShellExecutor implements ShellExecutor using os/exec +type DefaultShellExecutor struct{} + +// Exec executes a command using os/exec.CommandContext +func (e *DefaultShellExecutor) Exec(ctx context.Context, command string, args ...string) ([]byte, error) { + log := logger.WithContext(ctx) + startTime := time.Now() + + log.Info("executing command", + "command", command, + "args", args, + ) + + cmd := exec.CommandContext(ctx, command, args...) + output, err := cmd.CombinedOutput() + + duration := time.Since(startTime) + + if err != nil { + log.Error("command execution failed", + "command", command, + "args", args, + "error", err, + "output", string(output), + "duration", duration.Seconds(), + ) + } else { + log.Info("command execution successful", + "command", command, + "args", args, + "duration", duration.Seconds(), + ) + } + + return output, err +} + +// Context key for shell executor injection +type contextKey string + +const shellExecutorKey contextKey = "shellExecutor" + +// WithShellExecutor returns a context with the given shell executor +func WithShellExecutor(ctx context.Context, executor ShellExecutor) context.Context { + return context.WithValue(ctx, shellExecutorKey, executor) +} + +// GetShellExecutor retrieves the shell executor from context, or returns default +func GetShellExecutor(ctx context.Context) ShellExecutor { + if executor, ok := ctx.Value(shellExecutorKey).(ShellExecutor); ok { + return executor + } + return &DefaultShellExecutor{} +} diff --git a/internal/cmd/cmd_test.go b/internal/cmd/cmd_test.go new file mode 100644 index 0000000..f902d4c --- /dev/null +++ b/internal/cmd/cmd_test.go @@ -0,0 +1,58 @@ +package cmd + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDefaultShellExecutor(t *testing.T) { + executor := &DefaultShellExecutor{} + + // Test successful command + output, err := executor.Exec(context.Background(), "echo", "hello") + assert.NoError(t, err) + assert.Equal(t, "hello\n", string(output)) + + // Test command with error + _, err = executor.Exec(context.Background(), "nonexistent-command") + assert.Error(t, err) +} + +func TestMockShellExecutor(t *testing.T) { + mock := NewMockShellExecutor() + + t.Run("unmocked command returns error", func(t *testing.T) { + _, err := mock.Exec(context.Background(), "unmocked", "command") + assert.Error(t, err) + assert.Contains(t, err.Error(), "no mock found for command") + }) + + t.Run("mocked command returns expected result", func(t *testing.T) { + expectedOutput := "mocked output" + mock.AddCommandString("kubectl", []string{"get", "pods"}, expectedOutput, nil) + + output, err := mock.Exec(context.Background(), "kubectl", "get", "pods") + assert.NoError(t, err) + assert.Equal(t, expectedOutput, string(output)) + }) +} + +func TestContextShellExecutor(t *testing.T) { + t.Run("default executor when no context value", func(t *testing.T) { + ctx := context.Background() + executor := GetShellExecutor(ctx) + + _, ok := executor.(*DefaultShellExecutor) + assert.True(t, ok, "should return DefaultShellExecutor when no context value") + }) + + t.Run("mock executor from context", func(t *testing.T) { + mock := NewMockShellExecutor() + ctx := WithShellExecutor(context.Background(), mock) + + executor := GetShellExecutor(ctx) + assert.Equal(t, mock, executor, "should return the mock executor from context") + }) +} diff --git a/internal/cmd/mock.go b/internal/cmd/mock.go new file mode 100644 index 0000000..3f13c47 --- /dev/null +++ b/internal/cmd/mock.go @@ -0,0 +1,120 @@ +package cmd + +import ( + "context" + "fmt" + "strings" + "sync" +) + +// MockCall represents a recorded command execution for testing +type MockCall struct { + Command string + Args []string +} + +// MockShellExecutor is a mock implementation of ShellExecutor for testing +type MockShellExecutor struct { + mu sync.Mutex + callLog []MockCall + commandMocks map[string]map[string]struct { + output string + err error + } + partialMatchers []struct { + command string + args []string + output string + err error + } +} + +// NewMockShellExecutor creates a new mock shell executor +func NewMockShellExecutor() *MockShellExecutor { + return &MockShellExecutor{ + commandMocks: make(map[string]map[string]struct { + output string + err error + }), + } +} + +// AddCommandString mocks a command with specific arguments and a string output +func (m *MockShellExecutor) AddCommandString(command string, args []string, output string, err error) { + m.mu.Lock() + defer m.mu.Unlock() + + argsKey := strings.Join(args, " ") + if _, ok := m.commandMocks[command]; !ok { + m.commandMocks[command] = make(map[string]struct { + output string + err error + }) + } + m.commandMocks[command][argsKey] = struct { + output string + err error + }{output, err} +} + +// AddPartialMatcherString mocks a command with partial argument matching +func (m *MockShellExecutor) AddPartialMatcherString(command string, args []string, output string, err error) { + m.mu.Lock() + defer m.mu.Unlock() + + m.partialMatchers = append(m.partialMatchers, struct { + command string + args []string + output string + err error + }{command, args, output, err}) +} + +// Exec records the call and returns a mocked output or error +func (m *MockShellExecutor) Exec(ctx context.Context, command string, args ...string) ([]byte, error) { + m.mu.Lock() + defer m.mu.Unlock() + + m.callLog = append(m.callLog, MockCall{Command: command, Args: args}) + + // Check for exact match first + argsKey := strings.Join(args, " ") + if mocks, ok := m.commandMocks[command]; ok { + if mock, ok := mocks[argsKey]; ok { + return []byte(mock.output), mock.err + } + } + + // Check for partial match + for _, matcher := range m.partialMatchers { + if matcher.command == command && argsContain(args, matcher.args) { + return []byte(matcher.output), matcher.err + } + } + + return nil, fmt.Errorf("no mock found for command: %s %v", command, args) +} + +// GetCallLog returns the history of commands executed +func (m *MockShellExecutor) GetCallLog() []MockCall { + m.mu.Lock() + defer m.mu.Unlock() + return m.callLog +} + +// argsContain checks if all elements of subset are in set +func argsContain(set, subset []string) bool { + for _, sub := range subset { + found := false + for _, s := range set { + if strings.Contains(s, sub) { + found = true + break + } + } + if !found { + return false + } + } + return true +} diff --git a/internal/commands/builder.go b/internal/commands/builder.go new file mode 100644 index 0000000..a6bd8e6 --- /dev/null +++ b/internal/commands/builder.go @@ -0,0 +1,747 @@ +package commands + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/kagent-dev/tools/internal/cache" + "github.com/kagent-dev/tools/internal/cmd" + "github.com/kagent-dev/tools/internal/errors" + "github.com/kagent-dev/tools/internal/logger" + "github.com/kagent-dev/tools/internal/security" + "github.com/kagent-dev/tools/internal/telemetry" + "go.opentelemetry.io/otel/attribute" +) + +// CommandBuilder provides a fluent interface for building CLI commands +type CommandBuilder struct { + command string + args []string + namespace string + context string + kubeconfig string + output string + labels map[string]string + annotations map[string]string + timeout time.Duration + dryRun bool + force bool + wait bool + validate bool + cached bool + cacheTTL time.Duration + cacheKey string +} + +// NewCommandBuilder creates a new command builder +func NewCommandBuilder(command string) *CommandBuilder { + return &CommandBuilder{ + command: command, + args: make([]string, 0), + labels: make(map[string]string), + annotations: make(map[string]string), + timeout: 30 * time.Second, + validate: true, + cacheTTL: 5 * time.Minute, + } +} + +// KubectlBuilder creates a kubectl command builder +func KubectlBuilder() *CommandBuilder { + return NewCommandBuilder("kubectl") +} + +// HelmBuilder creates a helm command builder +func HelmBuilder() *CommandBuilder { + return NewCommandBuilder("helm") +} + +// IstioCtlBuilder creates an istioctl command builder +func IstioCtlBuilder() *CommandBuilder { + return NewCommandBuilder("istioctl") +} + +// CiliumBuilder creates a cilium command builder +func CiliumBuilder() *CommandBuilder { + return NewCommandBuilder("cilium") +} + +// ArgoRolloutsBuilder creates an argo rollouts command builder +func ArgoRolloutsBuilder() *CommandBuilder { + return NewCommandBuilder("kubectl").WithArgs("argo", "rollouts") +} + +// WithArgs adds arguments to the command +func (cb *CommandBuilder) WithArgs(args ...string) *CommandBuilder { + cb.args = append(cb.args, args...) + return cb +} + +// WithNamespace sets the namespace +func (cb *CommandBuilder) WithNamespace(namespace string) *CommandBuilder { + if err := security.ValidateNamespace(namespace); err != nil { + logger.Get().Error("Invalid namespace", "namespace", namespace, "error", err) + return cb + } + cb.namespace = namespace + return cb +} + +// WithContext sets the Kubernetes context +func (cb *CommandBuilder) WithContext(context string) *CommandBuilder { + if err := security.ValidateCommandInput(context); err != nil { + logger.Get().Error("Invalid context", "context", context, "error", err) + return cb + } + cb.context = context + return cb +} + +// WithKubeconfig sets the kubeconfig file +func (cb *CommandBuilder) WithKubeconfig(kubeconfig string) *CommandBuilder { + if err := security.ValidateFilePath(kubeconfig); err != nil { + logger.Get().Error("Invalid kubeconfig path", "kubeconfig", kubeconfig, "error", err) + return cb + } + cb.kubeconfig = kubeconfig + return cb +} + +// WithOutput sets the output format +func (cb *CommandBuilder) WithOutput(output string) *CommandBuilder { + validOutputs := []string{"json", "yaml", "wide", "name", "custom-columns", "custom-columns-file", "go-template", "go-template-file", "jsonpath", "jsonpath-file"} + + valid := false + for _, validOutput := range validOutputs { + if output == validOutput { + valid = true + break + } + } + + if !valid { + logger.Get().Error("Invalid output format", "output", output) + return cb + } + + cb.output = output + return cb +} + +// WithLabel adds a label selector +func (cb *CommandBuilder) WithLabel(key, value string) *CommandBuilder { + if err := security.ValidateK8sLabel(key, value); err != nil { + logger.Get().Error("Invalid label", "key", key, "value", value, "error", err) + return cb + } + cb.labels[key] = value + return cb +} + +// WithLabels adds multiple label selectors +func (cb *CommandBuilder) WithLabels(labels map[string]string) *CommandBuilder { + for key, value := range labels { + cb.WithLabel(key, value) + } + return cb +} + +// WithAnnotation adds an annotation +func (cb *CommandBuilder) WithAnnotation(key, value string) *CommandBuilder { + if err := security.ValidateK8sLabel(key, value); err != nil { + logger.Get().Error("Invalid annotation", "key", key, "value", value, "error", err) + return cb + } + cb.annotations[key] = value + return cb +} + +// WithTimeout sets the command timeout +func (cb *CommandBuilder) WithTimeout(timeout time.Duration) *CommandBuilder { + cb.timeout = timeout + return cb +} + +// WithDryRun enables dry run mode +func (cb *CommandBuilder) WithDryRun(dryRun bool) *CommandBuilder { + cb.dryRun = dryRun + return cb +} + +// WithForce enables force mode +func (cb *CommandBuilder) WithForce(force bool) *CommandBuilder { + cb.force = force + return cb +} + +// WithWait enables wait mode +func (cb *CommandBuilder) WithWait(wait bool) *CommandBuilder { + cb.wait = wait + return cb +} + +// WithValidation enables/disables validation +func (cb *CommandBuilder) WithValidation(validate bool) *CommandBuilder { + cb.validate = validate + return cb +} + +// WithCache enables caching of the command result +func (cb *CommandBuilder) WithCache(cached bool) *CommandBuilder { + cb.cached = cached + return cb +} + +// WithCacheTTL sets the cache TTL +func (cb *CommandBuilder) WithCacheTTL(ttl time.Duration) *CommandBuilder { + cb.cacheTTL = ttl + return cb +} + +// WithCacheKey sets a custom cache key +func (cb *CommandBuilder) WithCacheKey(key string) *CommandBuilder { + cb.cacheKey = key + return cb +} + +// Build constructs the final command arguments +func (cb *CommandBuilder) Build() (string, []string, error) { + args := make([]string, 0, len(cb.args)+20) + + // Add main arguments + args = append(args, cb.args...) + + // Add namespace if specified + if cb.namespace != "" { + args = append(args, "--namespace", cb.namespace) + } + + // Add context if specified + if cb.context != "" { + args = append(args, "--context", cb.context) + } + + // Add kubeconfig if specified + if cb.kubeconfig != "" { + args = append(args, "--kubeconfig", cb.kubeconfig) + } + + // Add output format + if cb.output != "" { + args = append(args, "--output", cb.output) + } + + // Add label selectors + if len(cb.labels) > 0 { + var labelSelectors []string + for key, value := range cb.labels { + if value != "" { + labelSelectors = append(labelSelectors, fmt.Sprintf("%s=%s", key, value)) + } else { + labelSelectors = append(labelSelectors, key) + } + } + if len(labelSelectors) > 0 { + args = append(args, "--selector", strings.Join(labelSelectors, ",")) + } + } + + // Add timeout only for commands that support it + if cb.timeout > 0 { + if cb.supportsTimeout() { + args = append(args, "--timeout", cb.timeout.String()) + } + } + + // Add dry run + if cb.dryRun { + args = append(args, "--dry-run=client") + } + + // Add force + if cb.force { + args = append(args, "--force") + } + + // Add wait + if cb.wait { + args = append(args, "--wait") + } + + // Add validation + if !cb.validate { + args = append(args, "--validate=false") + } + + return cb.command, args, nil +} + +// supportsTimeout checks if the command supports the --timeout flag +func (cb *CommandBuilder) supportsTimeout() bool { + // For kubectl, many commands support --timeout + if cb.command == "kubectl" { + if len(cb.args) == 0 { + return false + } + + // Check the first argument (subcommand) + subcommand := cb.args[0] + switch subcommand { + case "wait": + return true + case "delete": + // kubectl delete supports --timeout when waiting for deletion + return true + case "rollout": + // kubectl rollout status supports --timeout + if len(cb.args) > 1 && cb.args[1] == "status" { + return true + } + return false + case "apply": + // kubectl apply supports --timeout when used with --wait + return cb.wait + case "annotate", "label": + // kubectl annotate and label support --timeout + return true + case "create": + // kubectl create supports --timeout + return true + case "argo": + // kubectl argo rollouts commands support --timeout + if len(cb.args) > 1 && cb.args[1] == "rollouts" { + return true + } + return false + case "get": + // kubectl get supports --timeout for some operations + return false // Most get operations don't need timeout, they're read-only + default: + return false + } + } + + // For other commands (helm, istioctl, cilium), assume they support timeout + // unless we find specific cases where they don't + return true +} + +// Execute runs the command +func (cb *CommandBuilder) Execute(ctx context.Context) (string, error) { + log := logger.WithContext(ctx) + _, span := telemetry.StartSpan(ctx, "commands.execute", + attribute.String("command", cb.command), + attribute.StringSlice("args", cb.args), + attribute.Bool("cached", cb.cached), + ) + defer span.End() + + command, args, err := cb.Build() + if err != nil { + telemetry.RecordError(span, err, "Command build failed") + log.Error("failed to build command", + "command", cb.command, + "error", err, + ) + return "", err + } + + span.SetAttributes( + attribute.String("built_command", command), + attribute.StringSlice("built_args", args), + ) + + log.Debug("executing command", + "command", command, + "args", args, + "cached", cb.cached, + ) + + // Generate cache key if caching is enabled + if cb.cached { + telemetry.AddEvent(span, "execution.cached") + return cb.executeWithCache(ctx, command, args) + } + + // Execute the command + telemetry.AddEvent(span, "execution.direct") + result, err := cb.executeCommand(ctx, command, args) + if err != nil { + telemetry.RecordError(span, err, "Command execution failed") + return "", err + } + + telemetry.RecordSuccess(span, "Command executed successfully") + span.SetAttributes( + attribute.Int("result_length", len(result)), + ) + + return result, nil +} + +func (cb *CommandBuilder) executeWithCache(ctx context.Context, command string, args []string) (string, error) { + log := logger.WithContext(ctx) + _, span := telemetry.StartSpan(ctx, "commands.executeWithCache", + attribute.String("command", command), + attribute.StringSlice("args", args), + attribute.Bool("cached", true), + ) + defer span.End() + + cacheKey := cb.cacheKey + if cacheKey == "" { + cacheKey = cache.CacheKey(append([]string{command}, args...)...) + } + + log.Info("executing cached command", + "command", command, + "args", args, + "cache_key", cacheKey, + "cache_ttl", cb.cacheTTL.String(), + ) + + // Try to get from cache first + cacheInstance := cache.GetCacheByCommand(command) + + telemetry.AddEvent(span, "cache.lookup", + attribute.String("cache_key", cacheKey), + attribute.String("cache_ttl", cb.cacheTTL.String()), + ) + + result, err := cache.CacheResult(cacheInstance, cacheKey, cb.cacheTTL, func() (string, error) { + telemetry.AddEvent(span, "cache.miss.executing_command") + log.Debug("cache miss, executing command", + "command", command, + "args", args, + ) + return cb.executeCommand(ctx, command, args) + }) + + if err != nil { + telemetry.RecordError(span, err, "Cached command execution failed") + log.Error("cached command execution failed", + "command", command, + "args", args, + "cache_key", cacheKey, + "error", err, + ) + return "", err + } + + telemetry.RecordSuccess(span, "Cached command executed successfully") + log.Info("cached command execution successful", + "command", command, + "args", args, + "cache_key", cacheKey, + "result_length", len(result), + ) + + span.SetAttributes( + attribute.String("cache_key", cacheKey), + attribute.Int("result_length", len(result)), + ) + + return result, nil +} + +// executeCommand executes the actual command +func (cb *CommandBuilder) executeCommand(ctx context.Context, command string, args []string) (string, error) { + executor := cmd.GetShellExecutor(ctx) + output, err := executor.Exec(ctx, command, args...) + if err != nil { + // Create appropriate error based on command type + var toolError *errors.ToolError + switch command { + case "kubectl": + toolError = errors.NewKubernetesError(strings.Join(args, " "), err) + case "helm": + toolError = errors.NewHelmError(strings.Join(args, " "), err) + case "istioctl": + toolError = errors.NewIstioError(strings.Join(args, " "), err) + case "cilium": + toolError = errors.NewCiliumError(strings.Join(args, " "), err) + default: + toolError = errors.NewCommandError(command, err) + } + + return "", toolError + } + + return string(output), nil +} + +// Common command patterns as helper functions + +// GetPods creates a command to get pods +func GetPods(namespace string, labels map[string]string) *CommandBuilder { + builder := KubectlBuilder().WithArgs("get", "pods") + + if namespace != "" { + builder = builder.WithNamespace(namespace) + } + + if len(labels) > 0 { + builder = builder.WithLabels(labels) + } + + return builder.WithCache(true).WithOutput("json") +} + +// GetServices creates a command to get services +func GetServices(namespace string, labels map[string]string) *CommandBuilder { + builder := KubectlBuilder().WithArgs("get", "services") + + if namespace != "" { + builder = builder.WithNamespace(namespace) + } + + if len(labels) > 0 { + builder = builder.WithLabels(labels) + } + + return builder.WithCache(true).WithOutput("json") +} + +// GetDeployments creates a command to get deployments +func GetDeployments(namespace string, labels map[string]string) *CommandBuilder { + builder := KubectlBuilder().WithArgs("get", "deployments") + + if namespace != "" { + builder = builder.WithNamespace(namespace) + } + + if len(labels) > 0 { + builder = builder.WithLabels(labels) + } + + return builder.WithCache(true).WithOutput("json") +} + +// DescribeResource creates a command to describe a resource +func DescribeResource(resourceType, resourceName, namespace string) *CommandBuilder { + builder := KubectlBuilder().WithArgs("describe", resourceType, resourceName) + + if namespace != "" { + builder = builder.WithNamespace(namespace) + } + + return builder.WithCache(true).WithCacheTTL(2 * time.Minute) +} + +// GetLogs creates a command to get logs +func GetLogs(podName, namespace string, options LogOptions) *CommandBuilder { + builder := KubectlBuilder().WithArgs("logs", podName) + + if namespace != "" { + builder = builder.WithNamespace(namespace) + } + + if options.Container != "" { + builder = builder.WithArgs("--container", options.Container) + } + + if options.Follow { + builder = builder.WithArgs("--follow") + } + + if options.Previous { + builder = builder.WithArgs("--previous") + } + + if options.Timestamps { + builder = builder.WithArgs("--timestamps") + } + + if options.TailLines > 0 { + builder = builder.WithArgs("--tail", fmt.Sprintf("%d", options.TailLines)) + } + + if options.SinceTime != "" { + builder = builder.WithArgs("--since-time", options.SinceTime) + } + + if options.SinceDuration != "" { + builder = builder.WithArgs("--since", options.SinceDuration) + } + + // Don't cache logs by default as they change frequently + return builder.WithCache(false) +} + +// LogOptions represents options for log commands +type LogOptions struct { + Container string + Follow bool + Previous bool + Timestamps bool + TailLines int + SinceTime string + SinceDuration string +} + +// ApplyResource creates a command to apply a resource +func ApplyResource(filename string, namespace string, options ApplyOptions) *CommandBuilder { + builder := KubectlBuilder().WithArgs("apply", "-f", filename) + + if namespace != "" { + builder = builder.WithNamespace(namespace) + } + + if options.DryRun { + builder = builder.WithDryRun(true) + } + + if options.Force { + builder = builder.WithForce(true) + } + + if options.Wait { + builder = builder.WithWait(true) + } + + if !options.Validate { + builder = builder.WithValidation(false) + } + + return builder.WithCache(false) // Don't cache apply operations +} + +// ApplyOptions represents options for apply commands +type ApplyOptions struct { + DryRun bool + Force bool + Wait bool + Validate bool +} + +// DeleteResource creates a command to delete a resource +func DeleteResource(resourceType, resourceName, namespace string, options DeleteOptions) *CommandBuilder { + builder := KubectlBuilder().WithArgs("delete", resourceType, resourceName) + + if namespace != "" { + builder = builder.WithNamespace(namespace) + } + + if options.Force { + builder = builder.WithForce(true) + } + + if options.GracePeriod >= 0 { + builder = builder.WithArgs("--grace-period", fmt.Sprintf("%d", options.GracePeriod)) + } + + if options.Wait { + builder = builder.WithWait(true) + } + + return builder.WithCache(false) // Don't cache delete operations +} + +// DeleteOptions represents options for delete commands +type DeleteOptions struct { + Force bool + GracePeriod int + Wait bool +} + +// HelmInstall creates a command to install a Helm chart +func HelmInstall(releaseName, chart, namespace string, options HelmInstallOptions) *CommandBuilder { + builder := HelmBuilder().WithArgs("install", releaseName, chart) + + if namespace != "" { + builder = builder.WithNamespace(namespace) + } + + if options.CreateNamespace { + builder = builder.WithArgs("--create-namespace") + } + + if options.DryRun { + builder = builder.WithDryRun(true) + } + + if options.Wait { + builder = builder.WithWait(true) + } + + if options.ValuesFile != "" { + builder = builder.WithArgs("--values", options.ValuesFile) + } + + for key, value := range options.SetValues { + builder = builder.WithArgs("--set", fmt.Sprintf("%s=%s", key, value)) + } + + return builder.WithCache(false) // Don't cache install operations +} + +// HelmInstallOptions represents options for Helm install commands +type HelmInstallOptions struct { + CreateNamespace bool + DryRun bool + Wait bool + ValuesFile string + SetValues map[string]string +} + +// HelmList creates a command to list Helm releases +func HelmList(namespace string, options HelmListOptions) *CommandBuilder { + builder := HelmBuilder().WithArgs("list") + + if namespace != "" { + builder = builder.WithNamespace(namespace) + } + + if options.AllNamespaces { + builder = builder.WithArgs("--all-namespaces") + } + + if options.Output != "" { + builder = builder.WithOutput(options.Output) + } + + return builder.WithCache(true).WithCacheTTL(2 * time.Minute) +} + +// HelmListOptions represents options for Helm list commands +type HelmListOptions struct { + AllNamespaces bool + Output string +} + +// IstioProxyStatus creates a command to get Istio proxy status +func IstioProxyStatus(podName, namespace string) *CommandBuilder { + builder := IstioCtlBuilder().WithArgs("proxy-status") + + if namespace != "" { + builder = builder.WithNamespace(namespace) + } + + if podName != "" { + builder = builder.WithArgs(podName) + } + + return builder.WithCache(true).WithCacheTTL(30 * time.Second) +} + +// CiliumStatus creates a command to get Cilium status +func CiliumStatus() *CommandBuilder { + return CiliumBuilder().WithArgs("status").WithCache(true).WithCacheTTL(30 * time.Second) +} + +// ArgoRolloutsGet creates a command to get Argo rollouts +func ArgoRolloutsGet(rolloutName, namespace string) *CommandBuilder { + builder := ArgoRolloutsBuilder().WithArgs("get", "rollout") + + if rolloutName != "" { + builder = builder.WithArgs(rolloutName) + } + + if namespace != "" { + builder = builder.WithNamespace(namespace) + } + + return builder.WithCache(true).WithCacheTTL(1 * time.Minute) +} diff --git a/internal/commands/builder_test.go b/internal/commands/builder_test.go new file mode 100644 index 0000000..e326a76 --- /dev/null +++ b/internal/commands/builder_test.go @@ -0,0 +1,585 @@ +package commands + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewCommandBuilder(t *testing.T) { + cb := NewCommandBuilder("test-command") + + assert.Equal(t, "test-command", cb.command) + assert.Empty(t, cb.args) + assert.Empty(t, cb.namespace) + assert.Empty(t, cb.context) + assert.Empty(t, cb.kubeconfig) + assert.Empty(t, cb.output) + assert.NotNil(t, cb.labels) + assert.NotNil(t, cb.annotations) + assert.Equal(t, 30*time.Second, cb.timeout) + assert.Equal(t, 5*time.Minute, cb.cacheTTL) + assert.True(t, cb.validate) + assert.False(t, cb.cached) + assert.False(t, cb.dryRun) + assert.False(t, cb.force) + assert.False(t, cb.wait) +} + +func TestCommandBuilderFactories(t *testing.T) { + tests := []struct { + name string + factory func() *CommandBuilder + expected string + }{ + {"kubectl", KubectlBuilder, "kubectl"}, + {"helm", HelmBuilder, "helm"}, + {"istioctl", IstioCtlBuilder, "istioctl"}, + {"cilium", CiliumBuilder, "cilium"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cb := tt.factory() + assert.Equal(t, tt.expected, cb.command) + }) + } +} + +func TestArgoRolloutsBuilder(t *testing.T) { + cb := ArgoRolloutsBuilder() + + assert.Equal(t, "kubectl", cb.command) + assert.Equal(t, []string{"argo", "rollouts"}, cb.args) +} + +func TestCommandBuilderWithArgs(t *testing.T) { + cb := NewCommandBuilder("test").WithArgs("arg1", "arg2") + + assert.Equal(t, []string{"arg1", "arg2"}, cb.args) + + // Test chaining + cb.WithArgs("arg3") + assert.Equal(t, []string{"arg1", "arg2", "arg3"}, cb.args) +} + +func TestCommandBuilderWithNamespace(t *testing.T) { + cb := NewCommandBuilder("test").WithNamespace("default") + + assert.Equal(t, "default", cb.namespace) + + // Test invalid namespace - should not set the namespace + cb.WithNamespace("invalid..namespace") + assert.Equal(t, "default", cb.namespace) // Should remain unchanged +} + +func TestCommandBuilderWithContext(t *testing.T) { + cb := NewCommandBuilder("test").WithContext("minikube") + + assert.Equal(t, "minikube", cb.context) +} + +func TestCommandBuilderWithKubeconfig(t *testing.T) { + cb := NewCommandBuilder("test").WithKubeconfig("/path/to/config") + + assert.Equal(t, "/path/to/config", cb.kubeconfig) +} + +func TestCommandBuilderWithOutput(t *testing.T) { + validOutputs := []string{"json", "yaml", "wide", "name"} + + for _, output := range validOutputs { + cb := NewCommandBuilder("test").WithOutput(output) + assert.Equal(t, output, cb.output) + } + + // Test invalid output + cb := NewCommandBuilder("test").WithOutput("invalid") + assert.Empty(t, cb.output) +} + +func TestCommandBuilderWithLabel(t *testing.T) { + cb := NewCommandBuilder("test").WithLabel("app", "web") + + assert.Equal(t, "web", cb.labels["app"]) +} + +func TestCommandBuilderWithLabels(t *testing.T) { + labels := map[string]string{ + "app": "web", + "version": "v1.0.0", + } + + cb := NewCommandBuilder("test").WithLabels(labels) + + assert.Equal(t, labels["app"], cb.labels["app"]) + assert.Equal(t, labels["version"], cb.labels["version"]) +} + +func TestCommandBuilderWithAnnotation(t *testing.T) { + cb := NewCommandBuilder("test").WithAnnotation("simple-key", "value") + + // The annotation should be accepted if it's a valid format + assert.Equal(t, "value", cb.annotations["simple-key"]) + + // Test with invalid annotation - still gets added but logs an error + cb2 := NewCommandBuilder("test").WithAnnotation("invalid..key", "value") + assert.Equal(t, "value", cb2.annotations["invalid..key"]) // Invalid annotations are still added but logged +} + +func TestCommandBuilderWithTimeout(t *testing.T) { + timeout := 60 * time.Second + cb := NewCommandBuilder("test").WithTimeout(timeout) + + assert.Equal(t, timeout, cb.timeout) +} + +func TestCommandBuilderWithFlags(t *testing.T) { + cb := NewCommandBuilder("test"). + WithDryRun(true). + WithForce(true). + WithWait(true). + WithValidation(false) + + assert.True(t, cb.dryRun) + assert.True(t, cb.force) + assert.True(t, cb.wait) + assert.False(t, cb.validate) +} + +func TestCommandBuilderWithCache(t *testing.T) { + cb := NewCommandBuilder("test").WithCache(true) + + assert.True(t, cb.cached) +} + +func TestCommandBuilderWithCacheTTL(t *testing.T) { + ttl := 10 * time.Minute + cb := NewCommandBuilder("test").WithCacheTTL(ttl) + + assert.Equal(t, ttl, cb.cacheTTL) +} + +func TestCommandBuilderWithCacheKey(t *testing.T) { + cb := NewCommandBuilder("test").WithCacheKey("custom-key") + + assert.Equal(t, "custom-key", cb.cacheKey) +} + +func TestCommandBuilderBuild(t *testing.T) { + cb := NewCommandBuilder("kubectl"). + WithArgs("get", "pods"). + WithNamespace("default"). + WithContext("minikube"). + WithKubeconfig("/path/to/config"). + WithOutput("json"). + WithLabel("app", "web"). + WithDryRun(true). + WithForce(true). + WithWait(true). + WithValidation(false) + + command, args, err := cb.Build() + require.NoError(t, err) + + assert.Equal(t, "kubectl", command) + assert.Contains(t, args, "get") + assert.Contains(t, args, "pods") + assert.Contains(t, args, "--namespace") + assert.Contains(t, args, "default") + assert.Contains(t, args, "--context") + assert.Contains(t, args, "minikube") + assert.Contains(t, args, "--kubeconfig") + assert.Contains(t, args, "/path/to/config") + assert.Contains(t, args, "--output") + assert.Contains(t, args, "json") + assert.Contains(t, args, "--selector") + assert.Contains(t, args, "app=web") + assert.Contains(t, args, "--dry-run=client") + assert.Contains(t, args, "--force") + assert.Contains(t, args, "--wait") + assert.Contains(t, args, "--validate=false") +} + +func TestCommandBuilderBuildWithTimeout(t *testing.T) { + cb := NewCommandBuilder("kubectl"). + WithArgs("delete", "pod", "test-pod"). + WithTimeout(45 * time.Second) + + command, args, err := cb.Build() + require.NoError(t, err) + + assert.Equal(t, "kubectl", command) + assert.Contains(t, args, "--timeout") + assert.Contains(t, args, "45s") +} + +func TestCommandBuilderBuildWithMultipleLabels(t *testing.T) { + cb := NewCommandBuilder("kubectl"). + WithArgs("get", "pods"). + WithLabel("app", "web"). + WithLabel("version", "v1.0.0") + + command, args, err := cb.Build() + require.NoError(t, err) + + assert.Equal(t, "kubectl", command) + assert.Contains(t, args, "--selector") + + // Find the selector argument + var selectorValue string + for i, arg := range args { + if arg == "--selector" && i+1 < len(args) { + selectorValue = args[i+1] + break + } + } + + assert.Contains(t, selectorValue, "app=web") + assert.Contains(t, selectorValue, "version=v1.0.0") +} + +func TestGetPods(t *testing.T) { + namespace := "default" + labels := map[string]string{"app": "web"} + + cb := GetPods(namespace, labels) + + assert.Equal(t, "kubectl", cb.command) + assert.Contains(t, cb.args, "get") + assert.Contains(t, cb.args, "pods") + assert.Equal(t, namespace, cb.namespace) + assert.Equal(t, labels, cb.labels) + assert.True(t, cb.cached) + assert.Equal(t, "json", cb.output) +} + +func TestGetServices(t *testing.T) { + namespace := "default" + labels := map[string]string{"app": "web"} + + cb := GetServices(namespace, labels) + + assert.Equal(t, "kubectl", cb.command) + assert.Contains(t, cb.args, "get") + assert.Contains(t, cb.args, "services") + assert.Equal(t, namespace, cb.namespace) + assert.Equal(t, labels, cb.labels) + assert.True(t, cb.cached) + assert.Equal(t, "json", cb.output) +} + +func TestGetDeployments(t *testing.T) { + namespace := "default" + labels := map[string]string{"app": "web"} + + cb := GetDeployments(namespace, labels) + + assert.Equal(t, "kubectl", cb.command) + assert.Contains(t, cb.args, "get") + assert.Contains(t, cb.args, "deployments") + assert.Equal(t, namespace, cb.namespace) + assert.Equal(t, labels, cb.labels) + assert.True(t, cb.cached) + assert.Equal(t, "json", cb.output) +} + +func TestDescribeResource(t *testing.T) { + resourceType := "pod" + resourceName := "test-pod" + namespace := "default" + + cb := DescribeResource(resourceType, resourceName, namespace) + + assert.Equal(t, "kubectl", cb.command) + assert.Contains(t, cb.args, "describe") + assert.Contains(t, cb.args, resourceType) + assert.Contains(t, cb.args, resourceName) + assert.Equal(t, namespace, cb.namespace) + assert.True(t, cb.cached) + assert.Equal(t, 2*time.Minute, cb.cacheTTL) +} + +func TestGetLogs(t *testing.T) { + podName := "test-pod" + namespace := "default" + options := LogOptions{ + Container: "app", + Follow: true, + Previous: false, + Timestamps: true, + TailLines: 100, + SinceTime: "2023-01-01T00:00:00Z", + SinceDuration: "1h", + } + + cb := GetLogs(podName, namespace, options) + + assert.Equal(t, "kubectl", cb.command) + assert.Contains(t, cb.args, "logs") + assert.Contains(t, cb.args, podName) + assert.Equal(t, namespace, cb.namespace) + assert.Contains(t, cb.args, "--container") + assert.Contains(t, cb.args, "app") + assert.Contains(t, cb.args, "--follow") + assert.Contains(t, cb.args, "--timestamps") + assert.Contains(t, cb.args, "--tail") + assert.Contains(t, cb.args, "100") + assert.Contains(t, cb.args, "--since-time") + assert.Contains(t, cb.args, "2023-01-01T00:00:00Z") + assert.Contains(t, cb.args, "--since") + assert.Contains(t, cb.args, "1h") + assert.False(t, cb.cached) +} + +func TestGetLogsWithPrevious(t *testing.T) { + podName := "test-pod" + namespace := "default" + options := LogOptions{ + Previous: true, + } + + cb := GetLogs(podName, namespace, options) + + assert.Contains(t, cb.args, "--previous") +} + +func TestApplyResource(t *testing.T) { + filename := "/path/to/resource.yaml" + namespace := "default" + options := ApplyOptions{ + DryRun: true, + Force: true, + Wait: true, + Validate: false, + } + + cb := ApplyResource(filename, namespace, options) + + assert.Equal(t, "kubectl", cb.command) + assert.Contains(t, cb.args, "apply") + assert.Contains(t, cb.args, "-f") + assert.Contains(t, cb.args, filename) + assert.Equal(t, namespace, cb.namespace) + assert.True(t, cb.dryRun) + assert.True(t, cb.force) + assert.True(t, cb.wait) + assert.False(t, cb.validate) + assert.False(t, cb.cached) +} + +func TestDeleteResource(t *testing.T) { + resourceType := "pod" + resourceName := "test-pod" + namespace := "default" + options := DeleteOptions{ + Force: true, + GracePeriod: 30, + Wait: true, + } + + cb := DeleteResource(resourceType, resourceName, namespace, options) + + assert.Equal(t, "kubectl", cb.command) + assert.Contains(t, cb.args, "delete") + assert.Contains(t, cb.args, resourceType) + assert.Contains(t, cb.args, resourceName) + assert.Equal(t, namespace, cb.namespace) + assert.True(t, cb.force) + assert.True(t, cb.wait) + assert.False(t, cb.cached) +} + +func TestHelmInstall(t *testing.T) { + releaseName := "test-release" + chart := "bitnami/nginx" + namespace := "default" + options := HelmInstallOptions{ + CreateNamespace: true, + DryRun: true, + Wait: true, + ValuesFile: "/path/to/values.yaml", + SetValues: map[string]string{"image.tag": "1.20"}, + } + + cb := HelmInstall(releaseName, chart, namespace, options) + + assert.Equal(t, "helm", cb.command) + assert.Contains(t, cb.args, "install") + assert.Contains(t, cb.args, releaseName) + assert.Contains(t, cb.args, chart) + assert.Equal(t, namespace, cb.namespace) + assert.True(t, cb.dryRun) + assert.True(t, cb.wait) + assert.False(t, cb.cached) +} + +func TestHelmList(t *testing.T) { + namespace := "default" + options := HelmListOptions{ + AllNamespaces: true, + Output: "json", + } + + cb := HelmList(namespace, options) + + assert.Equal(t, "helm", cb.command) + assert.Contains(t, cb.args, "list") + assert.Equal(t, namespace, cb.namespace) + assert.Equal(t, "json", cb.output) + assert.True(t, cb.cached) +} + +func TestIstioProxyStatus(t *testing.T) { + podName := "test-pod" + namespace := "default" + + cb := IstioProxyStatus(podName, namespace) + + assert.Equal(t, "istioctl", cb.command) + assert.Contains(t, cb.args, "proxy-status") + assert.Contains(t, cb.args, podName) + assert.Equal(t, namespace, cb.namespace) + assert.True(t, cb.cached) +} + +func TestCiliumStatus(t *testing.T) { + cb := CiliumStatus() + + assert.Equal(t, "cilium", cb.command) + assert.Contains(t, cb.args, "status") + assert.Empty(t, cb.output) // CiliumStatus doesn't set output format + assert.True(t, cb.cached) +} + +func TestArgoRolloutsGet(t *testing.T) { + rolloutName := "test-rollout" + namespace := "default" + + cb := ArgoRolloutsGet(rolloutName, namespace) + + assert.Equal(t, "kubectl", cb.command) + assert.Contains(t, cb.args, "argo") + assert.Contains(t, cb.args, "rollouts") + assert.Contains(t, cb.args, "get") + assert.Contains(t, cb.args, "rollout") + assert.Contains(t, cb.args, rolloutName) + assert.Equal(t, namespace, cb.namespace) + assert.Empty(t, cb.output) // ArgoRolloutsGet doesn't set output format + assert.True(t, cb.cached) +} + +func TestCommandBuilderChaining(t *testing.T) { + cb := NewCommandBuilder("kubectl"). + WithArgs("get", "pods"). + WithNamespace("default"). + WithOutput("json"). + WithLabel("app", "web"). + WithTimeout(60 * time.Second). + WithCache(true). + WithCacheTTL(10 * time.Minute) + + assert.Equal(t, "kubectl", cb.command) + assert.Equal(t, []string{"get", "pods"}, cb.args) + assert.Equal(t, "default", cb.namespace) + assert.Equal(t, "json", cb.output) + assert.Equal(t, "web", cb.labels["app"]) + assert.Equal(t, 60*time.Second, cb.timeout) + assert.True(t, cb.cached) + assert.Equal(t, 10*time.Minute, cb.cacheTTL) +} + +func TestCommandBuilderEmptyNamespace(t *testing.T) { + cb := GetPods("", nil) + + assert.Empty(t, cb.namespace) +} + +func TestCommandBuilderEmptyLabels(t *testing.T) { + cb := GetPods("default", nil) + + assert.Empty(t, cb.labels) +} + +func TestLogOptionsDefaults(t *testing.T) { + options := LogOptions{} + + assert.False(t, options.Follow) + assert.False(t, options.Previous) + assert.False(t, options.Timestamps) + assert.Equal(t, 0, options.TailLines) + assert.Empty(t, options.SinceTime) + assert.Empty(t, options.SinceDuration) +} + +func TestApplyOptionsDefaults(t *testing.T) { + options := ApplyOptions{} + + assert.False(t, options.DryRun) + assert.False(t, options.Force) + assert.False(t, options.Wait) + assert.False(t, options.Validate) +} + +func TestDeleteOptionsDefaults(t *testing.T) { + options := DeleteOptions{} + + assert.False(t, options.Force) + assert.Equal(t, 0, options.GracePeriod) + assert.False(t, options.Wait) +} + +func TestHelmInstallOptionsDefaults(t *testing.T) { + options := HelmInstallOptions{} + + assert.False(t, options.CreateNamespace) + assert.False(t, options.DryRun) + assert.False(t, options.Wait) + assert.Empty(t, options.ValuesFile) + assert.Nil(t, options.SetValues) +} + +func TestHelmListOptionsDefaults(t *testing.T) { + options := HelmListOptions{} + + assert.False(t, options.AllNamespaces) + assert.Empty(t, options.Output) +} + +// Mock tests for Execute method - these would need a mock for utils.RunCommandWithContext +func TestCommandBuilderExecuteWithoutCache(t *testing.T) { + cb := NewCommandBuilder("echo"). + WithArgs("hello", "world"). + WithCache(false) + + // This test would need mocking to work properly + // For now, we'll just verify the command building part + command, args, err := cb.Build() + require.NoError(t, err) + + assert.Equal(t, "echo", command) + assert.Contains(t, args, "hello") + assert.Contains(t, args, "world") + assert.Contains(t, args, "--timeout") + assert.Contains(t, args, "30s") +} + +func TestCommandBuilderExecuteWithCache(t *testing.T) { + cb := NewCommandBuilder("echo"). + WithArgs("hello", "world"). + WithCache(true) + + // This test would need mocking to work properly + // For now, we'll just verify the command building part + command, args, err := cb.Build() + require.NoError(t, err) + + assert.Equal(t, "echo", command) + assert.Contains(t, args, "hello") + assert.Contains(t, args, "world") + assert.Contains(t, args, "--timeout") + assert.Contains(t, args, "30s") + assert.True(t, cb.cached) +} diff --git a/internal/errors/tool_errors.go b/internal/errors/tool_errors.go new file mode 100644 index 0000000..2677164 --- /dev/null +++ b/internal/errors/tool_errors.go @@ -0,0 +1,352 @@ +package errors + +import ( + "fmt" + "strings" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +// ToolError represents a structured error with context and recovery suggestions +type ToolError struct { + Operation string `json:"operation"` + Cause error `json:"cause"` + Suggestions []string `json:"suggestions"` + IsRetryable bool `json:"is_retryable"` + Timestamp time.Time `json:"timestamp"` + ErrorCode string `json:"error_code"` + Component string `json:"component"` + ResourceType string `json:"resource_type,omitempty"` + ResourceName string `json:"resource_name,omitempty"` + Context map[string]interface{} `json:"context,omitempty"` +} + +// Error implements the error interface +func (e *ToolError) Error() string { + return fmt.Sprintf("[%s] %s failed: %v", e.Component, e.Operation, e.Cause) +} + +// ToMCPResult converts the error to an MCP result with rich context +func (e *ToolError) ToMCPResult() *mcp.CallToolResult { + var message strings.Builder + + // Format the error message with context + message.WriteString(fmt.Sprintf("❌ **%s Error**\n\n", e.Component)) + message.WriteString(fmt.Sprintf("**Operation**: %s\n", e.Operation)) + message.WriteString(fmt.Sprintf("**Error**: %s\n", e.Cause.Error())) + + if e.ResourceType != "" { + message.WriteString(fmt.Sprintf("**Resource Type**: %s\n", e.ResourceType)) + } + + if e.ResourceName != "" { + message.WriteString(fmt.Sprintf("**Resource Name**: %s\n", e.ResourceName)) + } + + message.WriteString(fmt.Sprintf("**Error Code**: %s\n", e.ErrorCode)) + message.WriteString(fmt.Sprintf("**Timestamp**: %s\n", e.Timestamp.Format(time.RFC3339))) + + if e.IsRetryable { + message.WriteString("**Retryable**: Yes\n") + } else { + message.WriteString("**Retryable**: No\n") + } + + if len(e.Suggestions) > 0 { + message.WriteString("\n**💡 Suggestions**:\n") + for i, suggestion := range e.Suggestions { + message.WriteString(fmt.Sprintf("%d. %s\n", i+1, suggestion)) + } + } + + if len(e.Context) > 0 { + message.WriteString("\n**📋 Context**:\n") + for key, value := range e.Context { + message.WriteString(fmt.Sprintf("- %s: %v\n", key, value)) + } + } + + return mcp.NewToolResultError(message.String()) +} + +// NewToolError creates a new structured tool error +func NewToolError(component, operation string, cause error) *ToolError { + return &ToolError{ + Operation: operation, + Cause: cause, + Suggestions: []string{}, + IsRetryable: false, + Timestamp: time.Now(), + ErrorCode: "UNKNOWN", + Component: component, + Context: make(map[string]interface{}), + } +} + +// WithSuggestions adds recovery suggestions to the error +func (e *ToolError) WithSuggestions(suggestions ...string) *ToolError { + e.Suggestions = append(e.Suggestions, suggestions...) + return e +} + +// WithRetryable sets whether the error is retryable +func (e *ToolError) WithRetryable(retryable bool) *ToolError { + e.IsRetryable = retryable + return e +} + +// WithErrorCode sets the error code +func (e *ToolError) WithErrorCode(code string) *ToolError { + e.ErrorCode = code + return e +} + +// WithResource adds resource information to the error +func (e *ToolError) WithResource(resourceType, resourceName string) *ToolError { + e.ResourceType = resourceType + e.ResourceName = resourceName + return e +} + +// WithContext adds contextual information to the error +func (e *ToolError) WithContext(key string, value interface{}) *ToolError { + e.Context[key] = value + return e +} + +// Common error creators for different components + +// NewKubernetesError creates a Kubernetes-specific error +func NewKubernetesError(operation string, cause error) *ToolError { + err := NewToolError("Kubernetes", operation, cause) + + // Add Kubernetes-specific suggestions based on common errors + if strings.Contains(cause.Error(), "connection refused") { + err = err.WithSuggestions( + "Check if the Kubernetes cluster is running", + "Verify your kubeconfig is correct", + "Ensure network connectivity to the cluster", + ).WithRetryable(true).WithErrorCode("K8S_CONNECTION_ERROR") + } else if strings.Contains(cause.Error(), "forbidden") { + err = err.WithSuggestions( + "Check your RBAC permissions", + "Verify your service account has the required permissions", + "Contact your cluster administrator", + ).WithRetryable(false).WithErrorCode("K8S_PERMISSION_ERROR") + } else if strings.Contains(cause.Error(), "not found") { + err = err.WithSuggestions( + "Check if the resource exists", + "Verify the resource name and namespace", + "List available resources to confirm", + ).WithRetryable(false).WithErrorCode("K8S_RESOURCE_NOT_FOUND") + } else if strings.Contains(cause.Error(), "already exists") { + err = err.WithSuggestions( + "Use a different name for the resource", + "Delete the existing resource first", + "Use 'kubectl apply' instead of 'kubectl create'", + ).WithRetryable(false).WithErrorCode("K8S_RESOURCE_EXISTS") + } else { + err = err.WithSuggestions( + "Check the kubectl command syntax", + "Verify your kubeconfig is valid", + "Check cluster connectivity", + ).WithRetryable(true).WithErrorCode("K8S_GENERIC_ERROR") + } + + return err +} + +// NewHelmError creates a Helm-specific error +func NewHelmError(operation string, cause error) *ToolError { + err := NewToolError("Helm", operation, cause) + + if strings.Contains(cause.Error(), "not found") { + err = err.WithSuggestions( + "Check if the Helm release exists", + "Verify the release name and namespace", + "Use 'helm list' to see available releases", + ).WithRetryable(false).WithErrorCode("HELM_RELEASE_NOT_FOUND") + } else if strings.Contains(cause.Error(), "already exists") { + err = err.WithSuggestions( + "Use a different release name", + "Upgrade the existing release instead", + "Uninstall the existing release first", + ).WithRetryable(false).WithErrorCode("HELM_RELEASE_EXISTS") + } else if strings.Contains(cause.Error(), "repository") { + err = err.WithSuggestions( + "Add the required Helm repository", + "Update your Helm repositories", + "Check repository URL and credentials", + ).WithRetryable(true).WithErrorCode("HELM_REPOSITORY_ERROR") + } else { + err = err.WithSuggestions( + "Check the Helm command syntax", + "Verify your kubeconfig is valid", + "Ensure Helm is properly installed", + ).WithRetryable(true).WithErrorCode("HELM_GENERIC_ERROR") + } + + return err +} + +// NewIstioError creates an Istio-specific error +func NewIstioError(operation string, cause error) *ToolError { + err := NewToolError("Istio", operation, cause) + + if strings.Contains(cause.Error(), "not found") { + err = err.WithSuggestions( + "Check if Istio is installed in the cluster", + "Verify the pod/service name and namespace", + "Ensure Istio sidecar is injected", + ).WithRetryable(false).WithErrorCode("ISTIO_RESOURCE_NOT_FOUND") + } else if strings.Contains(cause.Error(), "connection refused") { + err = err.WithSuggestions( + "Check if Istio control plane is running", + "Verify Istio proxy is healthy", + "Check network policies", + ).WithRetryable(true).WithErrorCode("ISTIO_CONNECTION_ERROR") + } else { + err = err.WithSuggestions( + "Check istioctl command syntax", + "Verify Istio installation", + "Check Istio proxy status", + ).WithRetryable(true).WithErrorCode("ISTIO_GENERIC_ERROR") + } + + return err +} + +// NewPrometheusError creates a Prometheus-specific error +func NewPrometheusError(operation string, cause error) *ToolError { + err := NewToolError("Prometheus", operation, cause) + + if strings.Contains(cause.Error(), "connection refused") { + err = err.WithSuggestions( + "Check if Prometheus server is running", + "Verify the Prometheus URL", + "Check network connectivity", + ).WithRetryable(true).WithErrorCode("PROMETHEUS_CONNECTION_ERROR") + } else if strings.Contains(cause.Error(), "parse error") { + err = err.WithSuggestions( + "Check your PromQL query syntax", + "Verify metric names and labels", + "Test the query in Prometheus UI", + ).WithRetryable(false).WithErrorCode("PROMETHEUS_QUERY_ERROR") + } else { + err = err.WithSuggestions( + "Check Prometheus server status", + "Verify the query format", + "Check authentication if required", + ).WithRetryable(true).WithErrorCode("PROMETHEUS_GENERIC_ERROR") + } + + return err +} + +// NewArgoError creates an Argo-specific error +func NewArgoError(operation string, cause error) *ToolError { + err := NewToolError("Argo Rollouts", operation, cause) + + if strings.Contains(cause.Error(), "not found") { + err = err.WithSuggestions( + "Check if Argo Rollouts is installed", + "Verify the rollout name and namespace", + "Use 'kubectl get rollouts' to list available rollouts", + ).WithRetryable(false).WithErrorCode("ARGO_ROLLOUT_NOT_FOUND") + } else if strings.Contains(cause.Error(), "plugin") { + err = err.WithSuggestions( + "Install the kubectl argo rollouts plugin", + "Check plugin version compatibility", + "Verify plugin installation path", + ).WithRetryable(true).WithErrorCode("ARGO_PLUGIN_ERROR") + } else { + err = err.WithSuggestions( + "Check Argo Rollouts installation", + "Verify the command syntax", + "Check RBAC permissions", + ).WithRetryable(true).WithErrorCode("ARGO_GENERIC_ERROR") + } + + return err +} + +// NewCiliumError creates a Cilium-specific error +func NewCiliumError(operation string, cause error) *ToolError { + err := NewToolError("Cilium", operation, cause) + + if strings.Contains(cause.Error(), "not found") { + err = err.WithSuggestions( + "Check if Cilium is installed", + "Verify the cilium CLI is installed", + "Check Cilium agent status", + ).WithRetryable(false).WithErrorCode("CILIUM_NOT_FOUND") + } else if strings.Contains(cause.Error(), "connection") { + err = err.WithSuggestions( + "Check Cilium agent connectivity", + "Verify cluster mesh configuration", + "Check Cilium operator status", + ).WithRetryable(true).WithErrorCode("CILIUM_CONNECTION_ERROR") + } else { + err = err.WithSuggestions( + "Check Cilium installation", + "Verify cilium CLI version", + "Check Cilium system pods", + ).WithRetryable(true).WithErrorCode("CILIUM_GENERIC_ERROR") + } + + return err +} + +// NewValidationError creates a validation error +func NewValidationError(field, message string) *ToolError { + err := NewToolError("Validation", fmt.Sprintf("validate %s", field), fmt.Errorf("%s", message)) + + err = err.WithSuggestions( + "Check the input format", + "Refer to the documentation for valid values", + "Verify the parameter requirements", + ).WithRetryable(false).WithErrorCode("VALIDATION_ERROR") + + return err +} + +// NewSecurityError creates a security-related error +func NewSecurityError(operation string, cause error) *ToolError { + err := NewToolError("Security", operation, cause) + + err = err.WithSuggestions( + "Review the input for potentially dangerous content", + "Use only trusted input sources", + "Contact security team if needed", + ).WithRetryable(false).WithErrorCode("SECURITY_ERROR") + + return err +} + +// NewTimeoutError creates a timeout error +func NewTimeoutError(operation string, timeout time.Duration) *ToolError { + cause := fmt.Errorf("operation timed out after %v", timeout) + err := NewToolError("Timeout", operation, cause) + + err = err.WithSuggestions( + "Try the operation again", + "Check network connectivity", + "Increase timeout if possible", + ).WithRetryable(true).WithErrorCode("TIMEOUT_ERROR") + + return err +} + +// NewCommandError creates a command execution error +func NewCommandError(command string, cause error) *ToolError { + err := NewToolError("Command", fmt.Sprintf("execute %s", command), cause) + + err = err.WithSuggestions( + "Check if the command exists in PATH", + "Verify command syntax and arguments", + "Check system permissions", + ).WithRetryable(true).WithErrorCode("COMMAND_ERROR") + + return err +} diff --git a/internal/errors/tool_errors_test.go b/internal/errors/tool_errors_test.go new file mode 100644 index 0000000..bfa2f24 --- /dev/null +++ b/internal/errors/tool_errors_test.go @@ -0,0 +1,366 @@ +package errors + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewToolError(t *testing.T) { + cause := errors.New("test error") + err := NewToolError("TestComponent", "test operation", cause) + + assert.Equal(t, "test operation", err.Operation) + assert.Equal(t, cause, err.Cause) + assert.Equal(t, "TestComponent", err.Component) + assert.Equal(t, "UNKNOWN", err.ErrorCode) + assert.False(t, err.IsRetryable) + assert.Empty(t, err.Suggestions) + assert.NotNil(t, err.Context) + assert.WithinDuration(t, time.Now(), err.Timestamp, time.Second) +} + +func TestToolErrorError(t *testing.T) { + cause := errors.New("test error") + err := NewToolError("TestComponent", "test operation", cause) + + result := err.Error() + expected := "[TestComponent] test operation failed: test error" + assert.Equal(t, expected, result) +} + +func TestToolErrorWithSuggestions(t *testing.T) { + cause := errors.New("test error") + err := NewToolError("TestComponent", "test operation", cause) + + err = err.WithSuggestions("suggestion 1", "suggestion 2") + + assert.Equal(t, []string{"suggestion 1", "suggestion 2"}, err.Suggestions) + + // Test chaining + err = err.WithSuggestions("suggestion 3") + assert.Equal(t, []string{"suggestion 1", "suggestion 2", "suggestion 3"}, err.Suggestions) +} + +func TestToolErrorWithRetryable(t *testing.T) { + cause := errors.New("test error") + err := NewToolError("TestComponent", "test operation", cause) + + err = err.WithRetryable(true) + assert.True(t, err.IsRetryable) + + err = err.WithRetryable(false) + assert.False(t, err.IsRetryable) +} + +func TestToolErrorWithErrorCode(t *testing.T) { + cause := errors.New("test error") + err := NewToolError("TestComponent", "test operation", cause) + + err = err.WithErrorCode("TEST_ERROR") + assert.Equal(t, "TEST_ERROR", err.ErrorCode) +} + +func TestToolErrorWithResource(t *testing.T) { + cause := errors.New("test error") + err := NewToolError("TestComponent", "test operation", cause) + + err = err.WithResource("Pod", "test-pod") + assert.Equal(t, "Pod", err.ResourceType) + assert.Equal(t, "test-pod", err.ResourceName) +} + +func TestToolErrorWithContext(t *testing.T) { + cause := errors.New("test error") + err := NewToolError("TestComponent", "test operation", cause) + + err = err.WithContext("key1", "value1") + err = err.WithContext("key2", 42) + + assert.Equal(t, "value1", err.Context["key1"]) + assert.Equal(t, 42, err.Context["key2"]) +} + +func TestToolErrorToMCPResult(t *testing.T) { + cause := errors.New("test error") + err := NewToolError("TestComponent", "test operation", cause). + WithErrorCode("TEST_ERROR"). + WithResource("Pod", "test-pod"). + WithSuggestions("suggestion 1", "suggestion 2"). + WithContext("key1", "value1"). + WithRetryable(true) + + result := err.ToMCPResult() + + assert.NotNil(t, result) + assert.True(t, result.IsError) + assert.NotEmpty(t, result.Content) + + // Check content (assuming it's text content) + if len(result.Content) > 0 { + content := result.Content[0] + // This depends on the actual MCP implementation + // We'll just check that it's not empty + assert.NotNil(t, content) + } +} + +func TestNewKubernetesError(t *testing.T) { + tests := []struct { + name string + causeError string + expectedCode string + expectedRetry bool + expectedSuggs int + }{ + { + name: "connection refused", + causeError: "connection refused", + expectedCode: "K8S_CONNECTION_ERROR", + expectedRetry: true, + expectedSuggs: 3, + }, + { + name: "forbidden", + causeError: "forbidden", + expectedCode: "K8S_PERMISSION_ERROR", + expectedRetry: false, + expectedSuggs: 3, + }, + { + name: "not found", + causeError: "not found", + expectedCode: "K8S_RESOURCE_NOT_FOUND", + expectedRetry: false, + expectedSuggs: 3, + }, + { + name: "already exists", + causeError: "already exists", + expectedCode: "K8S_RESOURCE_EXISTS", + expectedRetry: false, + expectedSuggs: 3, + }, + { + name: "generic error", + causeError: "some other error", + expectedCode: "K8S_GENERIC_ERROR", + expectedRetry: true, + expectedSuggs: 3, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cause := errors.New(tt.causeError) + err := NewKubernetesError("test operation", cause) + + assert.Equal(t, "Kubernetes", err.Component) + assert.Equal(t, tt.expectedCode, err.ErrorCode) + assert.Equal(t, tt.expectedRetry, err.IsRetryable) + assert.Len(t, err.Suggestions, tt.expectedSuggs) + }) + } +} + +func TestNewHelmError(t *testing.T) { + tests := []struct { + name string + causeError string + expectedCode string + expectedRetry bool + expectedSuggs int + }{ + { + name: "not found", + causeError: "not found", + expectedCode: "HELM_RELEASE_NOT_FOUND", + expectedRetry: false, + expectedSuggs: 3, + }, + { + name: "already exists", + causeError: "already exists", + expectedCode: "HELM_RELEASE_EXISTS", + expectedRetry: false, + expectedSuggs: 3, + }, + { + name: "repository error", + causeError: "repository error", + expectedCode: "HELM_REPOSITORY_ERROR", + expectedRetry: true, + expectedSuggs: 3, + }, + { + name: "generic error", + causeError: "some other error", + expectedCode: "HELM_GENERIC_ERROR", + expectedRetry: true, + expectedSuggs: 3, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cause := errors.New(tt.causeError) + err := NewHelmError("test operation", cause) + + assert.Equal(t, "Helm", err.Component) + assert.Equal(t, tt.expectedCode, err.ErrorCode) + assert.Equal(t, tt.expectedRetry, err.IsRetryable) + assert.Len(t, err.Suggestions, tt.expectedSuggs) + }) + } +} + +func TestNewIstioError(t *testing.T) { + cause := errors.New("test error") + err := NewIstioError("test operation", cause) + + assert.Equal(t, "Istio", err.Component) + assert.Equal(t, "test operation", err.Operation) + assert.Equal(t, cause, err.Cause) +} + +func TestNewPrometheusError(t *testing.T) { + cause := errors.New("test error") + err := NewPrometheusError("test operation", cause) + + assert.Equal(t, "Prometheus", err.Component) + assert.Equal(t, "test operation", err.Operation) + assert.Equal(t, cause, err.Cause) +} + +func TestNewArgoError(t *testing.T) { + cause := errors.New("test error") + err := NewArgoError("test operation", cause) + + assert.Equal(t, "Argo Rollouts", err.Component) + assert.Equal(t, "test operation", err.Operation) + assert.Equal(t, cause, err.Cause) +} + +func TestNewCiliumError(t *testing.T) { + cause := errors.New("test error") + err := NewCiliumError("test operation", cause) + + assert.Equal(t, "Cilium", err.Component) + assert.Equal(t, "test operation", err.Operation) + assert.Equal(t, cause, err.Cause) +} + +func TestNewValidationError(t *testing.T) { + err := NewValidationError("test-field", "validation failed") + + assert.Equal(t, "Validation", err.Component) + assert.Equal(t, "validate test-field", err.Operation) + assert.Equal(t, "VALIDATION_ERROR", err.ErrorCode) + assert.False(t, err.IsRetryable) + assert.Contains(t, err.Cause.Error(), "validation failed") +} + +func TestNewSecurityError(t *testing.T) { + cause := errors.New("security violation") + err := NewSecurityError("test operation", cause) + + assert.Equal(t, "Security", err.Component) + assert.Equal(t, "test operation", err.Operation) + assert.Equal(t, cause, err.Cause) + assert.Equal(t, "SECURITY_ERROR", err.ErrorCode) + assert.False(t, err.IsRetryable) +} + +func TestNewTimeoutError(t *testing.T) { + timeout := 30 * time.Second + err := NewTimeoutError("test operation", timeout) + + assert.Equal(t, "Timeout", err.Component) + assert.Equal(t, "test operation", err.Operation) + assert.Equal(t, "TIMEOUT_ERROR", err.ErrorCode) + assert.True(t, err.IsRetryable) + assert.Contains(t, err.Cause.Error(), "30s") +} + +func TestNewCommandError(t *testing.T) { + cause := errors.New("command failed") + err := NewCommandError("test-command", cause) + + assert.Equal(t, "Command", err.Component) + assert.Equal(t, "execute test-command", err.Operation) + assert.Equal(t, cause, err.Cause) + assert.Equal(t, "COMMAND_ERROR", err.ErrorCode) + assert.True(t, err.IsRetryable) +} + +func TestToolErrorChaining(t *testing.T) { + cause := errors.New("test error") + err := NewToolError("TestComponent", "test operation", cause). + WithErrorCode("TEST_ERROR"). + WithResource("Pod", "test-pod"). + WithSuggestions("suggestion 1"). + WithContext("key1", "value1"). + WithRetryable(true) + + // Test that all methods return the same instance for chaining + assert.Equal(t, "TEST_ERROR", err.ErrorCode) + assert.Equal(t, "Pod", err.ResourceType) + assert.Equal(t, "test-pod", err.ResourceName) + assert.Equal(t, []string{"suggestion 1"}, err.Suggestions) + assert.Equal(t, "value1", err.Context["key1"]) + assert.True(t, err.IsRetryable) +} + +func TestToolErrorStringRepresentation(t *testing.T) { + cause := errors.New("test error") + err := NewToolError("TestComponent", "test operation", cause) + + errorStr := err.Error() + assert.Contains(t, errorStr, "TestComponent") + assert.Contains(t, errorStr, "test operation") + assert.Contains(t, errorStr, "test error") + assert.Contains(t, errorStr, "failed") +} + +func TestToolErrorTimestamp(t *testing.T) { + before := time.Now() + cause := errors.New("test error") + err := NewToolError("TestComponent", "test operation", cause) + after := time.Now() + + assert.True(t, err.Timestamp.After(before) || err.Timestamp.Equal(before)) + assert.True(t, err.Timestamp.Before(after) || err.Timestamp.Equal(after)) +} + +func TestToolErrorContextInitialization(t *testing.T) { + cause := errors.New("test error") + err := NewToolError("TestComponent", "test operation", cause) + + // Context should be initialized but empty + assert.NotNil(t, err.Context) + assert.Empty(t, err.Context) + + // Should be able to add to context + err = err.WithContext("test", "value") + assert.Equal(t, "value", err.Context["test"]) +} + +func TestMCPResultContainsExpectedFields(t *testing.T) { + cause := errors.New("test error") + err := NewToolError("TestComponent", "test operation", cause). + WithErrorCode("TEST_ERROR"). + WithResource("Pod", "test-pod"). + WithSuggestions("suggestion 1"). + WithContext("key1", "value1"). + WithRetryable(true) + + result := err.ToMCPResult() + + // The result should be an error result + assert.True(t, result.IsError) + + // Should have content + assert.NotEmpty(t, result.Content) +} diff --git a/internal/logger/logger.go b/internal/logger/logger.go new file mode 100644 index 0000000..041d499 --- /dev/null +++ b/internal/logger/logger.go @@ -0,0 +1,76 @@ +package logger + +import ( + "context" + "log/slog" + "os" + + "go.opentelemetry.io/otel/trace" +) + +var globalLogger *slog.Logger + +func Init() { + opts := &slog.HandlerOptions{ + Level: slog.LevelInfo, + } + + if os.Getenv("KAGENT_LOG_FORMAT") == "json" { + globalLogger = slog.New(slog.NewJSONHandler(os.Stdout, opts)) + } else { + globalLogger = slog.New(slog.NewTextHandler(os.Stdout, opts)) + } + + slog.SetDefault(globalLogger) +} + +func Get() *slog.Logger { + if globalLogger == nil { + Init() + } + return globalLogger +} + +func WithContext(ctx context.Context) *slog.Logger { + logger := Get() + span := trace.SpanFromContext(ctx) + if span.SpanContext().IsValid() { + logger = logger.With( + "trace_id", span.SpanContext().TraceID().String(), + "span_id", span.SpanContext().SpanID().String(), + ) + } + return logger +} + +func LogExecCommand(ctx context.Context, logger *slog.Logger, command string, args []string, caller string) { + logger.Info("executing command", + "command", command, + "args", args, + "caller", caller, + ) +} + +func LogExecCommandResult(ctx context.Context, logger *slog.Logger, command string, args []string, output string, err error, duration float64, caller string) { + if err != nil { + logger.Error("command execution failed", + "command", command, + "args", args, + "error", err.Error(), + "duration_seconds", duration, + "caller", caller, + ) + } else { + logger.Info("command execution successful", + "command", command, + "args", args, + "output", output, + "duration_seconds", duration, + "caller", caller, + ) + } +} + +func Sync() { + // No-op for slog, but kept for compatibility +} diff --git a/internal/logger/logger_test.go b/internal/logger/logger_test.go new file mode 100644 index 0000000..ad5c988 --- /dev/null +++ b/internal/logger/logger_test.go @@ -0,0 +1,72 @@ +package logger + +import ( + "bytes" + "context" + "encoding/json" + "log/slog" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace/noop" +) + +func TestLogExecCommand(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewTextHandler(&buf, nil)) + + ctx := context.Background() + LogExecCommand(ctx, logger, "test-command", []string{"arg1", "arg2"}, "test.go:123") + + output := buf.String() + assert.Contains(t, output, "executing command") + assert.Contains(t, output, "test-command") + assert.Contains(t, output, "arg1") + assert.Contains(t, output, "arg2") +} + +func TestLogExecCommandResult(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewTextHandler(&buf, nil)) + + ctx := context.Background() + LogExecCommandResult(ctx, logger, "test-command", []string{"arg1"}, "success output", nil, 1.5, "test.go:123") + assert.Contains(t, buf.String(), "command execution successful") + + buf.Reset() + LogExecCommandResult(ctx, logger, "test-command", []string{"arg1"}, "error output", assert.AnError, 0.5, "test.go:123") + assert.Contains(t, buf.String(), "command execution failed") +} + +func TestWithContextAddsTraceID(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewJSONHandler(&buf, nil)) + + // Create a context with a mock span + tp := noop.NewTracerProvider() + ctx, span := tp.Tracer("test").Start(context.Background(), "test-span") + defer span.End() + + loggerWithTrace := logger.With("trace_id", span.SpanContext().TraceID().String()) + loggerWithTrace.InfoContext(ctx, "test message") + + var logOutput map[string]interface{} + err := json.Unmarshal(buf.Bytes(), &logOutput) + require.NoError(t, err) + + traceID := span.SpanContext().TraceID().String() + assert.Equal(t, traceID, logOutput["trace_id"]) +} + +func TestGet(t *testing.T) { + assert.NotNil(t, Get()) +} + +func TestInit(t *testing.T) { + assert.NotPanics(t, Init) +} + +func TestSync(t *testing.T) { + assert.NotPanics(t, Sync) +} diff --git a/internal/security/validation.go b/internal/security/validation.go new file mode 100644 index 0000000..6aadc38 --- /dev/null +++ b/internal/security/validation.go @@ -0,0 +1,291 @@ +package security + +import ( + "fmt" + "regexp" + "strings" +) + +// ValidationError represents a validation error +type ValidationError struct { + Field string + Message string +} + +func (e ValidationError) Error() string { + return fmt.Sprintf("validation error in field '%s': %s", e.Field, e.Message) +} + +// Common validation patterns +var ( + // K8s resource name pattern (RFC 1123) + k8sNamePattern = regexp.MustCompile(`^[a-z0-9]([-a-z0-9]*[a-z0-9])?$`) + + // Namespace pattern + namespacePattern = regexp.MustCompile(`^[a-z0-9]([-a-z0-9]*[a-z0-9])?$`) + + // Container image pattern + imagePattern = regexp.MustCompile(`^[a-z0-9]+(([._-][a-z0-9]+)*(/[a-z0-9]+(([._-][a-z0-9]+)*)?)*)?(:([a-zA-Z0-9]([a-zA-Z0-9._-]*[a-zA-Z0-9])?))$`) + + // Path pattern (no directory traversal) + pathPattern = regexp.MustCompile(`^[a-zA-Z0-9._/-]+$`) + + // Command injection patterns to reject + commandInjectionPatterns = []*regexp.Regexp{ + regexp.MustCompile(`[;&|` + "`" + `$(){}[\]\\<>*?~!#\n\r\t]`), + regexp.MustCompile(`\.\./`), + regexp.MustCompile(`\$\{`), + regexp.MustCompile(`\$\(`), + regexp.MustCompile(`\|\|`), + regexp.MustCompile(`&&`), + } +) + +// ValidateK8sResourceName validates a Kubernetes resource name +func ValidateK8sResourceName(name string) error { + if name == "" { + return ValidationError{Field: "name", Message: "cannot be empty"} + } + + if len(name) > 63 { + return ValidationError{Field: "name", Message: "cannot exceed 63 characters"} + } + + if !k8sNamePattern.MatchString(name) { + return ValidationError{Field: "name", Message: "must follow RFC 1123 naming convention"} + } + + return nil +} + +// ValidateNamespace validates a Kubernetes namespace +func ValidateNamespace(namespace string) error { + if namespace == "" { + return nil // Empty namespace is allowed (defaults to 'default') + } + + if len(namespace) > 63 { + return ValidationError{Field: "namespace", Message: "cannot exceed 63 characters"} + } + + if !namespacePattern.MatchString(namespace) { + return ValidationError{Field: "namespace", Message: "must follow RFC 1123 naming convention"} + } + + // Reserved namespaces + reserved := []string{"kube-system", "kube-public", "kube-node-lease"} + for _, res := range reserved { + if namespace == res { + return ValidationError{Field: "namespace", Message: fmt.Sprintf("'%s' is a reserved namespace", namespace)} + } + } + + return nil +} + +// ValidateContainerImage validates a container image reference +func ValidateContainerImage(image string) error { + if image == "" { + return ValidationError{Field: "image", Message: "cannot be empty"} + } + + if len(image) > 255 { + return ValidationError{Field: "image", Message: "cannot exceed 255 characters"} + } + + if !imagePattern.MatchString(image) { + return ValidationError{Field: "image", Message: "invalid image format"} + } + + return nil +} + +// ValidateFilePath validates a file path for security +func ValidateFilePath(path string) error { + if path == "" { + return ValidationError{Field: "path", Message: "cannot be empty"} + } + + if len(path) > 4096 { + return ValidationError{Field: "path", Message: "path too long"} + } + + if strings.Contains(path, "..") { + return ValidationError{Field: "path", Message: "path traversal not allowed"} + } + + if !pathPattern.MatchString(path) { + return ValidationError{Field: "path", Message: "contains invalid characters"} + } + + return nil +} + +// ValidateCommandInput validates command inputs for injection attacks +func ValidateCommandInput(input string) error { + if input == "" { + return ValidationError{Field: "input", Message: "cannot be empty"} + } + + if len(input) > 1024 { + return ValidationError{Field: "input", Message: "input too long"} + } + + for _, pattern := range commandInjectionPatterns { + if pattern.MatchString(input) { + return ValidationError{Field: "input", Message: "potentially dangerous characters detected"} + } + } + + return nil +} + +// SanitizeInput sanitizes input strings by replacing potentially dangerous characters +func SanitizeInput(input string) string { + // Replace dangerous characters with safe alternatives + sanitized := strings.ReplaceAll(input, "\n", " ") + sanitized = strings.ReplaceAll(sanitized, "\r", " ") + sanitized = strings.ReplaceAll(sanitized, "\t", " ") + + // Replace multiple spaces with single space + spacePattern := regexp.MustCompile(`\s+`) + sanitized = spacePattern.ReplaceAllString(sanitized, " ") + + sanitized = strings.TrimSpace(sanitized) + + return sanitized +} + +// ValidateK8sLabel validates a Kubernetes label key and value +func ValidateK8sLabel(key, value string) error { + if key == "" { + return ValidationError{Field: "label_key", Message: "cannot be empty"} + } + + if len(key) > 63 { + return ValidationError{Field: "label_key", Message: "cannot exceed 63 characters"} + } + + if len(value) > 63 { + return ValidationError{Field: "label_value", Message: "cannot exceed 63 characters"} + } + + // Label key validation + labelKeyPattern := regexp.MustCompile(`^[a-z0-9A-Z]([a-z0-9A-Z._-]*[a-z0-9A-Z])?$`) + if !labelKeyPattern.MatchString(key) { + return ValidationError{Field: "label_key", Message: "invalid label key format"} + } + + // Label value validation (can be empty) + if value != "" { + labelValuePattern := regexp.MustCompile(`^[a-z0-9A-Z]([a-z0-9A-Z._-]*[a-z0-9A-Z])?$`) + if !labelValuePattern.MatchString(value) { + return ValidationError{Field: "label_value", Message: "invalid label value format"} + } + } + + return nil +} + +// ValidatePromQLQuery validates a PromQL query for basic security +func ValidatePromQLQuery(query string) error { + if query == "" { + return ValidationError{Field: "query", Message: "cannot be empty"} + } + + if len(query) > 8192 { + return ValidationError{Field: "query", Message: "query too long"} + } + + // Basic PromQL validation - no shell commands + dangerousPatterns := []string{ + "`", "$", "$(", "${", "&&", "||", ";", "|", ">", "<", "&", + } + + for _, pattern := range dangerousPatterns { + if strings.Contains(query, pattern) { + return ValidationError{Field: "query", Message: "potentially dangerous characters in query"} + } + } + + return nil +} + +// ValidateYAMLContent validates YAML content for basic security +func ValidateYAMLContent(content string) error { + if content == "" { + return ValidationError{Field: "content", Message: "cannot be empty"} + } + + if len(content) > 1024*1024 { // 1MB limit + return ValidationError{Field: "content", Message: "content too large"} + } + + // Check for potentially dangerous YAML content + dangerousPatterns := []string{ + "!!python/object/apply", + "!!python/object/new", + "!!python/object", + "__import__", + "eval(", + "exec(", + } + + for _, pattern := range dangerousPatterns { + if strings.Contains(content, pattern) { + return ValidationError{Field: "content", Message: "potentially dangerous YAML content detected"} + } + } + + return nil +} + +// ValidateHelmReleaseName validates a Helm release name +func ValidateHelmReleaseName(name string) error { + if name == "" { + return ValidationError{Field: "release_name", Message: "cannot be empty"} + } + + if len(name) > 53 { + return ValidationError{Field: "release_name", Message: "cannot exceed 53 characters"} + } + + // Helm release name pattern + helmNamePattern := regexp.MustCompile(`^[a-z0-9]([-a-z0-9]*[a-z0-9])?$`) + if !helmNamePattern.MatchString(name) { + return ValidationError{Field: "release_name", Message: "must follow DNS naming convention"} + } + + return nil +} + +// ValidateURL validates a URL for basic security +func ValidateURL(url string) error { + if url == "" { + return ValidationError{Field: "url", Message: "cannot be empty"} + } + + if len(url) > 2048 { + return ValidationError{Field: "url", Message: "URL too long"} + } + + // Basic URL validation + if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") { + return ValidationError{Field: "url", Message: "must start with http:// or https://"} + } + + // Check for dangerous URL patterns + dangerousPatterns := []string{ + "javascript:", "data:", "file:", "ftp:", + "", true}, + {"too long path", string(make([]byte, 5000)), true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateFilePath(tt.input) + if tt.expectError && err == nil { + t.Errorf("Expected error for input %q, but got none", tt.input) + } + if !tt.expectError && err != nil { + t.Errorf("Unexpected error for input %q: %v", tt.input, err) + } + }) + } +} + +func TestValidateCommandInput(t *testing.T) { + tests := []struct { + name string + input string + expectError bool + }{ + {"valid input", "my-service", false}, + {"empty input", "", true}, + {"command injection", "test; rm -rf /", true}, + {"pipe injection", "test | cat /etc/passwd", true}, + {"backtick injection", "test`whoami`", true}, + {"variable expansion", "test${USER}", true}, + {"too long input", string(make([]byte, 2000)), true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateCommandInput(tt.input) + if tt.expectError && err == nil { + t.Errorf("Expected error for input %q, but got none", tt.input) + } + if !tt.expectError && err != nil { + t.Errorf("Unexpected error for input %q: %v", tt.input, err) + } + }) + } +} + +func TestSanitizeInput(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + {"clean input", "hello world", "hello world"}, + {"with newlines", "hello\nworld", "hello world"}, + {"with tabs", "hello\tworld", "hello world"}, + {"with carriage returns", "hello\rworld", "hello world"}, + {"with spaces", " hello world ", "hello world"}, + {"mixed whitespace", "\n\t hello world \r\n", "hello world"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := SanitizeInput(tt.input) + if result != tt.expected { + t.Errorf("Expected %q, got %q", tt.expected, result) + } + }) + } +} + +func TestValidateK8sLabel(t *testing.T) { + tests := []struct { + name string + key string + value string + expectError bool + }{ + {"valid label", "app", "nginx", false}, + {"valid label with dash", "app-version", "1.0", false}, + {"valid label with underscore", "app_name", "nginx", false}, + {"empty key", "", "value", true}, + {"empty value", "key", "", false}, // Empty value is allowed + {"too long key", string(make([]byte, 70)), "value", true}, + {"too long value", "key", string(make([]byte, 70)), true}, + {"invalid key characters", "app/name", "nginx", true}, + {"invalid value characters", "app", "nginx/web", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateK8sLabel(tt.key, tt.value) + if tt.expectError && err == nil { + t.Errorf("Expected error for key %q, value %q, but got none", tt.key, tt.value) + } + if !tt.expectError && err != nil { + t.Errorf("Unexpected error for key %q, value %q: %v", tt.key, tt.value, err) + } + }) + } +} + +func TestValidatePromQLQuery(t *testing.T) { + tests := []struct { + name string + input string + expectError bool + }{ + {"valid query", "up{job=\"prometheus\"}", false}, + {"valid aggregation", "sum(rate(http_requests_total[5m]))", false}, + {"empty query", "", true}, + {"command injection", "up; rm -rf /", true}, + {"backtick injection", "up`whoami`", true}, + {"variable expansion", "up${USER}", true}, + {"too long query", string(make([]byte, 10000)), true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidatePromQLQuery(tt.input) + if tt.expectError && err == nil { + t.Errorf("Expected error for input %q, but got none", tt.input) + } + if !tt.expectError && err != nil { + t.Errorf("Unexpected error for input %q: %v", tt.input, err) + } + }) + } +} + +func TestValidateYAMLContent(t *testing.T) { + tests := []struct { + name string + input string + expectError bool + }{ + {"valid YAML", "apiVersion: v1\nkind: Pod", false}, + {"empty content", "", true}, + {"python object", "!!python/object/apply", true}, + {"python import", "__import__('os').system('rm -rf /')", true}, + {"eval injection", "eval('print(1)')", true}, + {"too large content", string(make([]byte, 2*1024*1024)), true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateYAMLContent(tt.input) + if tt.expectError && err == nil { + t.Errorf("Expected error for input %q, but got none", tt.input) + } + if !tt.expectError && err != nil { + t.Errorf("Unexpected error for input %q: %v", tt.input, err) + } + }) + } +} + +func TestValidateHelmReleaseName(t *testing.T) { + tests := []struct { + name string + input string + expectError bool + }{ + {"valid release name", "my-release", false}, + {"valid with numbers", "release-123", false}, + {"empty name", "", true}, + {"too long name", "this-is-a-very-long-release-name-that-exceeds-the-maximum-allowed-length-of-53-characters", true}, + {"invalid characters", "my_release", true}, + {"starts with dash", "-release", true}, + {"ends with dash", "release-", true}, + {"uppercase", "Release", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateHelmReleaseName(tt.input) + if tt.expectError && err == nil { + t.Errorf("Expected error for input %q, but got none", tt.input) + } + if !tt.expectError && err != nil { + t.Errorf("Unexpected error for input %q: %v", tt.input, err) + } + }) + } +} + +func TestValidateURL(t *testing.T) { + tests := []struct { + name string + input string + expectError bool + }{ + {"valid http URL", "http://example.com", false}, + {"valid https URL", "https://example.com/path", false}, + {"empty URL", "", true}, + {"invalid protocol", "ftp://example.com", true}, + {"javascript injection", "javascript:alert('xss')", true}, + {"data URL", "data:text/html,", true}, + {"too long URL", "https://" + string(make([]byte, 3000)), true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateURL(tt.input) + if tt.expectError && err == nil { + t.Errorf("Expected error for input %q, but got none", tt.input) + } + if !tt.expectError && err != nil { + t.Errorf("Unexpected error for input %q: %v", tt.input, err) + } + }) + } +} + +func TestValidationError(t *testing.T) { + err := ValidationError{ + Field: "test_field", + Message: "test message", + } + + expected := "validation error in field 'test_field': test message" + if err.Error() != expected { + t.Errorf("Expected error message %q, got %q", expected, err.Error()) + } +} diff --git a/internal/telemetry/config.go b/internal/telemetry/config.go new file mode 100644 index 0000000..56b266e --- /dev/null +++ b/internal/telemetry/config.go @@ -0,0 +1,74 @@ +package telemetry + +import ( + "os" + "strconv" + "strings" + "sync" +) + +// Telemetry holds all telemetry-related configuration. +type Telemetry struct { + ServiceName string + ServiceVersion string + Environment string + Endpoint string + Protocol string + SamplingRatio float64 + Insecure bool + Disabled bool +} + +// Config holds all application configuration. +type Config struct { + Telemetry Telemetry +} + +var ( + once sync.Once + config *Config +) + +// LoadOtelCfg initializes and returns the application configuration. +func LoadOtelCfg() *Config { + once.Do(func() { + config = &Config{ + Telemetry: Telemetry{ + ServiceName: getEnv("OTEL_SERVICE_NAME", "kagent-tools"), + ServiceVersion: getEnv("OTEL_SERVICE_VERSION", "dev"), + Environment: getEnv("OTEL_ENVIRONMENT", "development"), + Endpoint: getEnv("OTEL_EXPORTER_OTLP_ENDPOINT", ""), + Protocol: getEnv("OTEL_EXPORTER_OTLP_PROTOCOL", "auto"), + SamplingRatio: getEnvFloat("OTEL_TRACES_SAMPLER_ARG", 1.0), + Insecure: getEnvBool("OTEL_EXPORTER_OTLP_TRACES_INSECURE", false), + Disabled: getEnvBool("OTEL_SDK_DISABLED", false), + }, + } + }) + return config +} + +func getEnv(key, fallback string) string { + if value, ok := os.LookupEnv(key); ok { + return value + } + return fallback +} + +func getEnvFloat(key string, fallback float64) float64 { + if valueStr, ok := os.LookupEnv(key); ok { + if value, err := strconv.ParseFloat(valueStr, 64); err == nil { + return value + } + } + return fallback +} + +func getEnvBool(key string, fallback bool) bool { + if valueStr, ok := os.LookupEnv(key); ok { + if value, err := strconv.ParseBool(strings.ToLower(valueStr)); err == nil { + return value + } + } + return fallback +} diff --git a/internal/telemetry/config_test.go b/internal/telemetry/config_test.go new file mode 100644 index 0000000..fe6454b --- /dev/null +++ b/internal/telemetry/config_test.go @@ -0,0 +1,49 @@ +package telemetry + +import ( + "os" + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLoad(t *testing.T) { + // Reset singleton for testing + once = sync.Once{} + config = nil + + os.Setenv("OTEL_SERVICE_NAME", "test-service") + os.Setenv("OTEL_EXPORTER_OTLP_TRACES_INSECURE", "true") + defer func() { + os.Unsetenv("OTEL_SERVICE_NAME") + os.Unsetenv("OTEL_EXPORTER_OTLP_TRACES_INSECURE") + }() + + cfg := LoadOtelCfg() + assert.Equal(t, "test-service", cfg.Telemetry.ServiceName) + assert.True(t, cfg.Telemetry.Insecure) +} + +func TestLoadDefaults(t *testing.T) { + // Reset singleton for testing + once = sync.Once{} + config = nil + + cfg := LoadOtelCfg() + assert.Equal(t, "kagent-tools", cfg.Telemetry.ServiceName) + assert.False(t, cfg.Telemetry.Insecure) + assert.Equal(t, 1.0, cfg.Telemetry.SamplingRatio) +} + +func TestLoadDevelopmentSampling(t *testing.T) { + // Reset singleton for testing + once = sync.Once{} + config = nil + + os.Setenv("OTEL_ENVIRONMENT", "development") + defer os.Unsetenv("OTEL_ENVIRONMENT") + + cfg := LoadOtelCfg() + assert.Equal(t, 1.0, cfg.Telemetry.SamplingRatio) +} diff --git a/internal/telemetry/middleware.go b/internal/telemetry/middleware.go new file mode 100644 index 0000000..720a99b --- /dev/null +++ b/internal/telemetry/middleware.go @@ -0,0 +1,179 @@ +package telemetry + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/propagation" + "go.opentelemetry.io/otel/trace" +) + +type ToolHandler func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) + +// contextKey is used for storing HTTP context in the request context +type contextKey string + +const ( + HTTPHeadersKey contextKey = "http_headers" + TraceIDKey contextKey = "trace_id" + SpanIDKey contextKey = "span_id" +) + +// HTTPMiddleware wraps an HTTP handler to extract headers and propagate context +func HTTPMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Extract OpenTelemetry context from HTTP headers + propagator := otel.GetTextMapPropagator() + ctx = propagator.Extract(ctx, propagation.HeaderCarrier(r.Header)) + + // Store relevant HTTP headers in context for tool handlers + headers := make(map[string]string) + for name, values := range r.Header { + if len(values) > 0 { + // Store important headers for debugging/tracing + switch name { + case "X-Request-ID", "X-Correlation-ID", "X-Trace-ID", + "User-Agent", "Authorization", "X-Forwarded-For": + headers[name] = values[0] + } + } + } + + // Add headers to context + ctx = context.WithValue(ctx, HTTPHeadersKey, headers) + + // Extract trace information if available + span := trace.SpanFromContext(ctx) + if span.SpanContext().HasTraceID() { + ctx = context.WithValue(ctx, TraceIDKey, span.SpanContext().TraceID().String()) + ctx = context.WithValue(ctx, SpanIDKey, span.SpanContext().SpanID().String()) + } + + // Call next handler with enhanced context + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// ExtractHTTPHeaders retrieves HTTP headers from context +func ExtractHTTPHeaders(ctx context.Context) map[string]string { + if headers, ok := ctx.Value(HTTPHeadersKey).(map[string]string); ok { + return headers + } + return make(map[string]string) +} + +// ExtractTraceInfo retrieves trace information from context +func ExtractTraceInfo(ctx context.Context) (traceID, spanID string) { + if tid, ok := ctx.Value(TraceIDKey).(string); ok { + traceID = tid + } + if sid, ok := ctx.Value(SpanIDKey).(string); ok { + spanID = sid + } + return traceID, spanID +} + +func WithTracing(toolName string, handler ToolHandler) ToolHandler { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + tracer := otel.Tracer("kagent-tools/mcp") + + spanName := fmt.Sprintf("mcp.tool.%s", toolName) + ctx, span := tracer.Start(ctx, spanName) + defer span.End() + + // Extract HTTP headers from context and add as span attributes + headers := ExtractHTTPHeaders(ctx) + for key, value := range headers { + span.SetAttributes(attribute.String(fmt.Sprintf("http.header.%s", key), value)) + } + + // Extract parent trace information + parentTraceID, parentSpanID := ExtractTraceInfo(ctx) + if parentTraceID != "" { + span.SetAttributes( + attribute.String("http.parent_trace_id", parentTraceID), + attribute.String("http.parent_span_id", parentSpanID), + ) + } + + span.SetAttributes( + attribute.String("mcp.tool.name", toolName), + attribute.String("mcp.request.id", request.Params.Name), + ) + + if request.Params.Arguments != nil { + if argsJSON, err := json.Marshal(request.Params.Arguments); err == nil { + span.SetAttributes(attribute.String("mcp.request.arguments", string(argsJSON))) + } + } + + span.AddEvent("tool.execution.start") + startTime := time.Now() + + result, err := handler(ctx, request) + + duration := time.Since(startTime) + span.SetAttributes(attribute.Float64("mcp.tool.duration_seconds", duration.Seconds())) + + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + span.AddEvent("tool.execution.error", trace.WithAttributes( + attribute.String("error.message", err.Error()), + )) + } else { + span.SetStatus(codes.Ok, "tool execution completed successfully") + span.AddEvent("tool.execution.success") + + if result != nil { + span.SetAttributes(attribute.Bool("mcp.result.is_error", result.IsError)) + if result.Content != nil { + span.SetAttributes(attribute.Int("mcp.result.content_count", len(result.Content))) + } + } + } + + return result, err + } +} + +func StartSpan(ctx context.Context, operationName string, attrs ...attribute.KeyValue) (context.Context, trace.Span) { + tracer := otel.Tracer("kagent-tools") + ctx, span := tracer.Start(ctx, operationName) + + if len(attrs) > 0 { + span.SetAttributes(attrs...) + } + + return ctx, span +} + +func RecordError(span trace.Span, err error, message string) { + span.RecordError(err) + span.SetStatus(codes.Error, message) +} + +func RecordSuccess(span trace.Span, message string) { + span.SetStatus(codes.Ok, message) +} + +func AddEvent(span trace.Span, name string, attrs ...attribute.KeyValue) { + span.AddEvent(name, trace.WithAttributes(attrs...)) +} + +// AdaptToolHandler adapts a telemetry.ToolHandler to a server.ToolHandlerFunc. +func AdaptToolHandler(th ToolHandler) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return th(ctx, req) + } +} diff --git a/internal/telemetry/middleware_test.go b/internal/telemetry/middleware_test.go new file mode 100644 index 0000000..bcbf494 --- /dev/null +++ b/internal/telemetry/middleware_test.go @@ -0,0 +1,801 @@ +package telemetry + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/trace/noop" +) + +// InMemoryExporter is a simple in-memory exporter for testing +type InMemoryExporter struct { + spans []trace.ReadOnlySpan +} + +func (e *InMemoryExporter) ExportSpans(ctx context.Context, spans []trace.ReadOnlySpan) error { + e.spans = append(e.spans, spans...) + return nil +} + +func (e *InMemoryExporter) Shutdown(ctx context.Context) error { + return nil +} + +func (e *InMemoryExporter) GetSpans() []trace.ReadOnlySpan { + return e.spans +} + +// setupTracing initializes OpenTelemetry with in-memory exporter for testing +func setupTracing() (*trace.TracerProvider, *InMemoryExporter) { + exporter := &InMemoryExporter{} + provider := trace.NewTracerProvider( + trace.WithSampler(trace.AlwaysSample()), + trace.WithSpanProcessor(trace.NewSimpleSpanProcessor(exporter)), + ) + otel.SetTracerProvider(provider) + return provider, exporter +} + +func TestWithTracing(t *testing.T) { + // Initialize OpenTelemetry + provider, exporter := setupTracing() + defer func() { + if err := provider.Shutdown(context.Background()); err != nil { + t.Errorf("Failed to shutdown provider: %v", err) + } + }() + + // Create a test handler + testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + textContent := mcp.NewTextContent("test response") + return &mcp.CallToolResult{ + IsError: false, + Content: []mcp.Content{textContent}, + }, nil + } + + // Wrap with tracing + tracedHandler := WithTracing("test-tool", testHandler) + + // Create test request + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "test-tool", + Arguments: map[string]interface{}{ + "param1": "value1", + "param2": 42, + }, + }, + } + + // Execute the handler + result, err := tracedHandler(context.Background(), request) + + // Force flush to ensure spans are exported + if err := provider.ForceFlush(context.Background()); err != nil { + t.Errorf("Failed to flush provider: %v", err) + } + + // Verify result + require.NoError(t, err) + assert.NotNil(t, result) + assert.False(t, result.IsError) + assert.Len(t, result.Content, 1) + textContent, ok := mcp.AsTextContent(result.Content[0]) + require.True(t, ok) + assert.Equal(t, "test response", textContent.Text) + + // Verify span was created + spans := exporter.GetSpans() + assert.Len(t, spans, 1) + + span := spans[0] + assert.Equal(t, "mcp.tool.test-tool", span.Name()) + assert.Equal(t, codes.Ok, span.Status().Code) + // Note: SDK may not preserve description in test environment + // assert.Equal(t, "tool execution completed successfully", span.Status().Description) + + // Verify attributes + attributes := span.Attributes() + hasToolName := false + hasRequestID := false + hasIsError := false + hasContentCount := false + + for _, attr := range attributes { + if attr.Key == "mcp.tool.name" && attr.Value.AsString() == "test-tool" { + hasToolName = true + } + if attr.Key == "mcp.request.id" && attr.Value.AsString() == "test-tool" { + hasRequestID = true + } + if attr.Key == "mcp.result.is_error" && attr.Value.AsBool() == false { + hasIsError = true + } + if attr.Key == "mcp.result.content_count" && attr.Value.AsInt64() == 1 { + hasContentCount = true + } + } + + assert.True(t, hasToolName) + assert.True(t, hasRequestID) + assert.True(t, hasIsError) + assert.True(t, hasContentCount) + + // Verify events + events := span.Events() + assert.Len(t, events, 2) + assert.Equal(t, "tool.execution.start", events[0].Name) + assert.Equal(t, "tool.execution.success", events[1].Name) +} + +func TestWithTracingError(t *testing.T) { + // Initialize OpenTelemetry + provider, exporter := setupTracing() + defer func() { + if err := provider.Shutdown(context.Background()); err != nil { + t.Errorf("Failed to shutdown provider: %v", err) + } + }() + + // Create a test handler that returns an error + testError := errors.New("test error") + testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return nil, testError + } + + // Wrap with tracing + tracedHandler := WithTracing("test-tool", testHandler) + + // Create test request + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "test-tool", + }, + } + + // Execute the handler + result, err := tracedHandler(context.Background(), request) + + // Force flush to ensure spans are exported + if err := provider.ForceFlush(context.Background()); err != nil { + t.Errorf("Failed to flush provider: %v", err) + } + + // Verify result + assert.Error(t, err) + assert.Equal(t, testError, err) + assert.Nil(t, result) + + // Verify span was created with error + spans := exporter.GetSpans() + assert.Len(t, spans, 1) + + span := spans[0] + assert.Equal(t, "mcp.tool.test-tool", span.Name()) + assert.Equal(t, codes.Error, span.Status().Code) + // Note: SDK may not preserve description in test environment + // assert.Equal(t, "test error", span.Status().Description) + + // Verify events - span.RecordError() adds an "exception" event, plus our custom events + events := span.Events() + assert.Len(t, events, 3) + assert.Equal(t, "tool.execution.start", events[0].Name) + assert.Equal(t, "exception", events[1].Name) // Added by span.RecordError() + assert.Equal(t, "tool.execution.error", events[2].Name) +} + +func TestWithTracingErrorResult(t *testing.T) { + // Initialize OpenTelemetry + provider, exporter := setupTracing() + defer func() { + if err := provider.Shutdown(context.Background()); err != nil { + t.Errorf("Failed to shutdown provider: %v", err) + } + }() + + // Create a test handler that returns an error result + testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + textContent := mcp.NewTextContent("error occurred") + return &mcp.CallToolResult{ + IsError: true, + Content: []mcp.Content{textContent}, + }, nil + } + + // Wrap with tracing + tracedHandler := WithTracing("test-tool", testHandler) + + // Create test request + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "test-tool", + }, + } + + // Execute the handler + result, err := tracedHandler(context.Background(), request) + + // Force flush to ensure spans are exported + if err := provider.ForceFlush(context.Background()); err != nil { + t.Errorf("Failed to flush provider: %v", err) + } + + // Verify result + require.NoError(t, err) + assert.NotNil(t, result) + assert.True(t, result.IsError) + + // Verify span was created successfully (no error from handler) + spans := exporter.GetSpans() + assert.Len(t, spans, 1) + + span := spans[0] + assert.Equal(t, "mcp.tool.test-tool", span.Name()) + assert.Equal(t, codes.Ok, span.Status().Code) + + // Verify attributes + attributes := span.Attributes() + hasIsError := false + hasContentCount := false + + for _, attr := range attributes { + if attr.Key == "mcp.result.is_error" && attr.Value.AsBool() == true { + hasIsError = true + } + if attr.Key == "mcp.result.content_count" && attr.Value.AsInt64() == 1 { + hasContentCount = true + } + } + + assert.True(t, hasIsError) + assert.True(t, hasContentCount) +} + +func TestWithTracingWithArguments(t *testing.T) { + // Initialize OpenTelemetry + provider, exporter := setupTracing() + defer func() { + if err := provider.Shutdown(context.Background()); err != nil { + t.Errorf("Failed to shutdown provider: %v", err) + } + }() + + // Create a test handler + testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + textContent := mcp.NewTextContent("test response") + return &mcp.CallToolResult{ + IsError: false, + Content: []mcp.Content{textContent}, + }, nil + } + + // Wrap with tracing + tracedHandler := WithTracing("test-tool", testHandler) + + // Create test request with arguments + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "test-tool", + Arguments: map[string]interface{}{ + "string_param": "hello", + "number_param": 42, + "bool_param": true, + "array_param": []interface{}{"a", "b", "c"}, + "object_param": map[string]interface{}{ + "nested": "value", + }, + }, + }, + } + + // Execute the handler + result, err := tracedHandler(context.Background(), request) + + // Force flush to ensure spans are exported + if err := provider.ForceFlush(context.Background()); err != nil { + t.Errorf("Failed to flush provider: %v", err) + } + + // Verify result + require.NoError(t, err) + assert.NotNil(t, result) + assert.False(t, result.IsError) + + // Verify span was created + spans := exporter.GetSpans() + assert.Len(t, spans, 1) + + span := spans[0] + assert.Equal(t, "mcp.tool.test-tool", span.Name()) + + // Verify that arguments were added as an attribute (they are JSON-encoded) + attributes := span.Attributes() + hasArguments := false + + for _, attr := range attributes { + if attr.Key == "mcp.request.arguments" { + hasArguments = true + // Arguments should be JSON-encoded + assert.NotEmpty(t, attr.Value.AsString()) + } + } + + assert.True(t, hasArguments) +} + +func TestWithTracingNilArguments(t *testing.T) { + // Initialize OpenTelemetry + provider, exporter := setupTracing() + defer func() { + if err := provider.Shutdown(context.Background()); err != nil { + t.Errorf("Failed to shutdown provider: %v", err) + } + }() + + // Create a test handler + testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + textContent := mcp.NewTextContent("test response") + return &mcp.CallToolResult{ + IsError: false, + Content: []mcp.Content{textContent}, + }, nil + } + + // Wrap with tracing + tracedHandler := WithTracing("test-tool", testHandler) + + // Create test request without arguments + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "test-tool", + }, + } + + // Execute the handler + result, err := tracedHandler(context.Background(), request) + + // Force flush to ensure spans are exported + if err := provider.ForceFlush(context.Background()); err != nil { + t.Errorf("Failed to flush provider: %v", err) + } + + // Verify result + require.NoError(t, err) + assert.NotNil(t, result) + assert.False(t, result.IsError) + + // Verify span was created + spans := exporter.GetSpans() + assert.Len(t, spans, 1) + + span := spans[0] + assert.Equal(t, "mcp.tool.test-tool", span.Name()) +} + +func TestStartSpan(t *testing.T) { + // Initialize OpenTelemetry + provider, exporter := setupTracing() + defer func() { + if err := provider.Shutdown(context.Background()); err != nil { + t.Errorf("Failed to shutdown provider: %v", err) + } + }() + + // Start a span + _, span := StartSpan(context.Background(), "test-span", + attribute.String("key1", "value1"), + attribute.Int("key2", 42), + ) + + // End the span + span.End() + + // Force flush to ensure spans are exported + if err := provider.ForceFlush(context.Background()); err != nil { + t.Errorf("Failed to flush provider: %v", err) + } + + // Verify span was created + spans := exporter.GetSpans() + assert.Len(t, spans, 1) + + resultSpan := spans[0] + assert.Equal(t, "test-span", resultSpan.Name()) +} + +func TestStartSpanNoAttributes(t *testing.T) { + // Initialize OpenTelemetry + provider, exporter := setupTracing() + defer func() { + if err := provider.Shutdown(context.Background()); err != nil { + t.Errorf("Failed to shutdown provider: %v", err) + } + }() + + // Start a span without attributes + _, span := StartSpan(context.Background(), "test-span") + + // End the span + span.End() + + // Force flush to ensure spans are exported + if err := provider.ForceFlush(context.Background()); err != nil { + t.Errorf("Failed to flush provider: %v", err) + } + + // Verify span was created + spans := exporter.GetSpans() + assert.Len(t, spans, 1) + + resultSpan := spans[0] + assert.Equal(t, "test-span", resultSpan.Name()) +} + +func TestRecordError(t *testing.T) { + // Initialize OpenTelemetry + provider, exporter := setupTracing() + defer func() { + if err := provider.Shutdown(context.Background()); err != nil { + t.Errorf("Failed to shutdown provider: %v", err) + } + }() + + // Start a span + _, span := StartSpan(context.Background(), "test-span") + + // Record an error + testError := errors.New("test error") + RecordError(span, testError, "test error") + + // End the span + span.End() + + // Force flush to ensure spans are exported + if err := provider.ForceFlush(context.Background()); err != nil { + t.Errorf("Failed to flush provider: %v", err) + } + + // Verify span was created with error + spans := exporter.GetSpans() + assert.Len(t, spans, 1) + + resultSpan := spans[0] + assert.Equal(t, "test-span", resultSpan.Name()) + assert.Equal(t, codes.Error, resultSpan.Status().Code) + assert.Equal(t, "test error", resultSpan.Status().Description) +} + +func TestRecordSuccess(t *testing.T) { + // Initialize OpenTelemetry + provider, exporter := setupTracing() + defer func() { + if err := provider.Shutdown(context.Background()); err != nil { + t.Errorf("Failed to shutdown provider: %v", err) + } + }() + + // Start a span + _, span := StartSpan(context.Background(), "test-span") + + // Record success + RecordSuccess(span, "operation completed successfully") + + // End the span + span.End() + + // Force flush to ensure spans are exported + if err := provider.ForceFlush(context.Background()); err != nil { + t.Errorf("Failed to flush provider: %v", err) + } + + // Verify span was created with success + spans := exporter.GetSpans() + assert.Len(t, spans, 1) + + resultSpan := spans[0] + assert.Equal(t, "test-span", resultSpan.Name()) + assert.Equal(t, codes.Ok, resultSpan.Status().Code) + // Note: SDK may not preserve description in test environment + // assert.Equal(t, "operation completed successfully", resultSpan.Status().Description) +} + +func TestAddEvent(t *testing.T) { + // Initialize OpenTelemetry + provider, exporter := setupTracing() + defer func() { + if err := provider.Shutdown(context.Background()); err != nil { + t.Errorf("Failed to shutdown provider: %v", err) + } + }() + + // Start a span + _, span := StartSpan(context.Background(), "test-span") + + // Add an event + AddEvent(span, "test-event", + attribute.String("event_key", "event_value"), + attribute.Int("event_num", 123), + ) + + // End the span + span.End() + + // Force flush to ensure spans are exported + if err := provider.ForceFlush(context.Background()); err != nil { + t.Errorf("Failed to flush provider: %v", err) + } + + // Verify span was created with event + spans := exporter.GetSpans() + assert.Len(t, spans, 1) + + resultSpan := spans[0] + assert.Equal(t, "test-span", resultSpan.Name()) + + // Verify event + events := resultSpan.Events() + assert.Len(t, events, 1) + assert.Equal(t, "test-event", events[0].Name) +} + +func TestAddEventNoAttributes(t *testing.T) { + // Initialize OpenTelemetry + provider, exporter := setupTracing() + defer func() { + if err := provider.Shutdown(context.Background()); err != nil { + t.Errorf("Failed to shutdown provider: %v", err) + } + }() + + // Start a span + _, span := StartSpan(context.Background(), "test-span") + + // Add an event without attributes + AddEvent(span, "test-event") + + // End the span + span.End() + + // Force flush to ensure spans are exported + if err := provider.ForceFlush(context.Background()); err != nil { + t.Errorf("Failed to flush provider: %v", err) + } + + // Verify span was created with event + spans := exporter.GetSpans() + assert.Len(t, spans, 1) + + resultSpan := spans[0] + assert.Equal(t, "test-span", resultSpan.Name()) + + // Verify event + events := resultSpan.Events() + assert.Len(t, events, 1) + assert.Equal(t, "test-event", events[0].Name) +} + +func TestAdaptToolHandler(t *testing.T) { + // Create a test handler + testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + textContent := mcp.NewTextContent("test response") + return &mcp.CallToolResult{ + IsError: false, + Content: []mcp.Content{textContent}, + }, nil + } + + // Adapt the handler + adapted := AdaptToolHandler(testHandler) + + // Create test request + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "test-tool", + }, + } + + // Execute the adapted handler + result, err := adapted(context.Background(), request) + + // Verify result + require.NoError(t, err) + assert.NotNil(t, result) + assert.False(t, result.IsError) + assert.Len(t, result.Content, 1) + textContent, ok := mcp.AsTextContent(result.Content[0]) + require.True(t, ok) + assert.Equal(t, "test response", textContent.Text) +} + +func TestWithTracingNilResult(t *testing.T) { + // Initialize OpenTelemetry + provider, exporter := setupTracing() + defer func() { + if err := provider.Shutdown(context.Background()); err != nil { + t.Errorf("Failed to shutdown provider: %v", err) + } + }() + + // Create a test handler that returns nil result + testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return nil, nil + } + + // Wrap with tracing + tracedHandler := WithTracing("test-tool", testHandler) + + // Create test request + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "test-tool", + }, + } + + // Execute the handler + result, err := tracedHandler(context.Background(), request) + + // Force flush to ensure spans are exported + if err := provider.ForceFlush(context.Background()); err != nil { + t.Errorf("Failed to flush provider: %v", err) + } + + // Verify result + require.NoError(t, err) + assert.Nil(t, result) + + // Verify span was created + spans := exporter.GetSpans() + assert.Len(t, spans, 1) + + span := spans[0] + assert.Equal(t, "mcp.tool.test-tool", span.Name()) + assert.Equal(t, codes.Ok, span.Status().Code) +} + +func TestWithTracingNoContent(t *testing.T) { + // Initialize OpenTelemetry + provider, exporter := setupTracing() + defer func() { + if err := provider.Shutdown(context.Background()); err != nil { + t.Errorf("Failed to shutdown provider: %v", err) + } + }() + + // Create a test handler that returns result with no content + testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + IsError: false, + Content: []mcp.Content{}, + }, nil + } + + // Wrap with tracing + tracedHandler := WithTracing("test-tool", testHandler) + + // Create test request + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "test-tool", + }, + } + + // Execute the handler + result, err := tracedHandler(context.Background(), request) + + // Force flush to ensure spans are exported + if err := provider.ForceFlush(context.Background()); err != nil { + t.Errorf("Failed to flush provider: %v", err) + } + + // Verify result + require.NoError(t, err) + assert.NotNil(t, result) + assert.False(t, result.IsError) + assert.Len(t, result.Content, 0) + + // Verify span was created + spans := exporter.GetSpans() + assert.Len(t, spans, 1) + + span := spans[0] + assert.Equal(t, "mcp.tool.test-tool", span.Name()) + assert.Equal(t, codes.Ok, span.Status().Code) + + // Verify attributes + attributes := span.Attributes() + hasContentCount := false + + for _, attr := range attributes { + if attr.Key == "mcp.result.content_count" && attr.Value.AsInt64() == 0 { + hasContentCount = true + } + } + + assert.True(t, hasContentCount) +} + +func TestWithTracingNoopTracer(t *testing.T) { + // Set up noop tracer provider + otel.SetTracerProvider(noop.NewTracerProvider()) + + // Create a test handler + testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + textContent := mcp.NewTextContent("test response") + return &mcp.CallToolResult{ + IsError: false, + Content: []mcp.Content{textContent}, + }, nil + } + + // Wrap with tracing + tracedHandler := WithTracing("test-tool", testHandler) + + // Create test request + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "test-tool", + }, + } + + // Execute the handler + result, err := tracedHandler(context.Background(), request) + + // Verify result (should work normally with noop tracer) + require.NoError(t, err) + assert.NotNil(t, result) + assert.False(t, result.IsError) + assert.Len(t, result.Content, 1) + textContent, ok := mcp.AsTextContent(result.Content[0]) + require.True(t, ok) + assert.Equal(t, "test response", textContent.Text) +} + +func TestWithTracingPerformance(t *testing.T) { + // Initialize OpenTelemetry + provider, _ := setupTracing() + defer func() { + if err := provider.Shutdown(context.Background()); err != nil { + t.Errorf("Failed to shutdown provider: %v", err) + } + }() + + // Create a test handler + testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + textContent := mcp.NewTextContent("test response") + return &mcp.CallToolResult{ + IsError: false, + Content: []mcp.Content{textContent}, + }, nil + } + + // Wrap with tracing + tracedHandler := WithTracing("test-tool", testHandler) + + // Create test request + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "test-tool", + }, + } + + // Time execution + start := time.Now() + for i := 0; i < 100; i++ { + _, err := tracedHandler(context.Background(), request) + require.NoError(t, err) + } + duration := time.Since(start) + + // Verify performance is reasonable (should complete in less than 1 second) + assert.Less(t, duration, time.Second) +} diff --git a/internal/telemetry/tracing.go b/internal/telemetry/tracing.go new file mode 100644 index 0000000..6b6f720 --- /dev/null +++ b/internal/telemetry/tracing.go @@ -0,0 +1,282 @@ +package telemetry + +import ( + "context" + "fmt" + "net/url" + "os" + "strings" + "time" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" + "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp" + "go.opentelemetry.io/otel/exporters/stdout/stdouttrace" + "go.opentelemetry.io/otel/propagation" + "go.opentelemetry.io/otel/sdk/resource" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + semconv "go.opentelemetry.io/otel/semconv/v1.32.0" + "go.opentelemetry.io/otel/trace/noop" + + "github.com/kagent-dev/tools/internal/logger" +) + +// Standard OpenTelemetry environment variable names +// These follow the official OTLP specification +const ( + // Service identification + OtelServiceName = "OTEL_SERVICE_NAME" + OtelServiceVersion = "OTEL_SERVICE_VERSION" + OtelEnvironment = "OTEL_ENVIRONMENT" // Custom extension, not in official spec + + // OTLP Exporter configuration + OtelExporterOtlpEndpoint = "OTEL_EXPORTER_OTLP_ENDPOINT" + OtelExporterOtlpProtocol = "OTEL_EXPORTER_OTLP_PROTOCOL" + OtelExporterOtlpHeaders = "OTEL_EXPORTER_OTLP_HEADERS" + + // Trace-specific OTLP configuration + OtelExporterOtlpTracesInsecure = "OTEL_EXPORTER_OTLP_TRACES_INSECURE" + + // Sampling configuration + OtelTracesSamplerArg = "OTEL_TRACES_SAMPLER_ARG" + + // SDK control + OtelSdkDisabled = "OTEL_SDK_DISABLED" +) + +// OTLP Protocol constants +const ( + ProtocolGRPC = "grpc" + ProtocolHTTP = "http/protobuf" + ProtocolAuto = "auto" // Custom extension for automatic protocol detection +) + +// Standard OTLP port numbers +// These are the official OTLP default ports as per OpenTelemetry specification +const ( + DefaultOtlpGrpcPort = "4317" // Standard OTLP/gRPC port + DefaultOtlpHttpPort = "4318" // Standard OTLP/HTTP port +) + +// Default endpoint paths +const ( + DefaultHttpTracesPath = "/v1/traces" +) + +// SetupOTelSDK initializes the OpenTelemetry SDK +func SetupOTelSDK(ctx context.Context) error { + log := logger.WithContext(ctx) + cfg := LoadOtelCfg() + telemetryConfig := cfg.Telemetry + + // If tracing is disabled, set a no-op tracer provider and return. + // This prevents further initialization and ensures no traces are exported. + if cfg.Telemetry.Disabled { + otel.SetTracerProvider(noop.NewTracerProvider()) + return nil + } + + res, err := resource.New(ctx, + resource.WithDetectors(), // Detectors for cloud provider, k8s, etc. + resource.WithAttributes( + semconv.ServiceNameKey.String(telemetryConfig.ServiceName), + semconv.ServiceVersionKey.String(telemetryConfig.ServiceVersion), + attribute.String("deployment.environment", telemetryConfig.Environment), + ), + ) + if err != nil { + log.Error("failed to create resource", "error", err) + return fmt.Errorf("failed to create resource: %w", err) + } + + // Set up propagator + prop := propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}) + otel.SetTextMapPropagator(prop) + + exporter, err := createExporter(ctx, &telemetryConfig) + if err != nil { + log.Error("failed to create exporter", "error", err) + return fmt.Errorf("failed to create exporter: %w", err) + } + + // Set up trace provider + tracerProvider, err := newTracerProvider(ctx, &telemetryConfig, exporter, res) + if err != nil { + log.Error("failed to create tracer provider", "error", err) + return fmt.Errorf("failed to create tracer provider: %w", err) + } + otel.SetTracerProvider(tracerProvider) + + log.Info("OpenTelemetry SDK successfully initialized") + //start goroutine and wait for ctx cancellation + go func() { + <-ctx.Done() + if err := tracerProvider.Shutdown(ctx); err != nil { + log.Error("failed to shutdown tracer provider", "error", err) + } else { + log.Info("OpenTelemetry SDK shutdown successfully") + } + }() + return nil +} + +// newTracerProvider creates a new trace provider +func newTracerProvider(ctx context.Context, cfg *Telemetry, exporter sdktrace.SpanExporter, res *resource.Resource) (*sdktrace.TracerProvider, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + + sampler := sdktrace.AlwaysSample() + + tp := sdktrace.NewTracerProvider( + sdktrace.WithSampler(sampler), + sdktrace.WithBatcher(exporter), + sdktrace.WithResource(res), + ) + return tp, nil +} + +// createExporter creates a OTLP exporter +func createExporter(ctx context.Context, cfg *Telemetry) (sdktrace.SpanExporter, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + + if cfg.Endpoint == "" { + return stdouttrace.New(stdouttrace.WithPrettyPrint()) + } + + // Determine protocol + protocol := cfg.Protocol + if protocol == ProtocolAuto || protocol == "" { + protocol = detectProtocol(cfg.Endpoint) + } + + switch strings.ToLower(protocol) { + case ProtocolGRPC: + return createGRPCExporter(ctx, cfg) + case ProtocolHTTP: + return createHTTPExporter(ctx, cfg) + default: + return nil, fmt.Errorf("unsupported protocol: %s (supported: %s, %s)", protocol, ProtocolGRPC, ProtocolHTTP) + } +} + +// detectProtocol determines the protocol based on the endpoint URL +func detectProtocol(endpoint string) string { + // Parse URL to extract port + if parsedURL, err := url.Parse(endpoint); err == nil { + port := parsedURL.Port() + if port == "" { + // Check for default ports in hostname + if strings.Contains(parsedURL.Host, ":"+DefaultOtlpGrpcPort) { + return ProtocolGRPC + } + if strings.Contains(parsedURL.Host, ":"+DefaultOtlpHttpPort) { + return ProtocolHTTP + } + } else { + switch port { + case DefaultOtlpGrpcPort: + return ProtocolGRPC + case DefaultOtlpHttpPort: + return ProtocolHTTP + } + } + } + + // Check if endpoint contains port info directly + if strings.Contains(endpoint, ":"+DefaultOtlpGrpcPort) { + return ProtocolGRPC + } + if strings.Contains(endpoint, ":"+DefaultOtlpHttpPort) { + return ProtocolHTTP + } + + // Default to HTTP for backward compatibility + return ProtocolHTTP +} + +// createGRPCExporter creates a gRPC OTLP exporter +func createGRPCExporter(ctx context.Context, cfg *Telemetry) (sdktrace.SpanExporter, error) { + opts := []otlptracegrpc.Option{ + otlptracegrpc.WithEndpoint(normalizeGRPCEndpoint(cfg.Endpoint)), + otlptracegrpc.WithTimeout(30 * time.Second), + } + + // Use insecure connection if explicitly configured + if cfg.Insecure { + opts = append(opts, otlptracegrpc.WithInsecure()) + } + + if authToken := os.Getenv(OtelExporterOtlpHeaders); authToken != "" { + opts = append(opts, otlptracegrpc.WithHeaders(parseHeaders(authToken))) + } + + return otlptracegrpc.New(ctx, opts...) +} + +// createHTTPExporter creates an HTTP OTLP exporter +func createHTTPExporter(ctx context.Context, cfg *Telemetry) (sdktrace.SpanExporter, error) { + opts := []otlptracehttp.Option{ + otlptracehttp.WithEndpointURL(normalizeHTTPEndpoint(cfg.Endpoint, cfg.Insecure)), + otlptracehttp.WithTimeout(30 * time.Second), + } + + // Use insecure connection if explicitly configured + if cfg.Insecure { + opts = append(opts, otlptracehttp.WithInsecure()) + } + + if authToken := os.Getenv(OtelExporterOtlpHeaders); authToken != "" { + opts = append(opts, otlptracehttp.WithHeaders(parseHeaders(authToken))) + } + + return otlptracehttp.New(ctx, opts...) +} + +// normalizeGRPCEndpoint normalizes the endpoint for gRPC usage +func normalizeGRPCEndpoint(endpoint string) string { + if !strings.HasPrefix(endpoint, "http://") && !strings.HasPrefix(endpoint, "https://") { + return endpoint + } + + u, err := url.Parse(endpoint) + if err != nil { + return endpoint // Should not happen with the check above, but as a safeguard + } + + return u.Host + u.Path +} + +// normalizeHTTPEndpoint normalizes the endpoint for HTTP usage +func normalizeHTTPEndpoint(endpoint string, insecure bool) string { + // Ensure we have a proper HTTP URL + if !strings.HasPrefix(endpoint, "http://") && !strings.HasPrefix(endpoint, "https://") { + // Use HTTP if insecure is true or if endpoint contains localhost/127.0.0.1/docker.internal + if insecure || strings.Contains(endpoint, "localhost") || strings.Contains(endpoint, "127.0.0.1") || strings.Contains(endpoint, "docker.internal") { + endpoint = "http://" + endpoint + } else { + endpoint = "https://" + endpoint + } + } + + // Add /v1/traces suffix if not present + if !strings.HasSuffix(endpoint, DefaultHttpTracesPath) { + endpoint = strings.TrimSuffix(endpoint, "/") + DefaultHttpTracesPath + } + + return endpoint +} + +// parseHeaders parses a comma-separated string of headers into a map +func parseHeaders(headers string) map[string]string { + headerMap := make(map[string]string) + for _, h := range strings.Split(headers, ",") { + if parts := strings.SplitN(h, "=", 2); len(parts) == 2 { + headerMap[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1]) + } + } + return headerMap +} diff --git a/internal/telemetry/tracing_test.go b/internal/telemetry/tracing_test.go new file mode 100644 index 0000000..f26f3bd --- /dev/null +++ b/internal/telemetry/tracing_test.go @@ -0,0 +1,374 @@ +package telemetry + +import ( + "context" + "os" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/exporters/stdout/stdouttrace" + "go.opentelemetry.io/otel/sdk/resource" + "go.opentelemetry.io/otel/trace/noop" +) + +// Test protocol constants for additional test scenarios +const ( + ProtocolInvalid = "invalid" +) + +// resetConfig is a helper to reset the singleton config for tests +func resetConfig() { + once = sync.Once{} + config = nil +} + +func TestSetupOTelSDK_Disabled(t *testing.T) { + resetConfig() + ctx := context.Background() + err := os.Setenv("OTEL_SDK_DISABLED", "true") + require.NoError(t, err) + defer func() { + _ = os.Unsetenv("OTEL_SDK_DISABLED") + }() + resetConfig() + + err = SetupOTelSDK(ctx) + require.NoError(t, err) + + // In a disabled state, the tracer provider should be a no-op provider + tp := otel.GetTracerProvider() + assert.IsType(t, noop.NewTracerProvider(), tp) + + // Shutdown should be a no-op function + assert.NoError(t, err) +} + +func TestSetupOTelSDKEnabled(t *testing.T) { + resetConfig() + ctx := context.Background() + err := os.Setenv(OtelSdkDisabled, "false") + require.NoError(t, err) + defer func() { + _ = os.Unsetenv(OtelSdkDisabled) + }() + + err = SetupOTelSDK(ctx) + require.NoError(t, err) +} + +func TestNewTracerProviderDevelopment(t *testing.T) { + resetConfig() + ctx := context.Background() + res := resource.NewSchemaless() + cfg := &Telemetry{ + Environment: "development", + } + exporter, _ := stdouttrace.New() + + tp, err := newTracerProvider(ctx, cfg, exporter, res) + require.NoError(t, err) + assert.NotNil(t, tp) +} + +func TestNewTracerProviderProduction(t *testing.T) { + resetConfig() + ctx := context.Background() + res := resource.NewSchemaless() + cfg := &Telemetry{ + Environment: "production", + SamplingRatio: 0.5, + } + exporter, _ := stdouttrace.New() + + tp, err := newTracerProvider(ctx, cfg, exporter, res) + require.NoError(t, err) + assert.NotNil(t, tp) +} + +func TestCreateExporterDevelopment(t *testing.T) { + resetConfig() + ctx := context.Background() + cfg := &Telemetry{ + Environment: "development", + } + + exporter, err := createExporter(ctx, cfg) + require.NoError(t, err) + assert.NotNil(t, exporter) + assert.IsType(t, &stdouttrace.Exporter{}, exporter) +} + +func TestCreateExporterNoEndpoint(t *testing.T) { + resetConfig() + ctx := context.Background() + cfg := &Telemetry{ + Environment: "production", + } + + exporter, err := createExporter(ctx, cfg) + require.NoError(t, err) + assert.NotNil(t, exporter) + assert.IsType(t, &stdouttrace.Exporter{}, exporter) +} + +func TestCreateExporterWithEndpoint(t *testing.T) { + resetConfig() + ctx := context.Background() + cfg := &Telemetry{ + Environment: "production", + Endpoint: "http://localhost:4317", + Protocol: ProtocolAuto, + } + + exporter, err := createExporter(ctx, cfg) + require.NoError(t, err) + assert.NotNil(t, exporter) +} + +func TestCreateExporterWithInsecure(t *testing.T) { + resetConfig() + ctx := context.Background() + cfg := &Telemetry{ + Environment: "production", + Endpoint: "localhost:4317", + Insecure: true, + } + + // This should not fail, as insecure is handled by the exporters + _, err := createExporter(ctx, cfg) + require.NoError(t, err) +} + +func TestCreateExporterWithAuthHeaders(t *testing.T) { + resetConfig() + ctx := context.Background() + cfg := &Telemetry{ + Environment: "production", + Endpoint: "http://localhost:4317", + Protocol: ProtocolAuto, + } + + // Set auth header + err := os.Setenv(OtelExporterOtlpHeaders, "Authorization=Bearer token123") + require.NoError(t, err) + defer func() { + _ = os.Unsetenv(OtelExporterOtlpHeaders) + }() + + exporter, err := createExporter(ctx, cfg) + require.NoError(t, err) + assert.NotNil(t, exporter) + + // Clean up + err = exporter.Shutdown(ctx) + assert.NoError(t, err) +} + +func TestSetupOTelSDKWithCancellation(t *testing.T) { + resetConfig() + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel context immediately + + err := SetupOTelSDK(ctx) + require.Error(t, err) // Expect an error due to context cancellation +} + +func TestProtocolDetection(t *testing.T) { + tests := []struct { + name string + endpoint string + expected string + }{ + {"gRPC port 4317", "localhost:4317", ProtocolGRPC}, + {"HTTP port 4318", "localhost:4318", ProtocolHTTP}, + {"gRPC port 4317 without scheme", "localhost:4317", ProtocolGRPC}, + {"HTTP port 4318 without scheme", "localhost:4318", ProtocolHTTP}, + {"gRPC with docker internal", "host.docker.internal:4317", ProtocolGRPC}, + {"HTTP with docker internal", "host.docker.internal:4318", ProtocolHTTP}, + {"No port specified", "localhost", ProtocolHTTP}, + {"Unknown port", "localhost:1234", ProtocolHTTP}, + {"HTTPS with gRPC port", "https://localhost:4317", ProtocolGRPC}, + {"HTTPS with HTTP port", "https://localhost:4318", ProtocolHTTP}, + {"gRPC with path", "localhost:4317/v1/traces", ProtocolGRPC}, + {"HTTP with path", "localhost:4318/v1/traces", ProtocolHTTP}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := detectProtocol(tt.endpoint) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestEndpointNormalization(t *testing.T) { + tests := []struct { + name string + endpoint string + expected string + }{ + {"Basic gRPC endpoint", "localhost:4317", "localhost:4317"}, + {"gRPC with path", "localhost:4317/v1/traces", "localhost:4317/v1/traces"}, + {"gRPC without scheme", "localhost:4317", "localhost:4317"}, + {"gRPC with HTTPS", "https://localhost:4317", "localhost:4317"}, + {"Docker internal gRPC", "host.docker.internal:4317", "host.docker.internal:4317"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := normalizeGRPCEndpoint(tt.endpoint) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestHTTPEndpointNormalization(t *testing.T) { + tests := []struct { + name string + endpoint string + insecure bool + expected string + }{ + {"Basic HTTP endpoint", "http://localhost:4318", false, "http://localhost:4318/v1/traces"}, + {"HTTP with path", "http://localhost:4318/v1/traces", false, "http://localhost:4318/v1/traces"}, + {"HTTP without scheme - secure localhost", "localhost:4318", false, "http://localhost:4318/v1/traces"}, + {"HTTP without scheme - insecure localhost", "localhost:4318", true, "http://localhost:4318/v1/traces"}, + {"HTTP with trailing slash", "http://localhost:4318/", false, "http://localhost:4318/v1/traces"}, + {"Docker internal HTTP - secure", "host.docker.internal:4318", false, "http://host.docker.internal:4318/v1/traces"}, + {"Docker internal HTTP - insecure", "host.docker.internal:4318", true, "http://host.docker.internal:4318/v1/traces"}, + {"Remote endpoint - secure", "collector.example.com:4318", false, "https://collector.example.com:4318/v1/traces"}, + {"Remote endpoint - insecure", "collector.example.com:4318", true, "http://collector.example.com:4318/v1/traces"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := normalizeHTTPEndpoint(tt.endpoint, tt.insecure) + assert.Equal(t, tt.expected, result, "HTTP endpoint normalization failed for: %s", tt.endpoint) + }) + } +} + +func TestParseHeaders(t *testing.T) { + tests := []struct { + name string + headers string + want map[string]string + }{ + {"Empty string", "", map[string]string{}}, + {"Single header", "key=value", map[string]string{"key": "value"}}, + {"Multiple headers", "key1=value1,key2=value2", map[string]string{"key1": "value1", "key2": "value2"}}, + {"Headers with spaces", " key1 = value1 , key2 = value2 ", map[string]string{"key1": "value1", "key2": "value2"}}, + {"Invalid header format", "key-value,key2", map[string]string{}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := parseHeaders(tt.headers) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestCreateExporterWithProtocol(t *testing.T) { + + ctx := context.Background() + + tests := []struct { + name string + config *Telemetry + shouldError bool + description string + }{ + { + "gRPC protocol", + &Telemetry{ + Environment: "development", + Endpoint: "localhost:4317", + Protocol: ProtocolGRPC, + }, + false, + "Should create gRPC exporter", + }, + { + "HTTP protocol", + &Telemetry{ + Environment: "development", + Endpoint: "localhost:4318", + Protocol: ProtocolHTTP, + }, + false, + "Should create HTTP exporter", + }, + { + "Auto protocol with gRPC port", + &Telemetry{ + Environment: "development", + Endpoint: "localhost:4317", + Protocol: ProtocolAuto, + }, + false, + "Should auto-detect gRPC", + }, + { + "Auto protocol with HTTP port", + &Telemetry{ + Environment: "development", + Endpoint: "localhost:4318", + Protocol: ProtocolAuto, + }, + false, + "Should auto-detect HTTP", + }, + { + "gRPC protocol with insecure", + &Telemetry{ + Environment: "production", + Endpoint: "localhost:4317", + Protocol: ProtocolGRPC, + Insecure: true, + }, + false, + "Should create gRPC exporter with insecure", + }, + { + "HTTP protocol with insecure", + &Telemetry{ + Environment: "production", + Endpoint: "localhost:4318", + Protocol: ProtocolHTTP, + Insecure: true, + }, + false, + "Should create HTTP exporter with insecure", + }, + { + "Invalid protocol", + &Telemetry{ + Environment: "development", + Endpoint: "localhost:1234", + Protocol: ProtocolInvalid, + }, + true, + "Should return error for invalid protocol", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetConfig() + exporter, err := createExporter(ctx, tt.config) + if tt.shouldError { + require.Error(t, err, tt.description) + assert.Nil(t, exporter, tt.description) + } else { + require.NoError(t, err, tt.description) + assert.NotNil(t, exporter, tt.description) + err = exporter.Shutdown(ctx) + assert.NoError(t, err) + } + }) + } +} diff --git a/pkg/argo/argo.go b/pkg/argo/argo.go index 566a4a0..a24e978 100644 --- a/pkg/argo/argo.go +++ b/pkg/argo/argo.go @@ -13,6 +13,8 @@ import ( "strings" "time" + "github.com/kagent-dev/tools/internal/commands" + "github.com/kagent-dev/tools/internal/telemetry" "github.com/kagent-dev/tools/pkg/utils" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" @@ -20,8 +22,6 @@ import ( // Argo Rollouts tools -var kubeConfig = "" - func handleVerifyArgoRolloutsControllerInstall(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { ns := mcp.ParseString(request, "namespace", "argo-rollouts") label := mcp.ParseString(request, "label", "app.kubernetes.io/component=rollouts-controller") @@ -76,10 +76,11 @@ func handleVerifyKubectlPluginInstall(ctx context.Context, request mcp.CallToolR } func runArgoRolloutCommand(ctx context.Context, args []string) (string, error) { - if kubeConfig != "" { - args = append(args, "--kubeconfig", kubeConfig) - } - return utils.RunCommandWithContext(ctx, "kubectl", args) + kubeconfigPath := utils.GetKubeconfig() + return commands.NewCommandBuilder("kubectl"). + WithArgs(args...). + WithKubeconfig(kubeconfigPath). + Execute(ctx) } func handlePromoteRollout(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -198,9 +199,13 @@ func getSystemArchitecture() (string, error) { } } -func getLatestVersion() string { +func getLatestVersion(ctx context.Context) string { client := &http.Client{Timeout: 10 * time.Second} - resp, err := client.Get("https://api.github.com/repos/argoproj-labs/rollouts-plugin-trafficrouter-gatewayapi/releases/latest") + req, err := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/repos/argoproj-labs/rollouts-plugin-trafficrouter-gatewayapi/releases/latest", nil) + if err != nil { + return "0.5.0" // Default version + } + resp, err := client.Do(req) if err != nil { return "0.5.0" // Default version } @@ -220,7 +225,7 @@ func getLatestVersion() string { return "0.5.0" } -func configureGatewayPlugin(version, namespace string) GatewayPluginStatus { +func configureGatewayPlugin(ctx context.Context, version, namespace string) GatewayPluginStatus { arch, err := getSystemArchitecture() if err != nil { return GatewayPluginStatus{ @@ -230,7 +235,7 @@ func configureGatewayPlugin(version, namespace string) GatewayPluginStatus { } if version == "" { - version = getLatestVersion() + version = getLatestVersion(ctx) } configMap := fmt.Sprintf(`apiVersion: v1 @@ -263,11 +268,12 @@ data: tmpFile.Close() // Apply the ConfigMap - _, err = utils.RunCommandWithContext(context.Background(), "kubectl", []string{"apply", "-f", tmpFile.Name()}) + cmdArgs := []string{"apply", "-f", tmpFile.Name()} + output, err := runArgoRolloutCommand(ctx, cmdArgs) if err != nil { return GatewayPluginStatus{ Installed: false, - ErrorMessage: fmt.Sprintf("Failed to configure Gateway API plugin: %s", err.Error()), + ErrorMessage: fmt.Sprintf("Error applying Gateway API plugin config: %s. Output: %s", err.Error(), output), } } @@ -304,7 +310,7 @@ func handleVerifyGatewayPlugin(ctx context.Context, request mcp.CallToolRequest) } // Configure plugin - status := configureGatewayPlugin(version, namespace) + status := configureGatewayPlugin(ctx, version, namespace) return mcp.NewToolResultText(status.String()), nil } @@ -348,48 +354,47 @@ func handleCheckPluginLogs(ctx context.Context, request mcp.CallToolRequest) (*m return mcp.NewToolResultText(status.String()), nil } -func RegisterArgoTools(s *server.MCPServer, kubeconfig string) { - kubeConfig = kubeconfig +func RegisterTools(s *server.MCPServer) { s.AddTool(mcp.NewTool("argo_verify_argo_rollouts_controller_install", mcp.WithDescription("Verify that the Argo Rollouts controller is installed and running"), mcp.WithString("namespace", mcp.Description("The namespace where Argo Rollouts is installed")), mcp.WithString("label", mcp.Description("The label of the Argo Rollouts controller pods")), - ), handleVerifyArgoRolloutsControllerInstall) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("argo_verify_argo_rollouts_controller_install", handleVerifyArgoRolloutsControllerInstall))) s.AddTool(mcp.NewTool("argo_verify_kubectl_plugin_install", mcp.WithDescription("Verify that the kubectl Argo Rollouts plugin is installed"), - ), handleVerifyKubectlPluginInstall) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("argo_verify_kubectl_plugin_install", handleVerifyKubectlPluginInstall))) s.AddTool(mcp.NewTool("argo_promote_rollout", mcp.WithDescription("Promote a paused rollout to the next step"), mcp.WithString("rollout_name", mcp.Description("The name of the rollout to promote"), mcp.Required()), mcp.WithString("namespace", mcp.Description("The namespace of the rollout")), mcp.WithString("full", mcp.Description("Promote the rollout to the final step")), - ), handlePromoteRollout) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("argo_promote_rollout", handlePromoteRollout))) s.AddTool(mcp.NewTool("argo_pause_rollout", mcp.WithDescription("Pause a rollout"), mcp.WithString("rollout_name", mcp.Description("The name of the rollout to pause"), mcp.Required()), mcp.WithString("namespace", mcp.Description("The namespace of the rollout")), - ), handlePauseRollout) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("argo_pause_rollout", handlePauseRollout))) s.AddTool(mcp.NewTool("argo_set_rollout_image", mcp.WithDescription("Set the image of a rollout"), mcp.WithString("rollout_name", mcp.Description("The name of the rollout to set the image for"), mcp.Required()), mcp.WithString("container_image", mcp.Description("The container image to set for the rollout"), mcp.Required()), mcp.WithString("namespace", mcp.Description("The namespace of the rollout")), - ), handleSetRolloutImage) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("argo_set_rollout_image", handleSetRolloutImage))) s.AddTool(mcp.NewTool("argo_verify_gateway_plugin", mcp.WithDescription("Verify the installation status of the Argo Rollouts Gateway API plugin"), mcp.WithString("version", mcp.Description("The version of the plugin to check")), mcp.WithString("namespace", mcp.Description("The namespace for the plugin resources")), mcp.WithString("should_install", mcp.Description("Whether to install the plugin if not found")), - ), handleVerifyGatewayPlugin) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("argo_verify_gateway_plugin", handleVerifyGatewayPlugin))) s.AddTool(mcp.NewTool("argo_check_plugin_logs", mcp.WithDescription("Check the logs of the Argo Rollouts Gateway API plugin"), mcp.WithString("namespace", mcp.Description("The namespace of the plugin resources")), mcp.WithString("timeout", mcp.Description("Timeout for log collection in seconds")), - ), handleCheckPluginLogs) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("argo_check_plugin_logs", handleCheckPluginLogs))) } diff --git a/pkg/argo/argo_test.go b/pkg/argo/argo_test.go index 4a80823..0f90c39 100644 --- a/pkg/argo/argo_test.go +++ b/pkg/argo/argo_test.go @@ -5,7 +5,7 @@ import ( "strings" "testing" - "github.com/kagent-dev/tools/pkg/utils" + "github.com/kagent-dev/tools/internal/cmd" "github.com/mark3labs/mcp-go/mcp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -27,11 +27,11 @@ func getResultText(result *mcp.CallToolResult) string { // Test Argo Rollouts Promote func TestHandlePromoteRollout(t *testing.T) { t.Run("promote rollout basic", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `rollout "myapp" promoted` - mock.AddCommandString("kubectl", []string{"argo", "rollouts", "promote", "myapp"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock.AddCommandString("kubectl", []string{"argo", "rollouts", "promote", "myapp", "--timeout", "30s"}, expectedOutput, nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ @@ -52,15 +52,15 @@ func TestHandlePromoteRollout(t *testing.T) { callLog := mock.GetCallLog() require.Len(t, callLog, 1) assert.Equal(t, "kubectl", callLog[0].Command) - assert.Equal(t, []string{"argo", "rollouts", "promote", "myapp"}, callLog[0].Args) + assert.Equal(t, []string{"argo", "rollouts", "promote", "myapp", "--timeout", "30s"}, callLog[0].Args) }) t.Run("promote rollout with namespace", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `rollout "myapp" promoted` - mock.AddCommandString("kubectl", []string{"argo", "rollouts", "promote", "-n", "production", "myapp"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock.AddCommandString("kubectl", []string{"argo", "rollouts", "promote", "-n", "production", "myapp", "--timeout", "30s"}, expectedOutput, nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ @@ -77,15 +77,15 @@ func TestHandlePromoteRollout(t *testing.T) { callLog := mock.GetCallLog() require.Len(t, callLog, 1) assert.Equal(t, "kubectl", callLog[0].Command) - assert.Equal(t, []string{"argo", "rollouts", "promote", "-n", "production", "myapp"}, callLog[0].Args) + assert.Equal(t, []string{"argo", "rollouts", "promote", "-n", "production", "myapp", "--timeout", "30s"}, callLog[0].Args) }) t.Run("promote rollout with full flag", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `rollout "myapp" fully promoted` - mock.AddCommandString("kubectl", []string{"argo", "rollouts", "promote", "myapp", "--full"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock.AddCommandString("kubectl", []string{"argo", "rollouts", "promote", "myapp", "--full", "--timeout", "30s"}, expectedOutput, nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ @@ -102,12 +102,12 @@ func TestHandlePromoteRollout(t *testing.T) { callLog := mock.GetCallLog() require.Len(t, callLog, 1) assert.Equal(t, "kubectl", callLog[0].Command) - assert.Equal(t, []string{"argo", "rollouts", "promote", "myapp", "--full"}, callLog[0].Args) + assert.Equal(t, []string{"argo", "rollouts", "promote", "myapp", "--full", "--timeout", "30s"}, callLog[0].Args) }) t.Run("missing required parameters", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - ctx := utils.WithShellExecutor(context.Background(), mock) + mock := cmd.NewMockShellExecutor() + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ @@ -125,9 +125,9 @@ func TestHandlePromoteRollout(t *testing.T) { }) t.Run("kubectl command failure", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - mock.AddCommandString("kubectl", []string{"argo", "rollouts", "promote", "myapp"}, "", assert.AnError) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("kubectl", []string{"argo", "rollouts", "promote", "myapp", "--timeout", "30s"}, "", assert.AnError) + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ @@ -145,11 +145,11 @@ func TestHandlePromoteRollout(t *testing.T) { // Test Argo Rollouts Pause func TestHandlePauseRollout(t *testing.T) { t.Run("pause rollout basic", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `rollout "myapp" paused` - mock.AddCommandString("kubectl", []string{"argo", "rollouts", "pause", "myapp"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock.AddCommandString("kubectl", []string{"argo", "rollouts", "pause", "myapp", "--timeout", "30s"}, expectedOutput, nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ @@ -170,15 +170,15 @@ func TestHandlePauseRollout(t *testing.T) { callLog := mock.GetCallLog() require.Len(t, callLog, 1) assert.Equal(t, "kubectl", callLog[0].Command) - assert.Equal(t, []string{"argo", "rollouts", "pause", "myapp"}, callLog[0].Args) + assert.Equal(t, []string{"argo", "rollouts", "pause", "myapp", "--timeout", "30s"}, callLog[0].Args) }) t.Run("pause rollout with namespace", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `rollout "myapp" paused` - mock.AddCommandString("kubectl", []string{"argo", "rollouts", "pause", "-n", "production", "myapp"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock.AddCommandString("kubectl", []string{"argo", "rollouts", "pause", "-n", "production", "myapp", "--timeout", "30s"}, expectedOutput, nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ @@ -195,12 +195,12 @@ func TestHandlePauseRollout(t *testing.T) { callLog := mock.GetCallLog() require.Len(t, callLog, 1) assert.Equal(t, "kubectl", callLog[0].Command) - assert.Equal(t, []string{"argo", "rollouts", "pause", "-n", "production", "myapp"}, callLog[0].Args) + assert.Equal(t, []string{"argo", "rollouts", "pause", "-n", "production", "myapp", "--timeout", "30s"}, callLog[0].Args) }) t.Run("missing required parameters", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - ctx := utils.WithShellExecutor(context.Background(), mock) + mock := cmd.NewMockShellExecutor() + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ @@ -221,11 +221,11 @@ func TestHandlePauseRollout(t *testing.T) { // Test Argo Rollouts Set Image func TestHandleSetRolloutImage(t *testing.T) { t.Run("set rollout image basic", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `rollout "myapp" image updated` - mock.AddCommandString("kubectl", []string{"argo", "rollouts", "set", "image", "myapp", "nginx:latest"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock.AddCommandString("kubectl", []string{"argo", "rollouts", "set", "image", "myapp", "nginx:latest", "--timeout", "30s"}, expectedOutput, nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ @@ -247,15 +247,15 @@ func TestHandleSetRolloutImage(t *testing.T) { callLog := mock.GetCallLog() require.Len(t, callLog, 1) assert.Equal(t, "kubectl", callLog[0].Command) - assert.Equal(t, []string{"argo", "rollouts", "set", "image", "myapp", "nginx:latest"}, callLog[0].Args) + assert.Equal(t, []string{"argo", "rollouts", "set", "image", "myapp", "nginx:latest", "--timeout", "30s"}, callLog[0].Args) }) t.Run("set rollout image with namespace", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `rollout "myapp" image updated` - mock.AddCommandString("kubectl", []string{"argo", "rollouts", "set", "image", "myapp", "nginx:1.20", "-n", "production"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock.AddCommandString("kubectl", []string{"argo", "rollouts", "set", "image", "myapp", "nginx:1.20", "-n", "production", "--timeout", "30s"}, expectedOutput, nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ @@ -273,12 +273,12 @@ func TestHandleSetRolloutImage(t *testing.T) { callLog := mock.GetCallLog() require.Len(t, callLog, 1) assert.Equal(t, "kubectl", callLog[0].Command) - assert.Equal(t, []string{"argo", "rollouts", "set", "image", "myapp", "nginx:1.20", "-n", "production"}, callLog[0].Args) + assert.Equal(t, []string{"argo", "rollouts", "set", "image", "myapp", "nginx:1.20", "-n", "production", "--timeout", "30s"}, callLog[0].Args) }) t.Run("missing rollout_name parameter", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - ctx := utils.WithShellExecutor(context.Background(), mock) + mock := cmd.NewMockShellExecutor() + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ @@ -297,8 +297,8 @@ func TestHandleSetRolloutImage(t *testing.T) { }) t.Run("missing container_image parameter", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - ctx := utils.WithShellExecutor(context.Background(), mock) + mock := cmd.NewMockShellExecutor() + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ @@ -334,7 +334,7 @@ func TestGetSystemArchitecture(t *testing.T) { } func TestGetLatestVersion(t *testing.T) { - version := getLatestVersion() + version := getLatestVersion(context.Background()) if version == "" { t.Error("Expected non-empty version") } @@ -367,11 +367,11 @@ func TestGatewayPluginStatus(t *testing.T) { // Test Verify Gateway Plugin func TestHandleVerifyGatewayPlugin(t *testing.T) { t.Run("verify gateway plugin without install", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `gateway-api-plugin not found` - mock.AddCommandString("kubectl", []string{"get", "configmap", "argo-rollouts-config", "-n", "argo-rollouts", "-o", "yaml"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock.AddCommandString("kubectl", []string{"get", "configmap", "argo-rollouts-config", "-n", "argo-rollouts", "-o", "yaml", "--timeout", "30s"}, expectedOutput, nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ @@ -394,11 +394,11 @@ func TestHandleVerifyGatewayPlugin(t *testing.T) { }) t.Run("verify gateway plugin with custom namespace", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `gateway-api-plugin-abc123` - mock.AddCommandString("kubectl", []string{"get", "configmap", "argo-rollouts-config", "-n", "custom-namespace", "-o", "yaml"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock.AddCommandString("kubectl", []string{"get", "configmap", "argo-rollouts-config", "-n", "custom-namespace", "-o", "yaml", "--timeout", "30s"}, expectedOutput, nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ @@ -423,11 +423,11 @@ func TestHandleVerifyGatewayPlugin(t *testing.T) { // Test Verify Argo Rollouts Controller Install func TestHandleVerifyArgoRolloutsControllerInstall(t *testing.T) { t.Run("verify controller install", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `argo-rollouts-controller-manager-abc123` - mock.AddCommandString("kubectl", []string{"get", "pods", "-l", "app.kubernetes.io/name=argo-rollouts", "-n", "argo-rollouts", "-o", "jsonpath={.items[*].metadata.name}"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock.AddCommandString("kubectl", []string{"get", "pods", "-l", "app.kubernetes.io/name=argo-rollouts", "-n", "argo-rollouts", "-o", "jsonpath={.items[*].metadata.name}", "--timeout", "30s"}, expectedOutput, nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} result, err := handleVerifyArgoRolloutsControllerInstall(ctx, request) @@ -444,11 +444,11 @@ func TestHandleVerifyArgoRolloutsControllerInstall(t *testing.T) { }) t.Run("verify controller install with custom namespace", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `argo-rollouts-controller-manager-abc123` - mock.AddCommandString("kubectl", []string{"get", "pods", "-l", "app.kubernetes.io/name=argo-rollouts", "-n", "custom-argo", "-o", "jsonpath={.items[*].metadata.name}"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock.AddCommandString("kubectl", []string{"get", "pods", "-l", "app.kubernetes.io/name=argo-rollouts", "-n", "custom-argo", "-o", "jsonpath={.items[*].metadata.name}", "--timeout", "30s"}, expectedOutput, nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ @@ -469,11 +469,11 @@ func TestHandleVerifyArgoRolloutsControllerInstall(t *testing.T) { }) t.Run("verify controller install with custom label", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `argo-rollouts-controller-manager-abc123` - mock.AddCommandString("kubectl", []string{"get", "pods", "-l", "app=custom-rollouts", "-n", "argo-rollouts", "-o", "jsonpath={.items[*].metadata.name}"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock.AddCommandString("kubectl", []string{"get", "pods", "-l", "app=custom-rollouts", "-n", "argo-rollouts", "-o", "jsonpath={.items[*].metadata.name}", "--timeout", "30s"}, expectedOutput, nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ @@ -497,11 +497,11 @@ func TestHandleVerifyArgoRolloutsControllerInstall(t *testing.T) { // Test Verify Kubectl Plugin Install func TestHandleVerifyKubectlPluginInstall(t *testing.T) { t.Run("verify kubectl plugin install", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `kubectl-argo-rollouts` - mock.AddCommandString("kubectl", []string{"argo", "rollouts", "version"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock.AddCommandString("kubectl", []string{"argo", "rollouts", "version", "--timeout", "30s"}, expectedOutput, nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} result, err := handleVerifyKubectlPluginInstall(ctx, request) @@ -513,13 +513,13 @@ func TestHandleVerifyKubectlPluginInstall(t *testing.T) { callLog := mock.GetCallLog() require.Len(t, callLog, 1) assert.Equal(t, "kubectl", callLog[0].Command) - assert.Equal(t, []string{"argo", "rollouts", "version"}, callLog[0].Args) + assert.Equal(t, []string{"argo", "rollouts", "version", "--timeout", "30s"}, callLog[0].Args) }) t.Run("kubectl plugin command failure", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - mock.AddCommandString("kubectl", []string{"plugin", "list"}, "", assert.AnError) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("kubectl", []string{"plugin", "list", "--timeout", "30s"}, "", assert.AnError) + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} result, err := handleVerifyKubectlPluginInstall(ctx, request) diff --git a/pkg/cilium/cilium.go b/pkg/cilium/cilium.go index a84cae3..6ad576c 100644 --- a/pkg/cilium/cilium.go +++ b/pkg/cilium/cilium.go @@ -3,21 +3,21 @@ package cilium import ( "context" "fmt" - "strings" + "github.com/kagent-dev/tools/internal/commands" + "github.com/kagent-dev/tools/internal/telemetry" "github.com/kagent-dev/tools/pkg/utils" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" ) -var kubeConfig = "" - func runCiliumCliWithContext(ctx context.Context, args ...string) (string, error) { - if kubeConfig != "" { - args = append([]string{"--kubeconfig", kubeConfig}, args...) - } - return utils.RunCommandWithContext(ctx, "cilium", args) + kubeconfigPath := utils.GetKubeconfig() + return commands.NewCommandBuilder("cilium"). + WithArgs(args...). + WithKubeconfig(kubeconfigPath). + Execute(ctx) } func handleCiliumStatusAndVersion(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -200,70 +200,66 @@ func handleToggleClusterMesh(ctx context.Context, request mcp.CallToolRequest) ( return mcp.NewToolResultText(output), nil } -func RegisterCiliumTools(s *server.MCPServer, kubeconfig string) { - kubeConfig = kubeconfig +func RegisterTools(s *server.MCPServer) { - // Register debug tools - RegisterCiliumDbgTools(s) - - // Register main Cilium tools + // Register all Cilium tools (main and debug) s.AddTool(mcp.NewTool("cilium_status_and_version", mcp.WithDescription("Get the status and version of Cilium installation"), - ), handleCiliumStatusAndVersion) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_status_and_version", handleCiliumStatusAndVersion))) s.AddTool(mcp.NewTool("cilium_upgrade_cilium", mcp.WithDescription("Upgrade Cilium on the cluster"), mcp.WithString("cluster_name", mcp.Description("The name of the cluster to upgrade Cilium on")), mcp.WithString("datapath_mode", mcp.Description("The datapath mode to use for Cilium (tunnel, native, aws-eni, gke, azure, aks-byocni)")), - ), handleUpgradeCilium) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_upgrade_cilium", handleUpgradeCilium))) s.AddTool(mcp.NewTool("cilium_install_cilium", mcp.WithDescription("Install Cilium on the cluster"), mcp.WithString("cluster_name", mcp.Description("The name of the cluster to install Cilium on")), mcp.WithString("cluster_id", mcp.Description("The ID of the cluster to install Cilium on")), mcp.WithString("datapath_mode", mcp.Description("The datapath mode to use for Cilium (tunnel, native, aws-eni, gke, azure, aks-byocni)")), - ), handleInstallCilium) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_install_cilium", handleInstallCilium))) s.AddTool(mcp.NewTool("cilium_uninstall_cilium", mcp.WithDescription("Uninstall Cilium from the cluster"), - ), handleUninstallCilium) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_uninstall_cilium", handleUninstallCilium))) s.AddTool(mcp.NewTool("cilium_connect_to_remote_cluster", mcp.WithDescription("Connect to a remote cluster for cluster mesh"), mcp.WithString("cluster_name", mcp.Description("The name of the destination cluster"), mcp.Required()), mcp.WithString("context", mcp.Description("The kubectl context for the destination cluster")), - ), handleConnectToRemoteCluster) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_connect_to_remote_cluster", handleConnectToRemoteCluster))) s.AddTool(mcp.NewTool("cilium_disconnect_remote_cluster", mcp.WithDescription("Disconnect from a remote cluster"), mcp.WithString("cluster_name", mcp.Description("The name of the destination cluster"), mcp.Required()), - ), handleDisconnectRemoteCluster) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_disconnect_remote_cluster", handleDisconnectRemoteCluster))) s.AddTool(mcp.NewTool("cilium_list_bgp_peers", mcp.WithDescription("List BGP peers"), - ), handleListBGPPeers) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_bgp_peers", handleListBGPPeers))) s.AddTool(mcp.NewTool("cilium_list_bgp_routes", mcp.WithDescription("List BGP routes"), - ), handleListBGPRoutes) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_bgp_routes", handleListBGPRoutes))) s.AddTool(mcp.NewTool("cilium_show_cluster_mesh_status", mcp.WithDescription("Show cluster mesh status"), - ), handleShowClusterMeshStatus) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_show_cluster_mesh_status", handleShowClusterMeshStatus))) s.AddTool(mcp.NewTool("cilium_show_features_status", mcp.WithDescription("Show Cilium features status"), - ), handleShowFeaturesStatus) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_show_features_status", handleShowFeaturesStatus))) s.AddTool(mcp.NewTool("cilium_toggle_hubble", mcp.WithDescription("Enable or disable Hubble"), mcp.WithString("enable", mcp.Description("Set to 'true' to enable, 'false' to disable")), - ), handleToggleHubble) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_toggle_hubble", handleToggleHubble))) s.AddTool(mcp.NewTool("cilium_toggle_cluster_mesh", mcp.WithDescription("Enable or disable cluster mesh"), mcp.WithString("enable", mcp.Description("Set to 'true' to enable, 'false' to disable")), - ), handleToggleClusterMesh) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_toggle_cluster_mesh", handleToggleClusterMesh))) // Add tools that are also needed by cilium-manager agent s.AddTool(mcp.NewTool("cilium_get_daemon_status", @@ -276,12 +272,12 @@ func RegisterCiliumTools(s *server.MCPServer, kubeconfig string) { mcp.WithString("show_all_redirects", mcp.Description("Whether to show all redirects")), mcp.WithString("brief", mcp.Description("Whether to show a brief status")), mcp.WithString("node_name", mcp.Description("The name of the node to get the daemon status for")), - ), handleGetDaemonStatus) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_daemon_status", handleGetDaemonStatus))) s.AddTool(mcp.NewTool("cilium_get_endpoints_list", mcp.WithDescription("Get the list of all endpoints in the cluster"), mcp.WithString("node_name", mcp.Description("The name of the node to get the endpoints list for")), - ), handleGetEndpointsList) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_endpoints_list", handleGetEndpointsList))) s.AddTool(mcp.NewTool("cilium_get_endpoint_details", mcp.WithDescription("List the details of an endpoint in the cluster"), @@ -289,7 +285,7 @@ func RegisterCiliumTools(s *server.MCPServer, kubeconfig string) { mcp.WithString("labels", mcp.Description("The labels of the endpoint to get details for")), mcp.WithString("output_format", mcp.Description("The output format of the endpoint details (json, yaml, jsonpath)")), mcp.WithString("node_name", mcp.Description("The name of the node to get the endpoint details for")), - ), handleGetEndpointDetails) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_endpoint_details", handleGetEndpointDetails))) s.AddTool(mcp.NewTool("cilium_show_configuration_options", mcp.WithDescription("Show Cilium configuration options"), @@ -297,26 +293,26 @@ func RegisterCiliumTools(s *server.MCPServer, kubeconfig string) { mcp.WithString("list_read_only", mcp.Description("Whether to list read-only configuration options")), mcp.WithString("list_options", mcp.Description("Whether to list options")), mcp.WithString("node_name", mcp.Description("The name of the node to show the configuration options for")), - ), handleShowConfigurationOptions) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_show_configuration_options", handleShowConfigurationOptions))) s.AddTool(mcp.NewTool("cilium_toggle_configuration_option", mcp.WithDescription("Toggle a Cilium configuration option"), mcp.WithString("option", mcp.Description("The option to toggle"), mcp.Required()), mcp.WithString("value", mcp.Description("The value to set the option to (true/false)"), mcp.Required()), mcp.WithString("node_name", mcp.Description("The name of the node to toggle the configuration option for")), - ), handleToggleConfigurationOption) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_toggle_configuration_option", handleToggleConfigurationOption))) s.AddTool(mcp.NewTool("cilium_list_services", mcp.WithDescription("List services for the cluster"), mcp.WithString("show_cluster_mesh_affinity", mcp.Description("Whether to show cluster mesh affinity")), mcp.WithString("node_name", mcp.Description("The name of the node to get the services for")), - ), handleListServices) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_services", handleListServices))) s.AddTool(mcp.NewTool("cilium_get_service_information", mcp.WithDescription("Get information about a service in the cluster"), mcp.WithString("service_id", mcp.Description("The ID of the service to get information about"), mcp.Required()), mcp.WithString("node_name", mcp.Description("The name of the node to get the service information for")), - ), handleGetServiceInformation) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_service_information", handleGetServiceInformation))) s.AddTool(mcp.NewTool("cilium_update_service", mcp.WithDescription("Update a service in the cluster"), @@ -335,35 +331,258 @@ func RegisterCiliumTools(s *server.MCPServer, kubeconfig string) { mcp.WithString("protocol", mcp.Description("The protocol to update the service with")), mcp.WithString("states", mcp.Description("The states to update the service with")), mcp.WithString("node_name", mcp.Description("The name of the node to update the service on")), - ), handleUpdateService) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_update_service", handleUpdateService))) s.AddTool(mcp.NewTool("cilium_delete_service", mcp.WithDescription("Delete a service from the cluster"), mcp.WithString("service_id", mcp.Description("The ID of the service to delete")), mcp.WithString("all", mcp.Description("Whether to delete all services (true/false)")), mcp.WithString("node_name", mcp.Description("The name of the node to delete the service from")), - ), handleDeleteService) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_delete_service", handleDeleteService))) + + // Debug tools (previously in RegisterCiliumDbgTools) + s.AddTool(mcp.NewTool("cilium_get_endpoint_details", + mcp.WithDescription("List the details of an endpoint in the cluster"), + mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to get details for")), + mcp.WithString("labels", mcp.Description("The labels of the endpoint to get details for")), + mcp.WithString("output_format", mcp.Description("The output format of the endpoint details (json, yaml, jsonpath)")), + mcp.WithString("node_name", mcp.Description("The name of the node to get the endpoint details for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_endpoint_details", handleGetEndpointDetails))) + + s.AddTool(mcp.NewTool("cilium_get_endpoint_logs", + mcp.WithDescription("Get the logs of an endpoint in the cluster"), + mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to get logs for"), mcp.Required()), + mcp.WithString("node_name", mcp.Description("The name of the node to get the endpoint logs for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_endpoint_logs", handleGetEndpointLogs))) + + s.AddTool(mcp.NewTool("cilium_get_endpoint_health", + mcp.WithDescription("Get the health of an endpoint in the cluster"), + mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to get health for"), mcp.Required()), + mcp.WithString("node_name", mcp.Description("The name of the node to get the endpoint health for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_endpoint_health", handleGetEndpointHealth))) + + s.AddTool(mcp.NewTool("cilium_manage_endpoint_labels", + mcp.WithDescription("Manage the labels (add or delete) of an endpoint in the cluster"), + mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to manage labels for"), mcp.Required()), + mcp.WithString("labels", mcp.Description("Space-separated labels to manage (e.g., 'key1=value1 key2=value2')"), mcp.Required()), + mcp.WithString("action", mcp.Description("The action to perform on the labels (add or delete)"), mcp.Required()), + mcp.WithString("node_name", mcp.Description("The name of the node to manage the endpoint labels on")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_manage_endpoint_labels", handleManageEndpointLabels))) + + s.AddTool(mcp.NewTool("cilium_manage_endpoint_config", + mcp.WithDescription("Manage the configuration of an endpoint in the cluster"), + mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to manage configuration for"), mcp.Required()), + mcp.WithString("config", mcp.Description("The configuration to manage for the endpoint provided as a space-separated list of key-value pairs (e.g. 'DropNotification=false TraceNotification=false')"), mcp.Required()), + mcp.WithString("node_name", mcp.Description("The name of the node to manage the endpoint configuration on")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_manage_endpoint_config", handleManageEndpointConfiguration))) + + s.AddTool(mcp.NewTool("cilium_disconnect_endpoint", + mcp.WithDescription("Disconnect an endpoint from the network"), + mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to disconnect"), mcp.Required()), + mcp.WithString("node_name", mcp.Description("The name of the node to disconnect the endpoint from")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_disconnect_endpoint", handleDisconnectEndpoint))) + + s.AddTool(mcp.NewTool("cilium_list_identities", + mcp.WithDescription("List all identities in the cluster"), + mcp.WithString("node_name", mcp.Description("The name of the node to list the identities for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_identities", handleListIdentities))) + + s.AddTool(mcp.NewTool("cilium_get_identity_details", + mcp.WithDescription("Get the details of an identity in the cluster"), + mcp.WithString("identity_id", mcp.Description("The ID of the identity to get details for"), mcp.Required()), + mcp.WithString("node_name", mcp.Description("The name of the node to get the identity details for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_identity_details", handleGetIdentityDetails))) + + s.AddTool(mcp.NewTool("cilium_request_debugging_information", + mcp.WithDescription("Request debugging information for the cluster"), + mcp.WithString("node_name", mcp.Description("The name of the node to get the debugging information for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_request_debugging_information", handleRequestDebuggingInformation))) + + s.AddTool(mcp.NewTool("cilium_display_encryption_state", + mcp.WithDescription("Display the encryption state for the cluster"), + mcp.WithString("node_name", mcp.Description("The name of the node to get the encryption state for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_display_encryption_state", handleDisplayEncryptionState))) + + s.AddTool(mcp.NewTool("cilium_flush_ipsec_state", + mcp.WithDescription("Flush the IPsec state for the cluster"), + mcp.WithString("node_name", mcp.Description("The name of the node to flush the IPsec state for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_flush_ipsec_state", handleFlushIPsecState))) + + s.AddTool(mcp.NewTool("cilium_list_envoy_config", + mcp.WithDescription("List the Envoy configuration for a resource in the cluster"), + mcp.WithString("resource_name", mcp.Description("The name of the resource to get the Envoy configuration for"), mcp.Required()), + mcp.WithString("node_name", mcp.Description("The name of the node to get the Envoy configuration for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_envoy_config", handleListEnvoyConfig))) + + s.AddTool(mcp.NewTool("cilium_fqdn_cache", + mcp.WithDescription("Manage the FQDN cache for the cluster"), + mcp.WithString("command", mcp.Description("The command to perform on the FQDN cache (list, clean, or a specific command)"), mcp.Required()), + mcp.WithString("node_name", mcp.Description("The name of the node to manage the FQDN cache for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_fqdn_cache", handleFQDNCache))) + + s.AddTool(mcp.NewTool("cilium_show_dns_names", + mcp.WithDescription("Show the DNS names for the cluster"), + mcp.WithString("node_name", mcp.Description("The name of the node to get the DNS names for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_show_dns_names", handleShowDNSNames))) + + s.AddTool(mcp.NewTool("cilium_list_ip_addresses", + mcp.WithDescription("List the IP addresses for the cluster"), + mcp.WithString("node_name", mcp.Description("The name of the node to get the IP addresses for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_ip_addresses", handleListIPAddresses))) + + s.AddTool(mcp.NewTool("cilium_show_ip_cache_information", + mcp.WithDescription("Show the IP cache information for the cluster"), + mcp.WithString("cidr", mcp.Description("The CIDR of the IP to get cache information for")), + mcp.WithString("labels", mcp.Description("The labels of the IP to get cache information for")), + mcp.WithString("node_name", mcp.Description("The name of the node to get the IP cache information for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_show_ip_cache_information", handleShowIPCacheInformation))) + + s.AddTool(mcp.NewTool("cilium_delete_key_from_kv_store", + mcp.WithDescription("Delete a key from the kvstore for the cluster"), + mcp.WithString("key", mcp.Description("The key to delete from the kvstore"), mcp.Required()), + mcp.WithString("node_name", mcp.Description("The name of the node to delete the key from")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_delete_key_from_kv_store", handleDeleteKeyFromKVStore))) + + s.AddTool(mcp.NewTool("cilium_get_kv_store_key", + mcp.WithDescription("Get a key from the kvstore for the cluster"), + mcp.WithString("key", mcp.Description("The key to get from the kvstore"), mcp.Required()), + mcp.WithString("node_name", mcp.Description("The name of the node to get the key from")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_kv_store_key", handleGetKVStoreKey))) + + s.AddTool(mcp.NewTool("cilium_set_kv_store_key", + mcp.WithDescription("Set a key in the kvstore for the cluster"), + mcp.WithString("key", mcp.Description("The key to set in the kvstore"), mcp.Required()), + mcp.WithString("value", mcp.Description("The value to set in the kvstore"), mcp.Required()), + mcp.WithString("node_name", mcp.Description("The name of the node to set the key in")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_set_kv_store_key", handleSetKVStoreKey))) + + s.AddTool(mcp.NewTool("cilium_show_load_information", + mcp.WithDescription("Show load information for the cluster"), + mcp.WithString("node_name", mcp.Description("The name of the node to get the load information for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_show_load_information", handleShowLoadInformation))) + + s.AddTool(mcp.NewTool("cilium_list_local_redirect_policies", + mcp.WithDescription("List local redirect policies for the cluster"), + mcp.WithString("node_name", mcp.Description("The name of the node to get the local redirect policies for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_local_redirect_policies", handleListLocalRedirectPolicies))) + + s.AddTool(mcp.NewTool("cilium_list_bpf_map_events", + mcp.WithDescription("List BPF map events for the cluster"), + mcp.WithString("map_name", mcp.Description("The name of the BPF map to get events for"), mcp.Required()), + mcp.WithString("node_name", mcp.Description("The name of the node to get the BPF map events for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_bpf_map_events", handleListBPFMapEvents))) + + s.AddTool(mcp.NewTool("cilium_get_bpf_map", + mcp.WithDescription("Get BPF map for the cluster"), + mcp.WithString("map_name", mcp.Description("The name of the BPF map to get"), mcp.Required()), + mcp.WithString("node_name", mcp.Description("The name of the node to get the BPF map for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_bpf_map", handleGetBPFMap))) + + s.AddTool(mcp.NewTool("cilium_list_bpf_maps", + mcp.WithDescription("List BPF maps for the cluster"), + mcp.WithString("node_name", mcp.Description("The name of the node to get the BPF maps for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_bpf_maps", handleListBPFMaps))) + + s.AddTool(mcp.NewTool("cilium_list_metrics", + mcp.WithDescription("List metrics for the cluster"), + mcp.WithString("match_pattern", mcp.Description("The match pattern to filter metrics by")), + mcp.WithString("node_name", mcp.Description("The name of the node to get the metrics for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_metrics", handleListMetrics))) + + s.AddTool(mcp.NewTool("cilium_list_cluster_nodes", + mcp.WithDescription("List cluster nodes for the cluster"), + mcp.WithString("node_name", mcp.Description("The name of the node to get the cluster nodes for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_cluster_nodes", handleListClusterNodes))) + + s.AddTool(mcp.NewTool("cilium_list_node_ids", + mcp.WithDescription("List node IDs for the cluster"), + mcp.WithString("node_name", mcp.Description("The name of the node to get the node IDs for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_node_ids", handleListNodeIds))) + + s.AddTool(mcp.NewTool("cilium_display_policy_node_information", + mcp.WithDescription("Display policy node information for the cluster"), + mcp.WithString("labels", mcp.Description("The labels to get policy node information for")), + mcp.WithString("node_name", mcp.Description("The name of the node to get policy node information for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_display_policy_node_information", handleDisplayPolicyNodeInformation))) + + s.AddTool(mcp.NewTool("cilium_delete_policy_rules", + mcp.WithDescription("Delete policy rules for the cluster"), + mcp.WithString("labels", mcp.Description("The labels to delete policy rules for")), + mcp.WithString("all", mcp.Description("Whether to delete all policy rules")), + mcp.WithString("node_name", mcp.Description("The name of the node to delete policy rules for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_delete_policy_rules", handleDeletePolicyRules))) + + s.AddTool(mcp.NewTool("cilium_display_selectors", + mcp.WithDescription("Display selectors for the cluster"), + mcp.WithString("node_name", mcp.Description("The name of the node to get selectors for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_display_selectors", handleDisplaySelectors))) + + s.AddTool(mcp.NewTool("cilium_list_xdp_cidr_filters", + mcp.WithDescription("List XDP CIDR filters for the cluster"), + mcp.WithString("node_name", mcp.Description("The name of the node to get the XDP CIDR filters for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_xdp_cidr_filters", handleListXDPCIDRFilters))) + + s.AddTool(mcp.NewTool("cilium_update_xdp_cidr_filters", + mcp.WithDescription("Update XDP CIDR filters for the cluster"), + mcp.WithString("cidr_prefixes", mcp.Description("The CIDR prefixes to update the XDP filters for"), mcp.Required()), + mcp.WithString("revision", mcp.Description("The revision of the XDP filters to update")), + mcp.WithString("node_name", mcp.Description("The name of the node to update the XDP filters for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_update_xdp_cidr_filters", handleUpdateXDPCIDRFilters))) + + s.AddTool(mcp.NewTool("cilium_delete_xdp_cidr_filters", + mcp.WithDescription("Delete XDP CIDR filters for the cluster"), + mcp.WithString("cidr_prefixes", mcp.Description("The CIDR prefixes to delete the XDP filters for"), mcp.Required()), + mcp.WithString("revision", mcp.Description("The revision of the XDP filters to delete")), + mcp.WithString("node_name", mcp.Description("The name of the node to delete the XDP filters for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_delete_xdp_cidr_filters", handleDeleteXDPCIDRFilters))) + + s.AddTool(mcp.NewTool("cilium_validate_cilium_network_policies", + mcp.WithDescription("Validate Cilium network policies for the cluster"), + mcp.WithString("enable_k8s", mcp.Description("Whether to enable k8s API discovery")), + mcp.WithString("enable_k8s_api_discovery", mcp.Description("Whether to enable k8s API discovery")), + mcp.WithString("node_name", mcp.Description("The name of the node to validate the Cilium network policies for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_validate_cilium_network_policies", handleValidateCiliumNetworkPolicies))) + + s.AddTool(mcp.NewTool("cilium_list_pcap_recorders", + mcp.WithDescription("List PCAP recorders for the cluster"), + mcp.WithString("node_name", mcp.Description("The name of the node to get the PCAP recorders for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_pcap_recorders", handleListPCAPRecorders))) + + s.AddTool(mcp.NewTool("cilium_get_pcap_recorder", + mcp.WithDescription("Get a PCAP recorder for the cluster"), + mcp.WithString("recorder_id", mcp.Description("The ID of the PCAP recorder to get"), mcp.Required()), + mcp.WithString("node_name", mcp.Description("The name of the node to get the PCAP recorder for")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_pcap_recorder", handleGetPCAPRecorder))) + + s.AddTool(mcp.NewTool("cilium_delete_pcap_recorder", + mcp.WithDescription("Delete a PCAP recorder for the cluster"), + mcp.WithString("recorder_id", mcp.Description("The ID of the PCAP recorder to delete"), mcp.Required()), + mcp.WithString("node_name", mcp.Description("The name of the node to delete the PCAP recorder from")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_delete_pcap_recorder", handleDeletePCAPRecorder))) + + s.AddTool(mcp.NewTool("cilium_update_pcap_recorder", + mcp.WithDescription("Update a PCAP recorder for the cluster"), + mcp.WithString("recorder_id", mcp.Description("The ID of the PCAP recorder to update"), mcp.Required()), + mcp.WithString("filters", mcp.Description("The filters to update the PCAP recorder with"), mcp.Required()), + mcp.WithString("caplen", mcp.Description("The caplen to update the PCAP recorder with")), + mcp.WithString("id", mcp.Description("The id to update the PCAP recorder with")), + mcp.WithString("node_name", mcp.Description("The name of the node to update the PCAP recorder on")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_update_pcap_recorder", handleUpdatePCAPRecorder))) } // -- Debug Tools -- func getCiliumPodNameWithContext(ctx context.Context, nodeName string) (string, error) { - args := []string{"get", "pod", "-l", "k8s-app=cilium", "-o", "name", "-n", "kube-system"} - if nodeName != "" { - args = append(args, "--field-selector", "spec.nodeName="+nodeName) - } - podName, err := utils.RunCommandWithContext(ctx, "kubectl", args) - if err != nil { - return "", fmt.Errorf("failed to get cilium pod name: %v", err) - } - if podName == "" { - return "", fmt.Errorf("no cilium pod found") - } - return strings.TrimSpace(podName), nil + args := []string{"get", "pods", "-n", "kube-system", "--selector=k8s-app=cilium", fmt.Sprintf("--field-selector=spec.nodeName=%s", nodeName), "-o", "jsonpath={.items[0].metadata.name}"} + kubeconfigPath := utils.GetKubeconfig() + return commands.NewCommandBuilder("kubectl"). + WithArgs(args...). + WithKubeconfig(kubeconfigPath). + Execute(ctx) } -func runCiliumDbgCommand(command, nodeName string) (string, error) { - return runCiliumDbgCommandWithContext(context.Background(), command, nodeName) +func runCiliumDbgCommand(ctx context.Context, command, nodeName string) (string, error) { + return runCiliumDbgCommandWithContext(ctx, command, nodeName) } func runCiliumDbgCommandWithContext(ctx context.Context, command, nodeName string) (string, error) { @@ -371,10 +590,12 @@ func runCiliumDbgCommandWithContext(ctx context.Context, command, nodeName strin if err != nil { return "", err } - cmdParts := strings.Fields(command) - args := []string{"exec", "-it", podName, "-n", "kube-system", "--", "cilium-dbg"} - args = append(args, cmdParts...) - return utils.RunCommandWithContext(ctx, "kubectl", args) + args := []string{"exec", "-it", podName, "--", "cilium-dbg", command} + kubeconfigPath := utils.GetKubeconfig() + return commands.NewCommandBuilder("kubectl"). + WithArgs(args...). + WithKubeconfig(kubeconfigPath). + Execute(ctx) } func handleGetEndpointDetails(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -392,7 +613,7 @@ func handleGetEndpointDetails(ctx context.Context, request mcp.CallToolRequest) return mcp.NewToolResultError("either endpoint_id or labels must be provided"), nil } - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get endpoint details: %v", err)), nil } @@ -408,7 +629,7 @@ func handleGetEndpointLogs(ctx context.Context, request mcp.CallToolRequest) (*m } cmd := fmt.Sprintf("endpoint logs %s", endpointID) - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get endpoint logs: %v", err)), nil } @@ -424,7 +645,7 @@ func handleGetEndpointHealth(ctx context.Context, request mcp.CallToolRequest) ( } cmd := fmt.Sprintf("endpoint health %s", endpointID) - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get endpoint health: %v", err)), nil } @@ -442,7 +663,7 @@ func handleManageEndpointLabels(ctx context.Context, request mcp.CallToolRequest } cmd := fmt.Sprintf("endpoint labels %s --%s %s", endpointID, action, labels) - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to manage endpoint labels: %v", err)), nil } @@ -462,7 +683,7 @@ func handleManageEndpointConfiguration(ctx context.Context, request mcp.CallTool } command := fmt.Sprintf("endpoint config %s %s", endpointID, config) - output, err := runCiliumDbgCommand(command, nodeName) + output, err := runCiliumDbgCommand(ctx, command, nodeName) if err != nil { return mcp.NewToolResultError("Error managing endpoint configuration: " + err.Error()), nil } @@ -479,7 +700,7 @@ func handleDisconnectEndpoint(ctx context.Context, request mcp.CallToolRequest) } cmd := fmt.Sprintf("endpoint disconnect %s", endpointID) - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to disconnect endpoint: %v", err)), nil } @@ -489,7 +710,7 @@ func handleDisconnectEndpoint(ctx context.Context, request mcp.CallToolRequest) func handleGetEndpointsList(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { nodeName := mcp.ParseString(request, "node_name", "") - output, err := runCiliumDbgCommand("endpoint list", nodeName) + output, err := runCiliumDbgCommand(ctx, "endpoint list", nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get endpoints list: %v", err)), nil } @@ -499,7 +720,7 @@ func handleGetEndpointsList(ctx context.Context, request mcp.CallToolRequest) (* func handleListIdentities(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { nodeName := mcp.ParseString(request, "node_name", "") - output, err := runCiliumDbgCommand("identity list", nodeName) + output, err := runCiliumDbgCommand(ctx, "identity list", nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to list identities: %v", err)), nil } @@ -515,7 +736,7 @@ func handleGetIdentityDetails(ctx context.Context, request mcp.CallToolRequest) } cmd := fmt.Sprintf("identity get %s", identityID) - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get identity details: %v", err)), nil } @@ -539,7 +760,7 @@ func handleShowConfigurationOptions(ctx context.Context, request mcp.CallToolReq cmd = "endpoint config" } - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to show configuration options: %v", err)), nil } @@ -561,7 +782,7 @@ func handleToggleConfigurationOption(ctx context.Context, request mcp.CallToolRe } cmd := fmt.Sprintf("endpoint config %s=%s", option, valueStr) - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to toggle configuration option: %v", err)), nil } @@ -571,7 +792,7 @@ func handleToggleConfigurationOption(ctx context.Context, request mcp.CallToolRe func handleRequestDebuggingInformation(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { nodeName := mcp.ParseString(request, "node_name", "") - output, err := runCiliumDbgCommand("debuginfo", nodeName) + output, err := runCiliumDbgCommand(ctx, "debuginfo", nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to request debugging information: %v", err)), nil } @@ -581,7 +802,7 @@ func handleRequestDebuggingInformation(ctx context.Context, request mcp.CallTool func handleDisplayEncryptionState(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { nodeName := mcp.ParseString(request, "node_name", "") - output, err := runCiliumDbgCommand("encrypt status", nodeName) + output, err := runCiliumDbgCommand(ctx, "encrypt status", nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to display encryption state: %v", err)), nil } @@ -591,7 +812,7 @@ func handleDisplayEncryptionState(ctx context.Context, request mcp.CallToolReque func handleFlushIPsecState(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { nodeName := mcp.ParseString(request, "node_name", "") - output, err := runCiliumDbgCommand("encrypt flush -f", nodeName) + output, err := runCiliumDbgCommand(ctx, "encrypt flush -f", nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to flush IPsec state: %v", err)), nil } @@ -607,7 +828,7 @@ func handleListEnvoyConfig(ctx context.Context, request mcp.CallToolRequest) (*m } cmd := fmt.Sprintf("envoy admin %s", resourceName) - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to list Envoy config: %v", err)), nil } @@ -620,12 +841,12 @@ func handleFQDNCache(ctx context.Context, request mcp.CallToolRequest) (*mcp.Cal var cmd string if command == "clean" { - cmd = "fqdn cache clean -f" + cmd = "fqdn cache clean" } else { - cmd = fmt.Sprintf("fqdn cache %s", command) + cmd = "fqdn cache list" } - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to manage FQDN cache: %v", err)), nil } @@ -635,7 +856,7 @@ func handleFQDNCache(ctx context.Context, request mcp.CallToolRequest) (*mcp.Cal func handleShowDNSNames(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { nodeName := mcp.ParseString(request, "node_name", "") - output, err := runCiliumDbgCommand("dns names", nodeName) + output, err := runCiliumDbgCommand(ctx, "dns names", nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to show DNS names: %v", err)), nil } @@ -645,7 +866,7 @@ func handleShowDNSNames(ctx context.Context, request mcp.CallToolRequest) (*mcp. func handleListIPAddresses(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { nodeName := mcp.ParseString(request, "node_name", "") - output, err := runCiliumDbgCommand("ip list", nodeName) + output, err := runCiliumDbgCommand(ctx, "ip list", nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to list IP addresses: %v", err)), nil } @@ -666,7 +887,7 @@ func handleShowIPCacheInformation(ctx context.Context, request mcp.CallToolReque return mcp.NewToolResultError("either cidr or labels must be provided"), nil } - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to show IP cache information: %v", err)), nil } @@ -682,7 +903,7 @@ func handleDeleteKeyFromKVStore(ctx context.Context, request mcp.CallToolRequest } cmd := fmt.Sprintf("kvstore delete %s", key) - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to delete key from kvstore: %v", err)), nil } @@ -698,7 +919,7 @@ func handleGetKVStoreKey(ctx context.Context, request mcp.CallToolRequest) (*mcp } cmd := fmt.Sprintf("kvstore get %s", key) - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get key from kvstore: %v", err)), nil } @@ -715,7 +936,7 @@ func handleSetKVStoreKey(ctx context.Context, request mcp.CallToolRequest) (*mcp } cmd := fmt.Sprintf("kvstore set %s=%s", key, value) - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to set key in kvstore: %v", err)), nil } @@ -725,7 +946,7 @@ func handleSetKVStoreKey(ctx context.Context, request mcp.CallToolRequest) (*mcp func handleShowLoadInformation(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { nodeName := mcp.ParseString(request, "node_name", "") - output, err := runCiliumDbgCommand("loadinfo", nodeName) + output, err := runCiliumDbgCommand(ctx, "loadinfo", nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to show load information: %v", err)), nil } @@ -735,7 +956,7 @@ func handleShowLoadInformation(ctx context.Context, request mcp.CallToolRequest) func handleListLocalRedirectPolicies(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { nodeName := mcp.ParseString(request, "node_name", "") - output, err := runCiliumDbgCommand("lrp list", nodeName) + output, err := runCiliumDbgCommand(ctx, "lrp list", nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to list local redirect policies: %v", err)), nil } @@ -751,7 +972,7 @@ func handleListBPFMapEvents(ctx context.Context, request mcp.CallToolRequest) (* } cmd := fmt.Sprintf("bpf map events %s", mapName) - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to list BPF map events: %v", err)), nil } @@ -767,7 +988,7 @@ func handleGetBPFMap(ctx context.Context, request mcp.CallToolRequest) (*mcp.Cal } cmd := fmt.Sprintf("bpf map get %s", mapName) - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get BPF map: %v", err)), nil } @@ -777,7 +998,7 @@ func handleGetBPFMap(ctx context.Context, request mcp.CallToolRequest) (*mcp.Cal func handleListBPFMaps(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { nodeName := mcp.ParseString(request, "node_name", "") - output, err := runCiliumDbgCommand("bpf map list", nodeName) + output, err := runCiliumDbgCommand(ctx, "bpf map list", nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to list BPF maps: %v", err)), nil } @@ -795,7 +1016,7 @@ func handleListMetrics(ctx context.Context, request mcp.CallToolRequest) (*mcp.C cmd = "metrics list" } - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to list metrics: %v", err)), nil } @@ -805,7 +1026,7 @@ func handleListMetrics(ctx context.Context, request mcp.CallToolRequest) (*mcp.C func handleListClusterNodes(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { nodeName := mcp.ParseString(request, "node_name", "") - output, err := runCiliumDbgCommand("nodes list", nodeName) + output, err := runCiliumDbgCommand(ctx, "nodes list", nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to list cluster nodes: %v", err)), nil } @@ -815,7 +1036,7 @@ func handleListClusterNodes(ctx context.Context, request mcp.CallToolRequest) (* func handleListNodeIds(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { nodeName := mcp.ParseString(request, "node_name", "") - output, err := runCiliumDbgCommand("nodeid list", nodeName) + output, err := runCiliumDbgCommand(ctx, "nodeid list", nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to list node IDs: %v", err)), nil } @@ -833,7 +1054,7 @@ func handleDisplayPolicyNodeInformation(ctx context.Context, request mcp.CallToo cmd = "policy get" } - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to display policy node information: %v", err)), nil } @@ -854,7 +1075,7 @@ func handleDeletePolicyRules(ctx context.Context, request mcp.CallToolRequest) ( return mcp.NewToolResultError("either labels or all=true must be provided"), nil } - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to delete policy rules: %v", err)), nil } @@ -864,7 +1085,7 @@ func handleDeletePolicyRules(ctx context.Context, request mcp.CallToolRequest) ( func handleDisplaySelectors(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { nodeName := mcp.ParseString(request, "node_name", "") - output, err := runCiliumDbgCommand("policy selectors", nodeName) + output, err := runCiliumDbgCommand(ctx, "policy selectors", nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to display selectors: %v", err)), nil } @@ -874,7 +1095,7 @@ func handleDisplaySelectors(ctx context.Context, request mcp.CallToolRequest) (* func handleListXDPCIDRFilters(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { nodeName := mcp.ParseString(request, "node_name", "") - output, err := runCiliumDbgCommand("prefilter list", nodeName) + output, err := runCiliumDbgCommand(ctx, "prefilter list", nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to list XDP CIDR filters: %v", err)), nil } @@ -897,7 +1118,7 @@ func handleUpdateXDPCIDRFilters(ctx context.Context, request mcp.CallToolRequest cmd = fmt.Sprintf("prefilter update --cidr %s", cidrPrefixes) } - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to update XDP CIDR filters: %v", err)), nil } @@ -920,7 +1141,7 @@ func handleDeleteXDPCIDRFilters(ctx context.Context, request mcp.CallToolRequest cmd = fmt.Sprintf("prefilter delete --cidr %s", cidrPrefixes) } - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to delete XDP CIDR filters: %v", err)), nil } @@ -940,7 +1161,7 @@ func handleValidateCiliumNetworkPolicies(ctx context.Context, request mcp.CallTo cmd += " --enable-k8s-api-discovery" } - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to validate Cilium network policies: %v", err)), nil } @@ -950,7 +1171,7 @@ func handleValidateCiliumNetworkPolicies(ctx context.Context, request mcp.CallTo func handleListPCAPRecorders(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { nodeName := mcp.ParseString(request, "node_name", "") - output, err := runCiliumDbgCommand("recorder list", nodeName) + output, err := runCiliumDbgCommand(ctx, "recorder list", nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to list PCAP recorders: %v", err)), nil } @@ -966,7 +1187,7 @@ func handleGetPCAPRecorder(ctx context.Context, request mcp.CallToolRequest) (*m } cmd := fmt.Sprintf("recorder get %s", recorderID) - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get PCAP recorder: %v", err)), nil } @@ -982,7 +1203,7 @@ func handleDeletePCAPRecorder(ctx context.Context, request mcp.CallToolRequest) } cmd := fmt.Sprintf("recorder delete %s", recorderID) - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to delete PCAP recorder: %v", err)), nil } @@ -1001,7 +1222,7 @@ func handleUpdatePCAPRecorder(ctx context.Context, request mcp.CallToolRequest) } cmd := fmt.Sprintf("recorder update %s --filters %s --caplen %s --id %s", recorderID, filters, caplen, id) - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to update PCAP recorder: %v", err)), nil } @@ -1019,7 +1240,7 @@ func handleListServices(ctx context.Context, request mcp.CallToolRequest) (*mcp. cmd = "service list" } - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to list services: %v", err)), nil } @@ -1035,7 +1256,7 @@ func handleGetServiceInformation(ctx context.Context, request mcp.CallToolReques } cmd := fmt.Sprintf("service get %s", serviceID) - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get service information: %v", err)), nil } @@ -1056,7 +1277,7 @@ func handleDeleteService(ctx context.Context, request mcp.CallToolRequest) (*mcp return mcp.NewToolResultError("either service_id or all=true must be provided"), nil } - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to delete service: %v", err)), nil } @@ -1115,7 +1336,7 @@ func handleUpdateService(ctx context.Context, request mcp.CallToolRequest) (*mcp cmd += " --local-redirect" } - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to update service: %v", err)), nil } @@ -1155,239 +1376,9 @@ func handleGetDaemonStatus(ctx context.Context, request mcp.CallToolRequest) (*m cmd += " --brief" } - output, err := runCiliumDbgCommand(cmd, nodeName) + output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get daemon status: %v", err)), nil } return mcp.NewToolResultText(output), nil } - -func RegisterCiliumDbgTools(s *server.MCPServer) { - s.AddTool(mcp.NewTool("cilium_get_endpoint_details", - mcp.WithDescription("List the details of an endpoint in the cluster"), - mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to get details for")), - mcp.WithString("labels", mcp.Description("The labels of the endpoint to get details for")), - mcp.WithString("output_format", mcp.Description("The output format of the endpoint details (json, yaml, jsonpath)")), - mcp.WithString("node_name", mcp.Description("The name of the node to get the endpoint details for")), - ), handleGetEndpointDetails) - - s.AddTool(mcp.NewTool("cilium_get_endpoint_logs", - mcp.WithDescription("Get the logs of an endpoint in the cluster"), - mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to get logs for"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to get the endpoint logs for")), - ), handleGetEndpointLogs) - - s.AddTool(mcp.NewTool("cilium_get_endpoint_health", - mcp.WithDescription("Get the health of an endpoint in the cluster"), - mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to get health for"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to get the endpoint health for")), - ), handleGetEndpointHealth) - - s.AddTool(mcp.NewTool("cilium_manage_endpoint_labels", - mcp.WithDescription("Manage the labels (add or delete) of an endpoint in the cluster"), - mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to manage labels for"), mcp.Required()), - mcp.WithString("labels", mcp.Description("Space-separated labels to manage (e.g., 'key1=value1 key2=value2')"), mcp.Required()), - mcp.WithString("action", mcp.Description("The action to perform on the labels (add or delete)"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to manage the endpoint labels on")), - ), handleManageEndpointLabels) - - s.AddTool(mcp.NewTool("cilium_manage_endpoint_config", - mcp.WithDescription("Manage the configuration of an endpoint in the cluster"), - mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to manage configuration for"), mcp.Required()), - mcp.WithString("config", mcp.Description("The configuration to manage for the endpoint provided as a space-separated list of key-value pairs (e.g. 'DropNotification=false TraceNotification=false')"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to manage the endpoint configuration on")), - ), handleManageEndpointConfiguration) - - s.AddTool(mcp.NewTool("cilium_disconnect_endpoint", - mcp.WithDescription("Disconnect an endpoint from the network"), - mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to disconnect"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to disconnect the endpoint from")), - ), handleDisconnectEndpoint) - - s.AddTool(mcp.NewTool("cilium_list_identities", - mcp.WithDescription("List all identities in the cluster"), - mcp.WithString("node_name", mcp.Description("The name of the node to list the identities for")), - ), handleListIdentities) - - s.AddTool(mcp.NewTool("cilium_get_identity_details", - mcp.WithDescription("Get the details of an identity in the cluster"), - mcp.WithString("identity_id", mcp.Description("The ID of the identity to get details for"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to get the identity details for")), - ), handleGetIdentityDetails) - - s.AddTool(mcp.NewTool("cilium_request_debugging_information", - mcp.WithDescription("Request debugging information for the cluster"), - mcp.WithString("node_name", mcp.Description("The name of the node to get the debugging information for")), - ), handleRequestDebuggingInformation) - - s.AddTool(mcp.NewTool("cilium_display_encryption_state", - mcp.WithDescription("Display the encryption state for the cluster"), - mcp.WithString("node_name", mcp.Description("The name of the node to get the encryption state for")), - ), handleDisplayEncryptionState) - - s.AddTool(mcp.NewTool("cilium_flush_ipsec_state", - mcp.WithDescription("Flush the IPsec state for the cluster"), - mcp.WithString("node_name", mcp.Description("The name of the node to flush the IPsec state for")), - ), handleFlushIPsecState) - - s.AddTool(mcp.NewTool("cilium_list_envoy_config", - mcp.WithDescription("List the Envoy configuration for a resource in the cluster"), - mcp.WithString("resource_name", mcp.Description("The name of the resource to get the Envoy configuration for"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to get the Envoy configuration for")), - ), handleListEnvoyConfig) - - s.AddTool(mcp.NewTool("cilium_fqdn_cache", - mcp.WithDescription("Manage the FQDN cache for the cluster"), - mcp.WithString("command", mcp.Description("The command to perform on the FQDN cache (list, clean, or a specific command)"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to manage the FQDN cache for")), - ), handleFQDNCache) - - s.AddTool(mcp.NewTool("cilium_show_dns_names", - mcp.WithDescription("Show the DNS names for the cluster"), - mcp.WithString("node_name", mcp.Description("The name of the node to get the DNS names for")), - ), handleShowDNSNames) - - s.AddTool(mcp.NewTool("cilium_list_ip_addresses", - mcp.WithDescription("List the IP addresses for the cluster"), - mcp.WithString("node_name", mcp.Description("The name of the node to get the IP addresses for")), - ), handleListIPAddresses) - - s.AddTool(mcp.NewTool("cilium_show_ip_cache_information", - mcp.WithDescription("Show the IP cache information for the cluster"), - mcp.WithString("cidr", mcp.Description("The CIDR of the IP to get cache information for")), - mcp.WithString("labels", mcp.Description("The labels of the IP to get cache information for")), - mcp.WithString("node_name", mcp.Description("The name of the node to get the IP cache information for")), - ), handleShowIPCacheInformation) - - s.AddTool(mcp.NewTool("cilium_delete_key_from_kv_store", - mcp.WithDescription("Delete a key from the kvstore for the cluster"), - mcp.WithString("key", mcp.Description("The key to delete from the kvstore"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to delete the key from")), - ), handleDeleteKeyFromKVStore) - - s.AddTool(mcp.NewTool("cilium_get_kv_store_key", - mcp.WithDescription("Get a key from the kvstore for the cluster"), - mcp.WithString("key", mcp.Description("The key to get from the kvstore"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to get the key from")), - ), handleGetKVStoreKey) - - s.AddTool(mcp.NewTool("cilium_set_kv_store_key", - mcp.WithDescription("Set a key in the kvstore for the cluster"), - mcp.WithString("key", mcp.Description("The key to set in the kvstore"), mcp.Required()), - mcp.WithString("value", mcp.Description("The value to set in the kvstore"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to set the key in")), - ), handleSetKVStoreKey) - - s.AddTool(mcp.NewTool("cilium_show_load_information", - mcp.WithDescription("Show load information for the cluster"), - mcp.WithString("node_name", mcp.Description("The name of the node to get the load information for")), - ), handleShowLoadInformation) - - s.AddTool(mcp.NewTool("cilium_list_local_redirect_policies", - mcp.WithDescription("List local redirect policies for the cluster"), - mcp.WithString("node_name", mcp.Description("The name of the node to get the local redirect policies for")), - ), handleListLocalRedirectPolicies) - - s.AddTool(mcp.NewTool("cilium_list_bpf_map_events", - mcp.WithDescription("List BPF map events for the cluster"), - mcp.WithString("map_name", mcp.Description("The name of the BPF map to get events for"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to get the BPF map events for")), - ), handleListBPFMapEvents) - - s.AddTool(mcp.NewTool("cilium_get_bpf_map", - mcp.WithDescription("Get BPF map for the cluster"), - mcp.WithString("map_name", mcp.Description("The name of the BPF map to get"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to get the BPF map for")), - ), handleGetBPFMap) - - s.AddTool(mcp.NewTool("cilium_list_bpf_maps", - mcp.WithDescription("List BPF maps for the cluster"), - mcp.WithString("node_name", mcp.Description("The name of the node to get the BPF maps for")), - ), handleListBPFMaps) - - s.AddTool(mcp.NewTool("cilium_list_metrics", - mcp.WithDescription("List metrics for the cluster"), - mcp.WithString("match_pattern", mcp.Description("The match pattern to filter metrics by")), - mcp.WithString("node_name", mcp.Description("The name of the node to get the metrics for")), - ), handleListMetrics) - - s.AddTool(mcp.NewTool("cilium_list_cluster_nodes", - mcp.WithDescription("List cluster nodes for the cluster"), - mcp.WithString("node_name", mcp.Description("The name of the node to get the cluster nodes for")), - ), handleListClusterNodes) - - s.AddTool(mcp.NewTool("cilium_list_node_ids", - mcp.WithDescription("List node IDs for the cluster"), - mcp.WithString("node_name", mcp.Description("The name of the node to get the node IDs for")), - ), handleListNodeIds) - - s.AddTool(mcp.NewTool("cilium_display_policy_node_information", - mcp.WithDescription("Display policy node information for the cluster"), - mcp.WithString("labels", mcp.Description("The labels to get policy node information for")), - mcp.WithString("node_name", mcp.Description("The name of the node to get policy node information for")), - ), handleDisplayPolicyNodeInformation) - - s.AddTool(mcp.NewTool("cilium_delete_policy_rules", - mcp.WithDescription("Delete policy rules for the cluster"), - mcp.WithString("labels", mcp.Description("The labels to delete policy rules for")), - mcp.WithString("all", mcp.Description("Whether to delete all policy rules")), - mcp.WithString("node_name", mcp.Description("The name of the node to delete policy rules for")), - ), handleDeletePolicyRules) - - s.AddTool(mcp.NewTool("cilium_display_selectors", - mcp.WithDescription("Display selectors for the cluster"), - mcp.WithString("node_name", mcp.Description("The name of the node to get selectors for")), - ), handleDisplaySelectors) - - s.AddTool(mcp.NewTool("cilium_list_xdp_cidr_filters", - mcp.WithDescription("List XDP CIDR filters for the cluster"), - mcp.WithString("node_name", mcp.Description("The name of the node to get the XDP CIDR filters for")), - ), handleListXDPCIDRFilters) - - s.AddTool(mcp.NewTool("cilium_update_xdp_cidr_filters", - mcp.WithDescription("Update XDP CIDR filters for the cluster"), - mcp.WithString("cidr_prefixes", mcp.Description("The CIDR prefixes to update the XDP filters for"), mcp.Required()), - mcp.WithString("revision", mcp.Description("The revision of the XDP filters to update")), - mcp.WithString("node_name", mcp.Description("The name of the node to update the XDP filters for")), - ), handleUpdateXDPCIDRFilters) - - s.AddTool(mcp.NewTool("cilium_delete_xdp_cidr_filters", - mcp.WithDescription("Delete XDP CIDR filters for the cluster"), - mcp.WithString("cidr_prefixes", mcp.Description("The CIDR prefixes to delete the XDP filters for"), mcp.Required()), - mcp.WithString("revision", mcp.Description("The revision of the XDP filters to delete")), - mcp.WithString("node_name", mcp.Description("The name of the node to delete the XDP filters for")), - ), handleDeleteXDPCIDRFilters) - - s.AddTool(mcp.NewTool("cilium_validate_cilium_network_policies", - mcp.WithDescription("Validate Cilium network policies for the cluster"), - mcp.WithString("enable_k8s", mcp.Description("Whether to enable k8s API discovery")), - mcp.WithString("enable_k8s_api_discovery", mcp.Description("Whether to enable k8s API discovery")), - mcp.WithString("node_name", mcp.Description("The name of the node to validate the Cilium network policies for")), - ), handleValidateCiliumNetworkPolicies) - - s.AddTool(mcp.NewTool("cilium_list_pcap_recorders", - mcp.WithDescription("List PCAP recorders for the cluster"), - mcp.WithString("node_name", mcp.Description("The name of the node to get the PCAP recorders for")), - ), handleListPCAPRecorders) - - s.AddTool(mcp.NewTool("cilium_get_pcap_recorder", - mcp.WithDescription("Get a PCAP recorder for the cluster"), - mcp.WithString("recorder_id", mcp.Description("The ID of the PCAP recorder to get"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to get the PCAP recorder for")), - ), handleGetPCAPRecorder) - - s.AddTool(mcp.NewTool("cilium_delete_pcap_recorder", - mcp.WithDescription("Delete a PCAP recorder for the cluster"), - mcp.WithString("recorder_id", mcp.Description("The ID of the PCAP recorder to delete"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to delete the PCAP recorder from")), - ), handleDeletePCAPRecorder) - - s.AddTool(mcp.NewTool("cilium_update_pcap_recorder", - mcp.WithDescription("Update a PCAP recorder for the cluster"), - mcp.WithString("recorder_id", mcp.Description("The ID of the PCAP recorder to update"), mcp.Required()), - mcp.WithString("filters", mcp.Description("The filters to update the PCAP recorder with"), mcp.Required()), - mcp.WithString("caplen", mcp.Description("The caplen to update the PCAP recorder with")), - mcp.WithString("id", mcp.Description("The id to update the PCAP recorder with")), - mcp.WithString("node_name", mcp.Description("The name of the node to update the PCAP recorder on")), - ), handleUpdatePCAPRecorder) -} diff --git a/pkg/cilium/cilium_test.go b/pkg/cilium/cilium_test.go index 5b01846..866de5d 100644 --- a/pkg/cilium/cilium_test.go +++ b/pkg/cilium/cilium_test.go @@ -1,99 +1,269 @@ package cilium import ( + "context" + "errors" + "fmt" + "strings" "testing" + "github.com/kagent-dev/tools/internal/cmd" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -// Basic command construction tests for Cilium CLI commands -// Note: MCP handler tests are in cilium_mcp_test.go +func TestRegisterCiliumTools(t *testing.T) { + s := server.NewMCPServer("test-server", "v0.0.1") + RegisterTools(s) + // We can't directly check the tools, but we can ensure the call doesn't panic +} -func TestCiliumCommandConstruction(t *testing.T) { - t.Run("basic command construction patterns", func(t *testing.T) { - // Test that we can construct basic cilium commands - args := []string{"status"} - assert.Equal(t, "status", args[0]) +func TestHandleCiliumStatusAndVersion(t *testing.T) { + ctx := context.Background() + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("cilium", []string{"status", "--timeout", "30s"}, "Cilium status: OK", nil) + mock.AddCommandString("cilium", []string{"version", "--timeout", "30s"}, "cilium version 1.14.0", nil) - // Test upgrade command with parameters - upgradeArgs := []string{"upgrade"} - if clusterName := "test-cluster"; clusterName != "" { - upgradeArgs = append(upgradeArgs, "--cluster-name", clusterName) - } - if datapathMode := "tunnel"; datapathMode != "" { - upgradeArgs = append(upgradeArgs, "--datapath-mode", datapathMode) - } + ctx = cmd.WithShellExecutor(ctx, mock) - expected := []string{"upgrade", "--cluster-name", "test-cluster", "--datapath-mode", "tunnel"} - assert.Equal(t, expected, upgradeArgs) - }) + result, err := handleCiliumStatusAndVersion(ctx, mcp.CallToolRequest{}) + require.NoError(t, err) + assert.NotNil(t, result) + assert.False(t, result.IsError) - t.Run("install command with parameters", func(t *testing.T) { - args := []string{"install"} - if clusterName := "test-cluster"; clusterName != "" { - args = append(args, "--set", "cluster.name="+clusterName) - } - if clusterID := "123"; clusterID != "" { - args = append(args, "--set", "cluster.id="+clusterID) - } - if datapathMode := "tunnel"; datapathMode != "" { - args = append(args, "--datapath-mode", datapathMode) + var textContent mcp.TextContent + var ok bool + for _, content := range result.Content { + if textContent, ok = content.(mcp.TextContent); ok { + break } + } + require.True(t, ok, "no text content in result") - expected := []string{"install", "--set", "cluster.name=test-cluster", "--set", "cluster.id=123", "--datapath-mode", "tunnel"} - assert.Equal(t, expected, args) - }) + assert.Contains(t, textContent.Text, "Cilium status: OK") + assert.Contains(t, textContent.Text, "cilium version 1.14.0") +} + +func TestHandleCiliumStatusAndVersionError(t *testing.T) { + ctx := context.Background() + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("cilium", []string{"status", "--timeout", "30s"}, "", errors.New("command failed")) + mock.AddCommandString("cilium", []string{"version", "--timeout", "30s"}, "cilium version 1.14.0", nil) + + ctx = cmd.WithShellExecutor(ctx, mock) + + result, err := handleCiliumStatusAndVersion(ctx, mcp.CallToolRequest{}) + require.NoError(t, err) + assert.NotNil(t, result) + assert.True(t, result.IsError) + assert.Contains(t, getResultText(result), "Error getting Cilium status") +} + +func TestHandleInstallCilium(t *testing.T) { + ctx := context.Background() + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("cilium", []string{"install", "--timeout", "30s"}, "✓ Cilium was successfully installed!", nil) + + ctx = cmd.WithShellExecutor(ctx, mock) + + result, err := handleInstallCilium(ctx, mcp.CallToolRequest{}) + require.NoError(t, err) + assert.NotNil(t, result) + assert.False(t, result.IsError) + assert.Contains(t, getResultText(result), "✓ Cilium was successfully installed!") +} + +func TestHandleUninstallCilium(t *testing.T) { + ctx := context.Background() + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("cilium", []string{"uninstall", "--timeout", "30s"}, "✓ Cilium was successfully uninstalled!", nil) + + ctx = cmd.WithShellExecutor(ctx, mock) + + result, err := handleUninstallCilium(ctx, mcp.CallToolRequest{}) + require.NoError(t, err) + assert.NotNil(t, result) + assert.False(t, result.IsError) + assert.Contains(t, getResultText(result), "✓ Cilium was successfully uninstalled!") +} + +func TestHandleUpgradeCilium(t *testing.T) { + ctx := context.Background() + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("cilium", []string{"upgrade", "--timeout", "30s"}, "✓ Cilium was successfully upgraded!", nil) - t.Run("clustermesh connect command", func(t *testing.T) { - clusterName := "remote-cluster" - context := "remote-context" + ctx = cmd.WithShellExecutor(ctx, mock) - args := []string{"clustermesh", "connect", "--destination-cluster", clusterName} - if context != "" { - args = append(args, "--destination-context", context) + result, err := handleUpgradeCilium(ctx, mcp.CallToolRequest{}) + require.NoError(t, err) + assert.NotNil(t, result) + assert.False(t, result.IsError) + assert.Contains(t, getResultText(result), "✓ Cilium was successfully upgraded!") +} + +func TestHandleConnectToRemoteCluster(t *testing.T) { + ctx := context.Background() + + t.Run("success", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("cilium", []string{"clustermesh", "connect", "--destination-cluster", "my-cluster", "--timeout", "30s"}, "✓ Connected to cluster my-cluster!", nil) + ctx = cmd.WithShellExecutor(ctx, mock) + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]any{ + "cluster_name": "my-cluster", + }, + }, } - expected := []string{"clustermesh", "connect", "--destination-cluster", "remote-cluster", "--destination-context", "remote-context"} - assert.Equal(t, expected, args) + result, err := handleConnectToRemoteCluster(ctx, req) + require.NoError(t, err) + assert.NotNil(t, result) + assert.False(t, result.IsError) + assert.Contains(t, getResultText(result), "✓ Connected to cluster my-cluster!") }) - t.Run("bgp commands", func(t *testing.T) { - peersArgs := []string{"bgp", "peers"} - routesArgs := []string{"bgp", "routes"} - - assert.Equal(t, []string{"bgp", "peers"}, peersArgs) - assert.Equal(t, []string{"bgp", "routes"}, routesArgs) + t.Run("missing cluster_name", func(t *testing.T) { + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]any{}, + }, + } + result, err := handleConnectToRemoteCluster(ctx, req) + require.NoError(t, err) + assert.NotNil(t, result) + assert.True(t, result.IsError) + assert.Contains(t, getResultText(result), "cluster_name parameter is required") }) } -func TestCiliumParameterValidation(t *testing.T) { - t.Run("cluster name validation", func(t *testing.T) { - clusterName := "" - if clusterName == "" { - assert.True(t, true, "cluster_name parameter should be required for connect operations") +func TestHandleDisconnectFromRemoteCluster(t *testing.T) { + ctx := context.Background() + + t.Run("success", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("cilium", []string{"clustermesh", "disconnect", "--destination-cluster", "my-cluster", "--timeout", "30s"}, "✓ Disconnected from cluster my-cluster!", nil) + ctx = cmd.WithShellExecutor(ctx, mock) + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]any{ + "cluster_name": "my-cluster", + }, + }, } - clusterName = "valid-cluster" - if clusterName != "" { - assert.True(t, true, "valid cluster name should be accepted") + result, err := handleDisconnectRemoteCluster(ctx, req) + require.NoError(t, err) + assert.NotNil(t, result) + assert.False(t, result.IsError) + assert.Contains(t, getResultText(result), "✓ Disconnected from cluster my-cluster!") + }) + + t.Run("missing cluster_name", func(t *testing.T) { + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]any{}, + }, } + result, err := handleDisconnectRemoteCluster(ctx, req) + require.NoError(t, err) + assert.NotNil(t, result) + assert.True(t, result.IsError) + assert.Contains(t, getResultText(result), "cluster_name parameter is required") }) +} + +func TestHandleEnableHubble(t *testing.T) { + ctx := context.Background() + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("cilium", []string{"hubble", "enable", "--timeout", "30s"}, "✓ Hubble was successfully enabled!", nil) + ctx = cmd.WithShellExecutor(ctx, mock) + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]any{ + "enable": true, + }, + }, + } - t.Run("boolean parameter handling", func(t *testing.T) { - enableStr := "true" - enable := enableStr == "true" - assert.True(t, enable) + result, err := handleToggleHubble(ctx, req) + require.NoError(t, err) + assert.NotNil(t, result) + assert.False(t, result.IsError) + assert.Contains(t, getResultText(result), "✓ Hubble was successfully enabled!") +} - enableStr = "false" - enable = enableStr == "true" - assert.False(t, enable) +func TestHandleDisableHubble(t *testing.T) { + ctx := context.Background() + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("cilium", []string{"hubble", "disable", "--timeout", "30s"}, "✓ Hubble was successfully disabled!", nil) + ctx = cmd.WithShellExecutor(ctx, mock) + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]any{ + "enable": false, + }, + }, + } + result, err := handleToggleHubble(ctx, req) + require.NoError(t, err) + assert.NotNil(t, result) + assert.False(t, result.IsError) + assert.Contains(t, getResultText(result), "✓ Hubble was successfully disabled!") +} - // Default value handling - enableStr = "" - if enableStr == "" { - enableStr = "true" // default - } - enable = enableStr == "true" - assert.True(t, enable) +func TestHandleListBGPPeers(t *testing.T) { + ctx := context.Background() + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("cilium", []string{"bgp", "peers", "--timeout", "30s"}, "listing BGP peers", nil) + ctx = cmd.WithShellExecutor(ctx, mock) + result, err := handleListBGPPeers(ctx, mcp.CallToolRequest{}) + require.NoError(t, err) + assert.NotNil(t, result) + assert.False(t, result.IsError) + assert.Contains(t, getResultText(result), "listing BGP peers") +} + +func TestHandleListBGPRoutes(t *testing.T) { + ctx := context.Background() + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("cilium", []string{"bgp", "routes", "--timeout", "30s"}, "listing BGP routes", nil) + ctx = cmd.WithShellExecutor(ctx, mock) + result, err := handleListBGPRoutes(ctx, mcp.CallToolRequest{}) + require.NoError(t, err) + assert.NotNil(t, result) + assert.False(t, result.IsError) + assert.Contains(t, getResultText(result), "listing BGP routes") +} + +func TestRunCiliumCliWithContext(t *testing.T) { + ctx := context.Background() + t.Run("success", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("cilium", []string{"test", "--timeout", "30s"}, "success", nil) + ctx = cmd.WithShellExecutor(ctx, mock) + result, err := runCiliumCliWithContext(ctx, "test") + require.NoError(t, err) + assert.Equal(t, "success", result) + }) + t.Run("error", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("cilium", []string{"test", "--timeout", "30s"}, "", fmt.Errorf("test error")) + ctx = cmd.WithShellExecutor(ctx, mock) + _, err := runCiliumCliWithContext(ctx, "test") + require.Error(t, err) + assert.Contains(t, err.Error(), "test error") }) } + +func getResultText(r *mcp.CallToolResult) string { + if r == nil || len(r.Content) == 0 { + return "" + } + if textContent, ok := r.Content[0].(mcp.TextContent); ok { + return strings.TrimSpace(textContent.Text) + } + return "" +} diff --git a/pkg/helm/helm.go b/pkg/helm/helm.go index b3a65e3..06a9ac8 100644 --- a/pkg/helm/helm.go +++ b/pkg/helm/helm.go @@ -5,13 +5,15 @@ import ( "fmt" "strings" + "github.com/kagent-dev/tools/internal/commands" + "github.com/kagent-dev/tools/internal/errors" + "github.com/kagent-dev/tools/internal/security" + "github.com/kagent-dev/tools/internal/telemetry" "github.com/kagent-dev/tools/pkg/utils" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" ) -var kubeConfig = "" // Global variable to hold kubeconfig path - // Helm list releases func handleHelmListReleases(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { namespace := mcp.ParseString(request, "namespace", "") @@ -69,6 +71,15 @@ func handleHelmListReleases(ctx context.Context, request mcp.CallToolRequest) (* result, err := runHelmCommand(ctx, args) if err != nil { + // Check if it's a structured error + if toolErr, ok := err.(*errors.ToolError); ok { + // Add namespace context if provided + if namespace != "" { + toolErr = toolErr.WithContext("namespace", namespace) + } + return toolErr.ToMCPResult(), nil + } + // Fallback for non-structured errors return mcp.NewToolResultError(fmt.Sprintf("Helm list command failed: %v", err)), nil } @@ -76,10 +87,24 @@ func handleHelmListReleases(ctx context.Context, request mcp.CallToolRequest) (* } func runHelmCommand(ctx context.Context, args []string) (string, error) { - if kubeConfig != "" { - args = append(args, "--kubeconfig", kubeConfig) + kubeconfigPath := utils.GetKubeconfig() + result, err := commands.NewCommandBuilder("helm"). + WithArgs(args...). + WithKubeconfig(kubeconfigPath). + Execute(ctx) + + if err != nil { + if toolErr, ok := err.(*errors.ToolError); ok { + if len(args) > 0 { + toolErr = toolErr.WithContext("helm_operation", args[0]) + } + toolErr = toolErr.WithContext("helm_args", args) + return "", toolErr + } + return "", err } - return utils.RunCommandWithContext(ctx, "helm", args) + + return result, nil } // Helm get release @@ -98,7 +123,7 @@ func handleHelmGetRelease(ctx context.Context, request mcp.CallToolRequest) (*mc args := []string{"get", resource, name, "-n", namespace} - result, err := utils.RunCommandWithContext(ctx, "helm", args) + result, err := runHelmCommand(ctx, args) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Helm get command failed: %v", err)), nil } @@ -122,6 +147,25 @@ func handleHelmUpgradeRelease(ctx context.Context, request mcp.CallToolRequest) return mcp.NewToolResultError("name and chart parameters are required"), nil } + // Validate release name + if err := security.ValidateHelmReleaseName(name); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid release name: %v", err)), nil + } + + // Validate namespace if provided + if namespace != "" { + if err := security.ValidateNamespace(namespace); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid namespace: %v", err)), nil + } + } + + // Validate values file path if provided + if values != "" { + if err := security.ValidateFilePath(values); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid values file path: %v", err)), nil + } + } + args := []string{"upgrade", name, chart} if namespace != "" { @@ -156,7 +200,7 @@ func handleHelmUpgradeRelease(ctx context.Context, request mcp.CallToolRequest) args = append(args, "--wait") } - result, err := utils.RunCommandWithContext(ctx, "helm", args) + result, err := runHelmCommand(ctx, args) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Helm upgrade command failed: %v", err)), nil } @@ -185,7 +229,7 @@ func handleHelmUninstall(ctx context.Context, request mcp.CallToolRequest) (*mcp args = append(args, "--wait") } - result, err := utils.RunCommandWithContext(ctx, "helm", args) + result, err := runHelmCommand(ctx, args) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Helm uninstall command failed: %v", err)), nil } @@ -202,9 +246,19 @@ func handleHelmRepoAdd(ctx context.Context, request mcp.CallToolRequest) (*mcp.C return mcp.NewToolResultError("name and url parameters are required"), nil } + // Validate repository name + if err := security.ValidateHelmReleaseName(name); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid repository name: %v", err)), nil + } + + // Validate repository URL + if err := security.ValidateURL(url); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid repository URL: %v", err)), nil + } + args := []string{"repo", "add", name, url} - result, err := utils.RunCommandWithContext(ctx, "helm", args) + result, err := runHelmCommand(ctx, args) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Helm repo add command failed: %v", err)), nil } @@ -216,7 +270,7 @@ func handleHelmRepoAdd(ctx context.Context, request mcp.CallToolRequest) (*mcp.C func handleHelmRepoUpdate(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { args := []string{"repo", "update"} - result, err := utils.RunCommandWithContext(ctx, "helm", args) + result, err := runHelmCommand(ctx, args) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Helm repo update command failed: %v", err)), nil } @@ -225,8 +279,7 @@ func handleHelmRepoUpdate(ctx context.Context, request mcp.CallToolRequest) (*mc } // Register Helm tools -func RegisterHelmTools(s *server.MCPServer, kubeconfig string) { - kubeConfig = kubeconfig +func RegisterTools(s *server.MCPServer) { s.AddTool(mcp.NewTool("helm_list_releases", mcp.WithDescription("List Helm releases in a namespace"), @@ -240,14 +293,14 @@ func RegisterHelmTools(s *server.MCPServer, kubeconfig string) { mcp.WithString("pending", mcp.Description("List pending releases")), mcp.WithString("filter", mcp.Description("A regular expression to filter releases by")), mcp.WithString("output", mcp.Description("The output format (e.g., 'json', 'yaml', 'table')")), - ), handleHelmListReleases) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("helm_list_releases", handleHelmListReleases))) s.AddTool(mcp.NewTool("helm_get_release", mcp.WithDescription("Get extended information about a Helm release"), mcp.WithString("name", mcp.Description("The name of the release"), mcp.Required()), mcp.WithString("namespace", mcp.Description("The namespace of the release"), mcp.Required()), mcp.WithString("resource", mcp.Description("The resource to get (all, hooks, manifest, notes, values)")), - ), handleHelmGetRelease) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("helm_get_release", handleHelmGetRelease))) s.AddTool(mcp.NewTool("helm_upgrade", mcp.WithDescription("Upgrade or install a Helm release"), @@ -260,7 +313,7 @@ func RegisterHelmTools(s *server.MCPServer, kubeconfig string) { mcp.WithString("install", mcp.Description("Run an install if the release is not present")), mcp.WithString("dry_run", mcp.Description("Simulate an upgrade")), mcp.WithString("wait", mcp.Description("Wait for the upgrade to complete")), - ), handleHelmUpgradeRelease) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("helm_upgrade", handleHelmUpgradeRelease))) s.AddTool(mcp.NewTool("helm_uninstall", mcp.WithDescription("Uninstall a Helm release"), @@ -268,15 +321,15 @@ func RegisterHelmTools(s *server.MCPServer, kubeconfig string) { mcp.WithString("namespace", mcp.Description("The namespace of the release"), mcp.Required()), mcp.WithString("dry_run", mcp.Description("Simulate an uninstall")), mcp.WithString("wait", mcp.Description("Wait for the uninstall to complete")), - ), handleHelmUninstall) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("helm_uninstall", handleHelmUninstall))) s.AddTool(mcp.NewTool("helm_repo_add", mcp.WithDescription("Add a Helm repository"), mcp.WithString("name", mcp.Description("The name of the repository"), mcp.Required()), mcp.WithString("url", mcp.Description("The URL of the repository"), mcp.Required()), - ), handleHelmRepoAdd) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("helm_repo_add", handleHelmRepoAdd))) s.AddTool(mcp.NewTool("helm_repo_update", mcp.WithDescription("Update information of available charts locally from chart repositories"), - ), handleHelmRepoUpdate) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("helm_repo_update", handleHelmRepoUpdate))) } diff --git a/pkg/helm/helm_test.go b/pkg/helm/helm_test.go index 4a99165..3848de2 100644 --- a/pkg/helm/helm_test.go +++ b/pkg/helm/helm_test.go @@ -4,22 +4,28 @@ import ( "context" "testing" - "github.com/kagent-dev/tools/pkg/utils" + "github.com/kagent-dev/tools/internal/cmd" "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func TestRegisterTools(t *testing.T) { + s := server.NewMCPServer("test-server", "v0.0.1") + RegisterTools(s) +} + // Test Helm List Releases func TestHandleHelmListReleases(t *testing.T) { t.Run("basic list releases", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `NAME NAMESPACE REVISION UPDATED STATUS CHART APP VERSION app1 default 1 2023-01-01 12:00:00.000000000 +0000 UTC deployed myapp-1.0.0 1.0.0 app2 kube-system 2 2023-01-02 12:00:00.000000000 +0000 UTC deployed system-2.0.0 2.0.0` - mock.AddCommandString("helm", []string{"list"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock.AddCommandString("helm", []string{"list", "--timeout", "30s"}, expectedOutput, nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} result, err := handleHelmListReleases(ctx, request) @@ -37,13 +43,13 @@ app2 kube-system 2 2023-01-02 12:00:00.000000000 +0000 UTC deplo callLog := mock.GetCallLog() require.Len(t, callLog, 1) assert.Equal(t, "helm", callLog[0].Command) - assert.Equal(t, []string{"list"}, callLog[0].Args) + assert.Equal(t, []string{"list", "--timeout", "30s"}, callLog[0].Args) }) t.Run("list releases with namespace", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - mock.AddCommandString("helm", []string{"list", "-n", "production"}, "production releases", nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("helm", []string{"list", "-n", "production", "--timeout", "30s"}, "production releases", nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ @@ -59,13 +65,13 @@ app2 kube-system 2 2023-01-02 12:00:00.000000000 +0000 UTC deplo callLog := mock.GetCallLog() require.Len(t, callLog, 1) assert.Equal(t, "helm", callLog[0].Command) - assert.Equal(t, []string{"list", "-n", "production"}, callLog[0].Args) + assert.Equal(t, []string{"list", "-n", "production", "--timeout", "30s"}, callLog[0].Args) }) t.Run("list releases with all namespaces", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - mock.AddCommandString("helm", []string{"list", "-A"}, "all namespaces releases", nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("helm", []string{"list", "-A", "--timeout", "30s"}, "all namespaces releases", nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ @@ -81,13 +87,13 @@ app2 kube-system 2 2023-01-02 12:00:00.000000000 +0000 UTC deplo callLog := mock.GetCallLog() require.Len(t, callLog, 1) assert.Equal(t, "helm", callLog[0].Command) - assert.Equal(t, []string{"list", "-A"}, callLog[0].Args) + assert.Equal(t, []string{"list", "-A", "--timeout", "30s"}, callLog[0].Args) }) t.Run("list releases with multiple flags", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - mock.AddCommandString("helm", []string{"list", "-A", "-a", "--failed", "-o", "json"}, `[{"name":"failed-app","status":"failed"}]`, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("helm", []string{"list", "-A", "-a", "--failed", "-o", "json", "--timeout", "30s"}, `[{"name":"failed-app","status":"failed"}]`, nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ @@ -106,35 +112,35 @@ app2 kube-system 2 2023-01-02 12:00:00.000000000 +0000 UTC deplo callLog := mock.GetCallLog() require.Len(t, callLog, 1) assert.Equal(t, "helm", callLog[0].Command) - assert.Equal(t, []string{"list", "-A", "-a", "--failed", "-o", "json"}, callLog[0].Args) + assert.Equal(t, []string{"list", "-A", "-a", "--failed", "-o", "json", "--timeout", "30s"}, callLog[0].Args) }) t.Run("helm command failure", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - mock.AddCommandString("helm", []string{"list"}, "", assert.AnError) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("helm", []string{"list", "--timeout", "30s"}, "", assert.AnError) + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} result, err := handleHelmListReleases(ctx, request) assert.NoError(t, err) // MCP handlers should not return Go errors assert.True(t, result.IsError) - assert.Contains(t, getResultText(result), "Helm list command failed") + assert.Contains(t, getResultText(result), "**Helm Error**") }) } // Test Helm Get Release func TestHandleHelmGetRelease(t *testing.T) { t.Run("get release all resources", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `REVISION: 1 RELEASED: Mon Jan 01 12:00:00 UTC 2023 CHART: myapp-1.0.0 VALUES: replicaCount: 3` - mock.AddCommandString("helm", []string{"get", "all", "myapp", "-n", "default"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock.AddCommandString("helm", []string{"get", "all", "myapp", "-n", "default", "--timeout", "30s"}, expectedOutput, nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ @@ -152,13 +158,13 @@ replicaCount: 3` callLog := mock.GetCallLog() require.Len(t, callLog, 1) assert.Equal(t, "helm", callLog[0].Command) - assert.Equal(t, []string{"get", "all", "myapp", "-n", "default"}, callLog[0].Args) + assert.Equal(t, []string{"get", "all", "myapp", "-n", "default", "--timeout", "30s"}, callLog[0].Args) }) t.Run("get release values only", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - mock.AddCommandString("helm", []string{"get", "values", "myapp", "-n", "default"}, "replicaCount: 3", nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("helm", []string{"get", "values", "myapp", "-n", "default", "--timeout", "30s"}, "replicaCount: 3", nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ @@ -176,12 +182,12 @@ replicaCount: 3` callLog := mock.GetCallLog() require.Len(t, callLog, 1) assert.Equal(t, "helm", callLog[0].Command) - assert.Equal(t, []string{"get", "values", "myapp", "-n", "default"}, callLog[0].Args) + assert.Equal(t, []string{"get", "values", "myapp", "-n", "default", "--timeout", "30s"}, callLog[0].Args) }) t.Run("missing required parameters", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - ctx := utils.WithShellExecutor(context.Background(), mock) + mock := cmd.NewMockShellExecutor() + ctx := cmd.WithShellExecutor(context.Background(), mock) // Test missing name request := mcp.CallToolRequest{} @@ -213,7 +219,7 @@ replicaCount: 3` // Test Helm Upgrade Release func TestHandleHelmUpgradeRelease(t *testing.T) { t.Run("basic upgrade", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `Release "myapp" has been upgraded. Happy Helming! NAME: myapp LAST DEPLOYED: Mon Jan 01 12:00:00 UTC 2023 @@ -221,8 +227,8 @@ NAMESPACE: default STATUS: deployed REVISION: 2` - mock.AddCommandString("helm", []string{"upgrade", "myapp", "stable/myapp"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock.AddCommandString("helm", []string{"upgrade", "myapp", "stable/myapp", "--timeout", "30s"}, expectedOutput, nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ @@ -240,11 +246,11 @@ REVISION: 2` callLog := mock.GetCallLog() require.Len(t, callLog, 1) assert.Equal(t, "helm", callLog[0].Command) - assert.Equal(t, []string{"upgrade", "myapp", "stable/myapp"}, callLog[0].Args) + assert.Equal(t, []string{"upgrade", "myapp", "stable/myapp", "--timeout", "30s"}, callLog[0].Args) }) t.Run("upgrade with all options", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedArgs := []string{ "upgrade", "myapp", "stable/myapp", "-n", "production", @@ -255,9 +261,10 @@ REVISION: 2` "--install", "--dry-run", "--wait", + "--timeout", "30s", } - mock.AddCommandString("helm", expectedArgs, "dry run output", nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock.AddCommandString("helm", expectedArgs, "Upgraded with options", nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ @@ -284,14 +291,14 @@ REVISION: 2` assert.Equal(t, expectedArgs, callLog[0].Args) }) - t.Run("missing required parameters", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - ctx := utils.WithShellExecutor(context.Background(), mock) + t.Run("missing required parameters for upgrade", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + ctx := cmd.WithShellExecutor(context.Background(), mock) + // Test missing chart request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ "name": "myapp", - // Missing chart } result, err := handleHelmUpgradeRelease(ctx, request) @@ -308,11 +315,11 @@ REVISION: 2` // Test Helm Uninstall func TestHandleHelmUninstall(t *testing.T) { t.Run("basic uninstall", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `release "myapp" uninstalled` - mock.AddCommandString("helm", []string{"uninstall", "myapp", "-n", "default"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock.AddCommandString("helm", []string{"uninstall", "myapp", "-n", "default", "--timeout", "30s"}, expectedOutput, nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ @@ -330,14 +337,14 @@ func TestHandleHelmUninstall(t *testing.T) { callLog := mock.GetCallLog() require.Len(t, callLog, 1) assert.Equal(t, "helm", callLog[0].Command) - assert.Equal(t, []string{"uninstall", "myapp", "-n", "default"}, callLog[0].Args) + assert.Equal(t, []string{"uninstall", "myapp", "-n", "default", "--timeout", "30s"}, callLog[0].Args) }) t.Run("uninstall with options", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedArgs := []string{"uninstall", "myapp", "-n", "production", "--dry-run", "--wait"} + mock := cmd.NewMockShellExecutor() + expectedArgs := []string{"uninstall", "myapp", "-n", "production", "--dry-run", "--wait", "--timeout", "30s"} mock.AddCommandString("helm", expectedArgs, "dry run uninstall", nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ @@ -358,21 +365,51 @@ func TestHandleHelmUninstall(t *testing.T) { assert.Equal(t, "helm", callLog[0].Command) assert.Equal(t, expectedArgs, callLog[0].Args) }) + + t.Run("missing required parameters for uninstall", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + ctx := cmd.WithShellExecutor(context.Background(), mock) + + // Test missing name + request := mcp.CallToolRequest{} + request.Params.Arguments = map[string]interface{}{ + "namespace": "default", + } + + result, err := handleHelmUninstall(ctx, request) + assert.NoError(t, err) + assert.True(t, result.IsError) + assert.Contains(t, getResultText(result), "name and namespace parameters are required") + + // Test missing namespace + request.Params.Arguments = map[string]interface{}{ + "name": "myapp", + } + + result, err = handleHelmUninstall(ctx, request) + assert.NoError(t, err) + assert.True(t, result.IsError) + assert.Contains(t, getResultText(result), "name and namespace parameters are required") + + // Verify no commands were executed + callLog := mock.GetCallLog() + assert.Len(t, callLog, 0) + }) } // Test Helm Repo Add func TestHandleHelmRepoAdd(t *testing.T) { - t.Run("add repository", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedOutput := `"stable" has been added to your repositories` + t.Run("basic repo add", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + expectedOutput := `"my-repo" has been added to your repositories` - mock.AddCommandString("helm", []string{"repo", "add", "stable", "https://charts.helm.sh/stable"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock.AddCommandString("helm", []string{"repo", "add", "my-repo", "https://charts.example.com/", "--timeout", "30s"}, expectedOutput, nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ - "name": "stable", - "url": "https://charts.helm.sh/stable", + "name": "my-repo", + "url": "https://charts.example.com/", } result, err := handleHelmRepoAdd(ctx, request) @@ -385,17 +422,17 @@ func TestHandleHelmRepoAdd(t *testing.T) { callLog := mock.GetCallLog() require.Len(t, callLog, 1) assert.Equal(t, "helm", callLog[0].Command) - assert.Equal(t, []string{"repo", "add", "stable", "https://charts.helm.sh/stable"}, callLog[0].Args) + assert.Equal(t, []string{"repo", "add", "my-repo", "https://charts.example.com/", "--timeout", "30s"}, callLog[0].Args) }) - t.Run("missing required parameters", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - ctx := utils.WithShellExecutor(context.Background(), mock) + t.Run("missing required parameters for repo add", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + ctx := cmd.WithShellExecutor(context.Background(), mock) + // Test missing name request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ - "name": "stable", - // Missing url + "url": "https://charts.example.com/", } result, err := handleHelmRepoAdd(ctx, request) @@ -411,27 +448,26 @@ func TestHandleHelmRepoAdd(t *testing.T) { // Test Helm Repo Update func TestHandleHelmRepoUpdate(t *testing.T) { - t.Run("update repositories", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + t.Run("basic repo update", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() expectedOutput := `Hang tight while we grab the latest from your chart repositories... -...Successfully got an update from the "stable" chart repository -Update Complete. ⎈Happy Helming!⎈` +...Successfully got an update from the "my-repo" chart repository` - mock.AddCommandString("helm", []string{"repo", "update"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock.AddCommandString("helm", []string{"repo", "update", "--timeout", "30s"}, expectedOutput, nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) request := mcp.CallToolRequest{} result, err := handleHelmRepoUpdate(ctx, request) assert.NoError(t, err) assert.False(t, result.IsError) - assert.Contains(t, getResultText(result), "Update Complete") + assert.Contains(t, getResultText(result), "Successfully got an update") // Verify the correct command was called callLog := mock.GetCallLog() require.Len(t, callLog, 1) assert.Equal(t, "helm", callLog[0].Command) - assert.Equal(t, []string{"repo", "update"}, callLog[0].Args) + assert.Equal(t, []string{"repo", "update", "--timeout", "30s"}, callLog[0].Args) }) } diff --git a/pkg/istio/istio.go b/pkg/istio/istio.go index 2f198aa..680d83c 100644 --- a/pkg/istio/istio.go +++ b/pkg/istio/istio.go @@ -5,13 +5,13 @@ import ( "fmt" "strings" + "github.com/kagent-dev/tools/internal/commands" + "github.com/kagent-dev/tools/internal/telemetry" "github.com/kagent-dev/tools/pkg/utils" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" ) -var kubeConfig = "" // Global variable to hold kubeconfig path - // Istio proxy status func handleIstioProxyStatus(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { podName := mcp.ParseString(request, "pod_name", "") @@ -36,11 +36,11 @@ func handleIstioProxyStatus(ctx context.Context, request mcp.CallToolRequest) (* } func runIstioCtl(ctx context.Context, args []string) (string, error) { - if kubeConfig != "" { - args = append(args, "--kubeconfig", kubeConfig) - } - result, err := utils.RunCommandWithContext(ctx, "istioctl", args) - return result, err + kubeconfigPath := utils.GetKubeconfig() + return commands.NewCommandBuilder("istioctl"). + WithArgs(args...). + WithKubeconfig(kubeconfigPath). + Execute(ctx) } // Istio proxy config @@ -61,7 +61,7 @@ func handleIstioProxyConfig(ctx context.Context, request mcp.CallToolRequest) (* args = append(args, podName) } - result, err := utils.RunCommandWithContext(ctx, "istioctl", args) + result, err := runIstioCtl(ctx, args) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("istioctl proxy-config failed: %v", err)), nil } @@ -75,7 +75,7 @@ func handleIstioInstall(ctx context.Context, request mcp.CallToolRequest) (*mcp. args := []string{"install", "--set", fmt.Sprintf("profile=%s", profile), "-y"} - result, err := utils.RunCommandWithContext(ctx, "istioctl", args) + result, err := runIstioCtl(ctx, args) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("istioctl install failed: %v", err)), nil } @@ -89,7 +89,7 @@ func handleIstioGenerateManifest(ctx context.Context, request mcp.CallToolReques args := []string{"manifest", "generate", "--set", fmt.Sprintf("profile=%s", profile)} - result, err := utils.RunCommandWithContext(ctx, "istioctl", args) + result, err := runIstioCtl(ctx, args) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("istioctl manifest generate failed: %v", err)), nil } @@ -110,7 +110,7 @@ func handleIstioAnalyzeClusterConfiguration(ctx context.Context, request mcp.Cal args = append(args, "-n", namespace) } - result, err := utils.RunCommandWithContext(ctx, "istioctl", args) + result, err := runIstioCtl(ctx, args) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("istioctl analyze failed: %v", err)), nil } @@ -128,7 +128,7 @@ func handleIstioVersion(ctx context.Context, request mcp.CallToolRequest) (*mcp. args = append(args, "--short") } - result, err := utils.RunCommandWithContext(ctx, "istioctl", args) + result, err := runIstioCtl(ctx, args) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("istioctl version failed: %v", err)), nil } @@ -140,7 +140,7 @@ func handleIstioVersion(ctx context.Context, request mcp.CallToolRequest) (*mcp. func handleIstioRemoteClusters(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { args := []string{"remote-clusters"} - result, err := utils.RunCommandWithContext(ctx, "istioctl", args) + result, err := runIstioCtl(ctx, args) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("istioctl remote-clusters failed: %v", err)), nil } @@ -161,7 +161,7 @@ func handleWaypointList(ctx context.Context, request mcp.CallToolRequest) (*mcp. args = append(args, "-n", namespace) } - result, err := utils.RunCommandWithContext(ctx, "istioctl", args) + result, err := runIstioCtl(ctx, args) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("istioctl waypoint list failed: %v", err)), nil } @@ -191,7 +191,7 @@ func handleWaypointGenerate(ctx context.Context, request mcp.CallToolRequest) (* args = append(args, "--for", trafficType) } - result, err := utils.RunCommandWithContext(ctx, "istioctl", args) + result, err := runIstioCtl(ctx, args) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("istioctl waypoint generate failed: %v", err)), nil } @@ -214,7 +214,7 @@ func handleWaypointApply(ctx context.Context, request mcp.CallToolRequest) (*mcp args = append(args, "--enroll-namespace") } - result, err := utils.RunCommandWithContext(ctx, "istioctl", args) + result, err := runIstioCtl(ctx, args) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("istioctl waypoint apply failed: %v", err)), nil } @@ -245,7 +245,7 @@ func handleWaypointDelete(ctx context.Context, request mcp.CallToolRequest) (*mc args = append(args, "-n", namespace) - result, err := utils.RunCommandWithContext(ctx, "istioctl", args) + result, err := runIstioCtl(ctx, args) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("istioctl waypoint delete failed: %v", err)), nil } @@ -270,7 +270,7 @@ func handleWaypointStatus(ctx context.Context, request mcp.CallToolRequest) (*mc args = append(args, "-n", namespace) - result, err := utils.RunCommandWithContext(ctx, "istioctl", args) + result, err := runIstioCtl(ctx, args) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("istioctl waypoint status failed: %v", err)), nil } @@ -283,30 +283,29 @@ func handleZtunnelConfig(ctx context.Context, request mcp.CallToolRequest) (*mcp namespace := mcp.ParseString(request, "namespace", "") configType := mcp.ParseString(request, "config_type", "all") - args := []string{"ztunnel-config", configType} + args := []string{"ztunnel", "config", configType} if namespace != "" { args = append(args, "-n", namespace) } - result, err := utils.RunCommandWithContext(ctx, "istioctl", args) + result, err := runIstioCtl(ctx, args) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("istioctl ztunnel-config failed: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("istioctl ztunnel config failed: %v", err)), nil } return mcp.NewToolResultText(result), nil } // Register Istio tools -func RegisterIstioTools(s *server.MCPServer, kubeconfig string) { - kubeConfig = kubeconfig +func RegisterTools(s *server.MCPServer) { // Istio proxy status s.AddTool(mcp.NewTool("istio_proxy_status", mcp.WithDescription("Get Envoy proxy status for pods, retrieves last sent and acknowledged xDS sync from Istiod to each Envoy in the mesh"), mcp.WithString("pod_name", mcp.Description("Name of the pod to get proxy status for")), mcp.WithString("namespace", mcp.Description("Namespace of the pod")), - ), handleIstioProxyStatus) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_proxy_status", handleIstioProxyStatus))) // Istio proxy config s.AddTool(mcp.NewTool("istio_proxy_config", @@ -314,79 +313,62 @@ func RegisterIstioTools(s *server.MCPServer, kubeconfig string) { mcp.WithString("pod_name", mcp.Description("Name of the pod to get proxy configuration for"), mcp.Required()), mcp.WithString("namespace", mcp.Description("Namespace of the pod")), mcp.WithString("config_type", mcp.Description("Type of configuration (all, bootstrap, cluster, ecds, listener, log, route, secret)")), - ), handleIstioProxyConfig) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_proxy_config", handleIstioProxyConfig))) // Istio install s.AddTool(mcp.NewTool("istio_install_istio", mcp.WithDescription("Install Istio with a specified configuration profile"), mcp.WithString("profile", mcp.Description("Istio configuration profile (ambient, default, demo, minimal, empty)")), - ), handleIstioInstall) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_install_istio", handleIstioInstall))) // Istio generate manifest s.AddTool(mcp.NewTool("istio_generate_manifest", - mcp.WithDescription("Generate an Istio install manifest"), + mcp.WithDescription("Generate Istio manifest for a given profile"), mcp.WithString("profile", mcp.Description("Istio configuration profile (ambient, default, demo, minimal, empty)")), - ), handleIstioGenerateManifest) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_generate_manifest", handleIstioGenerateManifest))) // Istio analyze s.AddTool(mcp.NewTool("istio_analyze_cluster_configuration", - mcp.WithDescription("Analyze live cluster configuration for potential issues"), - mcp.WithString("namespace", mcp.Description("Namespace to analyze")), - mcp.WithString("all_namespaces", mcp.Description("Analyze all namespaces (true/false)")), - ), handleIstioAnalyzeClusterConfiguration) + mcp.WithDescription("Analyze Istio cluster configuration for issues"), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_analyze_cluster_configuration", handleIstioAnalyzeClusterConfiguration))) // Istio version s.AddTool(mcp.NewTool("istio_version", - mcp.WithDescription("Get Istio CLI client version, control plane and data plane versions"), - mcp.WithString("short", mcp.Description("Show short version format (true/false)")), - ), handleIstioVersion) + mcp.WithDescription("Get Istio version information"), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_version", handleIstioVersion))) // Istio remote clusters s.AddTool(mcp.NewTool("istio_remote_clusters", - mcp.WithDescription("List remote clusters each istiod instance is connected to"), - ), handleIstioRemoteClusters) + mcp.WithDescription("List remote clusters registered with Istio"), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_remote_clusters", handleIstioRemoteClusters))) // Waypoint list s.AddTool(mcp.NewTool("istio_list_waypoints", - mcp.WithDescription("List managed waypoint configurations in the cluster"), - mcp.WithString("namespace", mcp.Description("Namespace to list waypoints for")), - mcp.WithString("all_namespaces", mcp.Description("List waypoints for all namespaces (true/false)")), - ), handleWaypointList) + mcp.WithDescription("List all waypoints in the mesh"), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_list_waypoints", handleWaypointList))) // Waypoint generate s.AddTool(mcp.NewTool("istio_generate_waypoint", - mcp.WithDescription("Generate a waypoint configuration as YAML"), - mcp.WithString("namespace", mcp.Description("Namespace to generate the waypoint for"), mcp.Required()), - mcp.WithString("name", mcp.Description("Name of the waypoint to generate")), - mcp.WithString("traffic_type", mcp.Description("Traffic type for the waypoint (all, inbound, outbound)")), - ), handleWaypointGenerate) + mcp.WithDescription("Generate a waypoint resource YAML"), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_generate_waypoint", handleWaypointGenerate))) // Waypoint apply s.AddTool(mcp.NewTool("istio_apply_waypoint", - mcp.WithDescription("Apply a waypoint configuration to a cluster"), - mcp.WithString("namespace", mcp.Description("Namespace to apply the waypoint to"), mcp.Required()), - mcp.WithString("enroll_namespace", mcp.Description("Label the namespace with the waypoint name (true/false)")), - ), handleWaypointApply) + mcp.WithDescription("Apply a waypoint resource to the cluster"), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_apply_waypoint", handleWaypointApply))) // Waypoint delete s.AddTool(mcp.NewTool("istio_delete_waypoint", - mcp.WithDescription("Delete waypoint configurations from a cluster"), - mcp.WithString("namespace", mcp.Description("Namespace to delete waypoints from"), mcp.Required()), - mcp.WithString("names", mcp.Description("Comma-separated list of waypoint names to delete")), - mcp.WithString("all", mcp.Description("Delete all waypoints in the namespace (true/false)")), - ), handleWaypointDelete) + mcp.WithDescription("Delete a waypoint resource from the cluster"), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_delete_waypoint", handleWaypointDelete))) // Waypoint status s.AddTool(mcp.NewTool("istio_waypoint_status", - mcp.WithDescription("Get status of a waypoint"), - mcp.WithString("namespace", mcp.Description("Namespace of the waypoint"), mcp.Required()), - mcp.WithString("name", mcp.Description("Name of the waypoint to get status for")), - ), handleWaypointStatus) + mcp.WithDescription("Get the status of a waypoint resource"), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_waypoint_status", handleWaypointStatus))) // Ztunnel config s.AddTool(mcp.NewTool("istio_ztunnel_config", - mcp.WithDescription("Get ztunnel configuration"), - mcp.WithString("namespace", mcp.Description("Namespace of the pod")), - mcp.WithString("config_type", mcp.Description("Type of configuration (all, bootstrap, cluster, ecds, listener, log, route, secret)")), - ), handleZtunnelConfig) + mcp.WithDescription("Get the ztunnel configuration for a namespace"), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_ztunnel_config", handleZtunnelConfig))) } diff --git a/pkg/istio/istio_test.go b/pkg/istio/istio_test.go index 2adaf99..02efbee 100644 --- a/pkg/istio/istio_test.go +++ b/pkg/istio/istio_test.go @@ -4,299 +4,185 @@ import ( "context" "testing" - "github.com/kagent-dev/tools/pkg/utils" + "github.com/kagent-dev/tools/internal/cmd" "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -// Helper function to extract text content from MCP result -func getResultText(result *mcp.CallToolResult) string { - if result == nil || len(result.Content) == 0 { - return "" - } - if textContent, ok := result.Content[0].(mcp.TextContent); ok { - return textContent.Text - } - return "" +func TestRegisterTools(t *testing.T) { + s := server.NewMCPServer("test-server", "v0.0.1") + RegisterTools(s) } -// Test Istio Proxy Status func TestHandleIstioProxyStatus(t *testing.T) { + ctx := context.Background() + t.Run("basic proxy status", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedOutput := `NAME CDS LDS EDS RDS ISTIOD VERSION -app-1 SYNCED SYNCED SYNCED SYNCED istiod-68d5d5b5fc-7vf6n 1.18.0 -app-2 SYNCED SYNCED SYNCED SYNCED istiod-68d5d5b5fc-7vf6n 1.18.0` + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("istioctl", []string{"proxy-status", "--timeout", "30s"}, "Proxy status output", nil) - mock.AddCommandString("istioctl", []string{"proxy-status"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + ctx = cmd.WithShellExecutor(ctx, mock) - request := mcp.CallToolRequest{} - result, err := handleIstioProxyStatus(ctx, request) + result, err := handleIstioProxyStatus(ctx, mcp.CallToolRequest{}) - assert.NoError(t, err) + require.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) - - // Verify the expected output - content := getResultText(result) - assert.Contains(t, content, "app-1") - assert.Contains(t, content, "SYNCED") - - // Verify the correct command was called - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "istioctl", callLog[0].Command) - assert.Equal(t, []string{"proxy-status"}, callLog[0].Args) }) t.Run("proxy status with namespace", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedOutput := `NAME CDS LDS EDS RDS ISTIOD VERSION -app-1 SYNCED SYNCED SYNCED SYNCED istiod-68d5d5b5fc-7vf6n 1.18.0` + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("istioctl", []string{"proxy-status", "-n", "istio-system", "--timeout", "30s"}, "Proxy status output", nil) - mock.AddCommandString("istioctl", []string{"proxy-status", "-n", "production"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + ctx = cmd.WithShellExecutor(ctx, mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ - "namespace": "production", + "namespace": "istio-system", } result, err := handleIstioProxyStatus(ctx, request) - assert.NoError(t, err) + require.NoError(t, err) + assert.NotNil(t, result) assert.False(t, result.IsError) - - // Verify the correct command was called with namespace - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "istioctl", callLog[0].Command) - assert.Equal(t, []string{"proxy-status", "-n", "production"}, callLog[0].Args) }) - t.Run("proxy status with pod name and namespace", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedOutput := `NAME CDS LDS EDS RDS ISTIOD VERSION -app-1 SYNCED SYNCED SYNCED SYNCED istiod-68d5d5b5fc-7vf6n 1.18.0` + t.Run("proxy status with pod name", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("istioctl", []string{"proxy-status", "-n", "default", "test-pod", "--timeout", "30s"}, "Proxy status output", nil) - mock.AddCommandString("istioctl", []string{"proxy-status", "-n", "production", "app-1"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + ctx = cmd.WithShellExecutor(ctx, mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ - "namespace": "production", - "pod_name": "app-1", + "pod_name": "test-pod", + "namespace": "default", } result, err := handleIstioProxyStatus(ctx, request) - assert.NoError(t, err) + require.NoError(t, err) + assert.NotNil(t, result) assert.False(t, result.IsError) - - // Verify the correct command was called - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "istioctl", callLog[0].Command) - assert.Equal(t, []string{"proxy-status", "-n", "production", "app-1"}, callLog[0].Args) }) +} - t.Run("istioctl command failure", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - mock.AddCommandString("istioctl", []string{"proxy-status"}, "", assert.AnError) - ctx := utils.WithShellExecutor(context.Background(), mock) +func TestHandleIstioProxyConfig(t *testing.T) { + ctx := context.Background() - request := mcp.CallToolRequest{} - result, err := handleIstioProxyStatus(ctx, request) + t.Run("missing pod_name parameter", func(t *testing.T) { + result, err := handleIstioProxyConfig(ctx, mcp.CallToolRequest{}) - assert.NoError(t, err) // MCP handlers should not return Go errors + require.NoError(t, err) + assert.NotNil(t, result) assert.True(t, result.IsError) - assert.Contains(t, getResultText(result), "istioctl proxy-status failed") }) -} -// Test Istio Proxy Config -func TestHandleIstioProxyConfig(t *testing.T) { - t.Run("proxy config all", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedOutput := `CLUSTER NAME DIRECTION TYPE DESTINATION RULE -outbound|80||kubernetes.default.svc.cluster.local outbound EDS -inbound|80|| inbound EDS` + t.Run("proxy config with pod name", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("istioctl", []string{"proxy-config", "all", "test-pod", "--timeout", "30s"}, "Proxy config output", nil) - mock.AddCommandString("istioctl", []string{"proxy-config", "all", "app-1"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + ctx = cmd.WithShellExecutor(ctx, mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ - "pod_name": "app-1", + "pod_name": "test-pod", } result, err := handleIstioProxyConfig(ctx, request) - assert.NoError(t, err) + require.NoError(t, err) + assert.NotNil(t, result) assert.False(t, result.IsError) - assert.Contains(t, getResultText(result), "CLUSTER NAME") - - // Verify the correct command was called - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "istioctl", callLog[0].Command) - assert.Equal(t, []string{"proxy-config", "all", "app-1"}, callLog[0].Args) }) - t.Run("proxy config with namespace and config type", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedOutput := `CLUSTER NAME DIRECTION TYPE DESTINATION RULE -outbound|80||kubernetes.default.svc.cluster.local outbound EDS` + t.Run("proxy config with namespace", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("istioctl", []string{"proxy-config", "cluster", "test-pod.default", "--timeout", "30s"}, "Proxy config output", nil) - mock.AddCommandString("istioctl", []string{"proxy-config", "cluster", "app-1.production"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + ctx = cmd.WithShellExecutor(ctx, mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ - "pod_name": "app-1", - "namespace": "production", + "pod_name": "test-pod", + "namespace": "default", "config_type": "cluster", } result, err := handleIstioProxyConfig(ctx, request) - assert.NoError(t, err) + require.NoError(t, err) + assert.NotNil(t, result) assert.False(t, result.IsError) - - // Verify the correct command was called - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "istioctl", callLog[0].Command) - assert.Equal(t, []string{"proxy-config", "cluster", "app-1.production"}, callLog[0].Args) - }) - - t.Run("missing required parameters", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - ctx := utils.WithShellExecutor(context.Background(), mock) - - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - // Missing pod_name - } - - result, err := handleIstioProxyConfig(ctx, request) - assert.NoError(t, err) - assert.True(t, result.IsError) - assert.Contains(t, getResultText(result), "pod_name parameter is required") - - // Verify no commands were executed - callLog := mock.GetCallLog() - assert.Len(t, callLog, 0) }) } -// Test Istio Install func TestHandleIstioInstall(t *testing.T) { - t.Run("basic install", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedOutput := `✔ Istio core installed -✔ Istiod installed -✔ Ingress gateways installed -✔ Installation complete` + ctx := context.Background() - mock.AddCommandString("istioctl", []string{"install", "--set", "profile=default", "-y"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + t.Run("install with default profile", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("istioctl", []string{"install", "--set", "profile=default", "-y", "--timeout", "30s"}, "Install completed", nil) - request := mcp.CallToolRequest{} - result, err := handleIstioInstall(ctx, request) + ctx = cmd.WithShellExecutor(ctx, mock) - assert.NoError(t, err) - assert.False(t, result.IsError) - assert.Contains(t, getResultText(result), "Installation complete") + result, err := handleIstioInstall(ctx, mcp.CallToolRequest{}) - // Verify the correct command was called - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "istioctl", callLog[0].Command) - assert.Equal(t, []string{"install", "--set", "profile=default", "-y"}, callLog[0].Args) + require.NoError(t, err) + assert.NotNil(t, result) + assert.False(t, result.IsError) }) - t.Run("install with profile", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedOutput := `✔ Istio core installed -✔ Installation complete` + t.Run("install with custom profile", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("istioctl", []string{"install", "--set", "profile=demo", "-y", "--timeout", "30s"}, "Install completed", nil) - mock.AddCommandString("istioctl", []string{"install", "--set", "profile=minimal", "-y"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + ctx = cmd.WithShellExecutor(ctx, mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ - "profile": "minimal", + "profile": "demo", } result, err := handleIstioInstall(ctx, request) - assert.NoError(t, err) + require.NoError(t, err) + assert.NotNil(t, result) assert.False(t, result.IsError) - - // Verify the correct command was called with profile - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "istioctl", callLog[0].Command) - assert.Equal(t, []string{"install", "--set", "profile=minimal", "-y"}, callLog[0].Args) }) } -// Test Istio Analyze -func TestHandleIstioAnalyzeClusterConfiguration(t *testing.T) { - t.Run("basic analyze", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedOutput := `✔ No validation issues found when analyzing namespace: default.` - - mock.AddCommandString("istioctl", []string{"analyze"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) - - request := mcp.CallToolRequest{} - result, err := handleIstioAnalyzeClusterConfiguration(ctx, request) +func TestHandleIstioGenerateManifest(t *testing.T) { + ctx := context.Background() + mock := cmd.NewMockShellExecutor() - assert.NoError(t, err) - assert.False(t, result.IsError) - assert.Contains(t, getResultText(result), "No validation issues found") + mock.AddCommandString("istioctl", []string{"manifest", "generate", "--set", "profile=minimal", "--timeout", "30s"}, "Generated manifest", nil) - // Verify the correct command was called - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "istioctl", callLog[0].Command) - assert.Equal(t, []string{"analyze"}, callLog[0].Args) - }) + ctx = cmd.WithShellExecutor(ctx, mock) - t.Run("analyze with namespace", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedOutput := `✔ No validation issues found when analyzing namespace: production.` - - mock.AddCommandString("istioctl", []string{"analyze", "-n", "production"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) - - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "namespace": "production", - } + request := mcp.CallToolRequest{} + request.Params.Arguments = map[string]interface{}{ + "profile": "minimal", + } - result, err := handleIstioAnalyzeClusterConfiguration(ctx, request) + result, err := handleIstioGenerateManifest(ctx, request) - assert.NoError(t, err) - assert.False(t, result.IsError) + require.NoError(t, err) + assert.NotNil(t, result) + assert.False(t, result.IsError) +} - // Verify the correct command was called with namespace - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "istioctl", callLog[0].Command) - assert.Equal(t, []string{"analyze", "-n", "production"}, callLog[0].Args) - }) +func TestHandleIstioAnalyzeClusterConfiguration(t *testing.T) { + ctx := context.Background() t.Run("analyze all namespaces", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedOutput := `✔ No validation issues found when analyzing all namespaces.` + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("istioctl", []string{"analyze", "-A", "--timeout", "30s"}, "Analysis output", nil) - mock.AddCommandString("istioctl", []string{"analyze", "-A"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + ctx = cmd.WithShellExecutor(ctx, mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ @@ -305,565 +191,167 @@ func TestHandleIstioAnalyzeClusterConfiguration(t *testing.T) { result, err := handleIstioAnalyzeClusterConfiguration(ctx, request) - assert.NoError(t, err) - assert.False(t, result.IsError) - - // Verify the correct command was called with -A flag - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "istioctl", callLog[0].Command) - assert.Equal(t, []string{"analyze", "-A"}, callLog[0].Args) - }) -} - -// Test Istio Version -func TestHandleIstioVersion(t *testing.T) { - t.Run("version detailed output", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedOutput := `client version: 1.18.0 -control plane version: 1.18.0 -data plane version: 1.18.0 (2 proxies)` - - mock.AddCommandString("istioctl", []string{"version"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) - - request := mcp.CallToolRequest{} - result, err := handleIstioVersion(ctx, request) - - assert.NoError(t, err) - assert.False(t, result.IsError) - assert.Contains(t, getResultText(result), "client version: 1.18.0") - - // Verify the correct command was called - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "istioctl", callLog[0].Command) - assert.Equal(t, []string{"version"}, callLog[0].Args) - }) - - t.Run("version short output", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedOutput := `1.18.0` - - mock.AddCommandString("istioctl", []string{"version", "--short"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) - - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "short": "true", - } - - result, err := handleIstioVersion(ctx, request) - - assert.NoError(t, err) - assert.False(t, result.IsError) - assert.Contains(t, getResultText(result), "1.18.0") - - // Verify the correct command was called with --short flag - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "istioctl", callLog[0].Command) - assert.Equal(t, []string{"version", "--short"}, callLog[0].Args) - }) -} - -// Test Waypoint List -func TestHandleWaypointList(t *testing.T) { - t.Run("list waypoints", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedOutput := `NAMESPACE NAME TRAFFIC TYPE -default waypoint ALL -production waypoint INBOUND` - - mock.AddCommandString("istioctl", []string{"waypoint", "list"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) - - request := mcp.CallToolRequest{} - result, err := handleWaypointList(ctx, request) - - assert.NoError(t, err) - assert.False(t, result.IsError) - assert.Contains(t, getResultText(result), "NAMESPACE") - assert.Contains(t, getResultText(result), "waypoint") - - // Verify the correct command was called - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "istioctl", callLog[0].Command) - assert.Equal(t, []string{"waypoint", "list"}, callLog[0].Args) - }) - - t.Run("list waypoints in namespace", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedOutput := `NAMESPACE NAME TRAFFIC TYPE -production waypoint INBOUND` - - mock.AddCommandString("istioctl", []string{"waypoint", "list", "-n", "production"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) - - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "namespace": "production", - } - - result, err := handleWaypointList(ctx, request) - - assert.NoError(t, err) - assert.False(t, result.IsError) - - // Verify the correct command was called with namespace - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "istioctl", callLog[0].Command) - assert.Equal(t, []string{"waypoint", "list", "-n", "production"}, callLog[0].Args) - }) -} - -// Test Waypoint Generate -func TestHandleWaypointGenerate(t *testing.T) { - t.Run("generate waypoint", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedOutput := `apiVersion: gateway.networking.k8s.io/v1beta1 -kind: Gateway -metadata: - name: waypoint - namespace: production -spec: - gatewayClassName: istio-waypoint` - - mock.AddCommandString("istioctl", []string{"waypoint", "generate", "waypoint", "-n", "production", "--for", "all"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) - - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "namespace": "production", - } - - result, err := handleWaypointGenerate(ctx, request) - - assert.NoError(t, err) + require.NoError(t, err) + assert.NotNil(t, result) assert.False(t, result.IsError) - assert.Contains(t, getResultText(result), "apiVersion: gateway.networking.k8s.io/v1beta1") - - // Verify the correct command was called - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "istioctl", callLog[0].Command) - assert.Equal(t, []string{"waypoint", "generate", "waypoint", "-n", "production", "--for", "all"}, callLog[0].Args) - }) - - t.Run("missing required parameters", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - ctx := utils.WithShellExecutor(context.Background(), mock) - - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - // Missing namespace - } - - result, err := handleWaypointGenerate(ctx, request) - assert.NoError(t, err) - assert.True(t, result.IsError) - assert.Contains(t, getResultText(result), "namespace parameter is required") - - // Verify no commands were executed - callLog := mock.GetCallLog() - assert.Len(t, callLog, 0) }) -} -// Test Waypoint Apply -func TestHandleWaypointApply(t *testing.T) { - t.Run("basic waypoint apply", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedOutput := `waypoint/waypoint applied` + t.Run("analyze specific namespace", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("istioctl", []string{"analyze", "-n", "default", "--timeout", "30s"}, "Analysis output", nil) - mock.AddCommandString("istioctl", []string{"waypoint", "apply", "-n", "default"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + ctx = cmd.WithShellExecutor(ctx, mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ "namespace": "default", } - result, err := handleWaypointApply(ctx, request) + result, err := handleIstioAnalyzeClusterConfiguration(ctx, request) - assert.NoError(t, err) + require.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) - assert.Contains(t, getResultText(result), "applied") - - // Verify the correct command was called - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "istioctl", callLog[0].Command) - assert.Equal(t, []string{"waypoint", "apply", "-n", "default"}, callLog[0].Args) - }) - - t.Run("waypoint apply with enroll namespace", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedOutput := `waypoint/waypoint applied -namespace/default labeled with istio.io/use-waypoint=waypoint` - - mock.AddCommandString("istioctl", []string{"waypoint", "apply", "-n", "default", "--enroll-namespace"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) - - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "namespace": "default", - "enroll_namespace": "true", - } - - result, err := handleWaypointApply(ctx, request) - - assert.NoError(t, err) - assert.False(t, result.IsError) - assert.Contains(t, getResultText(result), "applied") - - // Verify the correct command was called with --enroll-namespace flag - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "istioctl", callLog[0].Command) - assert.Equal(t, []string{"waypoint", "apply", "-n", "default", "--enroll-namespace"}, callLog[0].Args) - }) - - t.Run("missing namespace parameter", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - ctx := utils.WithShellExecutor(context.Background(), mock) - - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - // Missing namespace - } - - result, err := handleWaypointApply(ctx, request) - assert.NoError(t, err) - assert.NotNil(t, result) - assert.True(t, result.IsError) - assert.Contains(t, getResultText(result), "namespace parameter is required") - - // Verify no commands were executed - callLog := mock.GetCallLog() - assert.Len(t, callLog, 0) - }) - - t.Run("istioctl command failure", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - mock.AddCommandString("istioctl", []string{"waypoint", "apply", "-n", "default"}, "", assert.AnError) - ctx := utils.WithShellExecutor(context.Background(), mock) - - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "namespace": "default", - } - - result, err := handleWaypointApply(ctx, request) - - assert.NoError(t, err) // MCP handlers should not return Go errors - assert.True(t, result.IsError) - assert.Contains(t, getResultText(result), "istioctl waypoint apply failed") }) } -// Test Waypoint Delete -func TestHandleWaypointDelete(t *testing.T) { - t.Run("delete all waypoints", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedOutput := `waypoint/waypoint deleted` +func TestHandleIstioVersion(t *testing.T) { + ctx := context.Background() - mock.AddCommandString("istioctl", []string{"waypoint", "delete", "--all", "-n", "default"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + t.Run("version full", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("istioctl", []string{"version", "--timeout", "30s"}, "Version output", nil) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "namespace": "default", - "all": "true", - } + ctx = cmd.WithShellExecutor(ctx, mock) - result, err := handleWaypointDelete(ctx, request) + result, err := handleIstioVersion(ctx, mcp.CallToolRequest{}) - assert.NoError(t, err) + require.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) - assert.Contains(t, getResultText(result), "deleted") - - // Verify the correct command was called - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "istioctl", callLog[0].Command) - assert.Equal(t, []string{"waypoint", "delete", "--all", "-n", "default"}, callLog[0].Args) }) - t.Run("delete specific waypoints", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedOutput := `waypoint/waypoint1 deleted -waypoint/waypoint2 deleted` + t.Run("version short", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("istioctl", []string{"version", "--short", "--timeout", "30s"}, "1.18.0", nil) - mock.AddCommandString("istioctl", []string{"waypoint", "delete", "waypoint1", "waypoint2", "-n", "default"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + ctx = cmd.WithShellExecutor(ctx, mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ - "namespace": "default", - "names": "waypoint1,waypoint2", + "short": "true", } - result, err := handleWaypointDelete(ctx, request) - - assert.NoError(t, err) - assert.False(t, result.IsError) - assert.Contains(t, getResultText(result), "deleted") - - // Verify the correct command was called with specific names - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "istioctl", callLog[0].Command) - assert.Equal(t, []string{"waypoint", "delete", "waypoint1", "waypoint2", "-n", "default"}, callLog[0].Args) - }) - - t.Run("missing namespace parameter", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - ctx := utils.WithShellExecutor(context.Background(), mock) - - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - // Missing namespace - } + result, err := handleIstioVersion(ctx, request) - result, err := handleWaypointDelete(ctx, request) - assert.NoError(t, err) + require.NoError(t, err) assert.NotNil(t, result) - assert.True(t, result.IsError) - assert.Contains(t, getResultText(result), "namespace parameter is required") - - // Verify no commands were executed - callLog := mock.GetCallLog() - assert.Len(t, callLog, 0) - }) - - t.Run("istioctl command failure", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - mock.AddCommandString("istioctl", []string{"waypoint", "delete", "--all", "-n", "default"}, "", assert.AnError) - ctx := utils.WithShellExecutor(context.Background(), mock) - - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "namespace": "default", - "all": "true", - } - - result, err := handleWaypointDelete(ctx, request) - - assert.NoError(t, err) // MCP handlers should not return Go errors - assert.True(t, result.IsError) - assert.Contains(t, getResultText(result), "istioctl waypoint delete failed") + assert.False(t, result.IsError) }) } -// Test Waypoint Status -func TestHandleWaypointStatus(t *testing.T) { - t.Run("waypoint status", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedOutput := `waypoint/waypoint is deployed and ready` +func TestHandleIstioRemoteClusters(t *testing.T) { + ctx := context.Background() + mock := cmd.NewMockShellExecutor() - mock.AddCommandString("istioctl", []string{"waypoint", "status", "-n", "default"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock.AddCommandString("istioctl", []string{"remote-clusters", "--timeout", "30s"}, "Remote clusters output", nil) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "namespace": "default", - } + ctx = cmd.WithShellExecutor(ctx, mock) - result, err := handleWaypointStatus(ctx, request) + result, err := handleIstioRemoteClusters(ctx, mcp.CallToolRequest{}) - assert.NoError(t, err) - assert.NotNil(t, result) - assert.False(t, result.IsError) - assert.Contains(t, getResultText(result), "waypoint") + require.NoError(t, err) + assert.NotNil(t, result) + assert.False(t, result.IsError) +} - // Verify the correct command was called - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "istioctl", callLog[0].Command) - assert.Equal(t, []string{"waypoint", "status", "-n", "default"}, callLog[0].Args) - }) +func TestHandleWaypointList(t *testing.T) { + ctx := context.Background() - t.Run("waypoint status with specific name", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedOutput := `waypoint/test-waypoint is deployed and ready` + t.Run("list waypoints in all namespaces", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("istioctl", []string{"waypoint", "list", "-A", "--timeout", "30s"}, "Waypoints list", nil) - mock.AddCommandString("istioctl", []string{"waypoint", "status", "test-waypoint", "-n", "default"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + ctx = cmd.WithShellExecutor(ctx, mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ - "namespace": "default", - "name": "test-waypoint", + "all_namespaces": "true", } - result, err := handleWaypointStatus(ctx, request) + result, err := handleWaypointList(ctx, request) - assert.NoError(t, err) + require.NoError(t, err) + assert.NotNil(t, result) assert.False(t, result.IsError) - assert.Contains(t, getResultText(result), "test-waypoint") - - // Verify the correct command was called with specific name - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "istioctl", callLog[0].Command) - assert.Equal(t, []string{"waypoint", "status", "test-waypoint", "-n", "default"}, callLog[0].Args) }) - t.Run("missing namespace parameter", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - ctx := utils.WithShellExecutor(context.Background(), mock) + t.Run("list waypoints in a specific namespace", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("istioctl", []string{"waypoint", "list", "-n", "default", "--timeout", "30s"}, "Waypoints list", nil) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - // Missing namespace - } - - result, err := handleWaypointStatus(ctx, request) - assert.NoError(t, err) - assert.NotNil(t, result) - assert.True(t, result.IsError) - assert.Contains(t, getResultText(result), "namespace parameter is required") - - // Verify no commands were executed - callLog := mock.GetCallLog() - assert.Len(t, callLog, 0) - }) - - t.Run("istioctl command failure", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - mock.AddCommandString("istioctl", []string{"waypoint", "status", "-n", "default"}, "", assert.AnError) - ctx := utils.WithShellExecutor(context.Background(), mock) + ctx = cmd.WithShellExecutor(ctx, mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ "namespace": "default", } - result, err := handleWaypointStatus(ctx, request) - - assert.NoError(t, err) // MCP handlers should not return Go errors - assert.True(t, result.IsError) - assert.Contains(t, getResultText(result), "istioctl waypoint status failed") - }) -} - -// Test Ztunnel Config -func TestHandleZtunnelConfig(t *testing.T) { - t.Run("default ztunnel config", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedOutput := `CLUSTER_NAME CLUSTER_TYPE ENDPOINTS -cluster1 EDS 10.0.0.1:15010 -cluster2 STATIC 10.0.0.2:15010` - - mock.AddCommandString("istioctl", []string{"ztunnel-config", "all"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) - - request := mcp.CallToolRequest{} - result, err := handleZtunnelConfig(ctx, request) + result, err := handleWaypointList(ctx, request) - assert.NoError(t, err) + require.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) - assert.Contains(t, getResultText(result), "CLUSTER_NAME") - - // Verify the correct command was called - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "istioctl", callLog[0].Command) - assert.Equal(t, []string{"ztunnel-config", "all"}, callLog[0].Args) }) +} - t.Run("ztunnel config with namespace", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedOutput := `CLUSTER_NAME CLUSTER_TYPE ENDPOINTS -cluster1 EDS 10.0.0.1:15010` - - mock.AddCommandString("istioctl", []string{"ztunnel-config", "all", "-n", "istio-system"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) - - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "namespace": "istio-system", - } - - result, err := handleZtunnelConfig(ctx, request) - - assert.NoError(t, err) - assert.False(t, result.IsError) - - // Verify the correct command was called with namespace - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "istioctl", callLog[0].Command) - assert.Equal(t, []string{"ztunnel-config", "all", "-n", "istio-system"}, callLog[0].Args) - }) +func TestHandleWaypointGenerate(t *testing.T) { + ctx := context.Background() - t.Run("ztunnel config with specific type", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedOutput := `CLUSTER_NAME CLUSTER_TYPE ENDPOINTS -cluster1 EDS 10.0.0.1:15010` + t.Run("generate waypoint with namespace", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("istioctl", []string{"waypoint", "generate", "waypoint", "-n", "default", "--for", "all", "--timeout", "30s"}, "Generated waypoint", nil) - mock.AddCommandString("istioctl", []string{"ztunnel-config", "cluster"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + ctx = cmd.WithShellExecutor(ctx, mock) request := mcp.CallToolRequest{} request.Params.Arguments = map[string]interface{}{ - "config_type": "cluster", + "namespace": "default", + "name": "waypoint", + "traffic_type": "all", } - result, err := handleZtunnelConfig(ctx, request) + result, err := handleWaypointGenerate(ctx, request) - assert.NoError(t, err) + require.NoError(t, err) + assert.NotNil(t, result) assert.False(t, result.IsError) - - // Verify the correct command was called with specific config type - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "istioctl", callLog[0].Command) - assert.Equal(t, []string{"ztunnel-config", "cluster"}, callLog[0].Args) }) +} - t.Run("ztunnel config with namespace and config type", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedOutput := `LISTENER_NAME ADDRESS PORT TYPE -listener1 0.0.0.0 15006 TCP` - - mock.AddCommandString("istioctl", []string{"ztunnel-config", "listener", "-n", "istio-system"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) - - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "namespace": "istio-system", - "config_type": "listener", - } - - result, err := handleZtunnelConfig(ctx, request) +func TestRunIstioCtl(t *testing.T) { + t.Run("run istioctl with context", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("istioctl", []string{"version", "--timeout", "30s"}, "1.18.0", nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) - assert.NoError(t, err) - assert.False(t, result.IsError) + result, err := runIstioCtl(ctx, []string{"version"}) - // Verify the correct command was called with both namespace and config type - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "istioctl", callLog[0].Command) - assert.Equal(t, []string{"ztunnel-config", "listener", "-n", "istio-system"}, callLog[0].Args) + require.NoError(t, err) + assert.Equal(t, "1.18.0", result) }) +} +func TestIstioErrorHandling(t *testing.T) { t.Run("istioctl command failure", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - mock.AddCommandString("istioctl", []string{"ztunnel-config", "all"}, "", assert.AnError) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("istioctl", []string{"proxy-status"}, "", assert.AnError) + ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - result, err := handleZtunnelConfig(ctx, request) + result, err := handleIstioProxyStatus(ctx, mcp.CallToolRequest{}) - assert.NoError(t, err) // MCP handlers should not return Go errors + require.NoError(t, err) + assert.NotNil(t, result) assert.True(t, result.IsError) - assert.Contains(t, getResultText(result), "istioctl ztunnel-config failed") }) } diff --git a/pkg/k8s/k8s.go b/pkg/k8s/k8s.go index 6bd73ec..6c29c9c 100644 --- a/pkg/k8s/k8s.go +++ b/pkg/k8s/k8s.go @@ -10,12 +10,15 @@ import ( "slices" "strings" - "github.com/kagent-dev/tools/pkg/logger" - "github.com/kagent-dev/tools/pkg/utils" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" "github.com/tmc/langchaingo/llms" - "github.com/tmc/langchaingo/llms/openai" + + "github.com/kagent-dev/tools/internal/cache" + "github.com/kagent-dev/tools/internal/commands" + "github.com/kagent-dev/tools/internal/logger" + "github.com/kagent-dev/tools/internal/security" + "github.com/kagent-dev/tools/internal/telemetry" ) // K8sTool struct to hold the LLM model @@ -32,6 +35,22 @@ func NewK8sToolWithConfig(kubeconfig string, llmModel llms.Model) *K8sTool { return &K8sTool{kubeconfig: kubeconfig, llmModel: llmModel} } +// runKubectlCommandWithCacheInvalidation runs a kubectl command and invalidates cache if it's a modification operation +func (k *K8sTool) runKubectlCommandWithCacheInvalidation(ctx context.Context, args ...string) (*mcp.CallToolResult, error) { + result, err := k.runKubectlCommand(ctx, args...) + + // If command succeeded and it's a modification command, invalidate cache + if err == nil && len(args) > 0 { + subcommand := args[0] + switch subcommand { + case "apply", "delete", "patch", "scale", "annotate", "label", "create", "run", "rollout": + cache.InvalidateKubernetesCache() + } + } + + return result, err +} + // Enhanced kubectl get func (k *K8sTool) handleKubectlGetEnhanced(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { resourceType := mcp.ParseString(request, "resource_type", "") @@ -62,7 +81,7 @@ func (k *K8sTool) handleKubectlGetEnhanced(ctx context.Context, request mcp.Call args = append(args, "-o", "json") } - return k.runKubectlCommand(ctx, args) + return k.runKubectlCommand(ctx, args...) } // Get pod logs @@ -86,7 +105,7 @@ func (k *K8sTool) handleKubectlLogsEnhanced(ctx context.Context, request mcp.Cal args = append(args, "--tail", fmt.Sprintf("%d", tailLines)) } - return k.runKubectlCommand(ctx, args) + return k.runKubectlCommand(ctx, args...) } // Scale deployment @@ -101,7 +120,7 @@ func (k *K8sTool) handleScaleDeployment(ctx context.Context, request mcp.CallToo args := []string{"scale", "deployment", deploymentName, "--replicas", fmt.Sprintf("%d", replicas), "-n", namespace} - return k.runKubectlCommand(ctx, args) + return k.runKubectlCommandWithCacheInvalidation(ctx, args...) } // Patch resource @@ -115,9 +134,24 @@ func (k *K8sTool) handlePatchResource(ctx context.Context, request mcp.CallToolR return mcp.NewToolResultError("resource_type, resource_name, and patch parameters are required"), nil } + // Validate resource name for security + if err := security.ValidateK8sResourceName(resourceName); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid resource name: %v", err)), nil + } + + // Validate namespace for security + if err := security.ValidateNamespace(namespace); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid namespace: %v", err)), nil + } + + // Validate patch content as JSON/YAML + if err := security.ValidateYAMLContent(patch); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid patch content: %v", err)), nil + } + args := []string{"patch", resourceType, resourceName, "-p", patch, "-n", namespace} - return k.runKubectlCommand(ctx, args) + return k.runKubectlCommandWithCacheInvalidation(ctx, args...) } // Apply manifest from content @@ -128,18 +162,41 @@ func (k *K8sTool) handleApplyManifest(ctx context.Context, request mcp.CallToolR return mcp.NewToolResultError("manifest parameter is required"), nil } - tmpFile, err := os.CreateTemp("", "manifest-*.yaml") + // Validate YAML content for security + if err := security.ValidateYAMLContent(manifest); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid manifest content: %v", err)), nil + } + + // Create temporary file with secure permissions + tmpFile, err := os.CreateTemp("", "k8s-manifest-*.yaml") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to create temp file: %v", err)), nil } - defer os.Remove(tmpFile.Name()) + // Ensure file is removed regardless of execution path + defer func() { + if removeErr := os.Remove(tmpFile.Name()); removeErr != nil { + logger.Get().Error("Failed to remove temporary file", "error", removeErr, "file", tmpFile.Name()) + } + }() + + // Set secure file permissions (readable/writable by owner only) + if err := os.Chmod(tmpFile.Name(), 0600); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Failed to set file permissions: %v", err)), nil + } + + // Write manifest content to temporary file if _, err := tmpFile.WriteString(manifest); err != nil { + tmpFile.Close() return mcp.NewToolResultError(fmt.Sprintf("Failed to write to temp file: %v", err)), nil } - tmpFile.Close() - return k.runKubectlCommand(ctx, []string{"apply", "-f", tmpFile.Name()}) + // Close the file before passing to kubectl + if err := tmpFile.Close(); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Failed to close temp file: %v", err)), nil + } + + return k.runKubectlCommandWithCacheInvalidation(ctx, "apply", "-f", tmpFile.Name()) } // Delete resource @@ -154,7 +211,7 @@ func (k *K8sTool) handleDeleteResource(ctx context.Context, request mcp.CallTool args := []string{"delete", resourceType, resourceName, "-n", namespace} - return k.runKubectlCommand(ctx, args) + return k.runKubectlCommandWithCacheInvalidation(ctx, args...) } // Check service connectivity @@ -169,23 +226,23 @@ func (k *K8sTool) handleCheckServiceConnectivity(ctx context.Context, request mc // Create a temporary curl pod for connectivity check podName := fmt.Sprintf("curl-test-%d", rand.Intn(10000)) defer func() { - _, _ = k.runKubectlCommand(ctx, []string{"delete", "pod", podName, "-n", namespace, "--ignore-not-found"}) + _, _ = k.runKubectlCommand(ctx, "delete", "pod", podName, "-n", namespace, "--ignore-not-found") }() // Create the curl pod - _, err := k.runKubectlCommand(ctx, []string{"run", podName, "--image=curlimages/curl", "-n", namespace, "--restart=Never", "--", "sleep", "3600"}) + _, err := k.runKubectlCommand(ctx, "run", podName, "--image=curlimages/curl", "-n", namespace, "--restart=Never", "--", "sleep", "3600") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to create curl pod: %v", err)), nil } // Wait for pod to be ready - _, err = k.runKubectlCommand(ctx, []string{"wait", "--for=condition=ready", "pod/" + podName, "-n", namespace, "--timeout=60s"}) + _, err = k.runKubectlCommand(ctx, "wait", "--for=condition=ready", "pod/"+podName, "-n", namespace, "--timeout=60s") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to wait for curl pod: %v", err)), nil } - // Execute curl command - return k.runKubectlCommand(ctx, []string{"exec", podName, "-n", namespace, "--", "curl", "-s", serviceName}) + // Execute kubectl command + return k.runKubectlCommand(ctx, "exec", podName, "-n", namespace, "--", "curl", "-s", serviceName) } // Get cluster events @@ -199,7 +256,7 @@ func (k *K8sTool) handleGetEvents(ctx context.Context, request mcp.CallToolReque args = append(args, "--all-namespaces") } - return k.runKubectlCommand(ctx, args) + return k.runKubectlCommand(ctx, args...) } // Execute command in pod @@ -212,14 +269,29 @@ func (k *K8sTool) handleExecCommand(ctx context.Context, request mcp.CallToolReq return mcp.NewToolResultError("pod_name and command parameters are required"), nil } + // Validate pod name for security + if err := security.ValidateK8sResourceName(podName); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid pod name: %v", err)), nil + } + + // Validate namespace for security + if err := security.ValidateNamespace(namespace); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid namespace: %v", err)), nil + } + + // Validate command input for security + if err := security.ValidateCommandInput(command); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid command: %v", err)), nil + } + args := []string{"exec", podName, "-n", namespace, "--", command} - return k.runKubectlCommand(ctx, args) + return k.runKubectlCommand(ctx, args...) } // Get available API resources func (k *K8sTool) handleGetAvailableAPIResources(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return k.runKubectlCommand(ctx, []string{"api-resources", "-o", "json"}) + return k.runKubectlCommand(ctx, "api-resources", "-o", "json") } // Kubectl describe tool @@ -237,7 +309,7 @@ func (k *K8sTool) handleKubectlDescribeTool(ctx context.Context, request mcp.Cal args = append(args, "-n", namespace) } - return k.runKubectlCommand(ctx, args) + return k.runKubectlCommand(ctx, args...) } // Rollout operations @@ -256,12 +328,12 @@ func (k *K8sTool) handleRollout(ctx context.Context, request mcp.CallToolRequest args = append(args, "-n", namespace) } - return k.runKubectlCommand(ctx, args) + return k.runKubectlCommand(ctx, args...) } // Get cluster configuration func (k *K8sTool) handleGetClusterConfiguration(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return k.runKubectlCommand(ctx, []string{"config", "view"}) + return k.runKubectlCommand(ctx, "config", "view", "-o", "json") } // Remove annotation @@ -280,7 +352,7 @@ func (k *K8sTool) handleRemoveAnnotation(ctx context.Context, request mcp.CallTo args = append(args, "-n", namespace) } - return k.runKubectlCommand(ctx, args) + return k.runKubectlCommand(ctx, args...) } // Remove label @@ -299,7 +371,7 @@ func (k *K8sTool) handleRemoveLabel(ctx context.Context, request mcp.CallToolReq args = append(args, "-n", namespace) } - return k.runKubectlCommand(ctx, args) + return k.runKubectlCommand(ctx, args...) } // Annotate resource @@ -320,7 +392,7 @@ func (k *K8sTool) handleAnnotateResource(ctx context.Context, request mcp.CallTo args = append(args, "-n", namespace) } - return k.runKubectlCommand(ctx, args) + return k.runKubectlCommand(ctx, args...) } // Label resource @@ -341,7 +413,7 @@ func (k *K8sTool) handleLabelResource(ctx context.Context, request mcp.CallToolR args = append(args, "-n", namespace) } - return k.runKubectlCommand(ctx, args) + return k.runKubectlCommand(ctx, args...) } // Create resource from URL @@ -358,7 +430,7 @@ func (k *K8sTool) handleCreateResourceFromURL(ctx context.Context, request mcp.C args = append(args, "-n", namespace) } - return k.runKubectlCommand(ctx, args) + return k.runKubectlCommand(ctx, args...) } // Resource generation embeddings @@ -450,30 +522,27 @@ func (k *K8sTool) handleGenerateResource(ctx context.Context, request mcp.CallTo return mcp.NewToolResultError("empty response from model"), nil } c1 := choices[0] - return mcp.NewToolResultText(c1.Content), nil + responseText := c1.Content + + return mcp.NewToolResultText(responseText), nil } -// Helper function to run kubectl commands -func (k *K8sTool) runKubectlCommand(ctx context.Context, args []string) (*mcp.CallToolResult, error) { - if k.kubeconfig != "" { - args = append([]string{"--kubeconfig", k.kubeconfig}, args...) - } - result, err := utils.RunCommandWithContext(ctx, "kubectl", args) +// runKubectlCommand is a helper function to execute kubectl commands +func (k *K8sTool) runKubectlCommand(ctx context.Context, args ...string) (*mcp.CallToolResult, error) { + output, err := commands.NewCommandBuilder("kubectl"). + WithArgs(args...). + WithKubeconfig(k.kubeconfig). + Execute(ctx) + if err != nil { return mcp.NewToolResultError(err.Error()), nil } - return mcp.NewToolResultText(result), nil + + return mcp.NewToolResultText(output), nil } // RegisterK8sTools registers all k8s tools with the MCP server -func RegisterK8sTools(s *server.MCPServer, kubeconfig string) { - var llm llms.Model - if openAiClient, err := openai.New(); err == nil { - llm = openAiClient - } else { - logger.Get().Error(err, "Failed to initialize OpenAI LLM, k8s_generate_resource tool will not be available") - } - +func RegisterTools(s *server.MCPServer, llm llms.Model, kubeconfig string) { k8sTool := NewK8sToolWithConfig(kubeconfig, llm) s.AddTool(mcp.NewTool("k8s_get_resources", @@ -483,7 +552,7 @@ func RegisterK8sTools(s *server.MCPServer, kubeconfig string) { mcp.WithString("namespace", mcp.Description("Namespace to query (optional)")), mcp.WithString("all_namespaces", mcp.Description("Query all namespaces (true/false)")), mcp.WithString("output", mcp.Description("Output format (json, yaml, wide, etc.)")), - ), k8sTool.handleKubectlGetEnhanced) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_get_resources", k8sTool.handleKubectlGetEnhanced))) s.AddTool(mcp.NewTool("k8s_get_pod_logs", mcp.WithDescription("Get logs from a Kubernetes pod"), @@ -491,14 +560,14 @@ func RegisterK8sTools(s *server.MCPServer, kubeconfig string) { mcp.WithString("namespace", mcp.Description("Namespace of the pod (default: default)")), mcp.WithString("container", mcp.Description("Container name (for multi-container pods)")), mcp.WithNumber("tail_lines", mcp.Description("Number of lines to show from the end (default: 50)")), - ), k8sTool.handleKubectlLogsEnhanced) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_get_pod_logs", k8sTool.handleKubectlLogsEnhanced))) s.AddTool(mcp.NewTool("k8s_scale", mcp.WithDescription("Scale a Kubernetes deployment"), mcp.WithString("name", mcp.Description("Name of the deployment"), mcp.Required()), mcp.WithString("namespace", mcp.Description("Namespace of the deployment (default: default)")), mcp.WithNumber("replicas", mcp.Description("Number of replicas"), mcp.Required()), - ), k8sTool.handleScaleDeployment) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_scale", k8sTool.handleScaleDeployment))) s.AddTool(mcp.NewTool("k8s_patch_resource", mcp.WithDescription("Patch a Kubernetes resource using strategic merge patch"), @@ -506,45 +575,46 @@ func RegisterK8sTools(s *server.MCPServer, kubeconfig string) { mcp.WithString("resource_name", mcp.Description("Name of the resource"), mcp.Required()), mcp.WithString("patch", mcp.Description("JSON patch to apply"), mcp.Required()), mcp.WithString("namespace", mcp.Description("Namespace of the resource (default: default)")), - ), k8sTool.handlePatchResource) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_patch_resource", k8sTool.handlePatchResource))) s.AddTool(mcp.NewTool("k8s_apply_manifest", mcp.WithDescription("Apply a YAML manifest to the Kubernetes cluster"), mcp.WithString("manifest", mcp.Description("YAML manifest content"), mcp.Required()), - ), k8sTool.handleApplyManifest) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_apply_manifest", k8sTool.handleApplyManifest))) s.AddTool(mcp.NewTool("k8s_delete_resource", mcp.WithDescription("Delete a Kubernetes resource"), mcp.WithString("resource_type", mcp.Description("Type of resource (pod, service, deployment, etc.)"), mcp.Required()), mcp.WithString("resource_name", mcp.Description("Name of the resource"), mcp.Required()), mcp.WithString("namespace", mcp.Description("Namespace of the resource (default: default)")), - ), k8sTool.handleDeleteResource) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_delete_resource", k8sTool.handleDeleteResource))) s.AddTool(mcp.NewTool("k8s_check_service_connectivity", mcp.WithDescription("Check connectivity to a service using a temporary curl pod"), mcp.WithString("service_name", mcp.Description("Service name to test (e.g., my-service.my-namespace.svc.cluster.local:80)"), mcp.Required()), mcp.WithString("namespace", mcp.Description("Namespace to run the check from (default: default)")), - ), k8sTool.handleCheckServiceConnectivity) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_check_service_connectivity", k8sTool.handleCheckServiceConnectivity))) s.AddTool(mcp.NewTool("k8s_get_events", - mcp.WithDescription("Get Kubernetes cluster events"), - mcp.WithString("namespace", mcp.Description("Namespace to query events from (optional, default: all namespaces)")), - ), k8sTool.handleGetEvents) + mcp.WithDescription("Get events from a Kubernetes namespace"), + mcp.WithString("namespace", mcp.Description("Namespace to get events from (default: default)")), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_get_events", k8sTool.handleGetEvents))) s.AddTool(mcp.NewTool("k8s_execute_command", - mcp.WithDescription("Execute a command inside a Kubernetes pod"), - mcp.WithString("pod_name", mcp.Description("Name of the pod"), mcp.Required()), + mcp.WithDescription("Execute a command in a Kubernetes pod"), + mcp.WithString("pod_name", mcp.Description("Name of the pod to execute in"), mcp.Required()), mcp.WithString("namespace", mcp.Description("Namespace of the pod (default: default)")), + mcp.WithString("container", mcp.Description("Container name (for multi-container pods)")), mcp.WithString("command", mcp.Description("Command to execute"), mcp.Required()), - ), k8sTool.handleExecCommand) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_execute_command", k8sTool.handleExecCommand))) s.AddTool(mcp.NewTool("k8s_get_available_api_resources", - mcp.WithDescription("Get all available API resources from the Kubernetes cluster"), - ), k8sTool.handleGetAvailableAPIResources) + mcp.WithDescription("Get available Kubernetes API resources"), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_get_available_api_resources", k8sTool.handleGetAvailableAPIResources))) s.AddTool(mcp.NewTool("k8s_get_cluster_configuration", - mcp.WithDescription("Get the current kubectl cluster configuration"), - ), k8sTool.handleGetClusterConfiguration) + mcp.WithDescription("Get cluster configuration details"), + ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_get_cluster_configuration", k8sTool.handleGetClusterConfiguration))) s.AddTool(mcp.NewTool("k8s_rollout", mcp.WithDescription("Perform rollout operations on Kubernetes resources (history, pause, restart, resume, status, undo)"), @@ -552,7 +622,7 @@ func RegisterK8sTools(s *server.MCPServer, kubeconfig string) { mcp.WithString("resource_type", mcp.Description("The type of resource to rollout (e.g., deployment)"), mcp.Required()), mcp.WithString("resource_name", mcp.Description("The name of the resource to rollout"), mcp.Required()), mcp.WithString("namespace", mcp.Description("The namespace of the resource")), - ), k8sTool.handleRollout) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_rollout", k8sTool.handleRollout))) s.AddTool(mcp.NewTool("k8s_label_resource", mcp.WithDescription("Add or update labels on a Kubernetes resource"), @@ -560,7 +630,7 @@ func RegisterK8sTools(s *server.MCPServer, kubeconfig string) { mcp.WithString("resource_name", mcp.Description("The name of the resource"), mcp.Required()), mcp.WithString("labels", mcp.Description("Space-separated key=value pairs for labels"), mcp.Required()), mcp.WithString("namespace", mcp.Description("The namespace of the resource")), - ), k8sTool.handleLabelResource) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_label_resource", k8sTool.handleLabelResource))) s.AddTool(mcp.NewTool("k8s_annotate_resource", mcp.WithDescription("Add or update annotations on a Kubernetes resource"), @@ -568,7 +638,7 @@ func RegisterK8sTools(s *server.MCPServer, kubeconfig string) { mcp.WithString("resource_name", mcp.Description("The name of the resource"), mcp.Required()), mcp.WithString("annotations", mcp.Description("Space-separated key=value pairs for annotations"), mcp.Required()), mcp.WithString("namespace", mcp.Description("The namespace of the resource")), - ), k8sTool.handleAnnotateResource) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_annotate_resource", k8sTool.handleAnnotateResource))) s.AddTool(mcp.NewTool("k8s_remove_annotation", mcp.WithDescription("Remove an annotation from a Kubernetes resource"), @@ -576,7 +646,7 @@ func RegisterK8sTools(s *server.MCPServer, kubeconfig string) { mcp.WithString("resource_name", mcp.Description("The name of the resource"), mcp.Required()), mcp.WithString("annotation_key", mcp.Description("The key of the annotation to remove"), mcp.Required()), mcp.WithString("namespace", mcp.Description("The namespace of the resource")), - ), k8sTool.handleRemoveAnnotation) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_remove_annotation", k8sTool.handleRemoveAnnotation))) s.AddTool(mcp.NewTool("k8s_remove_label", mcp.WithDescription("Remove a label from a Kubernetes resource"), @@ -584,12 +654,12 @@ func RegisterK8sTools(s *server.MCPServer, kubeconfig string) { mcp.WithString("resource_name", mcp.Description("The name of the resource"), mcp.Required()), mcp.WithString("label_key", mcp.Description("The key of the label to remove"), mcp.Required()), mcp.WithString("namespace", mcp.Description("The namespace of the resource")), - ), k8sTool.handleRemoveLabel) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_remove_label", k8sTool.handleRemoveLabel))) s.AddTool(mcp.NewTool("k8s_create_resource", mcp.WithDescription("Create a Kubernetes resource from YAML content"), mcp.WithString("yaml_content", mcp.Description("YAML content of the resource"), mcp.Required()), - ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_create_resource", func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { yamlContent := mcp.ParseString(request, "yaml_content", "") if yamlContent == "" { @@ -608,26 +678,26 @@ func RegisterK8sTools(s *server.MCPServer, kubeconfig string) { } tmpFile.Close() - result, err := utils.RunCommandWithContext(ctx, "kubectl", []string{"create", "-f", tmpFile.Name()}) + result, err := k8sTool.runKubectlCommand(ctx, "create", "-f", tmpFile.Name()) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Create command failed: %v", err)), nil } - return mcp.NewToolResultText(result), nil - }) + return result, nil + }))) s.AddTool(mcp.NewTool("k8s_create_resource_from_url", mcp.WithDescription("Create a Kubernetes resource from a URL pointing to a YAML manifest"), mcp.WithString("url", mcp.Description("The URL of the manifest"), mcp.Required()), mcp.WithString("namespace", mcp.Description("The namespace to create the resource in")), - ), k8sTool.handleCreateResourceFromURL) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_create_resource_from_url", k8sTool.handleCreateResourceFromURL))) s.AddTool(mcp.NewTool("k8s_get_resource_yaml", mcp.WithDescription("Get the YAML representation of a Kubernetes resource"), mcp.WithString("resource_type", mcp.Description("Type of resource"), mcp.Required()), mcp.WithString("resource_name", mcp.Description("Name of the resource"), mcp.Required()), mcp.WithString("namespace", mcp.Description("Namespace of the resource (optional)")), - ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_get_resource_yaml", func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { resourceType := mcp.ParseString(request, "resource_type", "") resourceName := mcp.ParseString(request, "resource_name", "") namespace := mcp.ParseString(request, "namespace", "") @@ -641,24 +711,24 @@ func RegisterK8sTools(s *server.MCPServer, kubeconfig string) { args = append(args, "-n", namespace) } - result, err := utils.RunCommandWithContext(ctx, "kubectl", args) + result, err := k8sTool.runKubectlCommand(ctx, args...) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Get YAML command failed: %v", err)), nil } - return mcp.NewToolResultText(result), nil - }) + return result, nil + }))) s.AddTool(mcp.NewTool("k8s_describe_resource", mcp.WithDescription("Describe a Kubernetes resource in detail"), mcp.WithString("resource_type", mcp.Description("Type of resource (deployment, service, pod, node, etc.)"), mcp.Required()), mcp.WithString("resource_name", mcp.Description("Name of the resource"), mcp.Required()), mcp.WithString("namespace", mcp.Description("Namespace of the resource (optional)")), - ), k8sTool.handleKubectlDescribeTool) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_describe_resource", k8sTool.handleKubectlDescribeTool))) s.AddTool(mcp.NewTool("k8s_generate_resource", mcp.WithDescription("Generate a Kubernetes resource YAML from a description"), mcp.WithString("resource_description", mcp.Description("Detailed description of the resource to generate"), mcp.Required()), mcp.WithString("resource_type", mcp.Description(fmt.Sprintf("Type of resource to generate (%s)", strings.Join(slices.Collect(resourceTypes), ", "))), mcp.Required()), - ), k8sTool.handleGenerateResource) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_generate_resource", k8sTool.handleGenerateResource))) } diff --git a/pkg/k8s/k8s_test.go b/pkg/k8s/k8s_test.go index 240e7ac..a71e10f 100644 --- a/pkg/k8s/k8s_test.go +++ b/pkg/k8s/k8s_test.go @@ -2,11 +2,9 @@ package k8s import ( "context" - "fmt" - "os" "testing" - "github.com/kagent-dev/tools/pkg/utils" + "github.com/kagent-dev/tools/internal/cmd" "github.com/mark3labs/mcp-go/mcp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -38,10 +36,10 @@ func TestHandleGetAvailableAPIResources(t *testing.T) { ctx := context.Background() t.Run("success", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `[{"name": "pods", "singularName": "pod", "namespaced": true, "kind": "Pod"}]` mock.AddCommandString("kubectl", []string{"api-resources", "-o", "json"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(ctx, mock) + ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sTool() @@ -57,9 +55,9 @@ func TestHandleGetAvailableAPIResources(t *testing.T) { }) t.Run("kubectl command failure", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() mock.AddCommandString("kubectl", []string{"api-resources", "-o", "json"}, "", assert.AnError) - ctx := utils.WithShellExecutor(ctx, mock) + ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sTool() @@ -75,10 +73,10 @@ func TestHandleScaleDeployment(t *testing.T) { ctx := context.Background() t.Run("success", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `deployment.apps/test-deployment scaled` mock.AddCommandString("kubectl", []string{"scale", "deployment", "test-deployment", "--replicas", "5", "-n", "default"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(ctx, mock) + ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sTool() @@ -99,8 +97,8 @@ func TestHandleScaleDeployment(t *testing.T) { }) t.Run("missing name parameter", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - ctx := utils.WithShellExecutor(context.Background(), mock) + mock := cmd.NewMockShellExecutor() + ctx := cmd.WithShellExecutor(context.Background(), mock) k8sTool := newTestK8sTool() @@ -122,18 +120,16 @@ func TestHandleScaleDeployment(t *testing.T) { }) t.Run("missing replicas parameter uses default", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `deployment.apps/test-deployment scaled` - // Default replicas is 1 mock.AddCommandString("kubectl", []string{"scale", "deployment", "test-deployment", "--replicas", "1", "-n", "default"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sTool() req := mcp.CallToolRequest{} req.Params.Arguments = map[string]interface{}{ "name": "test-deployment", - // Missing replicas parameter - should use default value of 1 } result, err := k8sTool.handleScaleDeployment(ctx, req) @@ -156,10 +152,10 @@ func TestHandleGetEvents(t *testing.T) { ctx := context.Background() t.Run("success", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `{"items": [{"metadata": {"name": "test-event"}, "message": "Test event message"}]}` mock.AddCommandString("kubectl", []string{"get", "events", "-o", "json", "--all-namespaces"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(ctx, mock) + ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sTool() @@ -174,10 +170,10 @@ func TestHandleGetEvents(t *testing.T) { }) t.Run("with namespace", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `{"items": []}` mock.AddCommandString("kubectl", []string{"get", "events", "-o", "json", "-n", "custom-namespace"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(ctx, mock) + ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sTool() @@ -197,8 +193,8 @@ func TestHandlePatchResource(t *testing.T) { ctx := context.Background() t.Run("missing parameters", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - ctx := utils.WithShellExecutor(context.Background(), mock) + mock := cmd.NewMockShellExecutor() + ctx := cmd.WithShellExecutor(context.Background(), mock) k8sTool := newTestK8sTool() @@ -219,10 +215,10 @@ func TestHandlePatchResource(t *testing.T) { }) t.Run("valid parameters", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `deployment.apps/test-deployment patched` mock.AddCommandString("kubectl", []string{"patch", "deployment", "test-deployment", "-p", `{"spec":{"replicas":5}}`, "-n", "default"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(ctx, mock) + ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sTool() @@ -247,8 +243,8 @@ func TestHandleDeleteResource(t *testing.T) { ctx := context.Background() t.Run("missing parameters", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - ctx := utils.WithShellExecutor(context.Background(), mock) + mock := cmd.NewMockShellExecutor() + ctx := cmd.WithShellExecutor(context.Background(), mock) k8sTool := newTestK8sTool() @@ -269,17 +265,17 @@ func TestHandleDeleteResource(t *testing.T) { }) t.Run("valid parameters", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedOutput := `pod "test-pod" deleted` - mock.AddCommandString("kubectl", []string{"delete", "pod", "test-pod", "-n", "default"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(ctx, mock) + mock := cmd.NewMockShellExecutor() + expectedOutput := `deployment.apps/test-deployment deleted` + mock.AddCommandString("kubectl", []string{"delete", "deployment", "test-deployment", "-n", "default", "--timeout", "30s"}, expectedOutput, nil) + ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sTool() req := mcp.CallToolRequest{} req.Params.Arguments = map[string]interface{}{ - "resource_type": "pod", - "resource_name": "test-pod", + "resource_type": "deployment", + "resource_name": "test-deployment", } result, err := k8sTool.handleDeleteResource(ctx, req) @@ -296,8 +292,8 @@ func TestHandleCheckServiceConnectivity(t *testing.T) { ctx := context.Background() t.Run("missing service_name", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - ctx := utils.WithShellExecutor(context.Background(), mock) + mock := cmd.NewMockShellExecutor() + ctx := cmd.WithShellExecutor(context.Background(), mock) k8sTool := newTestK8sTool() @@ -315,15 +311,15 @@ func TestHandleCheckServiceConnectivity(t *testing.T) { }) t.Run("valid service_name", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() // Mock the pod creation, wait, and exec commands using partial matchers mock.AddPartialMatcherString("kubectl", []string{"run", "*", "--image=curlimages/curl", "-n", "default", "--restart=Never", "--", "sleep", "3600"}, "pod/curl-test-123 created", nil) - mock.AddPartialMatcherString("kubectl", []string{"wait", "--for=condition=ready", "*", "-n", "default", "--timeout=60s"}, "pod/curl-test-123 condition met", nil) + mock.AddPartialMatcherString("kubectl", []string{"wait", "--for=condition=ready", "*", "-n", "default", "--timeout=60s", "--timeout", "30s"}, "pod/curl-test-123 condition met", nil) mock.AddPartialMatcherString("kubectl", []string{"exec", "*", "-n", "default", "--", "curl", "-s", "test-service.default.svc.cluster.local:80"}, "Connection successful", nil) - mock.AddPartialMatcherString("kubectl", []string{"delete", "pod", "*", "-n", "default", "--ignore-not-found"}, "pod deleted", nil) + mock.AddPartialMatcherString("kubectl", []string{"delete", "pod", "*", "-n", "default", "--ignore-not-found", "--timeout", "30s"}, "pod deleted", nil) - ctx := utils.WithShellExecutor(ctx, mock) + ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sTool() @@ -343,8 +339,8 @@ func TestHandleKubectlDescribeTool(t *testing.T) { ctx := context.Background() t.Run("missing parameters", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - ctx := utils.WithShellExecutor(context.Background(), mock) + mock := cmd.NewMockShellExecutor() + ctx := cmd.WithShellExecutor(context.Background(), mock) k8sTool := newTestK8sTool() @@ -365,12 +361,12 @@ func TestHandleKubectlDescribeTool(t *testing.T) { }) t.Run("valid parameters", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `Name: test-deployment Namespace: default Labels: app=test` mock.AddCommandString("kubectl", []string{"describe", "deployment", "test-deployment", "-n", "default"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(ctx, mock) + ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sTool() @@ -395,8 +391,8 @@ func TestHandleKubectlGetEnhanced(t *testing.T) { ctx := context.Background() t.Run("missing resource_type", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - ctx := utils.WithShellExecutor(context.Background(), mock) + mock := cmd.NewMockShellExecutor() + ctx := cmd.WithShellExecutor(context.Background(), mock) k8sTool := newTestK8sTool() req := mcp.CallToolRequest{} @@ -411,10 +407,10 @@ func TestHandleKubectlGetEnhanced(t *testing.T) { }) t.Run("valid resource_type", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `{"items": [{"metadata": {"name": "pod1"}}]}` mock.AddCommandString("kubectl", []string{"get", "pods", "-o", "json"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(ctx, mock) + ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sTool() req := mcp.CallToolRequest{} @@ -430,8 +426,8 @@ func TestHandleKubectlLogsEnhanced(t *testing.T) { ctx := context.Background() t.Run("missing pod_name", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - ctx := utils.WithShellExecutor(context.Background(), mock) + mock := cmd.NewMockShellExecutor() + ctx := cmd.WithShellExecutor(context.Background(), mock) k8sTool := newTestK8sTool() req := mcp.CallToolRequest{} @@ -446,11 +442,11 @@ func TestHandleKubectlLogsEnhanced(t *testing.T) { }) t.Run("valid pod_name", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `log line 1 log line 2` mock.AddCommandString("kubectl", []string{"logs", "test-pod", "-n", "default", "--tail", "50"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(ctx, mock) + ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sTool() req := mcp.CallToolRequest{} @@ -463,8 +459,9 @@ log line 2` } func TestHandleApplyManifest(t *testing.T) { + ctx := context.Background() t.Run("apply manifest from string", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() manifest := `apiVersion: v1 kind: Pod metadata: @@ -476,8 +473,8 @@ spec: expectedOutput := `pod/test-pod created` // Use partial matcher to handle dynamic temp file names - mock.AddPartialMatcherString("kubectl", []string{"apply", "-f", "*"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + mock.AddPartialMatcherString("kubectl", []string{"apply", "-f"}, expectedOutput, nil) + ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sTool() @@ -507,8 +504,8 @@ spec: }) t.Run("missing manifest parameter", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - ctx := utils.WithShellExecutor(context.Background(), mock) + mock := cmd.NewMockShellExecutor() + ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sTool() @@ -530,15 +527,16 @@ spec: } func TestHandleExecCommand(t *testing.T) { + ctx := context.Background() t.Run("exec command in pod", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `total 8 drwxr-xr-x 1 root root 4096 Jan 1 12:00 . drwxr-xr-x 1 root root 4096 Jan 1 12:00 ..` // The implementation passes the command as a single string after -- mock.AddCommandString("kubectl", []string{"exec", "mypod", "-n", "default", "--", "ls -la"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sTool() @@ -566,8 +564,8 @@ drwxr-xr-x 1 root root 4096 Jan 1 12:00 ..` }) t.Run("missing required parameters", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - ctx := utils.WithShellExecutor(context.Background(), mock) + mock := cmd.NewMockShellExecutor() + ctx := cmd.WithShellExecutor(context.Background(), mock) k8sTool := newTestK8sTool() @@ -590,12 +588,13 @@ drwxr-xr-x 1 root root 4096 Jan 1 12:00 ..` } func TestHandleRollout(t *testing.T) { + ctx := context.Background() t.Run("rollout restart deployment", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `deployment.apps/myapp restarted` mock.AddCommandString("kubectl", []string{"rollout", "restart", "deployment/myapp", "-n", "default"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(context.Background(), mock) + ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sTool() @@ -624,8 +623,8 @@ func TestHandleRollout(t *testing.T) { }) t.Run("missing required parameters", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - ctx := utils.WithShellExecutor(context.Background(), mock) + mock := cmd.NewMockShellExecutor() + ctx := cmd.WithShellExecutor(context.Background(), mock) k8sTool := newTestK8sTool() @@ -773,10 +772,10 @@ func TestHandleAnnotateResource(t *testing.T) { ctx := context.Background() t.Run("success", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `deployment.apps/test-deployment annotated` - mock.AddCommandString("kubectl", []string{"annotate", "deployment", "test-deployment", "key1=value1", "key2=value2", "-n", "default"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(ctx, mock) + mock.AddCommandString("kubectl", []string{"annotate", "deployment", "test-deployment", "key1=value1", "key2=value2", "-n", "default", "--timeout", "30s"}, expectedOutput, nil) + ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sTool() @@ -798,8 +797,8 @@ func TestHandleAnnotateResource(t *testing.T) { }) t.Run("missing parameters", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - ctx := utils.WithShellExecutor(context.Background(), mock) + mock := cmd.NewMockShellExecutor() + ctx := cmd.WithShellExecutor(context.Background(), mock) k8sTool := newTestK8sTool() @@ -825,10 +824,10 @@ func TestHandleLabelResource(t *testing.T) { ctx := context.Background() t.Run("success", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `deployment.apps/test-deployment labeled` - mock.AddCommandString("kubectl", []string{"label", "deployment", "test-deployment", "env=prod", "version=1.0", "-n", "default"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(ctx, mock) + mock.AddCommandString("kubectl", []string{"label", "deployment", "test-deployment", "env=prod", "version=1.0", "-n", "default", "--timeout", "30s"}, expectedOutput, nil) + ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sTool() @@ -850,8 +849,8 @@ func TestHandleLabelResource(t *testing.T) { }) t.Run("missing parameters", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - ctx := utils.WithShellExecutor(context.Background(), mock) + mock := cmd.NewMockShellExecutor() + ctx := cmd.WithShellExecutor(context.Background(), mock) k8sTool := newTestK8sTool() @@ -877,10 +876,10 @@ func TestHandleRemoveAnnotation(t *testing.T) { ctx := context.Background() t.Run("success", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `deployment.apps/test-deployment annotated` - mock.AddCommandString("kubectl", []string{"annotate", "deployment", "test-deployment", "key1-", "-n", "default"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(ctx, mock) + mock.AddCommandString("kubectl", []string{"annotate", "deployment", "test-deployment", "key1-", "-n", "default", "--timeout", "30s"}, expectedOutput, nil) + ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sTool() @@ -902,8 +901,8 @@ func TestHandleRemoveAnnotation(t *testing.T) { }) t.Run("missing parameters", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - ctx := utils.WithShellExecutor(context.Background(), mock) + mock := cmd.NewMockShellExecutor() + ctx := cmd.WithShellExecutor(context.Background(), mock) k8sTool := newTestK8sTool() @@ -929,10 +928,10 @@ func TestHandleRemoveLabel(t *testing.T) { ctx := context.Background() t.Run("success", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `deployment.apps/test-deployment labeled` - mock.AddCommandString("kubectl", []string{"label", "deployment", "test-deployment", "env-", "-n", "default"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(ctx, mock) + mock.AddCommandString("kubectl", []string{"label", "deployment", "test-deployment", "env-", "-n", "default", "--timeout", "30s"}, expectedOutput, nil) + ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sTool() @@ -954,8 +953,8 @@ func TestHandleRemoveLabel(t *testing.T) { }) t.Run("missing parameters", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - ctx := utils.WithShellExecutor(context.Background(), mock) + mock := cmd.NewMockShellExecutor() + ctx := cmd.WithShellExecutor(context.Background(), mock) k8sTool := newTestK8sTool() @@ -981,10 +980,10 @@ func TestHandleCreateResourceFromURL(t *testing.T) { ctx := context.Background() t.Run("success", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `deployment.apps/test-deployment created` - mock.AddCommandString("kubectl", []string{"create", "-f", "https://example.com/manifest.yaml", "-n", "default"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(ctx, mock) + mock.AddCommandString("kubectl", []string{"create", "-f", "https://example.com/manifest.yaml", "-n", "default", "--timeout", "30s"}, expectedOutput, nil) + ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sTool() @@ -1004,8 +1003,8 @@ func TestHandleCreateResourceFromURL(t *testing.T) { }) t.Run("missing url parameter", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - ctx := utils.WithShellExecutor(context.Background(), mock) + mock := cmd.NewMockShellExecutor() + ctx := cmd.WithShellExecutor(context.Background(), mock) k8sTool := newTestK8sTool() @@ -1030,7 +1029,7 @@ func TestHandleGetClusterConfiguration(t *testing.T) { ctx := context.Background() t.Run("success", func(t *testing.T) { - mock := utils.NewMockShellExecutor() + mock := cmd.NewMockShellExecutor() expectedOutput := `apiVersion: v1 clusters: - cluster: @@ -1046,8 +1045,8 @@ kind: Config preferences: {} users: - name: default` - mock.AddCommandString("kubectl", []string{"config", "view"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(ctx, mock) + mock.AddCommandString("kubectl", []string{"config", "view", "-o", "json"}, expectedOutput, nil) + ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sTool() @@ -1062,258 +1061,3 @@ users: assert.Contains(t, resultText, "clusters") }) } - -// Test the k8s_create_resource handler (inline function in RegisterK8sTools) -func TestHandleCreateResource(t *testing.T) { - ctx := context.Background() - - t.Run("success", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - yamlContent := `apiVersion: v1 -kind: Pod -metadata: - name: test-pod -spec: - containers: - - name: test - image: nginx` - - expectedOutput := `pod/test-pod created` - // Use partial matcher to handle dynamic temp file names - mock.AddPartialMatcherString("kubectl", []string{"create", "-f", "*"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(ctx, mock) - - // We need to test the inline function from RegisterK8sTools - // Let's create a test handler that mimics the inline function - testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - yamlContent := mcp.ParseString(request, "yaml_content", "") - - if yamlContent == "" { - return mcp.NewToolResultError("yaml_content is required"), nil - } - - // Create temporary file - tmpFile, err := os.CreateTemp("", "k8s-resource-*.yaml") - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to create temp file: %v", err)), nil - } - defer os.Remove(tmpFile.Name()) - - if _, err := tmpFile.WriteString(yamlContent); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to write to temp file: %v", err)), nil - } - tmpFile.Close() - - result, err := utils.RunCommandWithContext(ctx, "kubectl", []string{"create", "-f", tmpFile.Name()}) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Create command failed: %v", err)), nil - } - - return mcp.NewToolResultText(result), nil - } - - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "yaml_content": yamlContent, - } - - result, err := testHandler(ctx, req) - assert.NoError(t, err) - assert.NotNil(t, result) - assert.False(t, result.IsError) - - // Verify the expected output - content := getResultText(result) - assert.Contains(t, content, "created") - - // Verify kubectl create was called - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "kubectl", callLog[0].Command) - assert.Len(t, callLog[0].Args, 3) // create, -f, - assert.Equal(t, "create", callLog[0].Args[0]) - assert.Equal(t, "-f", callLog[0].Args[1]) - // Third argument should be the temporary file path - assert.Contains(t, callLog[0].Args[2], "k8s-resource-") - }) - - t.Run("missing yaml_content parameter", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - ctx := utils.WithShellExecutor(context.Background(), mock) - - // Test handler for missing parameter - testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - yamlContent := mcp.ParseString(request, "yaml_content", "") - - if yamlContent == "" { - return mcp.NewToolResultError("yaml_content is required"), nil - } - return mcp.NewToolResultText("should not reach here"), nil - } - - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - // Missing yaml_content parameter - } - - result, err := testHandler(ctx, req) - assert.NoError(t, err) - assert.NotNil(t, result) - assert.True(t, result.IsError) - assert.Contains(t, getResultText(result), "yaml_content is required") - - // Verify no commands were executed - callLog := mock.GetCallLog() - assert.Len(t, callLog, 0) - }) -} - -// Test the k8s_get_resource_yaml handler (inline function in RegisterK8sTools) -func TestHandleGetResourceYAML(t *testing.T) { - ctx := context.Background() - - t.Run("success", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedOutput := `apiVersion: v1 -kind: Pod -metadata: - name: test-pod - namespace: default -spec: - containers: - - name: test - image: nginx` - mock.AddCommandString("kubectl", []string{"get", "pod", "test-pod", "-o", "yaml", "-n", "default"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(ctx, mock) - - // Test handler that mimics the inline function - testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - resourceType := mcp.ParseString(request, "resource_type", "") - resourceName := mcp.ParseString(request, "resource_name", "") - namespace := mcp.ParseString(request, "namespace", "") - - if resourceType == "" || resourceName == "" { - return mcp.NewToolResultError("resource_type and resource_name are required"), nil - } - - args := []string{"get", resourceType, resourceName, "-o", "yaml"} - if namespace != "" { - args = append(args, "-n", namespace) - } - - result, err := utils.RunCommandWithContext(ctx, "kubectl", args) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Get YAML command failed: %v", err)), nil - } - - return mcp.NewToolResultText(result), nil - } - - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "resource_type": "pod", - "resource_name": "test-pod", - "namespace": "default", - } - - result, err := testHandler(ctx, req) - assert.NoError(t, err) - assert.NotNil(t, result) - assert.False(t, result.IsError) - - resultText := getResultText(result) - assert.Contains(t, resultText, "test-pod") - assert.Contains(t, resultText, "apiVersion") - - // Verify the correct kubectl command was called - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "kubectl", callLog[0].Command) - assert.Equal(t, []string{"get", "pod", "test-pod", "-o", "yaml", "-n", "default"}, callLog[0].Args) - }) - - t.Run("missing parameters", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - ctx := utils.WithShellExecutor(context.Background(), mock) - - testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - resourceType := mcp.ParseString(request, "resource_type", "") - resourceName := mcp.ParseString(request, "resource_name", "") - - if resourceType == "" || resourceName == "" { - return mcp.NewToolResultError("resource_type and resource_name are required"), nil - } - return mcp.NewToolResultText("should not reach here"), nil - } - - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "resource_type": "pod", - // Missing resource_name - } - - result, err := testHandler(ctx, req) - assert.NoError(t, err) - assert.NotNil(t, result) - assert.True(t, result.IsError) - assert.Contains(t, getResultText(result), "resource_type and resource_name are required") - - // Verify no commands were executed - callLog := mock.GetCallLog() - assert.Len(t, callLog, 0) - }) - - t.Run("without namespace", func(t *testing.T) { - mock := utils.NewMockShellExecutor() - expectedOutput := `apiVersion: v1 -kind: ClusterRole -metadata: - name: test-cluster-role` - mock.AddCommandString("kubectl", []string{"get", "clusterrole", "test-cluster-role", "-o", "yaml"}, expectedOutput, nil) - ctx := utils.WithShellExecutor(ctx, mock) - - testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - resourceType := mcp.ParseString(request, "resource_type", "") - resourceName := mcp.ParseString(request, "resource_name", "") - namespace := mcp.ParseString(request, "namespace", "") - - if resourceType == "" || resourceName == "" { - return mcp.NewToolResultError("resource_type and resource_name are required"), nil - } - - args := []string{"get", resourceType, resourceName, "-o", "yaml"} - if namespace != "" { - args = append(args, "-n", namespace) - } - - result, err := utils.RunCommandWithContext(ctx, "kubectl", args) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Get YAML command failed: %v", err)), nil - } - - return mcp.NewToolResultText(result), nil - } - - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "resource_type": "clusterrole", - "resource_name": "test-cluster-role", - // No namespace for cluster-scoped resource - } - - result, err := testHandler(ctx, req) - assert.NoError(t, err) - assert.NotNil(t, result) - assert.False(t, result.IsError) - - resultText := getResultText(result) - assert.Contains(t, resultText, "test-cluster-role") - assert.Contains(t, resultText, "ClusterRole") - - // Verify the correct kubectl command was called (without namespace) - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "kubectl", callLog[0].Command) - assert.Equal(t, []string{"get", "clusterrole", "test-cluster-role", "-o", "yaml"}, callLog[0].Args) - }) -} diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go deleted file mode 100644 index 062997d..0000000 --- a/pkg/logger/logger.go +++ /dev/null @@ -1,58 +0,0 @@ -package logger - -import ( - "github.com/go-logr/logr" - "github.com/go-logr/stdr" -) - -var globalLogger logr.Logger - -// Init initializes the global logger with appropriate configuration -func Init() { - // Set log level from environment variable (not directly supported by stdr, but can be extended) - // For now, just use stdr with default settings - globalLogger = stdr.New(nil) -} - -// Get returns the global logger instance -func Get() logr.Logger { - if globalLogger.GetSink() == nil { - Init() - } - return globalLogger -} - -// LogExecCommand logs information about an exec command being executed -func LogExecCommand(command string, args []string, caller string) { - logger := Get() - logger.Info("executing command", - "command", command, - "args", args, - "caller", caller, - ) -} - -// LogExecCommandResult logs the result of an exec command -func LogExecCommandResult(command string, args []string, output string, err error, duration float64, caller string) { - logger := Get() - - if err != nil { - logger.Error(err, "command execution failed", - "command", command, - "args", args, - "duration_seconds", duration, - "caller", caller, - ) - } else { - logger.Info("command execution successful", - "command", command, - "args", args, - "output", output, - "duration_seconds", duration, - "caller", caller, - ) - } -} - -// Sync is a no-op for logr/stdr -func Sync() {} diff --git a/pkg/logger/logger_test.go b/pkg/logger/logger_test.go deleted file mode 100644 index 24903d0..0000000 --- a/pkg/logger/logger_test.go +++ /dev/null @@ -1,83 +0,0 @@ -package logger - -import ( - "os" - "testing" - - "github.com/go-logr/logr" - "github.com/stretchr/testify/assert" -) - -func TestInit(t *testing.T) { - // Test initialization - Init() - assert.NotNil(t, globalLogger) -} - -func TestGet(t *testing.T) { - // Reset global logger - globalLogger = logr.Logger{} - - // Test Get without Init - logger := Get() - assert.NotNil(t, logger) - assert.NotNil(t, globalLogger) -} - -func TestLogExecCommand(t *testing.T) { - // Just test that it does not panic and logs - assert.NotPanics(t, func() { - LogExecCommand("test-command", []string{"arg1", "arg2"}, "test.go:123") - }) -} - -func TestLogExecCommandResult(t *testing.T) { - // Test successful command - assert.NotPanics(t, func() { - LogExecCommandResult("test-command", []string{"arg1"}, "success output", nil, 1.5, "test.go:123") - }) - // Test failed command - assert.NotPanics(t, func() { - LogExecCommandResult("test-command", []string{"arg1"}, "error output", assert.AnError, 0.5, "test.go:123") - }) -} - -func TestEnvironmentVariables(t *testing.T) { - // Test log level from environment (no-op for stdr) - os.Setenv("KAGENT_LOG_LEVEL", "debug") - defer os.Unsetenv("KAGENT_LOG_LEVEL") - - // Reset global logger - globalLogger = logr.Logger{} - - // Initialize with environment variable - Init() - - // Just check logger is set - assert.NotNil(t, globalLogger) -} - -func TestDevelopmentMode(t *testing.T) { - // Test development mode (no-op for stdr) - os.Setenv("KAGENT_ENV", "development") - defer os.Unsetenv("KAGENT_ENV") - - // Reset global logger - globalLogger = logr.Logger{} - - // Initialize in development mode - Init() - - // In development mode, the logger should be configured (no panic) - assert.NotNil(t, globalLogger) -} - -func TestSync(t *testing.T) { - // Test Sync function - Init() - - // Sync should not panic - assert.NotPanics(t, func() { - Sync() - }) -} diff --git a/pkg/prometheus/prometheus.go b/pkg/prometheus/prometheus.go index dc4d6cb..1239305 100644 --- a/pkg/prometheus/prometheus.go +++ b/pkg/prometheus/prometheus.go @@ -9,6 +9,9 @@ import ( "net/url" "time" + "github.com/kagent-dev/tools/internal/errors" + "github.com/kagent-dev/tools/internal/security" + "github.com/kagent-dev/tools/internal/telemetry" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" ) @@ -33,6 +36,16 @@ func handlePrometheusQueryTool(ctx context.Context, request mcp.CallToolRequest) return mcp.NewToolResultError("query parameter is required"), nil } + // Validate prometheus URL + if err := security.ValidateURL(prometheusURL); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid Prometheus URL: %v", err)), nil + } + + // Validate PromQL query + if err := security.ValidatePromQLQuery(query); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid PromQL query: %v", err)), nil + } + // Make request to Prometheus API apiURL := fmt.Sprintf("%s/api/v1/query", prometheusURL) params := url.Values{} @@ -42,19 +55,40 @@ func handlePrometheusQueryTool(ctx context.Context, request mcp.CallToolRequest) fullURL := fmt.Sprintf("%s?%s", apiURL, params.Encode()) client := getHTTPClient(ctx) - resp, err := client.Get(fullURL) + req, err := http.NewRequestWithContext(ctx, "GET", fullURL, nil) if err != nil { - return mcp.NewToolResultError("failed to query Prometheus: " + err.Error()), nil + toolErr := errors.NewPrometheusError("create_request", err). + WithContext("prometheus_url", prometheusURL). + WithContext("query", query) + return toolErr.ToMCPResult(), nil + } + + resp, err := client.Do(req) + if err != nil { + toolErr := errors.NewPrometheusError("query_execution", err). + WithContext("prometheus_url", prometheusURL). + WithContext("query", query). + WithContext("api_url", apiURL) + return toolErr.ToMCPResult(), nil } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - return mcp.NewToolResultError("failed to read response: " + err.Error()), nil + toolErr := errors.NewPrometheusError("read_response", err). + WithContext("prometheus_url", prometheusURL). + WithContext("query", query). + WithContext("status_code", resp.StatusCode) + return toolErr.ToMCPResult(), nil } if resp.StatusCode != http.StatusOK { - return mcp.NewToolResultError(fmt.Sprintf("Prometheus API error (%d): %s", resp.StatusCode, string(body))), nil + toolErr := errors.NewPrometheusError("api_error", fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))). + WithContext("prometheus_url", prometheusURL). + WithContext("query", query). + WithContext("status_code", resp.StatusCode). + WithContext("response_body", string(body)) + return toolErr.ToMCPResult(), nil } // Parse the JSON response to pretty-print it @@ -82,6 +116,33 @@ func handlePrometheusRangeQueryTool(ctx context.Context, request mcp.CallToolReq return mcp.NewToolResultError("query parameter is required"), nil } + // Validate prometheus URL + if err := security.ValidateURL(prometheusURL); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid Prometheus URL: %v", err)), nil + } + + // Validate PromQL query + if err := security.ValidatePromQLQuery(query); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid PromQL query: %v", err)), nil + } + + // Validate time parameters if provided + if start != "" { + if err := security.ValidateCommandInput(start); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid start time: %v", err)), nil + } + } + if end != "" { + if err := security.ValidateCommandInput(end); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid end time: %v", err)), nil + } + } + if step != "" { + if err := security.ValidateCommandInput(step); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid step parameter: %v", err)), nil + } + } + // Use default time range if not specified if start == "" { start = fmt.Sprintf("%d", time.Now().Add(-1*time.Hour).Unix()) @@ -101,7 +162,12 @@ func handlePrometheusRangeQueryTool(ctx context.Context, request mcp.CallToolReq fullURL := fmt.Sprintf("%s?%s", apiURL, params.Encode()) client := getHTTPClient(ctx) - resp, err := client.Get(fullURL) + req, err := http.NewRequestWithContext(ctx, "GET", fullURL, nil) + if err != nil { + return mcp.NewToolResultError("failed to create request: " + err.Error()), nil + } + + resp, err := client.Do(req) if err != nil { return mcp.NewToolResultError("failed to query Prometheus: " + err.Error()), nil } @@ -133,23 +199,48 @@ func handlePrometheusRangeQueryTool(ctx context.Context, request mcp.CallToolReq func handlePrometheusLabelsQueryTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { prometheusURL := mcp.ParseString(request, "prometheus_url", "http://localhost:9090") + // Validate prometheus URL + if err := security.ValidateURL(prometheusURL); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid Prometheus URL: %v", err)), nil + } + // Make request to Prometheus API for labels apiURL := fmt.Sprintf("%s/api/v1/labels", prometheusURL) client := getHTTPClient(ctx) - resp, err := client.Get(apiURL) + req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil) if err != nil { - return mcp.NewToolResultError("failed to query Prometheus: " + err.Error()), nil + toolErr := errors.NewPrometheusError("create_request", err). + WithContext("prometheus_url", prometheusURL). + WithContext("api_url", apiURL) + return toolErr.ToMCPResult(), nil + } + + resp, err := client.Do(req) + if err != nil { + toolErr := errors.NewPrometheusError("query_execution", err). + WithContext("prometheus_url", prometheusURL). + WithContext("api_url", apiURL) + return toolErr.ToMCPResult(), nil } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - return mcp.NewToolResultError("failed to read response: " + err.Error()), nil + toolErr := errors.NewPrometheusError("read_response", err). + WithContext("prometheus_url", prometheusURL). + WithContext("api_url", apiURL). + WithContext("status_code", resp.StatusCode) + return toolErr.ToMCPResult(), nil } if resp.StatusCode != http.StatusOK { - return mcp.NewToolResultError(fmt.Sprintf("Prometheus API error (%d): %s", resp.StatusCode, string(body))), nil + toolErr := errors.NewPrometheusError("api_error", fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))). + WithContext("prometheus_url", prometheusURL). + WithContext("api_url", apiURL). + WithContext("status_code", resp.StatusCode). + WithContext("response_body", string(body)) + return toolErr.ToMCPResult(), nil } // Parse the JSON response to pretty-print it @@ -169,11 +260,21 @@ func handlePrometheusLabelsQueryTool(ctx context.Context, request mcp.CallToolRe func handlePrometheusTargetsQueryTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { prometheusURL := mcp.ParseString(request, "prometheus_url", "http://localhost:9090") + // Validate prometheus URL + if err := security.ValidateURL(prometheusURL); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid Prometheus URL: %v", err)), nil + } + // Make request to Prometheus API for targets apiURL := fmt.Sprintf("%s/api/v1/targets", prometheusURL) client := getHTTPClient(ctx) - resp, err := client.Get(apiURL) + req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil) + if err != nil { + return mcp.NewToolResultError("failed to create request: " + err.Error()), nil + } + + resp, err := client.Do(req) if err != nil { return mcp.NewToolResultError("failed to query Prometheus: " + err.Error()), nil } @@ -202,12 +303,12 @@ func handlePrometheusTargetsQueryTool(ctx context.Context, request mcp.CallToolR return mcp.NewToolResultText(string(prettyJSON)), nil } -func RegisterPrometheusTools(s *server.MCPServer, kubeconfig string) { +func RegisterTools(s *server.MCPServer) { s.AddTool(mcp.NewTool("prometheus_query_tool", mcp.WithDescription("Execute a PromQL query against Prometheus"), mcp.WithString("query", mcp.Description("PromQL query to execute"), mcp.Required()), mcp.WithString("prometheus_url", mcp.Description("Prometheus server URL (default: http://localhost:9090)")), - ), handlePrometheusQueryTool) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("prometheus_query_tool", handlePrometheusQueryTool))) s.AddTool(mcp.NewTool("prometheus_query_range_tool", mcp.WithDescription("Execute a PromQL range query against Prometheus"), @@ -216,20 +317,20 @@ func RegisterPrometheusTools(s *server.MCPServer, kubeconfig string) { mcp.WithString("end", mcp.Description("End time (Unix timestamp or relative time)")), mcp.WithString("step", mcp.Description("Query resolution step (default: 15s)")), mcp.WithString("prometheus_url", mcp.Description("Prometheus server URL (default: http://localhost:9090)")), - ), handlePrometheusRangeQueryTool) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("prometheus_query_range_tool", handlePrometheusRangeQueryTool))) s.AddTool(mcp.NewTool("prometheus_label_names_tool", mcp.WithDescription("Get all available labels from Prometheus"), mcp.WithString("prometheus_url", mcp.Description("Prometheus server URL (default: http://localhost:9090)")), - ), handlePrometheusLabelsQueryTool) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("prometheus_label_names_tool", handlePrometheusLabelsQueryTool))) s.AddTool(mcp.NewTool("prometheus_targets_tool", mcp.WithDescription("Get all Prometheus targets and their status"), mcp.WithString("prometheus_url", mcp.Description("Prometheus server URL (default: http://localhost:9090)")), - ), handlePrometheusTargetsQueryTool) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("prometheus_targets_tool", handlePrometheusTargetsQueryTool))) s.AddTool(mcp.NewTool("prometheus_promql_tool", mcp.WithDescription("Generate a PromQL query"), mcp.WithString("query_description", mcp.Description("A string describing the query to generate"), mcp.Required()), - ), handlePromql) + ), telemetry.AdaptToolHandler(telemetry.WithTracing("prometheus_promql_tool", handlePromql))) } diff --git a/pkg/prometheus/prometheus_test.go b/pkg/prometheus/prometheus_test.go index d51b52e..647d1f3 100644 --- a/pkg/prometheus/prometheus_test.go +++ b/pkg/prometheus/prometheus_test.go @@ -122,7 +122,7 @@ func TestHandlePrometheusQueryTool(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) - assert.Contains(t, getResultText(result), "failed to query Prometheus") + assert.Contains(t, getResultText(result), "**Prometheus Error**") }) t.Run("HTTP 500 error", func(t *testing.T) { @@ -139,7 +139,7 @@ func TestHandlePrometheusQueryTool(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) - assert.Contains(t, getResultText(result), "Prometheus API error (500)") + assert.Contains(t, getResultText(result), "**Prometheus Error**") }) t.Run("malformed JSON response", func(t *testing.T) { @@ -283,7 +283,7 @@ func TestHandlePrometheusLabelsQueryTool(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) - assert.Contains(t, getResultText(result), "failed to query Prometheus") + assert.Contains(t, getResultText(result), "**Prometheus Error**") }) t.Run("custom prometheus URL", func(t *testing.T) { diff --git a/pkg/utils/common.go b/pkg/utils/common.go index d8be795..ce8b73b 100644 --- a/pkg/utils/common.go +++ b/pkg/utils/common.go @@ -3,314 +3,49 @@ package utils import ( "context" "fmt" - "os/exec" - "runtime" "strings" + "sync" "time" - "github.com/kagent-dev/tools/pkg/logger" + "github.com/kagent-dev/tools/internal/commands" + "github.com/kagent-dev/tools/internal/logger" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/codes" - "go.opentelemetry.io/otel/metric" ) -// ShellExecutor defines the interface for executing shell commands -type ShellExecutor interface { - Exec(ctx context.Context, command string, args ...string) (output []byte, err error) +// KubeConfigManager manages kubeconfig path with thread safety +type KubeConfigManager struct { + mu sync.RWMutex + kubeconfigPath string } -// DefaultShellExecutor implements ShellExecutor using os/exec -type DefaultShellExecutor struct{} +// globalKubeConfigManager is the singleton instance +var globalKubeConfigManager = &KubeConfigManager{} -// Exec executes a command using os/exec.CommandContext -func (e *DefaultShellExecutor) Exec(ctx context.Context, command string, args ...string) ([]byte, error) { - cmd := exec.CommandContext(ctx, command, args...) - return cmd.CombinedOutput() -} - -// MockShellExecutor implements ShellExecutor for testing -type MockShellExecutor struct { - // Commands maps command+args to expected output and error - Commands map[string]MockCommandResult - // CallLog keeps track of all executed commands for verification - CallLog []MockCommandCall - // PartialMatchers allows partial matching for dynamic arguments - PartialMatchers []PartialMatcher -} - -// PartialMatcher represents a partial command matcher for dynamic arguments -type PartialMatcher struct { - Command string - Args []string // Use "*" for wildcard matching - Result MockCommandResult -} - -// MockCommandResult represents the expected result of a mocked command -type MockCommandResult struct { - Output []byte - Error error -} - -// MockCommandCall represents a logged command execution -type MockCommandCall struct { - Command string - Args []string -} - -// Exec executes a mocked command -func (m *MockShellExecutor) Exec(ctx context.Context, command string, args ...string) ([]byte, error) { - // Log the call - m.CallLog = append(m.CallLog, MockCommandCall{ - Command: command, - Args: args, - }) - - // Try exact match first - key := m.commandKey(command, args...) - if result, exists := m.Commands[key]; exists { - return result.Output, result.Error - } - - // Try partial matchers - for _, matcher := range m.PartialMatchers { - if m.matchesPartial(command, args, matcher) { - return matcher.Result.Output, matcher.Result.Error - } - } - - // Default behavior for unmocked commands - return []byte(""), fmt.Errorf("unmocked command: %s %v", command, args) -} - -// matchesPartial checks if a command matches a partial matcher -func (m *MockShellExecutor) matchesPartial(command string, args []string, matcher PartialMatcher) bool { - if command != matcher.Command { - return false - } - - if len(args) != len(matcher.Args) { - return false - } +// SetKubeconfig sets the global kubeconfig path in a thread-safe manner +func SetKubeconfig(path string) { + globalKubeConfigManager.mu.Lock() + defer globalKubeConfigManager.mu.Unlock() - for i, expectedArg := range matcher.Args { - if expectedArg == "*" { - continue // Wildcard match - } - if args[i] != expectedArg { - return false - } - } - - return true + globalKubeConfigManager.kubeconfigPath = path + logger.Get().Info("Setting shared kubeconfig", "path", path) } -// AddCommand adds a command mock -func (m *MockShellExecutor) AddCommand(command string, args []string, output []byte, err error) { - if m.Commands == nil { - m.Commands = make(map[string]MockCommandResult) - } - key := m.commandKey(command, args...) - m.Commands[key] = MockCommandResult{ - Output: output, - Error: err, - } -} +// GetKubeconfig returns the global kubeconfig path in a thread-safe manner +func GetKubeconfig() string { + globalKubeConfigManager.mu.RLock() + defer globalKubeConfigManager.mu.RUnlock() -// AddCommandString is a convenience method for adding string output -func (m *MockShellExecutor) AddCommandString(command string, args []string, output string, err error) { - m.AddCommand(command, args, []byte(output), err) + return globalKubeConfigManager.kubeconfigPath } -// AddPartialMatcher adds a partial matcher for dynamic arguments -func (m *MockShellExecutor) AddPartialMatcher(command string, args []string, output []byte, err error) { - if m.PartialMatchers == nil { - m.PartialMatchers = []PartialMatcher{} +// AddKubeconfigArgs adds kubeconfig arguments to command args if configured +func AddKubeconfigArgs(args []string) []string { + kubeconfigPath := GetKubeconfig() + if kubeconfigPath != "" { + return append([]string{"--kubeconfig", kubeconfigPath}, args...) } - m.PartialMatchers = append(m.PartialMatchers, PartialMatcher{ - Command: command, - Args: args, - Result: MockCommandResult{ - Output: output, - Error: err, - }, - }) -} - -// AddPartialMatcherString is a convenience method for adding string output with partial matching -func (m *MockShellExecutor) AddPartialMatcherString(command string, args []string, output string, err error) { - m.AddPartialMatcher(command, args, []byte(output), err) -} - -// GetCallLog returns the log of all command calls -func (m *MockShellExecutor) GetCallLog() []MockCommandCall { - return m.CallLog -} - -// Reset clears the mock state -func (m *MockShellExecutor) Reset() { - m.Commands = make(map[string]MockCommandResult) - m.CallLog = []MockCommandCall{} - m.PartialMatchers = []PartialMatcher{} -} - -// commandKey creates a unique key for command+args combination -func (m *MockShellExecutor) commandKey(command string, args ...string) string { - return fmt.Sprintf("%s %s", command, strings.Join(args, " ")) -} - -// Context key for shell executor injection -type contextKey string - -const shellExecutorKey contextKey = "shellExecutor" - -// WithShellExecutor returns a context with the given shell executor -func WithShellExecutor(ctx context.Context, executor ShellExecutor) context.Context { - return context.WithValue(ctx, shellExecutorKey, executor) -} - -// GetShellExecutor retrieves the shell executor from context, or returns default -func GetShellExecutor(ctx context.Context) ShellExecutor { - if executor, ok := ctx.Value(shellExecutorKey).(ShellExecutor); ok { - return executor - } - return &DefaultShellExecutor{} -} - -// NewMockShellExecutor creates a new mock shell executor for testing -func NewMockShellExecutor() *MockShellExecutor { - return &MockShellExecutor{ - Commands: make(map[string]MockCommandResult), - CallLog: []MockCommandCall{}, - PartialMatchers: []PartialMatcher{}, - } -} - -var ( - tracer = otel.Tracer("kagent-tools") - meter = otel.Meter("kagent-tools") - - // Metrics - commandExecutionCounter metric.Int64Counter - commandExecutionDuration metric.Float64Histogram - commandExecutionErrors metric.Int64Counter -) - -func init() { - // Initialize metrics (these are safe to call even if OTEL is not configured) - var err error - - commandExecutionCounter, err = meter.Int64Counter( - "command_executions_total", - metric.WithDescription("Total number of command executions"), - ) - if err != nil { - logger.Get().Error(err, "Failed to create command execution counter") - } - - commandExecutionDuration, err = meter.Float64Histogram( - "command_execution_duration_seconds", - metric.WithDescription("Duration of command executions in seconds"), - metric.WithUnit("s"), - ) - if err != nil { - logger.Get().Error(err, "Failed to create command execution duration histogram") - } - - commandExecutionErrors, err = meter.Int64Counter( - "command_execution_errors_total", - metric.WithDescription("Total number of command execution errors"), - ) - if err != nil { - logger.Get().Error(err, "Failed to create command execution errors counter") - } -} - -// RunCommand executes a command and returns output or error with OTEL tracing -func RunCommand(command string, args []string) (string, error) { - return RunCommandWithContext(context.Background(), command, args) -} - -// RunCommandWithContext executes a command with context and returns output or error with OTEL tracing -func RunCommandWithContext(ctx context.Context, command string, args []string) (string, error) { - // Get caller information for tracing - _, file, line, _ := runtime.Caller(1) - caller := fmt.Sprintf("%s:%d", file, line) - - // Start OpenTelemetry span - spanName := fmt.Sprintf("exec.%s", command) - ctx, span := tracer.Start(ctx, spanName) - defer span.End() - - // Set span attributes - span.SetAttributes( - attribute.String("command", command), - attribute.StringSlice("args", args), - attribute.String("caller", caller), - ) - - // Record metrics - startTime := time.Now() - - // Use the shell executor from context (or default) - executor := GetShellExecutor(ctx) - output, err := executor.Exec(ctx, command, args...) - - duration := time.Since(startTime) - - // Set additional span attributes with results - span.SetAttributes( - attribute.Float64("duration_seconds", duration.Seconds()), - attribute.Int("output_size", len(output)), - ) - - // Record metrics - attributes := []attribute.KeyValue{ - attribute.String("command", command), - attribute.Bool("success", err == nil), - } - - if commandExecutionCounter != nil { - commandExecutionCounter.Add(ctx, 1, metric.WithAttributes(attributes...)) - } - - if commandExecutionDuration != nil { - commandExecutionDuration.Record(ctx, duration.Seconds(), metric.WithAttributes(attributes...)) - } - - if err != nil { - // Set span status and record error - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) - span.SetAttributes(attribute.String("error", err.Error())) - - if commandExecutionErrors != nil { - commandExecutionErrors.Add(ctx, 1, metric.WithAttributes(attributes...)) - } - - logger.Get().Error(err, "CommandExec failed", - "command", command, - "args", args, - "duration", duration, - "caller", caller, - ) - return "", fmt.Errorf("command %s failed: %v", command, err) - } - - // Set successful span status - span.SetStatus(codes.Ok, "CommandExec") - - logger.Get().Info("CommandExec", - "command", command, - "args", args, - "duration", duration, - "outputSize", len(output), - "caller", caller, - ) - - return strings.TrimSpace(string(output)), nil + return args } // shellTool provides shell command execution functionality @@ -328,10 +63,21 @@ func shellTool(ctx context.Context, params shellParams) (string, error) { cmd := parts[0] args := parts[1:] - return RunCommandWithContext(ctx, cmd, args) + return commands.NewCommandBuilder(cmd).WithArgs(args...).Execute(ctx) +} + +// handleGetCurrentDateTimeTool provides datetime functionality for both MCP and testing +func handleGetCurrentDateTimeTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Returns the current date and time in ISO 8601 format (RFC3339) + // This matches the Python implementation: datetime.datetime.now().isoformat() + now := time.Now() + return mcp.NewToolResultText(now.Format(time.RFC3339)), nil } -func RegisterCommonTools(s *server.MCPServer) { +func RegisterTools(s *server.MCPServer) { + logger.Get().Info("RegisterTools initialized") + + // Register shell tool s.AddTool(mcp.NewTool("shell", mcp.WithDescription("Execute shell commands"), mcp.WithString("command", mcp.Description("The shell command to execute"), mcp.Required()), @@ -350,5 +96,10 @@ func RegisterCommonTools(s *server.MCPServer) { return mcp.NewToolResultText(result), nil }) + // Register datetime tool + s.AddTool(mcp.NewTool("datetime_get_current_time", + mcp.WithDescription("Returns the current date and time in ISO 8601 format."), + ), handleGetCurrentDateTimeTool) + // Note: LLM Tool implementation would go here if needed } diff --git a/pkg/utils/common_test.go b/pkg/utils/common_test.go deleted file mode 100644 index e21cf76..0000000 --- a/pkg/utils/common_test.go +++ /dev/null @@ -1,288 +0,0 @@ -package utils - -import ( - "context" - "errors" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestDefaultShellExecutor(t *testing.T) { - executor := &DefaultShellExecutor{} - - // Test successful command - output, err := executor.Exec(context.Background(), "echo", "hello") - assert.NoError(t, err) - assert.Equal(t, "hello\n", string(output)) - - // Test command with error - output, err = executor.Exec(context.Background(), "nonexistent-command") - assert.Error(t, err) - assert.Empty(t, output) -} - -func TestMockShellExecutor(t *testing.T) { - mock := NewMockShellExecutor() - - t.Run("unmocked command returns error", func(t *testing.T) { - output, err := mock.Exec(context.Background(), "unmocked", "command") - assert.Error(t, err) - assert.Contains(t, err.Error(), "unmocked command") - assert.Empty(t, output) - }) - - t.Run("mocked command returns expected result", func(t *testing.T) { - expectedOutput := "mocked output" - mock.AddCommandString("kubectl", []string{"get", "pods"}, expectedOutput, nil) - - output, err := mock.Exec(context.Background(), "kubectl", "get", "pods") - assert.NoError(t, err) - assert.Equal(t, expectedOutput, string(output)) - }) - - t.Run("mocked command with error", func(t *testing.T) { - expectedError := errors.New("mocked error") - mock.AddCommandString("helm", []string{"install", "app"}, "", expectedError) - - output, err := mock.Exec(context.Background(), "helm", "install", "app") - assert.Error(t, err) - assert.Equal(t, expectedError, err) - assert.Empty(t, output) - }) - - t.Run("call log tracking", func(t *testing.T) { - mock.Reset() - - // Execute some commands - mock.AddCommandString("cmd1", []string{"arg1"}, "output1", nil) - mock.AddCommandString("cmd2", []string{"arg2", "arg3"}, "output2", nil) - - _, _ = mock.Exec(context.Background(), "cmd1", "arg1") - _, _ = mock.Exec(context.Background(), "cmd2", "arg2", "arg3") - _, _ = mock.Exec(context.Background(), "unmocked", "command") - - callLog := mock.GetCallLog() - require.Len(t, callLog, 3) - - assert.Equal(t, "cmd1", callLog[0].Command) - assert.Equal(t, []string{"arg1"}, callLog[0].Args) - - assert.Equal(t, "cmd2", callLog[1].Command) - assert.Equal(t, []string{"arg2", "arg3"}, callLog[1].Args) - - assert.Equal(t, "unmocked", callLog[2].Command) - assert.Equal(t, []string{"command"}, callLog[2].Args) - }) - - t.Run("reset functionality", func(t *testing.T) { - // Create a fresh mock for this test - freshMock := NewMockShellExecutor() - freshMock.AddCommandString("test", []string{}, "output", nil) - _, _ = freshMock.Exec(context.Background(), "test") - - assert.Len(t, freshMock.Commands, 1) - assert.Len(t, freshMock.CallLog, 1) - - freshMock.Reset() - - assert.Len(t, freshMock.Commands, 0) - assert.Len(t, freshMock.CallLog, 0) - }) -} - -func TestContextShellExecutor(t *testing.T) { - t.Run("default executor when no context value", func(t *testing.T) { - ctx := context.Background() - executor := GetShellExecutor(ctx) - - _, ok := executor.(*DefaultShellExecutor) - assert.True(t, ok, "should return DefaultShellExecutor when no context value") - }) - - t.Run("mock executor from context", func(t *testing.T) { - mock := NewMockShellExecutor() - ctx := WithShellExecutor(context.Background(), mock) - - executor := GetShellExecutor(ctx) - assert.Equal(t, mock, executor, "should return the mock executor from context") - }) - - t.Run("context propagation", func(t *testing.T) { - mock := NewMockShellExecutor() - mock.AddCommandString("test", []string{"arg"}, "test output", nil) - - ctx := WithShellExecutor(context.Background(), mock) - - // Test that RunCommandWithContext uses the mock - output, err := RunCommandWithContext(ctx, "test", []string{"arg"}) - assert.NoError(t, err) - assert.Equal(t, "test output", output) - - // Verify the command was logged - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "test", callLog[0].Command) - assert.Equal(t, []string{"arg"}, callLog[0].Args) - }) -} - -func TestRunCommandWithMocking(t *testing.T) { - t.Run("successful command execution with mock", func(t *testing.T) { - mock := NewMockShellExecutor() - mock.AddCommandString("kubectl", []string{"get", "pods", "-n", "default"}, "pod1\npod2", nil) - - ctx := WithShellExecutor(context.Background(), mock) - - output, err := RunCommandWithContext(ctx, "kubectl", []string{"get", "pods", "-n", "default"}) - assert.NoError(t, err) - assert.Equal(t, "pod1\npod2", output) - - // Verify command was called - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "kubectl", callLog[0].Command) - assert.Equal(t, []string{"get", "pods", "-n", "default"}, callLog[0].Args) - }) - - t.Run("command failure with mock", func(t *testing.T) { - mock := NewMockShellExecutor() - expectedError := errors.New("command failed") - mock.AddCommandString("helm", []string{"install", "app"}, "", expectedError) - - ctx := WithShellExecutor(context.Background(), mock) - - output, err := RunCommandWithContext(ctx, "helm", []string{"install", "app"}) - assert.Error(t, err) - assert.Contains(t, err.Error(), "command helm failed") - assert.Empty(t, output) - }) - - t.Run("multiple commands with mock", func(t *testing.T) { - mock := NewMockShellExecutor() - mock.AddCommandString("kubectl", []string{"get", "pods"}, "pod-list", nil) - mock.AddCommandString("kubectl", []string{"get", "services"}, "service-list", nil) - mock.AddCommandString("helm", []string{"list"}, "helm-releases", nil) - - ctx := WithShellExecutor(context.Background(), mock) - - // Execute multiple commands - output1, err1 := RunCommandWithContext(ctx, "kubectl", []string{"get", "pods"}) - assert.NoError(t, err1) - assert.Equal(t, "pod-list", output1) - - output2, err2 := RunCommandWithContext(ctx, "kubectl", []string{"get", "services"}) - assert.NoError(t, err2) - assert.Equal(t, "service-list", output2) - - output3, err3 := RunCommandWithContext(ctx, "helm", []string{"list"}) - assert.NoError(t, err3) - assert.Equal(t, "helm-releases", output3) - - // Verify all commands were logged - callLog := mock.GetCallLog() - require.Len(t, callLog, 3) - - assert.Equal(t, "kubectl", callLog[0].Command) - assert.Equal(t, []string{"get", "pods"}, callLog[0].Args) - - assert.Equal(t, "kubectl", callLog[1].Command) - assert.Equal(t, []string{"get", "services"}, callLog[1].Args) - - assert.Equal(t, "helm", callLog[2].Command) - assert.Equal(t, []string{"list"}, callLog[2].Args) - }) -} - -func TestShellToolWithMocking(t *testing.T) { - t.Run("shell tool uses mock executor", func(t *testing.T) { - mock := NewMockShellExecutor() - mock.AddCommandString("echo", []string{"hello", "world"}, "hello world", nil) - - ctx := WithShellExecutor(context.Background(), mock) - - params := shellParams{Command: "echo hello world"} - output, err := shellTool(ctx, params) - assert.NoError(t, err) - assert.Equal(t, "hello world", output) - - // Verify command was called - callLog := mock.GetCallLog() - require.Len(t, callLog, 1) - assert.Equal(t, "echo", callLog[0].Command) - assert.Equal(t, []string{"hello", "world"}, callLog[0].Args) - }) - - t.Run("shell tool with empty command", func(t *testing.T) { - mock := NewMockShellExecutor() - ctx := WithShellExecutor(context.Background(), mock) - - params := shellParams{Command: ""} - output, err := shellTool(ctx, params) - assert.Error(t, err) - assert.Contains(t, err.Error(), "empty command") - assert.Empty(t, output) - - // No commands should be logged - callLog := mock.GetCallLog() - assert.Len(t, callLog, 0) - }) -} - -func TestMockShellExecutorCommandKey(t *testing.T) { - mock := NewMockShellExecutor() - - // Test that different argument combinations create different keys - mock.AddCommandString("kubectl", []string{"get", "pods"}, "pods", nil) - mock.AddCommandString("kubectl", []string{"get", "services"}, "services", nil) - mock.AddCommandString("kubectl", []string{}, "kubectl-help", nil) - - // Test first command - output, err := mock.Exec(context.Background(), "kubectl", "get", "pods") - assert.NoError(t, err) - assert.Equal(t, "pods", string(output)) - - // Test second command - output, err = mock.Exec(context.Background(), "kubectl", "get", "services") - assert.NoError(t, err) - assert.Equal(t, "services", string(output)) - - // Test third command (no args) - output, err = mock.Exec(context.Background(), "kubectl") - assert.NoError(t, err) - assert.Equal(t, "kubectl-help", string(output)) -} - -// Benchmark tests to ensure mocking doesn't add significant overhead -func BenchmarkDefaultShellExecutor(b *testing.B) { - executor := &DefaultShellExecutor{} - ctx := context.Background() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = executor.Exec(ctx, "echo", "test") - } -} - -func BenchmarkMockShellExecutor(b *testing.B) { - mock := NewMockShellExecutor() - mock.AddCommandString("echo", []string{"test"}, "test", nil) - ctx := context.Background() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = mock.Exec(ctx, "echo", "test") - } -} - -func BenchmarkRunCommandWithContext(b *testing.B) { - mock := NewMockShellExecutor() - mock.AddCommandString("echo", []string{"test"}, "test", nil) - ctx := WithShellExecutor(context.Background(), mock) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = RunCommandWithContext(ctx, "echo", []string{"test"}) - } -} diff --git a/pkg/utils/datetime.go b/pkg/utils/datetime.go index 165ea68..3bca950 100644 --- a/pkg/utils/datetime.go +++ b/pkg/utils/datetime.go @@ -1,31 +1,5 @@ package utils -import ( - "context" - "github.com/kagent-dev/tools/pkg/logger" - "time" - - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" -) - -var kubeConfig = "" - -// DateTime tools using direct Go time package -// This implementation matches the Python version exactly -func handleGetCurrentDateTimeTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - // Returns the current date and time in ISO 8601 format (RFC3339) - // This matches the Python implementation: datetime.datetime.now().isoformat() - now := time.Now() - return mcp.NewToolResultText(now.Format(time.RFC3339)), nil -} - -func RegisterDateTimeTools(s *server.MCPServer, kubeconfig string) { - kubeConfig = kubeconfig - logger.Get().Info("kubeConfig", kubeConfig) - - // Register the GetCurrentDateTime tool to match Python implementation exactly - s.AddTool(mcp.NewTool("datetime_get_current_time", - mcp.WithDescription("Returns the current date and time in ISO 8601 format."), - ), handleGetCurrentDateTimeTool) -} +// DateTime tools implementation moved to RegisterTools function in common.go +// This file remains for backwards compatibility but the tools are now registered +// through the unified RegisterTools function.