diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json
index 9b7a19c..22b0a48 100644
--- a/.devcontainer/devcontainer.json
+++ b/.devcontainer/devcontainer.json
@@ -48,6 +48,9 @@
//forward the following ports
"forwardPorts": [8084],
+ //network
+ "network": "host",
+
//mount docker directly on the host
"mounts": ["source=/var/run/docker.sock,target=/var/run/docker.sock,type=bind"],
diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml
index a857e6e..4d1c4d6 100644
--- a/.github/workflows/ci.yaml
+++ b/.github/workflows/ci.yaml
@@ -53,4 +53,21 @@ jobs:
- name: Run cmd/main.go tests
working-directory: .
run: |
- go test -v ./...
+ make test
+
+ go-e2e-tests:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+
+ - name: Set up Go
+ uses: actions/setup-go@v4
+ with:
+ go-version: "1.24"
+ cache: true
+
+ - name: Run cmd/main.go tests
+ working-directory: .
+ run: |
+ make e2e
diff --git a/.gitignore b/.gitignore
index d146fbf..2496f99 100644
--- a/.gitignore
+++ b/.gitignore
@@ -6,4 +6,8 @@ bin/
.env.local
.env.development.local
.env.test.local
-.env.production.local
\ No newline at end of file
+.env.production.local
+/logs/
+/kagent-tools
+/*.out
+*.html
diff --git a/Makefile b/Makefile
index 90924db..1d9fdd1 100644
--- a/Makefile
+++ b/Makefile
@@ -13,6 +13,10 @@ LDFLAGS := -X github.com/kagent-dev/tools/internal/version.Version=$(VERSION) -X
## Location to install dependencies to
LOCALBIN ?= $(shell pwd)/bin
+.PHONY: clean
+clean:
+ rm -rf ./bin/kagent-tools-*
+
.PHONY: fmt
fmt: ## Run go fmt against code.
go fmt ./...
@@ -23,11 +27,11 @@ vet: ## Run go vet against code.
.PHONY: lint
lint: golangci-lint ## Run golangci-lint linter
- $(GOLANGCI_LINT) run
+ $(GOLANGCI_LINT) run --build-tags=test
.PHONY: lint-fix
lint-fix: golangci-lint ## Run golangci-lint linter and perform fixes
- $(GOLANGCI_LINT) run --fix
+ $(GOLANGCI_LINT) run --build-tags=test --fix
.PHONY: lint-config
lint-config: golangci-lint ## Verify golangci-lint linter configuration
@@ -43,8 +47,16 @@ tidy: ## Run go mod tidy to ensure dependencies are up to date.
go mod tidy
.PHONY: test
-test:
- go test -v -cover ./...
+test: build lint ## Run all tests with build, lint, and coverage
+ go test -tags=test -v -cover ./pkg/... ./internal/...
+
+.PHONY: test-only
+test-only: ## Run tests only (without build/lint for faster iteration)
+ go test -tags=test -v -cover ./pkg/... ./internal/...
+
+.PHONY: e2e
+e2e: test docker-build
+ go test -tags=test -v -cover ./e2e/...
bin/kagent-tools-linux-amd64:
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags "$(LDFLAGS)" -o bin/kagent-tools-linux-amd64 ./cmd
@@ -143,6 +155,12 @@ docker-build-all: DOCKER_BUILD_ARGS = --progress=plain --builder $(BUILDX_BUILDE
docker-build-all:
$(DOCKER_BUILDER) build $(DOCKER_BUILD_ARGS) $(TOOLS_IMAGE_BUILD_ARGS) -f Dockerfile ./
+.PHONY: kind-update-kagent
+kind-update-kagent: docker-build
+ kind get clusters | grep -q $(KIND_CLUSTER_NAME) || kind create cluster --name $(KIND_CLUSTER_NAME)
+ kind load docker-image --name $(KIND_CLUSTER_NAME) $(TOOLS_IMG)
+ kubectl patch --namespace kagent deployment/kagent --type='json' -p='[{"op": "replace", "path": "/spec/template/spec/containers/3/image", "value": "$(TOOLS_IMG)"}]'
+
## Tool Binaries
## Location to install dependencies t
diff --git a/cmd/main.go b/cmd/main.go
index 5c44309..6ea40ae 100644
--- a/cmd/main.go
+++ b/cmd/main.go
@@ -7,24 +7,29 @@ import (
"net/http"
"os"
"os/signal"
+ "runtime"
"strings"
"sync"
"syscall"
"time"
"github.com/joho/godotenv"
+ "github.com/kagent-dev/tools/internal/logger"
+ "github.com/kagent-dev/tools/internal/telemetry"
"github.com/kagent-dev/tools/internal/version"
- "github.com/kagent-dev/tools/pkg/logger"
- "github.com/kagent-dev/tools/pkg/utils"
-
"github.com/kagent-dev/tools/pkg/argo"
"github.com/kagent-dev/tools/pkg/cilium"
"github.com/kagent-dev/tools/pkg/helm"
"github.com/kagent-dev/tools/pkg/istio"
"github.com/kagent-dev/tools/pkg/k8s"
"github.com/kagent-dev/tools/pkg/prometheus"
- "github.com/mark3labs/mcp-go/server"
+ "github.com/kagent-dev/tools/pkg/utils"
"github.com/spf13/cobra"
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/attribute"
+ "go.opentelemetry.io/otel/codes"
+
+ "github.com/mark3labs/mcp-go/server"
)
var (
@@ -69,12 +74,36 @@ func run(cmd *cobra.Command, args []string) {
logger.Init()
defer logger.Sync()
- logger.Get().Info("Starting "+Name, "version", Version, "git_commit", GitCommit, "build_date", BuildDate)
-
// Setup context with cancellation for graceful shutdown
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
+ // Initialize OpenTelemetry tracing
+ cfg := telemetry.LoadOtelCfg()
+
+ err := telemetry.SetupOTelSDK(ctx)
+ if err != nil {
+ logger.Get().Error("Failed to setup OpenTelemetry SDK", "error", err)
+ os.Exit(1)
+ }
+
+ // Start root span for server lifecycle
+ tracer := otel.Tracer("kagent-tools/server")
+ ctx, rootSpan := tracer.Start(ctx, "server.lifecycle")
+ defer rootSpan.End()
+
+ rootSpan.SetAttributes(
+ attribute.String("server.name", Name),
+ attribute.String("server.version", cfg.Telemetry.ServiceVersion),
+ attribute.String("server.git_commit", GitCommit),
+ attribute.String("server.build_date", BuildDate),
+ attribute.Bool("server.stdio_mode", stdio),
+ attribute.Int("server.port", port),
+ attribute.StringSlice("server.tools", tools),
+ )
+
+ logger.Get().Info("Starting "+Name, "version", Version, "git_commit", GitCommit, "build_date", BuildDate)
+
mcp := server.NewMCPServer(
Name,
Version,
@@ -91,7 +120,7 @@ func run(cmd *cobra.Command, args []string) {
signal.Notify(signalChan, os.Interrupt, syscall.SIGTERM)
// HTTP server reference (only used when not in stdio mode)
- var sseServer *server.StreamableHTTPServer
+ var httpServer *http.Server
// Start server based on chosen mode
wg.Add(1)
@@ -101,16 +130,49 @@ func run(cmd *cobra.Command, args []string) {
runStdioServer(ctx, mcp)
}()
} else {
- sseServer = server.NewStreamableHTTPServer(mcp)
+ sseServer := server.NewStreamableHTTPServer(mcp)
+
+ // Create a mux to handle different routes
+ mux := http.NewServeMux()
+
+ // Add health endpoint
+ mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ if err := writeResponse(w, []byte("OK")); err != nil {
+ logger.Get().Error("Failed to write health response", "error", err)
+ }
+ })
+
+ // Add metrics endpoint (basic implementation for e2e tests)
+ mux.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/plain")
+ w.WriteHeader(http.StatusOK)
+
+ // Generate real runtime metrics instead of hardcoded values
+ metrics := generateRuntimeMetrics()
+ if err := writeResponse(w, []byte(metrics)); err != nil {
+ logger.Get().Error("Failed to write metrics response", "error", err)
+ }
+ })
+
+ // Handle all other routes with the MCP server wrapped in telemetry middleware
+ mux.Handle("/", telemetry.HTTPMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ sseServer.ServeHTTP(w, r)
+ })))
+
+ httpServer = &http.Server{
+ Addr: fmt.Sprintf(":%d", port),
+ Handler: mux,
+ }
+
go func() {
defer wg.Done()
- addr := fmt.Sprintf(":%d", port)
- logger.Get().Info("Running KAgent Tools Server", "port", addr, "tools", strings.Join(tools, ","))
- if err := sseServer.Start(addr); err != nil {
+ logger.Get().Info("Running KAgent Tools Server", "port", fmt.Sprintf(":%d", port), "tools", strings.Join(tools, ","))
+ if err := httpServer.ListenAndServe(); err != nil {
if !errors.Is(err, http.ErrServerClosed) {
- logger.Get().Error(err, "Failed to start SSE server")
+ logger.Get().Error("Failed to start HTTP server", "error", err)
} else {
- logger.Get().Info("SSE server closed gracefully.")
+ logger.Get().Info("HTTP server closed gracefully.")
}
}
}()
@@ -121,16 +183,23 @@ func run(cmd *cobra.Command, args []string) {
<-signalChan
logger.Get().Info("Received termination signal, shutting down server...")
+ // Mark root span as shutting down
+ rootSpan.AddEvent("server.shutdown.initiated")
+
// Cancel context to notify any context-aware operations
cancel()
// Gracefully shutdown HTTP server if running
- if !stdio && sseServer != nil {
+ if !stdio && httpServer != nil {
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer shutdownCancel()
- if err := sseServer.Shutdown(shutdownCtx); err != nil {
- logger.Get().Error(err, "Failed to shutdown server gracefully")
+ if err := httpServer.Shutdown(shutdownCtx); err != nil {
+ logger.Get().Error("Failed to shutdown server gracefully", "error", err)
+ rootSpan.RecordError(err)
+ rootSpan.SetStatus(codes.Error, "Server shutdown failed")
+ } else {
+ rootSpan.AddEvent("server.shutdown.completed")
}
}
}()
@@ -140,6 +209,53 @@ func run(cmd *cobra.Command, args []string) {
logger.Get().Info("Server shutdown complete")
}
+// writeResponse writes data to an HTTP response writer with proper error handling
+func writeResponse(w http.ResponseWriter, data []byte) error {
+ _, err := w.Write(data)
+ return err
+}
+
+// generateRuntimeMetrics generates real runtime metrics for the /metrics endpoint
+func generateRuntimeMetrics() string {
+ var m runtime.MemStats
+ runtime.ReadMemStats(&m)
+
+ now := time.Now().Unix()
+
+ // Build metrics in Prometheus format
+ metrics := strings.Builder{}
+
+ // Go runtime info
+ metrics.WriteString("# HELP go_info Information about the Go environment.\n")
+ metrics.WriteString("# TYPE go_info gauge\n")
+ metrics.WriteString(fmt.Sprintf("go_info{version=\"%s\"} 1\n", runtime.Version()))
+
+ // Process start time
+ metrics.WriteString("# HELP process_start_time_seconds Start time of the process since unix epoch in seconds.\n")
+ metrics.WriteString("# TYPE process_start_time_seconds gauge\n")
+ metrics.WriteString(fmt.Sprintf("process_start_time_seconds %d\n", now))
+
+ // Memory metrics
+ metrics.WriteString("# HELP go_memstats_alloc_bytes Number of bytes allocated and still in use.\n")
+ metrics.WriteString("# TYPE go_memstats_alloc_bytes gauge\n")
+ metrics.WriteString(fmt.Sprintf("go_memstats_alloc_bytes %d\n", m.Alloc))
+
+ metrics.WriteString("# HELP go_memstats_total_alloc_bytes Total number of bytes allocated, even if freed.\n")
+ metrics.WriteString("# TYPE go_memstats_total_alloc_bytes counter\n")
+ metrics.WriteString(fmt.Sprintf("go_memstats_total_alloc_bytes %d\n", m.TotalAlloc))
+
+ metrics.WriteString("# HELP go_memstats_sys_bytes Number of bytes obtained from system.\n")
+ metrics.WriteString("# TYPE go_memstats_sys_bytes gauge\n")
+ metrics.WriteString(fmt.Sprintf("go_memstats_sys_bytes %d\n", m.Sys))
+
+ // Goroutine count
+ metrics.WriteString("# HELP go_goroutines Number of goroutines that currently exist.\n")
+ metrics.WriteString("# TYPE go_goroutines gauge\n")
+ metrics.WriteString(fmt.Sprintf("go_goroutines %d\n", runtime.NumGoroutine()))
+
+ return metrics.String()
+}
+
func runStdioServer(ctx context.Context, mcp *server.MCPServer) {
logger.Get().Info("Running KAgent Tools Server STDIO:", "tools", strings.Join(tools, ","))
stdioServer := server.NewStdioServer(mcp)
@@ -149,39 +265,28 @@ func runStdioServer(ctx context.Context, mcp *server.MCPServer) {
}
func registerMCP(mcp *server.MCPServer, enabledToolProviders []string, kubeconfig string) {
-
- var toolProviderMap = map[string]func(*server.MCPServer, string){
- "utils": utils.RegisterDateTimeTools,
- "k8s": k8s.RegisterK8sTools,
- "prometheus": prometheus.RegisterPrometheusTools,
- "helm": helm.RegisterHelmTools,
- "istio": istio.RegisterIstioTools,
- "argo": argo.RegisterArgoTools,
- "cilium": cilium.RegisterCiliumTools,
+ // A map to hold tool providers and their registration functions
+ toolProviderMap := map[string]func(*server.MCPServer){
+ "argo": argo.RegisterTools,
+ "cilium": cilium.RegisterTools,
+ "helm": helm.RegisterTools,
+ "istio": istio.RegisterTools,
+ "k8s": func(s *server.MCPServer) { k8s.RegisterTools(s, nil, kubeconfig) },
+ "prometheus": prometheus.RegisterTools,
+ "utils": utils.RegisterTools,
}
- if len(kubeconfig) > 0 {
- logger.Get().Info("Using kubeconfig file", "path", kubeconfig)
- }
-
- // If no tools specified, register all tools
+ // If no specific tools are specified, register all available tools.
if len(enabledToolProviders) == 0 {
- logger.Get().Info("No specific tools provided, registering all tools")
- for toolProvider, registerFunc := range toolProviderMap {
- logger.Get().Info("Registering tools", "provider", toolProvider)
- registerFunc(mcp, kubeconfig)
+ for name := range toolProviderMap {
+ enabledToolProviders = append(enabledToolProviders, name)
}
- return
}
-
- // Register only the specified tools
- logger.Get().Info("provider list", "tools", enabledToolProviders)
for _, toolProviderName := range enabledToolProviders {
- if registerFunc, ok := toolProviderMap[strings.ToLower(toolProviderName)]; ok {
- logger.Get().Info("Registering tool", "provider", toolProviderName)
- registerFunc(mcp, kubeconfig)
+ if registerFunc, ok := toolProviderMap[toolProviderName]; ok {
+ registerFunc(mcp)
} else {
- logger.Get().Error(nil, "Unknown tool specified", "provider", toolProviderName)
+ logger.Get().Error("Unknown tool specified", "provider", toolProviderName)
}
}
}
diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go
new file mode 100644
index 0000000..ec04751
--- /dev/null
+++ b/e2e/e2e_test.go
@@ -0,0 +1,1005 @@
+package e2e
+
+import (
+ "bufio"
+ "context"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "runtime"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// getBinaryName returns the platform-specific binary name
+func getBinaryName() string {
+ osName := runtime.GOOS
+ archName := runtime.GOARCH
+ return fmt.Sprintf("kagent-tools-%s-%s", osName, archName)
+}
+
+// TestServerConfig holds configuration for server tests
+type TestServerConfig struct {
+ Port int
+ Tools []string
+ Kubeconfig string
+ Stdio bool
+ Timeout time.Duration
+}
+
+// ServerTestResult holds the result of a server test
+type ServerTestResult struct {
+ Output string
+ Error error
+ Duration time.Duration
+}
+
+// TestServer represents a test server instance
+type TestServer struct {
+ cmd *exec.Cmd
+ port int
+ stdio bool
+ cancel context.CancelFunc
+ done chan struct{}
+ output strings.Builder
+ mu sync.RWMutex
+}
+
+// NewTestServer creates a new test server instance
+func NewTestServer(config TestServerConfig) *TestServer {
+ return &TestServer{
+ port: config.Port,
+ stdio: config.Stdio,
+ done: make(chan struct{}),
+ }
+}
+
+// Start starts the test server
+func (ts *TestServer) Start(ctx context.Context, config TestServerConfig) error {
+ ts.mu.Lock()
+ defer ts.mu.Unlock()
+
+ // Build command arguments
+ args := []string{}
+ if config.Stdio {
+ args = append(args, "--stdio")
+ } else {
+ args = append(args, "--port", fmt.Sprintf("%d", config.Port))
+ }
+
+ if len(config.Tools) > 0 {
+ args = append(args, "--tools", strings.Join(config.Tools, ","))
+ }
+
+ if config.Kubeconfig != "" {
+ args = append(args, "--kubeconfig", config.Kubeconfig)
+ }
+
+ // Create context with cancellation
+ ctx, cancel := context.WithCancel(ctx)
+ ts.cancel = cancel
+
+ // Start server process
+ binaryName := getBinaryName()
+ ts.cmd = exec.CommandContext(ctx, fmt.Sprintf("../bin/%s", binaryName), args...)
+ ts.cmd.Env = append(os.Environ(), "LOG_LEVEL=debug")
+
+ // Set up output capture
+ stdout, err := ts.cmd.StdoutPipe()
+ if err != nil {
+ return fmt.Errorf("failed to create stdout pipe: %w", err)
+ }
+
+ stderr, err := ts.cmd.StderrPipe()
+ if err != nil {
+ return fmt.Errorf("failed to create stderr pipe: %w", err)
+ }
+
+ // Start the command
+ if err := ts.cmd.Start(); err != nil {
+ return fmt.Errorf("failed to start server: %w", err)
+ }
+
+ // Start goroutines to capture output
+ go ts.captureOutput(stdout, "STDOUT")
+ go ts.captureOutput(stderr, "STDERR")
+
+ // Wait for server to start
+ if !config.Stdio {
+ return ts.waitForHTTPServer(ctx, config.Timeout)
+ }
+
+ return nil
+}
+
+// Stop stops the test server
+func (ts *TestServer) Stop() error {
+ ts.mu.Lock()
+ defer ts.mu.Unlock()
+
+ if ts.cancel != nil {
+ ts.cancel()
+ }
+
+ if ts.cmd != nil && ts.cmd.Process != nil {
+ // Send interrupt signal for graceful shutdown
+ if err := ts.cmd.Process.Signal(os.Interrupt); err != nil {
+ // If interrupt fails, kill the process
+ _ = ts.cmd.Process.Kill()
+ }
+
+ // Wait for process to exit with timeout
+ done := make(chan error, 1)
+ go func() {
+ done <- ts.cmd.Wait()
+ }()
+
+ select {
+ case <-done:
+ // Process exited
+ case <-time.After(5 * time.Second):
+ // Timeout, force kill
+ _ = ts.cmd.Process.Kill()
+ }
+ }
+
+ close(ts.done)
+ return nil
+}
+
+// GetOutput returns the captured output
+func (ts *TestServer) GetOutput() string {
+ ts.mu.RLock()
+ defer ts.mu.RUnlock()
+ return ts.output.String()
+}
+
+// captureOutput captures output from the server
+func (ts *TestServer) captureOutput(reader io.Reader, prefix string) {
+ scanner := bufio.NewScanner(reader)
+ for scanner.Scan() {
+ line := scanner.Text()
+ ts.mu.Lock()
+ ts.output.WriteString(fmt.Sprintf("[%s] %s\n", prefix, line))
+ ts.mu.Unlock()
+ }
+}
+
+// waitForHTTPServer waits for the HTTP server to become available
+func (ts *TestServer) waitForHTTPServer(ctx context.Context, timeout time.Duration) error {
+ ctx, cancel := context.WithTimeout(ctx, timeout)
+ defer cancel()
+
+ url := fmt.Sprintf("http://localhost:%d/health", ts.port)
+ ticker := time.NewTicker(100 * time.Millisecond)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ return fmt.Errorf("timeout waiting for server to start")
+ case <-ticker.C:
+ resp, err := http.Get(url)
+ if err == nil {
+ resp.Body.Close()
+ if resp.StatusCode == http.StatusOK {
+ return nil
+ }
+ }
+ }
+ }
+}
+
+// TestHTTPServerStartup tests basic HTTP server startup and shutdown
+func TestHTTPServerStartup(t *testing.T) {
+ ctx := context.Background()
+
+ config := TestServerConfig{
+ Port: 8085,
+ Stdio: false,
+ Timeout: 30 * time.Second,
+ }
+
+ server := NewTestServer(config)
+
+ // Start server
+ err := server.Start(ctx, config)
+ require.NoError(t, err, "Server should start successfully")
+
+ // Wait a bit for server to be fully ready
+ time.Sleep(3 * time.Second)
+
+ // Test health endpoint
+ resp, err := http.Get(fmt.Sprintf("http://localhost:%d/health", config.Port))
+ require.NoError(t, err, "Health endpoint should be accessible")
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ resp.Body.Close()
+
+ // Check server output
+ output := server.GetOutput()
+ assert.Contains(t, output, "Running KAgent Tools Server")
+ assert.Contains(t, output, fmt.Sprintf(":%d", config.Port))
+
+ // Stop server
+ err = server.Stop()
+ require.NoError(t, err, "Server should stop gracefully")
+
+ // Verify server is stopped
+ time.Sleep(1 * time.Second)
+ _, err = http.Get(fmt.Sprintf("http://localhost:%d/health", config.Port))
+ assert.Error(t, err, "Server should not be accessible after stop")
+}
+
+// TestHTTPServerWithSpecificTools tests server with specific tools enabled
+func TestHTTPServerWithSpecificTools(t *testing.T) {
+ ctx := context.Background()
+
+ config := TestServerConfig{
+ Port: 8086,
+ Tools: []string{"utils", "k8s"},
+ Stdio: false,
+ Timeout: 30 * time.Second,
+ }
+
+ server := NewTestServer(config)
+
+ // Start server
+ err := server.Start(ctx, config)
+ require.NoError(t, err, "Server should start successfully")
+
+ // Wait for server to be ready
+ time.Sleep(3 * time.Second)
+
+ // Check server output for tool registration
+ output := server.GetOutput()
+ assert.Contains(t, output, "RegisterTools initialized", "Should register specified tools")
+ assert.Contains(t, output, "utils", "Should register utils tools")
+ assert.Contains(t, output, "k8s", "Should register k8s tools")
+
+ // Stop server
+ err = server.Stop()
+ require.NoError(t, err, "Server should stop gracefully")
+}
+
+// TestHTTPServerWithAllTools tests server with all tools enabled (default)
+func TestHTTPServerWithAllTools(t *testing.T) {
+ ctx := context.Background()
+
+ config := TestServerConfig{
+ Port: 8087,
+ Stdio: false,
+ Timeout: 30 * time.Second,
+ }
+
+ server := NewTestServer(config)
+
+ // Start server
+ err := server.Start(ctx, config)
+ require.NoError(t, err, "Server should start successfully")
+
+ // Wait for server to be ready
+ time.Sleep(3 * time.Second)
+
+ // Check server output for all tools registration
+ output := server.GetOutput()
+ assert.Contains(t, output, "RegisterTools initialized", "Should initialize RegisterTools")
+
+ // Verify server is running (tools are implicitly registered when no specific tools are provided)
+ assert.Contains(t, output, "Running KAgent Tools Server", "Should be running with all tools")
+
+ // Stop server
+ err = server.Stop()
+ require.NoError(t, err, "Server should stop gracefully")
+}
+
+// TestHTTPServerWithKubeconfig tests server with kubeconfig parameter
+func TestHTTPServerWithKubeconfig(t *testing.T) {
+ ctx := context.Background()
+
+ // Create a temporary kubeconfig file
+ tempDir := t.TempDir()
+ kubeconfigPath := filepath.Join(tempDir, "kubeconfig")
+
+ kubeconfigContent := `apiVersion: v1
+kind: Config
+clusters:
+- cluster:
+ server: https://test-cluster
+ name: test-cluster
+contexts:
+- context:
+ cluster: test-cluster
+ user: test-user
+ name: test-context
+current-context: test-context
+users:
+- name: test-user
+ user:
+ token: test-token
+`
+
+ err := os.WriteFile(kubeconfigPath, []byte(kubeconfigContent), 0644)
+ require.NoError(t, err, "Should create temporary kubeconfig file")
+
+ config := TestServerConfig{
+ Port: 8088,
+ Kubeconfig: kubeconfigPath,
+ Stdio: false,
+ Timeout: 30 * time.Second,
+ }
+
+ server := NewTestServer(config)
+
+ // Start server
+ err = server.Start(ctx, config)
+ require.NoError(t, err, "Server should start successfully")
+
+ // Wait for server to be ready
+ time.Sleep(3 * time.Second)
+
+ // Check server output for kubeconfig setting
+ output := server.GetOutput()
+ assert.Contains(t, output, "RegisterTools initialized", "Should initialize RegisterTools")
+ assert.Contains(t, output, "Running KAgent Tools Server", "Should be running with kubeconfig")
+
+ // Stop server
+ err = server.Stop()
+ require.NoError(t, err, "Server should stop gracefully")
+}
+
+// TestStdioServer tests STDIO server mode
+func TestStdioServer(t *testing.T) {
+ ctx := context.Background()
+
+ config := TestServerConfig{
+ Stdio: true,
+ Timeout: 30 * time.Second,
+ }
+
+ server := NewTestServer(config)
+
+ // Start server
+ err := server.Start(ctx, config)
+ require.NoError(t, err, "Server should start successfully")
+
+ // Wait for server to be ready
+ time.Sleep(3 * time.Second)
+
+ // Check server output for STDIO mode
+ output := server.GetOutput()
+ assert.Contains(t, output, "Running KAgent Tools Server STDIO")
+
+ // Stop server
+ err = server.Stop()
+ require.NoError(t, err, "Server should stop gracefully")
+}
+
+// TestServerGracefulShutdown tests graceful shutdown behavior
+func TestServerGracefulShutdown(t *testing.T) {
+ ctx := context.Background()
+
+ config := TestServerConfig{
+ Port: 8100,
+ Stdio: false,
+ Timeout: 30 * time.Second,
+ }
+
+ server := NewTestServer(config)
+
+ // Start server
+ err := server.Start(ctx, config)
+ require.NoError(t, err, "Server should start successfully")
+
+ // Wait for server to be ready
+ time.Sleep(3 * time.Second)
+
+ // Stop server and measure shutdown time
+ start := time.Now()
+ err = server.Stop()
+ duration := time.Since(start)
+
+ require.NoError(t, err, "Server should stop gracefully")
+ assert.Less(t, duration, 10*time.Second, "Shutdown should complete within reasonable time")
+
+ // Wait a bit for shutdown logs to be captured
+ time.Sleep(3 * time.Second)
+
+ // Check server output for graceful shutdown
+ output := server.GetOutput()
+ // The main test is that the server started successfully and stopped without error
+ assert.Contains(t, output, "Running KAgent Tools Server", "Server should have started successfully")
+
+ // Try to verify the server is actually stopped by attempting to connect
+ _, err = http.Get(fmt.Sprintf("http://localhost:%d/health", config.Port))
+ assert.Error(t, err, "Server should not be accessible after stop")
+}
+
+// TestServerWithInvalidTool tests server behavior with invalid tool names
+func TestServerWithInvalidTool(t *testing.T) {
+ ctx := context.Background()
+
+ config := TestServerConfig{
+ Port: 8090,
+ Tools: []string{"invalid-tool", "utils"},
+ Stdio: false,
+ Timeout: 30 * time.Second,
+ }
+
+ server := NewTestServer(config)
+
+ // Start server
+ err := server.Start(ctx, config)
+ require.NoError(t, err, "Server should start even with invalid tools")
+
+ // Wait for server to be ready
+ time.Sleep(3 * time.Second)
+
+ // Check server output for error about invalid tool
+ output := server.GetOutput()
+ assert.Contains(t, output, "Unknown tool specified")
+ assert.Contains(t, output, "invalid-tool")
+
+ // Valid tools should still be registered
+ assert.Contains(t, output, "RegisterTools initialized")
+ assert.Contains(t, output, "utils")
+
+ // Stop server
+ err = server.Stop()
+ require.NoError(t, err, "Server should stop gracefully")
+}
+
+// TestServerVersionAndBuildInfo tests server version and build information
+func TestServerVersionAndBuildInfo(t *testing.T) {
+ ctx := context.Background()
+
+ config := TestServerConfig{
+ Port: 8091,
+ Stdio: false,
+ Timeout: 30 * time.Second,
+ }
+
+ server := NewTestServer(config)
+
+ // Start server
+ err := server.Start(ctx, config)
+ require.NoError(t, err, "Server should start successfully")
+
+ // Wait for server to be ready
+ time.Sleep(3 * time.Second)
+
+ // Check server output for version information
+ output := server.GetOutput()
+ assert.Contains(t, output, "Starting kagent-tools-server")
+ assert.Contains(t, output, "version")
+
+ // Stop server
+ err = server.Stop()
+ require.NoError(t, err, "Server should stop gracefully")
+}
+
+// TestConcurrentServerInstances tests running multiple server instances
+func TestConcurrentServerInstances(t *testing.T) {
+ ctx := context.Background()
+
+ var wg sync.WaitGroup
+ numServers := 3
+ servers := make([]*TestServer, numServers)
+
+ // Start multiple servers on different ports
+ for i := 0; i < numServers; i++ {
+ wg.Add(1)
+ go func(index int) {
+ defer wg.Done()
+
+ config := TestServerConfig{
+ Port: 8092 + index,
+ Tools: []string{"utils"},
+ Stdio: false,
+ Timeout: 30 * time.Second,
+ }
+
+ server := NewTestServer(config)
+ servers[index] = server
+
+ err := server.Start(ctx, config)
+ assert.NoError(t, err, fmt.Sprintf("Server %d should start successfully", index))
+
+ // Wait for server to be ready
+ time.Sleep(3 * time.Second)
+
+ // Test health endpoint
+ resp, err := http.Get(fmt.Sprintf("http://localhost:%d/health", config.Port))
+ assert.NoError(t, err, fmt.Sprintf("Health endpoint should be accessible for server %d", index))
+ if resp != nil {
+ resp.Body.Close()
+ }
+ }(i)
+ }
+
+ wg.Wait()
+
+ // Stop all servers
+ for i, server := range servers {
+ if server != nil {
+ err := server.Stop()
+ assert.NoError(t, err, fmt.Sprintf("Server %d should stop gracefully", i))
+ }
+ }
+}
+
+// TestServerEnvironmentVariables tests server with environment variables
+func TestServerEnvironmentVariables(t *testing.T) {
+ ctx := context.Background()
+
+ // Set environment variables
+ originalEnv := os.Environ()
+ defer func() {
+ os.Clearenv()
+ for _, env := range originalEnv {
+ parts := strings.SplitN(env, "=", 2)
+ if len(parts) == 2 {
+ os.Setenv(parts[0], parts[1])
+ }
+ }
+ }()
+
+ os.Setenv("LOG_LEVEL", "info")
+ os.Setenv("OTEL_SERVICE_NAME", "test-kagent-tools")
+
+ config := TestServerConfig{
+ Port: 8095,
+ Stdio: false,
+ Timeout: 30 * time.Second,
+ }
+
+ server := NewTestServer(config)
+
+ // Start server
+ err := server.Start(ctx, config)
+ require.NoError(t, err, "Server should start successfully")
+
+ // Wait for server to be ready
+ time.Sleep(3 * time.Second)
+
+ // Check server output
+ output := server.GetOutput()
+ assert.Contains(t, output, "Starting kagent-tools-server")
+
+ // Stop server
+ err = server.Stop()
+ require.NoError(t, err, "Server should stop gracefully")
+}
+
+// TestServerBuildAndExecution tests that the server binary exists and is executable
+func TestServerBuildAndExecution(t *testing.T) {
+ // Check if server binary exists
+ binaryName := getBinaryName()
+ binaryPath := fmt.Sprintf("../bin/%s", binaryName)
+ _, err := os.Stat(binaryPath)
+ if os.IsNotExist(err) {
+ t.Skip("Server binary not found, skipping test. Run 'make build' first.")
+ }
+ require.NoError(t, err, "Server binary should exist")
+
+ // Test --help flag
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ cmd := exec.CommandContext(ctx, binaryPath, "--help")
+ output, err := cmd.CombinedOutput()
+ require.NoError(t, err, "Server should respond to --help flag")
+
+ outputStr := string(output)
+ assert.Contains(t, outputStr, "KAgent tool server")
+ assert.Contains(t, outputStr, "--port")
+ assert.Contains(t, outputStr, "--stdio")
+ assert.Contains(t, outputStr, "--tools")
+ assert.Contains(t, outputStr, "--kubeconfig")
+}
+
+// Benchmark tests
+func BenchmarkServerStartup(b *testing.B) {
+ ctx := context.Background()
+
+ for i := 0; i < b.N; i++ {
+ config := TestServerConfig{
+ Port: 8096 + i,
+ Stdio: false,
+ Timeout: 30 * time.Second,
+ }
+
+ server := NewTestServer(config)
+
+ start := time.Now()
+ err := server.Start(ctx, config)
+ if err != nil {
+ b.Fatalf("Server startup failed: %v", err)
+ }
+
+ // Wait for server to be ready
+ time.Sleep(1 * time.Second)
+
+ duration := time.Since(start)
+ b.ReportMetric(float64(duration.Nanoseconds()), "startup_time_ns")
+
+ // Stop server
+ _ = server.Stop()
+ }
+}
+
+// Helper functions for test setup
+func init() {
+ // Ensure the binary exists before running tests
+ binaryName := getBinaryName()
+ binaryPath := fmt.Sprintf("../bin/%s", binaryName)
+ if _, err := os.Stat(binaryPath); os.IsNotExist(err) {
+ // Try to build the binary
+ cmd := exec.Command("make", "build")
+ cmd.Dir = ".."
+ if err := cmd.Run(); err != nil {
+ panic(fmt.Sprintf("Failed to build server binary: %v", err))
+ }
+ }
+}
+
+// TestToolRegistrationValidation tests that tool registration works correctly
+func TestToolRegistrationValidation(t *testing.T) {
+ ctx := context.Background()
+
+ testCases := []struct {
+ name string
+ config TestServerConfig
+ expectedTools []string
+ shouldFail bool
+ }{
+ {
+ name: "Register single tool",
+ config: TestServerConfig{
+ Port: 8087,
+ Tools: []string{"k8s"},
+ Timeout: 30 * time.Second,
+ },
+ expectedTools: []string{"k8s"},
+ shouldFail: false,
+ },
+ {
+ name: "Register multiple tools",
+ config: TestServerConfig{
+ Port: 8088,
+ Tools: []string{"k8s", "prometheus", "utils"},
+ Timeout: 30 * time.Second,
+ },
+ expectedTools: []string{"k8s", "prometheus", "utils"},
+ shouldFail: false,
+ },
+ {
+ name: "Register invalid tool",
+ config: TestServerConfig{
+ Port: 8089,
+ Tools: []string{"invalid-tool"},
+ Timeout: 30 * time.Second,
+ },
+ shouldFail: false,
+ },
+ {
+ name: "Register all tools implicitly",
+ config: TestServerConfig{
+ Port: 8090,
+ Tools: []string{},
+ Timeout: 30 * time.Second,
+ },
+ expectedTools: []string{"utils", "k8s", "prometheus", "helm", "istio", "argo", "cilium"},
+ shouldFail: false,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ server := NewTestServer(tc.config)
+ err := server.Start(ctx, tc.config)
+
+ if tc.shouldFail {
+ require.Error(t, err, "Server should fail to start with invalid configuration")
+ return
+ }
+
+ require.NoError(t, err, "Server should start successfully")
+ defer func() {
+ if err := server.Stop(); err != nil {
+ t.Errorf("Failed to stop server: %v", err)
+ }
+ }()
+
+ // Wait for server to be ready
+ time.Sleep(3 * time.Second)
+
+ // Verify registered tools
+ output := server.GetOutput()
+
+ // Special handling for invalid tool test case
+ if tc.name == "Register invalid tool" {
+ assert.Contains(t, output, "Unknown tool specified", "Should warn about invalid tool")
+ assert.Contains(t, output, "invalid-tool", "Should mention the invalid tool name")
+ } else {
+ if tc.name == "Register all tools implicitly" {
+ // For implicit all tools registration, check for RegisterTools initialized
+ assert.Contains(t, output, "RegisterTools initialized", "Should initialize RegisterTools")
+ // Don't check for individual tool names as they're not logged individually
+ assert.Contains(t, output, "Running KAgent Tools Server", "Should be running with all tools")
+ } else {
+ // For specific tools, check for Running server message and tool names
+ assert.Contains(t, output, "Running KAgent Tools Server", "Should be running server")
+ for _, tool := range tc.expectedTools {
+ assert.Contains(t, output, tool, fmt.Sprintf("Should register %s tool", tool))
+ }
+ }
+ }
+
+ // Test health endpoint
+ resp, err := http.Get(fmt.Sprintf("http://localhost:%d/health", tc.config.Port))
+ require.NoError(t, err, "Health endpoint should be accessible")
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ resp.Body.Close()
+ })
+ }
+}
+
+// TestToolExecutionFlow tests the complete flow of tool execution
+func TestToolExecutionFlow(t *testing.T) {
+ ctx := context.Background()
+
+ config := TestServerConfig{
+ Port: 8091,
+ Tools: []string{"utils"},
+ Timeout: 30 * time.Second,
+ }
+
+ server := NewTestServer(config)
+ err := server.Start(ctx, config)
+ require.NoError(t, err, "Server should start successfully")
+ defer func() {
+ if err := server.Stop(); err != nil {
+ t.Errorf("Failed to stop server: %v", err)
+ }
+ }()
+
+ // Wait for server to be ready
+ time.Sleep(3 * time.Second)
+
+ // Test health endpoint (MCP server doesn't have REST endpoints for tool execution)
+ resp, err := http.Get(fmt.Sprintf("http://localhost:%d/health", config.Port))
+ require.NoError(t, err, "Should execute request successfully")
+ defer resp.Body.Close()
+
+ // Check response
+ assert.Equal(t, http.StatusOK, resp.StatusCode, "Should return OK status")
+
+ // Read response body
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err, "Should read response body")
+
+ // Response should contain "OK"
+ assert.Equal(t, "OK", string(body), "Should return OK response")
+}
+
+// TestServerTelemetry tests that telemetry is properly initialized and working
+func TestServerTelemetry(t *testing.T) {
+ ctx := context.Background()
+
+ config := TestServerConfig{
+ Port: 8092,
+ Tools: []string{"utils"},
+ Timeout: 30 * time.Second,
+ }
+
+ // Set test environment variables for telemetry
+ os.Setenv("OTEL_SERVICE_NAME", "kagent-tools-test")
+ os.Setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "localhost:4317")
+ defer os.Unsetenv("OTEL_SERVICE_NAME")
+ defer os.Unsetenv("OTEL_EXPORTER_OTLP_ENDPOINT")
+
+ server := NewTestServer(config)
+ err := server.Start(ctx, config)
+ require.NoError(t, err, "Server should start successfully")
+ defer func() {
+ if err := server.Stop(); err != nil {
+ t.Errorf("Failed to stop server: %v", err)
+ }
+ }()
+
+ // Wait for server to be ready
+ time.Sleep(3 * time.Second)
+
+ // Check server output for telemetry initialization
+ output := server.GetOutput()
+ assert.Contains(t, output, "Starting kagent-tools-server", "Server should start with telemetry")
+
+ // Make a request to generate telemetry
+ resp, err := http.Get(fmt.Sprintf("http://localhost:%d/health", config.Port))
+ require.NoError(t, err, "Health endpoint should be accessible")
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ resp.Body.Close()
+
+ // Check server output for successful startup (telemetry is initialized internally)
+ output = server.GetOutput()
+ assert.Contains(t, output, "Running KAgent Tools Server", "Server should be running with telemetry enabled")
+}
+
+// TestToolRegistrationWithInvalidNames tests server behavior with invalid tool names
+func TestToolRegistrationWithInvalidNames(t *testing.T) {
+ ctx := context.Background()
+
+ config := TestServerConfig{
+ Port: 8087,
+ Tools: []string{"invalid-tool", "not-exists", "k8s"},
+ Stdio: false,
+ Timeout: 30 * time.Second,
+ }
+
+ server := NewTestServer(config)
+ err := server.Start(ctx, config)
+ require.NoError(t, err, "Server should start successfully despite invalid tools")
+
+ // Wait for server to be ready
+ time.Sleep(3 * time.Second)
+
+ // Check server output for warning messages about invalid tools
+ output := server.GetOutput()
+ assert.Contains(t, output, "Unknown tool specified")
+ assert.Contains(t, output, "invalid-tool")
+ assert.Contains(t, output, "not-exists")
+
+ // Verify that valid tools were still registered
+ assert.Contains(t, output, "Running KAgent Tools Server")
+ assert.Contains(t, output, "k8s")
+
+ err = server.Stop()
+ require.NoError(t, err, "Server should stop gracefully")
+}
+
+// TestConcurrentToolExecution tests concurrent tool execution
+func TestConcurrentToolExecution(t *testing.T) {
+ ctx := context.Background()
+
+ config := TestServerConfig{
+ Port: 8088,
+ Tools: []string{"utils", "k8s"},
+ Stdio: false,
+ Timeout: 30 * time.Second,
+ }
+
+ server := NewTestServer(config)
+ err := server.Start(ctx, config)
+ require.NoError(t, err, "Server should start successfully")
+
+ // Wait for server to be ready
+ time.Sleep(3 * time.Second)
+
+ // Create multiple concurrent requests
+ var wg sync.WaitGroup
+ for i := 0; i < 10; i++ {
+ wg.Add(1)
+ go func(id int) {
+ defer wg.Done()
+ resp, err := http.Get(fmt.Sprintf("http://localhost:%d/health", config.Port))
+ require.NoError(t, err, "Concurrent request %d should succeed", id)
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ resp.Body.Close()
+ }(i)
+ }
+
+ wg.Wait()
+ err = server.Stop()
+ require.NoError(t, err, "Server should stop gracefully")
+}
+
+// TestServerErrorHandling tests server's error handling capabilities
+func TestServerErrorHandling(t *testing.T) {
+ ctx := context.Background()
+
+ config := TestServerConfig{
+ Port: 8089,
+ Tools: []string{"utils"},
+ Stdio: false,
+ Timeout: 30 * time.Second,
+ }
+
+ server := NewTestServer(config)
+ err := server.Start(ctx, config)
+ require.NoError(t, err, "Server should start successfully")
+
+ // Wait for server to be ready
+ time.Sleep(3 * time.Second)
+
+ // Test malformed request
+ req, err := http.NewRequest("POST", fmt.Sprintf("http://localhost:%d/nonexistent", config.Port), strings.NewReader("invalid json"))
+ require.NoError(t, err)
+ req.Header.Set("Content-Type", "application/json")
+
+ client := &http.Client{}
+ resp, err := client.Do(req)
+ require.NoError(t, err)
+ assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
+ resp.Body.Close()
+
+ err = server.Stop()
+ require.NoError(t, err, "Server should stop gracefully")
+}
+
+// TestServerMetricsEndpoint tests the metrics endpoint functionality
+func TestServerMetricsEndpoint(t *testing.T) {
+ ctx := context.Background()
+
+ config := TestServerConfig{
+ Port: 8090,
+ Tools: []string{"utils"},
+ Stdio: false,
+ Timeout: 30 * time.Second,
+ }
+
+ server := NewTestServer(config)
+ err := server.Start(ctx, config)
+ require.NoError(t, err, "Server should start successfully")
+
+ // Wait for server to be ready
+ time.Sleep(3 * time.Second)
+
+ // Test metrics endpoint
+ resp, err := http.Get(fmt.Sprintf("http://localhost:%d/metrics", config.Port))
+ require.NoError(t, err, "Metrics endpoint should be accessible")
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+
+ // Read and verify metrics content
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+ resp.Body.Close()
+
+ metricsContent := string(body)
+ assert.Contains(t, metricsContent, "go_")
+ assert.Contains(t, metricsContent, "process_")
+
+ err = server.Stop()
+ require.NoError(t, err, "Server should stop gracefully")
+}
+
+// TestToolSpecificFunctionality tests specific functionality of registered tools
+func TestToolSpecificFunctionality(t *testing.T) {
+ ctx := context.Background()
+
+ config := TestServerConfig{
+ Port: 8091,
+ Tools: []string{"utils", "k8s"},
+ Stdio: false,
+ Timeout: 30 * time.Second,
+ }
+
+ server := NewTestServer(config)
+ err := server.Start(ctx, config)
+ require.NoError(t, err, "Server should start successfully")
+
+ // Wait for server to be ready
+ time.Sleep(3 * time.Second)
+
+ // Test utils tool endpoint
+ resp, err := http.Get(fmt.Sprintf("http://localhost:%d/health", config.Port))
+ require.NoError(t, err)
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+ resp.Body.Close()
+
+ // Verify response format matches expected OK response
+ assert.Equal(t, "OK", string(body), "Should return OK response")
+
+ err = server.Stop()
+ require.NoError(t, err, "Server should stop gracefully")
+}
diff --git a/go.mod b/go.mod
index 08a9e04..220ea7f 100644
--- a/go.mod
+++ b/go.mod
@@ -1,67 +1,47 @@
module github.com/kagent-dev/tools
-go 1.24.4
+go 1.24.5
require (
github.com/go-logr/logr v1.4.3
github.com/go-logr/stdr v1.2.2
github.com/joho/godotenv v1.5.1
- github.com/kagent-dev/kagent/go v0.0.0-20250707014726-aa7651a0e4e3
github.com/mark3labs/mcp-go v0.32.0
github.com/spf13/cobra v1.9.1
github.com/stretchr/testify v1.10.0
github.com/tmc/langchaingo v0.1.13
go.opentelemetry.io/otel v1.37.0
+ go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.34.0
+ go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.34.0
go.opentelemetry.io/otel/metric v1.37.0
+ go.opentelemetry.io/otel/sdk v1.37.0
+ go.opentelemetry.io/otel/trace v1.37.0
)
require (
+ github.com/cenkalti/backoff/v4 v4.3.0 // indirect
+ github.com/cenkalti/backoff/v5 v5.0.2 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
- github.com/emicklei/go-restful/v3 v3.12.2 // indirect
- github.com/fxamacker/cbor/v2 v2.8.0 // indirect
- github.com/go-openapi/jsonpointer v0.21.1 // indirect
- github.com/go-openapi/jsonreference v0.21.0 // indirect
- github.com/go-openapi/swag v0.23.1 // indirect
- github.com/gogo/protobuf v1.3.2 // indirect
- github.com/google/gnostic-models v0.6.9 // indirect
- github.com/google/go-cmp v0.7.0 // indirect
github.com/google/uuid v1.6.0 // indirect
+ github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.1 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
- github.com/josharian/intern v1.0.0 // indirect
- github.com/json-iterator/go v1.1.12 // indirect
- github.com/mailru/easyjson v0.9.0 // indirect
- github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
- github.com/modern-go/reflect2 v1.0.2 // indirect
- github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
- github.com/pkg/errors v0.9.1 // indirect
github.com/pkoukk/tiktoken-go v0.1.6 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/spf13/cast v1.9.2 // indirect
github.com/spf13/pflag v1.0.6 // indirect
- github.com/x448/float16 v0.8.4 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
- go.opentelemetry.io/otel/trace v1.37.0 // indirect
- go.uber.org/automaxprocs v1.6.0 // indirect
+ go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.37.0 // indirect
+ go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.37.0 // indirect
+ go.opentelemetry.io/proto/otlp v1.7.0 // indirect
golang.org/x/net v0.41.0 // indirect
- golang.org/x/oauth2 v0.30.0 // indirect
golang.org/x/sys v0.33.0 // indirect
- golang.org/x/term v0.32.0 // indirect
golang.org/x/text v0.26.0 // indirect
- golang.org/x/time v0.12.0 // indirect
+ google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 // indirect
+ google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect
+ google.golang.org/grpc v1.73.0 // indirect
google.golang.org/protobuf v1.36.6 // indirect
- gopkg.in/evanphx/json-patch.v4 v4.12.0 // indirect
- gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
- k8s.io/api v0.33.2 // indirect
- k8s.io/apimachinery v0.33.2 // indirect
- k8s.io/client-go v0.33.2 // indirect
- k8s.io/klog/v2 v2.130.1 // indirect
- k8s.io/kube-openapi v0.0.0-20250610211856-8b98d1ed966a // indirect
- k8s.io/utils v0.0.0-20250604170112-4c0f3b243397 // indirect
- sigs.k8s.io/json v0.0.0-20241014173422-cfa47c3a1cc8 // indirect
- sigs.k8s.io/randfill v1.0.0 // indirect
- sigs.k8s.io/structured-merge-diff/v4 v4.6.0 // indirect
sigs.k8s.io/yaml v1.4.0 // indirect
)
diff --git a/go.sum b/go.sum
index 15455ac..3a1cf49 100644
--- a/go.sum
+++ b/go.sum
@@ -1,77 +1,40 @@
+github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
+github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
+github.com/cenkalti/backoff/v5 v5.0.2 h1:rIfFVxEf1QsI7E1ZHfp/B4DF/6QBAUhmgkxc0H7Zss8=
+github.com/cenkalti/backoff/v5 v5.0.2/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
-github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
-github.com/emicklei/go-restful/v3 v3.12.2 h1:DhwDP0vY3k8ZzE0RunuJy8GhNpPL6zqLkDf9B/a0/xU=
-github.com/emicklei/go-restful/v3 v3.12.2/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
-github.com/fxamacker/cbor/v2 v2.8.0 h1:fFtUGXUzXPHTIUdne5+zzMPTfffl3RD5qYnkY40vtxU=
-github.com/fxamacker/cbor/v2 v2.8.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
-github.com/go-openapi/jsonpointer v0.21.1 h1:whnzv/pNXtK2FbX/W9yJfRmE2gsmkfahjMKB0fZvcic=
-github.com/go-openapi/jsonpointer v0.21.1/go.mod h1:50I1STOfbY1ycR8jGz8DaMeLCdXiI6aDteEdRNNzpdk=
-github.com/go-openapi/jsonreference v0.21.0 h1:Rs+Y7hSXT83Jacb7kFyjn4ijOuVGSvOdF2+tg1TRrwQ=
-github.com/go-openapi/jsonreference v0.21.0/go.mod h1:LmZmgsrTkVg9LG4EaHeY8cBDslNPMo06cago5JNLkm4=
-github.com/go-openapi/swag v0.23.1 h1:lpsStH0n2ittzTnbaSloVZLuB5+fvSY/+hnagBjSNZU=
-github.com/go-openapi/swag v0.23.1/go.mod h1:STZs8TbRvEQQKUA+JZNAm3EWlgaOBGpyFDqQnDHMef0=
-github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI=
-github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8=
-github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
-github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
-github.com/google/gnostic-models v0.6.9 h1:MU/8wDLif2qCXZmzncUQ/BOfxWfthHi63KqpoNbWqVw=
-github.com/google/gnostic-models v0.6.9/go.mod h1:CiWsm0s6BSQd1hRn8/QmxqB6BesYcbSZxsz9b0KuDBw=
+github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
+github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
-github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
-github.com/google/pprof v0.0.0-20250607225305-033d6d78b36a h1://KbezygeMJZCSHH+HgUZiTeSoiuFspbMg1ge+eFj18=
-github.com/google/pprof v0.0.0-20250607225305-033d6d78b36a/go.mod h1:5hDyRhoBCxViHszMt12TnOpEI4VVi+U8Gm9iphldiMA=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.1 h1:X5VWvz21y3gzm9Nw/kaUeku/1+uBhcekkmy4IkffJww=
+github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.1/go.mod h1:Zanoh4+gvIgluNqcfMVTJueD4wSS5hT7zTt4Mrutd90=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
-github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
-github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
-github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
-github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
-github.com/kagent-dev/kagent/go v0.0.0-20250707014726-aa7651a0e4e3 h1:B5EkhSmYMG6bgn7DTsOfhal8sl1MmhjixSXP1PP/jNw=
-github.com/kagent-dev/kagent/go v0.0.0-20250707014726-aa7651a0e4e3/go.mod h1:hwTH7K+UkePRxA6DhXOXavNyXRK3nPmvipA07DSRUxI=
-github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
-github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
-github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4=
-github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
github.com/mark3labs/mcp-go v0.32.0 h1:fgwmbfL2gbd67obg57OfV2Dnrhs1HtSdlY/i5fn7MU8=
github.com/mark3labs/mcp-go v0.32.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4=
-github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
-github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
-github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
-github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
-github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
-github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
-github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
-github.com/onsi/ginkgo/v2 v2.23.4 h1:ktYTpKJAVZnDT4VjxSbiBenUjmlL/5QkBEocaWXiQus=
-github.com/onsi/ginkgo/v2 v2.23.4/go.mod h1:Bt66ApGPBFzHyR+JO10Zbt0Gsp4uWxu5mIOTusL46e8=
-github.com/onsi/gomega v1.37.0 h1:CdEG8g0S133B4OswTDC/5XPSzE1OeP29QOioj2PID2Y=
-github.com/onsi/gomega v1.37.0/go.mod h1:8D9+Txp43QWKhM24yyOBEdpkzN8FvJyAwecBgsU4KU0=
-github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
-github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw=
github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
-github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
@@ -83,98 +46,54 @@ github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo=
github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0=
github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o=
github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
-github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
-github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
-github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
-github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tmc/langchaingo v0.1.13 h1:rcpMWBIi2y3B90XxfE4Ao8dhCQPVDMaNPnN5cGB1CaA=
github.com/tmc/langchaingo v0.1.13/go.mod h1:vpQ5NOIhpzxDfTZK9B6tf2GM/MoaHewPWM5KXXGh7hg=
-github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
-github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
-github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
-github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ=
go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I=
+go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.37.0 h1:Ahq7pZmv87yiyn3jeFz/LekZmPLLdKejuO3NcK9MssM=
+go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.37.0/go.mod h1:MJTqhM0im3mRLw1i8uGHnCvUEeS7VwRyxlLC78PA18M=
+go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.37.0 h1:EtFWSnwW9hGObjkIdmlnWSydO+Qs8OwzfzXLUPg4xOc=
+go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.37.0/go.mod h1:QjUEoiGCPkvFZ/MjK6ZZfNOS6mfVEVKYE99dFhuN2LI=
+go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.34.0 h1:BEj3SPM81McUZHYjRS5pEgNgnmzGJ5tRpU5krWnV8Bs=
+go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.34.0/go.mod h1:9cKLGBDzI/F3NoHLQGm4ZrYdIHsvGt6ej6hUowxY0J4=
+go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.34.0 h1:jBpDk4HAUsrnVO1FsfCfCOTEc/MkInJmvfCHYLFiT80=
+go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.34.0/go.mod h1:H9LUIM1daaeZaz91vZcfeM0fejXPmgCYE8ZhzqfJuiU=
go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE=
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
+go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI=
+go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg=
+go.opentelemetry.io/otel/sdk/metric v1.35.0 h1:1RriWBmCKgkeHEhM7a2uMjMUfP7MsOF5JpUCaEqEI9o=
+go.opentelemetry.io/otel/sdk/metric v1.35.0/go.mod h1:is6XYCUMpcKi+ZsOvfluY5YstFnhW0BidkR+gL+qN+w=
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
-go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs=
-go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8=
-golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
-golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
-golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
-golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
-golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
-golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
-golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
+go.opentelemetry.io/proto/otlp v1.7.0 h1:jX1VolD6nHuFzOYso2E73H85i92Mv8JQYk0K9vz09os=
+go.opentelemetry.io/proto/otlp v1.7.0/go.mod h1:fSKjH6YJ7HDlwzltzyMj036AJ3ejJLCgCSHGj4efDDo=
+go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
+go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
-golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
-golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
-golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
-golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
-golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg=
-golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ=
-golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
-golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
-golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
-golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
-golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
-golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
-golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
-golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
-golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
-golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
-golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
-golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
-golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
-golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 h1:oWVWY3NzT7KJppx2UKhKmzPq4SRe0LdCijVRwvGeikY=
+google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822/go.mod h1:h3c4v36UTKzUiuaOKQ6gr3S+0hovBtUrXzTG/i3+XEc=
+google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 h1:fc6jSaCT0vBduLYZHYrBBNY4dsWuvgyff9noRNDdBeE=
+google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
+google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok=
+google.golang.org/grpc v1.73.0/go.mod h1:50sbHOUqWoCQGI8V2HQLJM0B+LMlIUjNSZmow7EVBQc=
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
-gopkg.in/evanphx/json-patch.v4 v4.12.0 h1:n6jtcsulIzXPJaxegRbvFNNrZDjbij7ny3gmSPG+6V4=
-gopkg.in/evanphx/json-patch.v4 v4.12.0/go.mod h1:p8EYWUEYMpynmqDbY58zCKCFZw8pRWMG4EsWvDvM72M=
-gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
-gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
-k8s.io/api v0.33.2 h1:YgwIS5jKfA+BZg//OQhkJNIfie/kmRsO0BmNaVSimvY=
-k8s.io/api v0.33.2/go.mod h1:fhrbphQJSM2cXzCWgqU29xLDuks4mu7ti9vveEnpSXs=
-k8s.io/apimachinery v0.33.2 h1:IHFVhqg59mb8PJWTLi8m1mAoepkUNYmptHsV+Z1m5jY=
-k8s.io/apimachinery v0.33.2/go.mod h1:BHW0YOu7n22fFv/JkYOEfkUYNRN0fj0BlvMFWA7b+SM=
-k8s.io/client-go v0.33.2 h1:z8CIcc0P581x/J1ZYf4CNzRKxRvQAwoAolYPbtQes+E=
-k8s.io/client-go v0.33.2/go.mod h1:9mCgT4wROvL948w6f6ArJNb7yQd7QsvqavDeZHvNmHo=
-k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk=
-k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE=
-k8s.io/kube-openapi v0.0.0-20250610211856-8b98d1ed966a h1:ZV3Zr+/7s7aVbjNGICQt+ppKWsF1tehxggNfbM7XnG8=
-k8s.io/kube-openapi v0.0.0-20250610211856-8b98d1ed966a/go.mod h1:5jIi+8yX4RIb8wk3XwBo5Pq2ccx4FP10ohkbSKCZoK8=
-k8s.io/utils v0.0.0-20250604170112-4c0f3b243397 h1:hwvWFiBzdWw1FhfY1FooPn3kzWuJ8tmbZBHi4zVsl1Y=
-k8s.io/utils v0.0.0-20250604170112-4c0f3b243397/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0=
-sigs.k8s.io/json v0.0.0-20241014173422-cfa47c3a1cc8 h1:gBQPwqORJ8d8/YNZWEjoZs7npUVDpVXUUOFfW6CgAqE=
-sigs.k8s.io/json v0.0.0-20241014173422-cfa47c3a1cc8/go.mod h1:mdzfpAEoE6DHQEN0uh9ZbOCuHbLK5wOm7dK4ctXE9Tg=
-sigs.k8s.io/randfill v0.0.0-20250304075658-069ef1bbf016/go.mod h1:XeLlZ/jmk4i1HRopwe7/aU3H5n1zNUcX6TM94b3QxOY=
-sigs.k8s.io/randfill v1.0.0 h1:JfjMILfT8A6RbawdsK2JXGBR5AQVfd+9TbzrlneTyrU=
-sigs.k8s.io/randfill v1.0.0/go.mod h1:XeLlZ/jmk4i1HRopwe7/aU3H5n1zNUcX6TM94b3QxOY=
-sigs.k8s.io/structured-merge-diff/v4 v4.6.0 h1:IUA9nvMmnKWcj5jl84xn+T5MnlZKThmUW1TdblaLVAc=
-sigs.k8s.io/structured-merge-diff/v4 v4.6.0/go.mod h1:dDy58f92j70zLsuZVuUX5Wp9vtxXpaZnkPGWeqDfCps=
sigs.k8s.io/yaml v1.4.0 h1:Mk1wCc2gy/F0THH0TAp1QYyJNzRm2KCLy3o5ASXVI5E=
sigs.k8s.io/yaml v1.4.0/go.mod h1:Ejl7/uTz7PSA4eKMyQCUTnhZYNmLIl+5c2lQPGR2BPY=
diff --git a/internal/cache/cache.go b/internal/cache/cache.go
new file mode 100644
index 0000000..4c2c105
--- /dev/null
+++ b/internal/cache/cache.go
@@ -0,0 +1,545 @@
+package cache
+
+import (
+ "context"
+ "fmt"
+ "sync"
+ "time"
+
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/attribute"
+ "go.opentelemetry.io/otel/metric"
+
+ "github.com/kagent-dev/tools/internal/logger"
+ "github.com/kagent-dev/tools/internal/telemetry"
+)
+
+// CacheType represents the type of cache using enum pattern
+type CacheType int
+
+const (
+ CacheTypeKubernetes CacheType = iota
+ CacheTypeCommand
+ CacheTypeHelm
+ CacheTypeIstio
+)
+
+// String returns the string representation of CacheType
+func (ct CacheType) String() string {
+ switch ct {
+ case CacheTypeKubernetes:
+ return "kubernetes"
+ case CacheTypeCommand:
+ return "command"
+ case CacheTypeHelm:
+ return "helm"
+ case CacheTypeIstio:
+ return "istio"
+ default:
+ return "unknown"
+ }
+}
+
+// Command to cache type mapping
+var commandToCacheType = map[string]CacheType{
+ "kubectl": CacheTypeKubernetes,
+ "helm": CacheTypeHelm,
+ "istioctl": CacheTypeIstio,
+ "cilium": CacheTypeCommand, // Use command cache for cilium
+ "argo": CacheTypeCommand, // Use command cache for argo
+}
+
+// CacheEntry represents a cached item with TTL
+type CacheEntry[T any] struct {
+ Value T
+ CreatedAt time.Time
+ ExpiresAt time.Time
+ AccessedAt time.Time
+ AccessCount int64
+}
+
+// IsExpired checks if the cache entry has expired
+func (e *CacheEntry[T]) IsExpired() bool {
+ return time.Now().After(e.ExpiresAt)
+}
+
+// Cache is a thread-safe cache with TTL support
+type Cache[T any] struct {
+ mu sync.RWMutex
+ data map[string]*CacheEntry[T]
+ name string
+ defaultTTL time.Duration
+ maxSize int
+ cleanupInterval time.Duration
+ stopCleanup chan struct{}
+
+ // Metrics
+ hits metric.Int64Counter
+ misses metric.Int64Counter
+ evictions metric.Int64Counter
+ size metric.Int64UpDownCounter
+}
+
+// NewCache creates a new cache with specified configuration and name
+func NewCache[T any](name string, defaultTTL time.Duration, maxSize int, cleanupInterval time.Duration) *Cache[T] {
+ meter := otel.Meter(fmt.Sprintf("kagent-tools/cache/%s", name))
+
+ // Create metrics with cache name as a label
+ hits, _ := meter.Int64Counter(
+ "cache_hits_total",
+ metric.WithDescription("Total number of cache hits"),
+ )
+
+ misses, _ := meter.Int64Counter(
+ "cache_misses_total",
+ metric.WithDescription("Total number of cache misses"),
+ )
+
+ evictions, _ := meter.Int64Counter(
+ "cache_evictions_total",
+ metric.WithDescription("Total number of cache evictions"),
+ )
+
+ size, _ := meter.Int64UpDownCounter(
+ "cache_size",
+ metric.WithDescription("Current number of items in cache"),
+ )
+
+ cache := &Cache[T]{
+ data: make(map[string]*CacheEntry[T]),
+ name: name,
+ defaultTTL: defaultTTL,
+ maxSize: maxSize,
+ cleanupInterval: cleanupInterval,
+ stopCleanup: make(chan struct{}),
+ hits: hits,
+ misses: misses,
+ evictions: evictions,
+ size: size,
+ }
+
+ // Start background cleanup
+ go cache.cleanupExpired()
+
+ return cache
+}
+
+// Get retrieves a value from the cache
+func (c *Cache[T]) Get(key string) (T, bool) {
+ ctx := context.Background()
+ _, span := telemetry.StartSpan(ctx, "cache.get",
+ attribute.String("cache.name", c.name),
+ attribute.String("cache.key", key),
+ )
+ defer span.End()
+
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ entry, exists := c.data[key]
+ if !exists {
+ var zero T
+ c.recordMiss(key)
+ telemetry.AddEvent(span, "cache.miss",
+ attribute.String("cache.result", "miss"),
+ )
+ span.SetAttributes(attribute.String("cache.result", "miss"))
+ return zero, false
+ }
+
+ if entry.IsExpired() {
+ var zero T
+ c.recordMiss(key)
+ telemetry.AddEvent(span, "cache.miss",
+ attribute.String("cache.result", "miss"),
+ attribute.String("cache.miss_reason", "expired"),
+ )
+ span.SetAttributes(
+ attribute.String("cache.result", "miss"),
+ attribute.String("cache.miss_reason", "expired"),
+ )
+ return zero, false
+ }
+
+ // Update access time and count
+ entry.AccessedAt = time.Now()
+ entry.AccessCount++
+
+ c.recordHit(key)
+ telemetry.AddEvent(span, "cache.hit",
+ attribute.String("cache.result", "hit"),
+ attribute.Int64("cache.access_count", entry.AccessCount),
+ )
+ span.SetAttributes(
+ attribute.String("cache.result", "hit"),
+ attribute.Int64("cache.access_count", entry.AccessCount),
+ )
+
+ logger.Get().Debug("Cache hit", "key", key, "access_count", entry.AccessCount)
+ return entry.Value, true
+}
+
+// Set stores a value in the cache with default TTL
+func (c *Cache[T]) Set(key string, value T) {
+ c.SetWithTTL(key, value, c.defaultTTL)
+}
+
+// SetWithTTL stores a value in the cache with specified TTL
+func (c *Cache[T]) SetWithTTL(key string, value T, ttl time.Duration) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ now := time.Now()
+
+ // Check if we need to evict items to make room
+ if len(c.data) >= c.maxSize {
+ c.evictLRU()
+ }
+
+ entry := &CacheEntry[T]{
+ Value: value,
+ CreatedAt: now,
+ ExpiresAt: now.Add(ttl),
+ AccessedAt: now,
+ AccessCount: 1,
+ }
+
+ // Check if key already exists
+ if _, exists := c.data[key]; !exists {
+ c.size.Add(context.Background(), 1)
+ }
+
+ c.data[key] = entry
+
+ logger.Get().Debug("Cache set", "key", key, "ttl", ttl)
+}
+
+// Delete removes a value from the cache
+func (c *Cache[T]) Delete(key string) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if _, exists := c.data[key]; exists {
+ delete(c.data, key)
+ c.size.Add(context.Background(), -1)
+ logger.Get().Debug("Cache delete", "key", key)
+ }
+}
+
+// Clear removes all items from the cache
+func (c *Cache[T]) Clear() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ count := len(c.data)
+ c.data = make(map[string]*CacheEntry[T])
+ c.size.Add(context.Background(), -int64(count))
+
+ logger.Get().Info("Cache cleared", "items_removed", count)
+}
+
+// Size returns the current number of items in the cache
+func (c *Cache[T]) Size() int {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+ return len(c.data)
+}
+
+// Name returns the name of the cache
+func (c *Cache[T]) Name() string {
+ return c.name
+}
+
+// Stats returns cache statistics
+func (c *Cache[T]) Stats() CacheStats {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ stats := CacheStats{
+ Size: len(c.data),
+ MaxSize: c.maxSize,
+ Expired: 0,
+ Oldest: time.Now(),
+ Newest: time.Time{},
+ }
+
+ for _, entry := range c.data {
+ if entry.IsExpired() {
+ stats.Expired++
+ }
+
+ if entry.CreatedAt.Before(stats.Oldest) {
+ stats.Oldest = entry.CreatedAt
+ }
+
+ if entry.CreatedAt.After(stats.Newest) {
+ stats.Newest = entry.CreatedAt
+ }
+ }
+
+ return stats
+}
+
+// CacheStats represents cache statistics
+type CacheStats struct {
+ Size int `json:"size"`
+ MaxSize int `json:"max_size"`
+ Expired int `json:"expired"`
+ Oldest time.Time `json:"oldest"`
+ Newest time.Time `json:"newest"`
+}
+
+// cleanupExpired removes expired entries from the cache
+func (c *Cache[T]) cleanupExpired() {
+ ticker := time.NewTicker(c.cleanupInterval)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ c.performCleanup()
+ case <-c.stopCleanup:
+ return
+ }
+ }
+}
+
+// performCleanup removes expired entries
+func (c *Cache[T]) performCleanup() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ keysToDelete := make([]string, 0)
+
+ for key, entry := range c.data {
+ if entry.IsExpired() {
+ keysToDelete = append(keysToDelete, key)
+ }
+ }
+
+ if len(keysToDelete) > 0 {
+ for _, key := range keysToDelete {
+ delete(c.data, key)
+ c.evictions.Add(context.Background(), 1)
+ }
+
+ c.size.Add(context.Background(), -int64(len(keysToDelete)))
+ logger.Get().Debug("Cache cleanup", "expired_items", len(keysToDelete))
+ }
+}
+
+// evictLRU removes the least recently used item
+func (c *Cache[T]) evictLRU() {
+ var oldestKey string
+ var oldestTime time.Time = time.Now()
+
+ for key, entry := range c.data {
+ if entry.AccessedAt.Before(oldestTime) {
+ oldestTime = entry.AccessedAt
+ oldestKey = key
+ }
+ }
+
+ if oldestKey != "" {
+ delete(c.data, oldestKey)
+ c.evictions.Add(context.Background(), 1)
+ c.size.Add(context.Background(), -1)
+ logger.Get().Debug("Cache LRU eviction", "key", oldestKey)
+ }
+}
+
+// recordHit records a cache hit
+func (c *Cache[T]) recordHit(key string) {
+ c.hits.Add(context.Background(), 1, metric.WithAttributes(
+ attribute.String("cache.key", key),
+ attribute.String("cache.result", "hit"),
+ attribute.String("cache.name", c.name),
+ ))
+}
+
+// recordMiss records a cache miss
+func (c *Cache[T]) recordMiss(key string) {
+ c.misses.Add(context.Background(), 1, metric.WithAttributes(
+ attribute.String("cache.key", key),
+ attribute.String("cache.result", "miss"),
+ attribute.String("cache.name", c.name),
+ ))
+}
+
+// Close stops the cache cleanup goroutine
+func (c *Cache[T]) Close() {
+ close(c.stopCleanup)
+}
+
+// InvalidateByType clears the entire cache for a specific cache type
+func InvalidateByType(cacheType CacheType) {
+ ctx := context.Background()
+ _, span := telemetry.StartSpan(ctx, "cache.invalidate",
+ attribute.String("cache.type", cacheType.String()),
+ attribute.String("cache.operation", "invalidate"),
+ )
+ defer span.End()
+
+ InitCaches()
+ if cache, exists := cacheRegistry[cacheType]; exists {
+ oldSize := cache.Size()
+ cache.Clear()
+
+ telemetry.AddEvent(span, "cache.invalidated",
+ attribute.String("cache.name", cache.name),
+ attribute.Int("cache.items_cleared", oldSize),
+ )
+ span.SetAttributes(
+ attribute.String("cache.name", cache.name),
+ attribute.Int("cache.items_cleared", oldSize),
+ )
+ telemetry.RecordSuccess(span, "Cache invalidated successfully")
+
+ logger.Get().Info("Cache invalidated", "cache_type", cacheType.String(), "reason", "modification_command", "items_cleared", oldSize)
+ } else {
+ telemetry.RecordError(span, fmt.Errorf("cache type not found: %s", cacheType.String()), "Cache type not found")
+ }
+}
+
+// InvalidateKubernetesCache clears the Kubernetes cache
+func InvalidateKubernetesCache() {
+ InvalidateByType(CacheTypeKubernetes)
+}
+
+// InvalidateHelmCache clears the Helm cache
+func InvalidateHelmCache() {
+ InvalidateByType(CacheTypeHelm)
+}
+
+// InvalidateIstioCache clears the Istio cache
+func InvalidateIstioCache() {
+ InvalidateByType(CacheTypeIstio)
+}
+
+// InvalidateCommandCache clears the Command cache
+func InvalidateCommandCache() {
+ InvalidateByType(CacheTypeCommand)
+}
+
+// InvalidateCacheForCommand invalidates the appropriate cache based on command type
+func InvalidateCacheForCommand(command string) {
+ if cacheType, exists := commandToCacheType[command]; exists {
+ InvalidateByType(cacheType)
+ } else {
+ // Default to command cache for unknown commands
+ InvalidateCommandCache()
+ }
+}
+
+// Global cache instances for different use cases
+var (
+ // cacheRegistry holds all cache instances by type
+ cacheRegistry = make(map[CacheType]*Cache[string])
+ once sync.Once
+)
+
+// InitCaches initializes all global cache instances
+func InitCaches() {
+ once.Do(func() {
+ // Initialize caches with optimized TTL values based on use case
+ // Kubernetes: 45s - K8s resources change frequently, users expect fresh data
+ cacheRegistry[CacheTypeKubernetes] = NewCache[string](CacheTypeKubernetes.String(), 45*time.Second, 1000, 1*time.Minute)
+
+ // Istio: 1m - Service mesh config more stable than pods, but proxy status can change
+ cacheRegistry[CacheTypeIstio] = NewCache[string](CacheTypeIstio.String(), 1*time.Minute, 500, 1*time.Minute)
+
+ // Helm: 2m - Releases change less frequently, chart info is stable
+ cacheRegistry[CacheTypeHelm] = NewCache[string](CacheTypeHelm.String(), 2*time.Minute, 300, 2*time.Minute)
+
+ // Command: 3m - General CLI commands have stable output, status commands don't change rapidly
+ cacheRegistry[CacheTypeCommand] = NewCache[string](CacheTypeCommand.String(), 3*time.Minute, 200, 1*time.Minute)
+
+ logger.Get().Info("Caches initialized")
+ })
+}
+
+// GetCacheByType returns a cache instance by cache type
+func GetCacheByType(cacheType CacheType) *Cache[string] {
+ InitCaches()
+ if cache, exists := cacheRegistry[cacheType]; exists {
+ return cache
+ }
+ // Fallback to command cache if type not found
+ return cacheRegistry[CacheTypeCommand]
+}
+
+// GetCacheByCommand returns a cache instance based on the command name
+func GetCacheByCommand(command string) *Cache[string] {
+ InitCaches()
+ if cacheType, exists := commandToCacheType[command]; exists {
+ return GetCacheByType(cacheType)
+ }
+ // Default to command cache for unknown commands
+ return GetCacheByType(CacheTypeCommand)
+}
+
+// CacheKey generates a consistent cache key from components
+func CacheKey(components ...string) string {
+ result := ""
+ for i, component := range components {
+ if i > 0 {
+ result += ":"
+ }
+ result += component
+ }
+ return result
+}
+
+// CacheResult is a helper function to cache the result of a function
+func CacheResult[T any](cache *Cache[T], key string, ttl time.Duration, fn func() (T, error)) (T, error) {
+ ctx := context.Background()
+ _, span := telemetry.StartSpan(ctx, "cache.result",
+ attribute.String("cache.name", cache.name),
+ attribute.String("cache.key", key),
+ attribute.String("cache.ttl", ttl.String()),
+ )
+ defer span.End()
+
+ var zero T
+
+ // Try to get from cache first
+ if cachedResult, found := cache.Get(key); found {
+ telemetry.AddEvent(span, "cache.result.hit",
+ attribute.String("cache.operation", "get"),
+ attribute.String("cache.result", "hit"),
+ )
+ span.SetAttributes(
+ attribute.String("cache.operation", "get"),
+ attribute.String("cache.result", "hit"),
+ )
+ telemetry.RecordSuccess(span, "Cache hit - returning cached result")
+ return cachedResult, nil
+ }
+
+ // Not in cache, execute function
+ telemetry.AddEvent(span, "cache.result.miss",
+ attribute.String("cache.operation", "compute"),
+ attribute.String("cache.result", "miss"),
+ )
+ span.SetAttributes(
+ attribute.String("cache.operation", "compute"),
+ attribute.String("cache.result", "miss"),
+ )
+
+ result, err := fn()
+ if err != nil {
+ telemetry.RecordError(span, err, "Function execution failed")
+ return zero, err
+ }
+
+ // Store in cache
+ cache.SetWithTTL(key, result, ttl)
+
+ telemetry.AddEvent(span, "cache.result.stored",
+ attribute.String("cache.operation", "set"),
+ )
+ span.SetAttributes(attribute.String("cache.operation", "set"))
+ telemetry.RecordSuccess(span, "Function executed and result cached")
+
+ return result, nil
+}
diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go
new file mode 100644
index 0000000..cc7cf64
--- /dev/null
+++ b/internal/cache/cache_test.go
@@ -0,0 +1,488 @@
+package cache
+
+import (
+ "fmt"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestNewCache(t *testing.T) {
+ cache := NewCache[string]("test-cache", 1*time.Minute, 100, 10*time.Second)
+
+ if cache.defaultTTL != 1*time.Minute {
+ t.Errorf("Expected default TTL of 1 minute, got %v", cache.defaultTTL)
+ }
+
+ if cache.maxSize != 100 {
+ t.Errorf("Expected max size of 100, got %d", cache.maxSize)
+ }
+
+ if cache.cleanupInterval != 10*time.Second {
+ t.Errorf("Expected cleanup interval of 10 seconds, got %v", cache.cleanupInterval)
+ }
+
+ if cache.name != "test-cache" {
+ t.Errorf("Expected cache name 'test-cache', got %s", cache.name)
+ }
+
+ cache.Close()
+}
+
+func TestCacheName(t *testing.T) {
+ cache := NewCache[string]("my-test-cache", 1*time.Minute, 100, 10*time.Second)
+ defer cache.Close()
+
+ if cache.Name() != "my-test-cache" {
+ t.Errorf("Expected cache name 'my-test-cache', got %s", cache.Name())
+ }
+}
+
+func TestCacheSetAndGet(t *testing.T) {
+ cache := NewCache[string]("test-cache", 1*time.Minute, 100, 10*time.Second)
+ defer cache.Close()
+
+ // Test set and get
+ cache.Set("key1", "value1")
+ value, found := cache.Get("key1")
+
+ if !found {
+ t.Error("Expected to find key1")
+ }
+
+ if value != "value1" {
+ t.Errorf("Expected value1, got %v", value)
+ }
+}
+
+func TestCacheSetWithTTL(t *testing.T) {
+ cache := NewCache[string]("test-cache", 1*time.Minute, 100, 10*time.Second)
+ defer cache.Close()
+
+ // Test set with custom TTL
+ cache.SetWithTTL("key1", "value1", 100*time.Millisecond)
+
+ // Should be found immediately
+ value, found := cache.Get("key1")
+ if !found {
+ t.Error("Expected to find key1")
+ }
+ if value != "value1" {
+ t.Errorf("Expected value1, got %v", value)
+ }
+
+ // Wait for expiration
+ time.Sleep(150 * time.Millisecond)
+
+ // Should not be found after expiration
+ _, found = cache.Get("key1")
+ if found {
+ t.Error("Expected key1 to be expired")
+ }
+}
+
+func TestCacheDelete(t *testing.T) {
+ cache := NewCache[string]("test-cache", 1*time.Minute, 100, 10*time.Second)
+ defer cache.Close()
+
+ cache.Set("key1", "value1")
+ cache.Delete("key1")
+
+ _, found := cache.Get("key1")
+ if found {
+ t.Error("Expected key1 to be deleted")
+ }
+}
+
+func TestCacheClear(t *testing.T) {
+ cache := NewCache[string]("test-cache", 1*time.Minute, 100, 10*time.Second)
+ defer cache.Close()
+
+ cache.Set("key1", "value1")
+ cache.Set("key2", "value2")
+
+ if cache.Size() != 2 {
+ t.Errorf("Expected size 2, got %d", cache.Size())
+ }
+
+ cache.Clear()
+
+ if cache.Size() != 0 {
+ t.Errorf("Expected size 0 after clear, got %d", cache.Size())
+ }
+}
+
+func TestCacheEviction(t *testing.T) {
+ cache := NewCache[string]("test-cache", 1*time.Minute, 2, 10*time.Second) // Small cache
+ defer cache.Close()
+
+ // Fill cache to capacity
+ cache.Set("key1", "value1")
+ cache.Set("key2", "value2")
+
+ // Add one more item - should evict LRU
+ cache.Set("key3", "value3")
+
+ // key1 should be evicted (oldest)
+ _, found := cache.Get("key1")
+ if found {
+ t.Error("Expected key1 to be evicted")
+ }
+
+ // key2 and key3 should still be there
+ _, found = cache.Get("key2")
+ if !found {
+ t.Error("Expected key2 to be present")
+ }
+
+ _, found = cache.Get("key3")
+ if !found {
+ t.Error("Expected key3 to be present")
+ }
+}
+
+func TestCacheExpiration(t *testing.T) {
+ cache := NewCache[string]("test-cache", 1*time.Minute, 100, 50*time.Millisecond) // Fast cleanup
+ defer cache.Close()
+
+ // Set item with short TTL
+ cache.SetWithTTL("key1", "value1", 100*time.Millisecond)
+
+ // Wait for cleanup to run
+ time.Sleep(200 * time.Millisecond)
+
+ // Item should be cleaned up
+ _, found := cache.Get("key1")
+ if found {
+ t.Error("Expected key1 to be cleaned up")
+ }
+}
+
+func TestCacheStats(t *testing.T) {
+ cache := NewCache[string]("test-cache", 1*time.Minute, 100, 10*time.Second)
+ defer cache.Close()
+
+ cache.Set("key1", "value1")
+ cache.Set("key2", "value2")
+
+ stats := cache.Stats()
+
+ if stats.Size != 2 {
+ t.Errorf("Expected stats size 2, got %d", stats.Size)
+ }
+
+ if stats.MaxSize != 100 {
+ t.Errorf("Expected stats max size 100, got %d", stats.MaxSize)
+ }
+
+ if stats.Expired != 0 {
+ t.Errorf("Expected 0 expired items, got %d", stats.Expired)
+ }
+}
+
+func TestCacheKey(t *testing.T) {
+ tests := []struct {
+ name string
+ components []string
+ expected string
+ }{
+ {"single component", []string{"key1"}, "key1"},
+ {"multiple components", []string{"key1", "key2", "key3"}, "key1:key2:key3"},
+ {"empty components", []string{}, ""},
+ {"empty string component", []string{"key1", "", "key3"}, "key1::key3"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := CacheKey(tt.components...)
+ if result != tt.expected {
+ t.Errorf("Expected %q, got %q", tt.expected, result)
+ }
+ })
+ }
+}
+
+func TestCacheResult(t *testing.T) {
+ cache := NewCache[string]("test-cache", 1*time.Minute, 100, 10*time.Second)
+ defer cache.Close()
+
+ callCount := 0
+ testFunction := func() (string, error) {
+ callCount++
+ return "result", nil
+ }
+
+ // First call should execute function
+ result, err := CacheResult(cache, "test-key", 1*time.Minute, testFunction)
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ if result != "result" {
+ t.Errorf("Expected 'result', got %q", result)
+ }
+ if callCount != 1 {
+ t.Errorf("Expected function to be called once, got %d", callCount)
+ }
+
+ // Second call should use cache
+ result, err = CacheResult(cache, "test-key", 1*time.Minute, testFunction)
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ if result != "result" {
+ t.Errorf("Expected 'result', got %q", result)
+ }
+ if callCount != 1 {
+ t.Errorf("Expected function to be called once (cached), got %d", callCount)
+ }
+}
+
+func TestCacheResultWithError(t *testing.T) {
+ cache := NewCache[string]("test-cache", 1*time.Minute, 100, 10*time.Second)
+ defer cache.Close()
+
+ testFunction := func() (string, error) {
+ return "", &testError{message: "test error"}
+ }
+
+ result, err := CacheResult(cache, "test-key", 1*time.Minute, testFunction)
+ if err == nil {
+ t.Error("Expected error")
+ }
+ if result != "" {
+ t.Errorf("Expected empty result, got %q", result)
+ }
+
+ // Check that error result is not cached
+ _, found := cache.Get("test-key")
+ if found {
+ t.Error("Expected error result not to be cached")
+ }
+}
+
+func TestCacheInitialization(t *testing.T) {
+ // Test that all cache types are properly initialized
+ types := []CacheType{
+ CacheTypeKubernetes,
+ CacheTypeCommand,
+ CacheTypeHelm,
+ CacheTypeIstio,
+ }
+
+ for _, cacheType := range types {
+ t.Run(cacheType.String(), func(t *testing.T) {
+ cache := GetCacheByType(cacheType)
+ if cache == nil {
+ t.Errorf("Expected cache for type %s to be initialized", cacheType.String())
+ }
+ if cache.Name() != cacheType.String() {
+ t.Errorf("Expected cache name %s, got %s", cacheType.String(), cache.Name())
+ }
+ })
+ }
+}
+
+func TestCacheEntry(t *testing.T) {
+ now := time.Now()
+ entry := &CacheEntry[string]{
+ Value: "test",
+ CreatedAt: now,
+ ExpiresAt: now.Add(1 * time.Minute),
+ AccessedAt: now,
+ AccessCount: 1,
+ }
+
+ // Should not be expired
+ if entry.IsExpired() {
+ t.Error("Expected entry not to be expired")
+ }
+
+ // Make it expired
+ entry.ExpiresAt = now.Add(-1 * time.Minute)
+
+ // Should be expired
+ if !entry.IsExpired() {
+ t.Error("Expected entry to be expired")
+ }
+}
+
+func TestCachePerformCleanup(t *testing.T) {
+ cache := NewCache[string]("test-cache", 1*time.Minute, 100, 10*time.Second)
+ defer cache.Close()
+
+ // Add expired item
+ cache.SetWithTTL("expired", "value", -1*time.Minute)
+
+ // Add valid item
+ cache.Set("valid", "value")
+
+ // Perform cleanup
+ cache.performCleanup()
+
+ // Expired item should be removed
+ _, found := cache.Get("expired")
+ if found {
+ t.Error("Expected expired item to be removed")
+ }
+
+ // Valid item should remain
+ _, found = cache.Get("valid")
+ if !found {
+ t.Error("Expected valid item to remain")
+ }
+}
+
+func TestCacheConcurrency(t *testing.T) {
+ cache := NewCache[string]("test-cache", 1*time.Minute, 1000, 10*time.Second)
+ defer cache.Close()
+
+ // Test concurrent operations
+ done := make(chan bool)
+
+ // Writer goroutine
+ go func() {
+ for i := 0; i < 100; i++ {
+ cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i))
+ }
+ done <- true
+ }()
+
+ // Reader goroutine
+ go func() {
+ for i := 0; i < 100; i++ {
+ cache.Get(fmt.Sprintf("key%d", i))
+ }
+ done <- true
+ }()
+
+ // Wait for both goroutines
+ <-done
+ <-done
+
+ // Cache should have items
+ if cache.Size() == 0 {
+ t.Error("Expected cache to have items")
+ }
+}
+
+// Helper types for testing
+type testError struct {
+ message string
+}
+
+func (e *testError) Error() string {
+ return e.message
+}
+
+func TestCacheTypeString(t *testing.T) {
+ tests := []struct {
+ cacheType CacheType
+ expected string
+ }{
+ {CacheTypeKubernetes, "kubernetes"},
+ {CacheTypeCommand, "command"},
+ {CacheTypeHelm, "helm"},
+ {CacheTypeIstio, "istio"},
+ {CacheType(999), "unknown"}, // Test unknown type
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.expected, func(t *testing.T) {
+ result := tt.cacheType.String()
+ if result != tt.expected {
+ t.Errorf("Expected %q, got %q", tt.expected, result)
+ }
+ })
+ }
+}
+
+func TestGetCacheByType(t *testing.T) {
+ // Test all valid cache types
+ types := []CacheType{
+ CacheTypeKubernetes,
+ CacheTypeCommand,
+ CacheTypeHelm,
+ CacheTypeIstio,
+ }
+
+ for _, cacheType := range types {
+ t.Run(cacheType.String(), func(t *testing.T) {
+ cache := GetCacheByType(cacheType)
+ if cache == nil {
+ t.Errorf("Expected cache for type %s, got nil", cacheType.String())
+ }
+ if cache.Name() != cacheType.String() {
+ t.Errorf("Expected cache name %s, got %s", cacheType.String(), cache.Name())
+ }
+ })
+ }
+}
+
+func TestGetCacheByCommand(t *testing.T) {
+ tests := []struct {
+ command string
+ expectedType CacheType
+ }{
+ {"kubectl", CacheTypeKubernetes},
+ {"helm", CacheTypeHelm},
+ {"istioctl", CacheTypeIstio},
+ {"cilium", CacheTypeCommand},
+ {"argo", CacheTypeCommand},
+ {"unknown-command", CacheTypeCommand}, // Should default to command cache
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.command, func(t *testing.T) {
+ cache := GetCacheByCommand(tt.command)
+ if cache == nil {
+ t.Errorf("Expected cache for command %s, got nil", tt.command)
+ }
+ if cache.Name() != tt.expectedType.String() {
+ t.Errorf("Expected cache name %s for command %s, got %s",
+ tt.expectedType.String(), tt.command, cache.Name())
+ }
+ })
+ }
+}
+
+func TestCacheOTelTracing(t *testing.T) {
+ // This test verifies that OTEL tracing calls don't panic
+ // The actual tracing verification would require setting up an OTEL test environment
+ cache := NewCache[string]("test-tracing", 1*time.Minute, 10, 5*time.Minute)
+ defer cache.Close()
+
+ // Test cache miss with tracing
+ _, found := cache.Get("missing-key")
+ assert.False(t, found)
+
+ // Test cache hit with tracing
+ cache.Set("test-key", "test-value")
+ value, found := cache.Get("test-key")
+ assert.True(t, found)
+ assert.Equal(t, "test-value", value)
+
+ // Test CacheResult with tracing
+ callCount := 0
+ result, err := CacheResult(cache, "result-key", 1*time.Minute, func() (string, error) {
+ callCount++
+ return "computed-value", nil
+ })
+ assert.NoError(t, err)
+ assert.Equal(t, "computed-value", result)
+ assert.Equal(t, 1, callCount)
+
+ // Test cache hit on second call
+ result2, err := CacheResult(cache, "result-key", 1*time.Minute, func() (string, error) {
+ callCount++
+ return "computed-value", nil
+ })
+ assert.NoError(t, err)
+ assert.Equal(t, "computed-value", result2)
+ assert.Equal(t, 1, callCount) // Should not increment due to cache hit
+
+ // Test cache invalidation with tracing
+ oldSize := cache.Size()
+ InvalidateByType(CacheTypeCommand)
+ assert.True(t, oldSize > 0) // Verify we had items to clear
+}
diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go
new file mode 100644
index 0000000..3061006
--- /dev/null
+++ b/internal/cmd/cmd.go
@@ -0,0 +1,69 @@
+package cmd
+
+import (
+ "context"
+ "os/exec"
+ "time"
+
+ "github.com/kagent-dev/tools/internal/logger"
+)
+
+// ShellExecutor defines the interface for executing shell commands
+type ShellExecutor interface {
+ Exec(ctx context.Context, command string, args ...string) (output []byte, err error)
+}
+
+// DefaultShellExecutor implements ShellExecutor using os/exec
+type DefaultShellExecutor struct{}
+
+// Exec executes a command using os/exec.CommandContext
+func (e *DefaultShellExecutor) Exec(ctx context.Context, command string, args ...string) ([]byte, error) {
+ log := logger.WithContext(ctx)
+ startTime := time.Now()
+
+ log.Info("executing command",
+ "command", command,
+ "args", args,
+ )
+
+ cmd := exec.CommandContext(ctx, command, args...)
+ output, err := cmd.CombinedOutput()
+
+ duration := time.Since(startTime)
+
+ if err != nil {
+ log.Error("command execution failed",
+ "command", command,
+ "args", args,
+ "error", err,
+ "output", string(output),
+ "duration", duration.Seconds(),
+ )
+ } else {
+ log.Info("command execution successful",
+ "command", command,
+ "args", args,
+ "duration", duration.Seconds(),
+ )
+ }
+
+ return output, err
+}
+
+// Context key for shell executor injection
+type contextKey string
+
+const shellExecutorKey contextKey = "shellExecutor"
+
+// WithShellExecutor returns a context with the given shell executor
+func WithShellExecutor(ctx context.Context, executor ShellExecutor) context.Context {
+ return context.WithValue(ctx, shellExecutorKey, executor)
+}
+
+// GetShellExecutor retrieves the shell executor from context, or returns default
+func GetShellExecutor(ctx context.Context) ShellExecutor {
+ if executor, ok := ctx.Value(shellExecutorKey).(ShellExecutor); ok {
+ return executor
+ }
+ return &DefaultShellExecutor{}
+}
diff --git a/internal/cmd/cmd_test.go b/internal/cmd/cmd_test.go
new file mode 100644
index 0000000..f902d4c
--- /dev/null
+++ b/internal/cmd/cmd_test.go
@@ -0,0 +1,58 @@
+package cmd
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestDefaultShellExecutor(t *testing.T) {
+ executor := &DefaultShellExecutor{}
+
+ // Test successful command
+ output, err := executor.Exec(context.Background(), "echo", "hello")
+ assert.NoError(t, err)
+ assert.Equal(t, "hello\n", string(output))
+
+ // Test command with error
+ _, err = executor.Exec(context.Background(), "nonexistent-command")
+ assert.Error(t, err)
+}
+
+func TestMockShellExecutor(t *testing.T) {
+ mock := NewMockShellExecutor()
+
+ t.Run("unmocked command returns error", func(t *testing.T) {
+ _, err := mock.Exec(context.Background(), "unmocked", "command")
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "no mock found for command")
+ })
+
+ t.Run("mocked command returns expected result", func(t *testing.T) {
+ expectedOutput := "mocked output"
+ mock.AddCommandString("kubectl", []string{"get", "pods"}, expectedOutput, nil)
+
+ output, err := mock.Exec(context.Background(), "kubectl", "get", "pods")
+ assert.NoError(t, err)
+ assert.Equal(t, expectedOutput, string(output))
+ })
+}
+
+func TestContextShellExecutor(t *testing.T) {
+ t.Run("default executor when no context value", func(t *testing.T) {
+ ctx := context.Background()
+ executor := GetShellExecutor(ctx)
+
+ _, ok := executor.(*DefaultShellExecutor)
+ assert.True(t, ok, "should return DefaultShellExecutor when no context value")
+ })
+
+ t.Run("mock executor from context", func(t *testing.T) {
+ mock := NewMockShellExecutor()
+ ctx := WithShellExecutor(context.Background(), mock)
+
+ executor := GetShellExecutor(ctx)
+ assert.Equal(t, mock, executor, "should return the mock executor from context")
+ })
+}
diff --git a/internal/cmd/mock.go b/internal/cmd/mock.go
new file mode 100644
index 0000000..3f13c47
--- /dev/null
+++ b/internal/cmd/mock.go
@@ -0,0 +1,120 @@
+package cmd
+
+import (
+ "context"
+ "fmt"
+ "strings"
+ "sync"
+)
+
+// MockCall represents a recorded command execution for testing
+type MockCall struct {
+ Command string
+ Args []string
+}
+
+// MockShellExecutor is a mock implementation of ShellExecutor for testing
+type MockShellExecutor struct {
+ mu sync.Mutex
+ callLog []MockCall
+ commandMocks map[string]map[string]struct {
+ output string
+ err error
+ }
+ partialMatchers []struct {
+ command string
+ args []string
+ output string
+ err error
+ }
+}
+
+// NewMockShellExecutor creates a new mock shell executor
+func NewMockShellExecutor() *MockShellExecutor {
+ return &MockShellExecutor{
+ commandMocks: make(map[string]map[string]struct {
+ output string
+ err error
+ }),
+ }
+}
+
+// AddCommandString mocks a command with specific arguments and a string output
+func (m *MockShellExecutor) AddCommandString(command string, args []string, output string, err error) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ argsKey := strings.Join(args, " ")
+ if _, ok := m.commandMocks[command]; !ok {
+ m.commandMocks[command] = make(map[string]struct {
+ output string
+ err error
+ })
+ }
+ m.commandMocks[command][argsKey] = struct {
+ output string
+ err error
+ }{output, err}
+}
+
+// AddPartialMatcherString mocks a command with partial argument matching
+func (m *MockShellExecutor) AddPartialMatcherString(command string, args []string, output string, err error) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ m.partialMatchers = append(m.partialMatchers, struct {
+ command string
+ args []string
+ output string
+ err error
+ }{command, args, output, err})
+}
+
+// Exec records the call and returns a mocked output or error
+func (m *MockShellExecutor) Exec(ctx context.Context, command string, args ...string) ([]byte, error) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ m.callLog = append(m.callLog, MockCall{Command: command, Args: args})
+
+ // Check for exact match first
+ argsKey := strings.Join(args, " ")
+ if mocks, ok := m.commandMocks[command]; ok {
+ if mock, ok := mocks[argsKey]; ok {
+ return []byte(mock.output), mock.err
+ }
+ }
+
+ // Check for partial match
+ for _, matcher := range m.partialMatchers {
+ if matcher.command == command && argsContain(args, matcher.args) {
+ return []byte(matcher.output), matcher.err
+ }
+ }
+
+ return nil, fmt.Errorf("no mock found for command: %s %v", command, args)
+}
+
+// GetCallLog returns the history of commands executed
+func (m *MockShellExecutor) GetCallLog() []MockCall {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ return m.callLog
+}
+
+// argsContain checks if all elements of subset are in set
+func argsContain(set, subset []string) bool {
+ for _, sub := range subset {
+ found := false
+ for _, s := range set {
+ if strings.Contains(s, sub) {
+ found = true
+ break
+ }
+ }
+ if !found {
+ return false
+ }
+ }
+ return true
+}
diff --git a/internal/commands/builder.go b/internal/commands/builder.go
new file mode 100644
index 0000000..a6bd8e6
--- /dev/null
+++ b/internal/commands/builder.go
@@ -0,0 +1,747 @@
+package commands
+
+import (
+ "context"
+ "fmt"
+ "strings"
+ "time"
+
+ "github.com/kagent-dev/tools/internal/cache"
+ "github.com/kagent-dev/tools/internal/cmd"
+ "github.com/kagent-dev/tools/internal/errors"
+ "github.com/kagent-dev/tools/internal/logger"
+ "github.com/kagent-dev/tools/internal/security"
+ "github.com/kagent-dev/tools/internal/telemetry"
+ "go.opentelemetry.io/otel/attribute"
+)
+
+// CommandBuilder provides a fluent interface for building CLI commands
+type CommandBuilder struct {
+ command string
+ args []string
+ namespace string
+ context string
+ kubeconfig string
+ output string
+ labels map[string]string
+ annotations map[string]string
+ timeout time.Duration
+ dryRun bool
+ force bool
+ wait bool
+ validate bool
+ cached bool
+ cacheTTL time.Duration
+ cacheKey string
+}
+
+// NewCommandBuilder creates a new command builder
+func NewCommandBuilder(command string) *CommandBuilder {
+ return &CommandBuilder{
+ command: command,
+ args: make([]string, 0),
+ labels: make(map[string]string),
+ annotations: make(map[string]string),
+ timeout: 30 * time.Second,
+ validate: true,
+ cacheTTL: 5 * time.Minute,
+ }
+}
+
+// KubectlBuilder creates a kubectl command builder
+func KubectlBuilder() *CommandBuilder {
+ return NewCommandBuilder("kubectl")
+}
+
+// HelmBuilder creates a helm command builder
+func HelmBuilder() *CommandBuilder {
+ return NewCommandBuilder("helm")
+}
+
+// IstioCtlBuilder creates an istioctl command builder
+func IstioCtlBuilder() *CommandBuilder {
+ return NewCommandBuilder("istioctl")
+}
+
+// CiliumBuilder creates a cilium command builder
+func CiliumBuilder() *CommandBuilder {
+ return NewCommandBuilder("cilium")
+}
+
+// ArgoRolloutsBuilder creates an argo rollouts command builder
+func ArgoRolloutsBuilder() *CommandBuilder {
+ return NewCommandBuilder("kubectl").WithArgs("argo", "rollouts")
+}
+
+// WithArgs adds arguments to the command
+func (cb *CommandBuilder) WithArgs(args ...string) *CommandBuilder {
+ cb.args = append(cb.args, args...)
+ return cb
+}
+
+// WithNamespace sets the namespace
+func (cb *CommandBuilder) WithNamespace(namespace string) *CommandBuilder {
+ if err := security.ValidateNamespace(namespace); err != nil {
+ logger.Get().Error("Invalid namespace", "namespace", namespace, "error", err)
+ return cb
+ }
+ cb.namespace = namespace
+ return cb
+}
+
+// WithContext sets the Kubernetes context
+func (cb *CommandBuilder) WithContext(context string) *CommandBuilder {
+ if err := security.ValidateCommandInput(context); err != nil {
+ logger.Get().Error("Invalid context", "context", context, "error", err)
+ return cb
+ }
+ cb.context = context
+ return cb
+}
+
+// WithKubeconfig sets the kubeconfig file
+func (cb *CommandBuilder) WithKubeconfig(kubeconfig string) *CommandBuilder {
+ if err := security.ValidateFilePath(kubeconfig); err != nil {
+ logger.Get().Error("Invalid kubeconfig path", "kubeconfig", kubeconfig, "error", err)
+ return cb
+ }
+ cb.kubeconfig = kubeconfig
+ return cb
+}
+
+// WithOutput sets the output format
+func (cb *CommandBuilder) WithOutput(output string) *CommandBuilder {
+ validOutputs := []string{"json", "yaml", "wide", "name", "custom-columns", "custom-columns-file", "go-template", "go-template-file", "jsonpath", "jsonpath-file"}
+
+ valid := false
+ for _, validOutput := range validOutputs {
+ if output == validOutput {
+ valid = true
+ break
+ }
+ }
+
+ if !valid {
+ logger.Get().Error("Invalid output format", "output", output)
+ return cb
+ }
+
+ cb.output = output
+ return cb
+}
+
+// WithLabel adds a label selector
+func (cb *CommandBuilder) WithLabel(key, value string) *CommandBuilder {
+ if err := security.ValidateK8sLabel(key, value); err != nil {
+ logger.Get().Error("Invalid label", "key", key, "value", value, "error", err)
+ return cb
+ }
+ cb.labels[key] = value
+ return cb
+}
+
+// WithLabels adds multiple label selectors
+func (cb *CommandBuilder) WithLabels(labels map[string]string) *CommandBuilder {
+ for key, value := range labels {
+ cb.WithLabel(key, value)
+ }
+ return cb
+}
+
+// WithAnnotation adds an annotation
+func (cb *CommandBuilder) WithAnnotation(key, value string) *CommandBuilder {
+ if err := security.ValidateK8sLabel(key, value); err != nil {
+ logger.Get().Error("Invalid annotation", "key", key, "value", value, "error", err)
+ return cb
+ }
+ cb.annotations[key] = value
+ return cb
+}
+
+// WithTimeout sets the command timeout
+func (cb *CommandBuilder) WithTimeout(timeout time.Duration) *CommandBuilder {
+ cb.timeout = timeout
+ return cb
+}
+
+// WithDryRun enables dry run mode
+func (cb *CommandBuilder) WithDryRun(dryRun bool) *CommandBuilder {
+ cb.dryRun = dryRun
+ return cb
+}
+
+// WithForce enables force mode
+func (cb *CommandBuilder) WithForce(force bool) *CommandBuilder {
+ cb.force = force
+ return cb
+}
+
+// WithWait enables wait mode
+func (cb *CommandBuilder) WithWait(wait bool) *CommandBuilder {
+ cb.wait = wait
+ return cb
+}
+
+// WithValidation enables/disables validation
+func (cb *CommandBuilder) WithValidation(validate bool) *CommandBuilder {
+ cb.validate = validate
+ return cb
+}
+
+// WithCache enables caching of the command result
+func (cb *CommandBuilder) WithCache(cached bool) *CommandBuilder {
+ cb.cached = cached
+ return cb
+}
+
+// WithCacheTTL sets the cache TTL
+func (cb *CommandBuilder) WithCacheTTL(ttl time.Duration) *CommandBuilder {
+ cb.cacheTTL = ttl
+ return cb
+}
+
+// WithCacheKey sets a custom cache key
+func (cb *CommandBuilder) WithCacheKey(key string) *CommandBuilder {
+ cb.cacheKey = key
+ return cb
+}
+
+// Build constructs the final command arguments
+func (cb *CommandBuilder) Build() (string, []string, error) {
+ args := make([]string, 0, len(cb.args)+20)
+
+ // Add main arguments
+ args = append(args, cb.args...)
+
+ // Add namespace if specified
+ if cb.namespace != "" {
+ args = append(args, "--namespace", cb.namespace)
+ }
+
+ // Add context if specified
+ if cb.context != "" {
+ args = append(args, "--context", cb.context)
+ }
+
+ // Add kubeconfig if specified
+ if cb.kubeconfig != "" {
+ args = append(args, "--kubeconfig", cb.kubeconfig)
+ }
+
+ // Add output format
+ if cb.output != "" {
+ args = append(args, "--output", cb.output)
+ }
+
+ // Add label selectors
+ if len(cb.labels) > 0 {
+ var labelSelectors []string
+ for key, value := range cb.labels {
+ if value != "" {
+ labelSelectors = append(labelSelectors, fmt.Sprintf("%s=%s", key, value))
+ } else {
+ labelSelectors = append(labelSelectors, key)
+ }
+ }
+ if len(labelSelectors) > 0 {
+ args = append(args, "--selector", strings.Join(labelSelectors, ","))
+ }
+ }
+
+ // Add timeout only for commands that support it
+ if cb.timeout > 0 {
+ if cb.supportsTimeout() {
+ args = append(args, "--timeout", cb.timeout.String())
+ }
+ }
+
+ // Add dry run
+ if cb.dryRun {
+ args = append(args, "--dry-run=client")
+ }
+
+ // Add force
+ if cb.force {
+ args = append(args, "--force")
+ }
+
+ // Add wait
+ if cb.wait {
+ args = append(args, "--wait")
+ }
+
+ // Add validation
+ if !cb.validate {
+ args = append(args, "--validate=false")
+ }
+
+ return cb.command, args, nil
+}
+
+// supportsTimeout checks if the command supports the --timeout flag
+func (cb *CommandBuilder) supportsTimeout() bool {
+ // For kubectl, many commands support --timeout
+ if cb.command == "kubectl" {
+ if len(cb.args) == 0 {
+ return false
+ }
+
+ // Check the first argument (subcommand)
+ subcommand := cb.args[0]
+ switch subcommand {
+ case "wait":
+ return true
+ case "delete":
+ // kubectl delete supports --timeout when waiting for deletion
+ return true
+ case "rollout":
+ // kubectl rollout status supports --timeout
+ if len(cb.args) > 1 && cb.args[1] == "status" {
+ return true
+ }
+ return false
+ case "apply":
+ // kubectl apply supports --timeout when used with --wait
+ return cb.wait
+ case "annotate", "label":
+ // kubectl annotate and label support --timeout
+ return true
+ case "create":
+ // kubectl create supports --timeout
+ return true
+ case "argo":
+ // kubectl argo rollouts commands support --timeout
+ if len(cb.args) > 1 && cb.args[1] == "rollouts" {
+ return true
+ }
+ return false
+ case "get":
+ // kubectl get supports --timeout for some operations
+ return false // Most get operations don't need timeout, they're read-only
+ default:
+ return false
+ }
+ }
+
+ // For other commands (helm, istioctl, cilium), assume they support timeout
+ // unless we find specific cases where they don't
+ return true
+}
+
+// Execute runs the command
+func (cb *CommandBuilder) Execute(ctx context.Context) (string, error) {
+ log := logger.WithContext(ctx)
+ _, span := telemetry.StartSpan(ctx, "commands.execute",
+ attribute.String("command", cb.command),
+ attribute.StringSlice("args", cb.args),
+ attribute.Bool("cached", cb.cached),
+ )
+ defer span.End()
+
+ command, args, err := cb.Build()
+ if err != nil {
+ telemetry.RecordError(span, err, "Command build failed")
+ log.Error("failed to build command",
+ "command", cb.command,
+ "error", err,
+ )
+ return "", err
+ }
+
+ span.SetAttributes(
+ attribute.String("built_command", command),
+ attribute.StringSlice("built_args", args),
+ )
+
+ log.Debug("executing command",
+ "command", command,
+ "args", args,
+ "cached", cb.cached,
+ )
+
+ // Generate cache key if caching is enabled
+ if cb.cached {
+ telemetry.AddEvent(span, "execution.cached")
+ return cb.executeWithCache(ctx, command, args)
+ }
+
+ // Execute the command
+ telemetry.AddEvent(span, "execution.direct")
+ result, err := cb.executeCommand(ctx, command, args)
+ if err != nil {
+ telemetry.RecordError(span, err, "Command execution failed")
+ return "", err
+ }
+
+ telemetry.RecordSuccess(span, "Command executed successfully")
+ span.SetAttributes(
+ attribute.Int("result_length", len(result)),
+ )
+
+ return result, nil
+}
+
+func (cb *CommandBuilder) executeWithCache(ctx context.Context, command string, args []string) (string, error) {
+ log := logger.WithContext(ctx)
+ _, span := telemetry.StartSpan(ctx, "commands.executeWithCache",
+ attribute.String("command", command),
+ attribute.StringSlice("args", args),
+ attribute.Bool("cached", true),
+ )
+ defer span.End()
+
+ cacheKey := cb.cacheKey
+ if cacheKey == "" {
+ cacheKey = cache.CacheKey(append([]string{command}, args...)...)
+ }
+
+ log.Info("executing cached command",
+ "command", command,
+ "args", args,
+ "cache_key", cacheKey,
+ "cache_ttl", cb.cacheTTL.String(),
+ )
+
+ // Try to get from cache first
+ cacheInstance := cache.GetCacheByCommand(command)
+
+ telemetry.AddEvent(span, "cache.lookup",
+ attribute.String("cache_key", cacheKey),
+ attribute.String("cache_ttl", cb.cacheTTL.String()),
+ )
+
+ result, err := cache.CacheResult(cacheInstance, cacheKey, cb.cacheTTL, func() (string, error) {
+ telemetry.AddEvent(span, "cache.miss.executing_command")
+ log.Debug("cache miss, executing command",
+ "command", command,
+ "args", args,
+ )
+ return cb.executeCommand(ctx, command, args)
+ })
+
+ if err != nil {
+ telemetry.RecordError(span, err, "Cached command execution failed")
+ log.Error("cached command execution failed",
+ "command", command,
+ "args", args,
+ "cache_key", cacheKey,
+ "error", err,
+ )
+ return "", err
+ }
+
+ telemetry.RecordSuccess(span, "Cached command executed successfully")
+ log.Info("cached command execution successful",
+ "command", command,
+ "args", args,
+ "cache_key", cacheKey,
+ "result_length", len(result),
+ )
+
+ span.SetAttributes(
+ attribute.String("cache_key", cacheKey),
+ attribute.Int("result_length", len(result)),
+ )
+
+ return result, nil
+}
+
+// executeCommand executes the actual command
+func (cb *CommandBuilder) executeCommand(ctx context.Context, command string, args []string) (string, error) {
+ executor := cmd.GetShellExecutor(ctx)
+ output, err := executor.Exec(ctx, command, args...)
+ if err != nil {
+ // Create appropriate error based on command type
+ var toolError *errors.ToolError
+ switch command {
+ case "kubectl":
+ toolError = errors.NewKubernetesError(strings.Join(args, " "), err)
+ case "helm":
+ toolError = errors.NewHelmError(strings.Join(args, " "), err)
+ case "istioctl":
+ toolError = errors.NewIstioError(strings.Join(args, " "), err)
+ case "cilium":
+ toolError = errors.NewCiliumError(strings.Join(args, " "), err)
+ default:
+ toolError = errors.NewCommandError(command, err)
+ }
+
+ return "", toolError
+ }
+
+ return string(output), nil
+}
+
+// Common command patterns as helper functions
+
+// GetPods creates a command to get pods
+func GetPods(namespace string, labels map[string]string) *CommandBuilder {
+ builder := KubectlBuilder().WithArgs("get", "pods")
+
+ if namespace != "" {
+ builder = builder.WithNamespace(namespace)
+ }
+
+ if len(labels) > 0 {
+ builder = builder.WithLabels(labels)
+ }
+
+ return builder.WithCache(true).WithOutput("json")
+}
+
+// GetServices creates a command to get services
+func GetServices(namespace string, labels map[string]string) *CommandBuilder {
+ builder := KubectlBuilder().WithArgs("get", "services")
+
+ if namespace != "" {
+ builder = builder.WithNamespace(namespace)
+ }
+
+ if len(labels) > 0 {
+ builder = builder.WithLabels(labels)
+ }
+
+ return builder.WithCache(true).WithOutput("json")
+}
+
+// GetDeployments creates a command to get deployments
+func GetDeployments(namespace string, labels map[string]string) *CommandBuilder {
+ builder := KubectlBuilder().WithArgs("get", "deployments")
+
+ if namespace != "" {
+ builder = builder.WithNamespace(namespace)
+ }
+
+ if len(labels) > 0 {
+ builder = builder.WithLabels(labels)
+ }
+
+ return builder.WithCache(true).WithOutput("json")
+}
+
+// DescribeResource creates a command to describe a resource
+func DescribeResource(resourceType, resourceName, namespace string) *CommandBuilder {
+ builder := KubectlBuilder().WithArgs("describe", resourceType, resourceName)
+
+ if namespace != "" {
+ builder = builder.WithNamespace(namespace)
+ }
+
+ return builder.WithCache(true).WithCacheTTL(2 * time.Minute)
+}
+
+// GetLogs creates a command to get logs
+func GetLogs(podName, namespace string, options LogOptions) *CommandBuilder {
+ builder := KubectlBuilder().WithArgs("logs", podName)
+
+ if namespace != "" {
+ builder = builder.WithNamespace(namespace)
+ }
+
+ if options.Container != "" {
+ builder = builder.WithArgs("--container", options.Container)
+ }
+
+ if options.Follow {
+ builder = builder.WithArgs("--follow")
+ }
+
+ if options.Previous {
+ builder = builder.WithArgs("--previous")
+ }
+
+ if options.Timestamps {
+ builder = builder.WithArgs("--timestamps")
+ }
+
+ if options.TailLines > 0 {
+ builder = builder.WithArgs("--tail", fmt.Sprintf("%d", options.TailLines))
+ }
+
+ if options.SinceTime != "" {
+ builder = builder.WithArgs("--since-time", options.SinceTime)
+ }
+
+ if options.SinceDuration != "" {
+ builder = builder.WithArgs("--since", options.SinceDuration)
+ }
+
+ // Don't cache logs by default as they change frequently
+ return builder.WithCache(false)
+}
+
+// LogOptions represents options for log commands
+type LogOptions struct {
+ Container string
+ Follow bool
+ Previous bool
+ Timestamps bool
+ TailLines int
+ SinceTime string
+ SinceDuration string
+}
+
+// ApplyResource creates a command to apply a resource
+func ApplyResource(filename string, namespace string, options ApplyOptions) *CommandBuilder {
+ builder := KubectlBuilder().WithArgs("apply", "-f", filename)
+
+ if namespace != "" {
+ builder = builder.WithNamespace(namespace)
+ }
+
+ if options.DryRun {
+ builder = builder.WithDryRun(true)
+ }
+
+ if options.Force {
+ builder = builder.WithForce(true)
+ }
+
+ if options.Wait {
+ builder = builder.WithWait(true)
+ }
+
+ if !options.Validate {
+ builder = builder.WithValidation(false)
+ }
+
+ return builder.WithCache(false) // Don't cache apply operations
+}
+
+// ApplyOptions represents options for apply commands
+type ApplyOptions struct {
+ DryRun bool
+ Force bool
+ Wait bool
+ Validate bool
+}
+
+// DeleteResource creates a command to delete a resource
+func DeleteResource(resourceType, resourceName, namespace string, options DeleteOptions) *CommandBuilder {
+ builder := KubectlBuilder().WithArgs("delete", resourceType, resourceName)
+
+ if namespace != "" {
+ builder = builder.WithNamespace(namespace)
+ }
+
+ if options.Force {
+ builder = builder.WithForce(true)
+ }
+
+ if options.GracePeriod >= 0 {
+ builder = builder.WithArgs("--grace-period", fmt.Sprintf("%d", options.GracePeriod))
+ }
+
+ if options.Wait {
+ builder = builder.WithWait(true)
+ }
+
+ return builder.WithCache(false) // Don't cache delete operations
+}
+
+// DeleteOptions represents options for delete commands
+type DeleteOptions struct {
+ Force bool
+ GracePeriod int
+ Wait bool
+}
+
+// HelmInstall creates a command to install a Helm chart
+func HelmInstall(releaseName, chart, namespace string, options HelmInstallOptions) *CommandBuilder {
+ builder := HelmBuilder().WithArgs("install", releaseName, chart)
+
+ if namespace != "" {
+ builder = builder.WithNamespace(namespace)
+ }
+
+ if options.CreateNamespace {
+ builder = builder.WithArgs("--create-namespace")
+ }
+
+ if options.DryRun {
+ builder = builder.WithDryRun(true)
+ }
+
+ if options.Wait {
+ builder = builder.WithWait(true)
+ }
+
+ if options.ValuesFile != "" {
+ builder = builder.WithArgs("--values", options.ValuesFile)
+ }
+
+ for key, value := range options.SetValues {
+ builder = builder.WithArgs("--set", fmt.Sprintf("%s=%s", key, value))
+ }
+
+ return builder.WithCache(false) // Don't cache install operations
+}
+
+// HelmInstallOptions represents options for Helm install commands
+type HelmInstallOptions struct {
+ CreateNamespace bool
+ DryRun bool
+ Wait bool
+ ValuesFile string
+ SetValues map[string]string
+}
+
+// HelmList creates a command to list Helm releases
+func HelmList(namespace string, options HelmListOptions) *CommandBuilder {
+ builder := HelmBuilder().WithArgs("list")
+
+ if namespace != "" {
+ builder = builder.WithNamespace(namespace)
+ }
+
+ if options.AllNamespaces {
+ builder = builder.WithArgs("--all-namespaces")
+ }
+
+ if options.Output != "" {
+ builder = builder.WithOutput(options.Output)
+ }
+
+ return builder.WithCache(true).WithCacheTTL(2 * time.Minute)
+}
+
+// HelmListOptions represents options for Helm list commands
+type HelmListOptions struct {
+ AllNamespaces bool
+ Output string
+}
+
+// IstioProxyStatus creates a command to get Istio proxy status
+func IstioProxyStatus(podName, namespace string) *CommandBuilder {
+ builder := IstioCtlBuilder().WithArgs("proxy-status")
+
+ if namespace != "" {
+ builder = builder.WithNamespace(namespace)
+ }
+
+ if podName != "" {
+ builder = builder.WithArgs(podName)
+ }
+
+ return builder.WithCache(true).WithCacheTTL(30 * time.Second)
+}
+
+// CiliumStatus creates a command to get Cilium status
+func CiliumStatus() *CommandBuilder {
+ return CiliumBuilder().WithArgs("status").WithCache(true).WithCacheTTL(30 * time.Second)
+}
+
+// ArgoRolloutsGet creates a command to get Argo rollouts
+func ArgoRolloutsGet(rolloutName, namespace string) *CommandBuilder {
+ builder := ArgoRolloutsBuilder().WithArgs("get", "rollout")
+
+ if rolloutName != "" {
+ builder = builder.WithArgs(rolloutName)
+ }
+
+ if namespace != "" {
+ builder = builder.WithNamespace(namespace)
+ }
+
+ return builder.WithCache(true).WithCacheTTL(1 * time.Minute)
+}
diff --git a/internal/commands/builder_test.go b/internal/commands/builder_test.go
new file mode 100644
index 0000000..e326a76
--- /dev/null
+++ b/internal/commands/builder_test.go
@@ -0,0 +1,585 @@
+package commands
+
+import (
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewCommandBuilder(t *testing.T) {
+ cb := NewCommandBuilder("test-command")
+
+ assert.Equal(t, "test-command", cb.command)
+ assert.Empty(t, cb.args)
+ assert.Empty(t, cb.namespace)
+ assert.Empty(t, cb.context)
+ assert.Empty(t, cb.kubeconfig)
+ assert.Empty(t, cb.output)
+ assert.NotNil(t, cb.labels)
+ assert.NotNil(t, cb.annotations)
+ assert.Equal(t, 30*time.Second, cb.timeout)
+ assert.Equal(t, 5*time.Minute, cb.cacheTTL)
+ assert.True(t, cb.validate)
+ assert.False(t, cb.cached)
+ assert.False(t, cb.dryRun)
+ assert.False(t, cb.force)
+ assert.False(t, cb.wait)
+}
+
+func TestCommandBuilderFactories(t *testing.T) {
+ tests := []struct {
+ name string
+ factory func() *CommandBuilder
+ expected string
+ }{
+ {"kubectl", KubectlBuilder, "kubectl"},
+ {"helm", HelmBuilder, "helm"},
+ {"istioctl", IstioCtlBuilder, "istioctl"},
+ {"cilium", CiliumBuilder, "cilium"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cb := tt.factory()
+ assert.Equal(t, tt.expected, cb.command)
+ })
+ }
+}
+
+func TestArgoRolloutsBuilder(t *testing.T) {
+ cb := ArgoRolloutsBuilder()
+
+ assert.Equal(t, "kubectl", cb.command)
+ assert.Equal(t, []string{"argo", "rollouts"}, cb.args)
+}
+
+func TestCommandBuilderWithArgs(t *testing.T) {
+ cb := NewCommandBuilder("test").WithArgs("arg1", "arg2")
+
+ assert.Equal(t, []string{"arg1", "arg2"}, cb.args)
+
+ // Test chaining
+ cb.WithArgs("arg3")
+ assert.Equal(t, []string{"arg1", "arg2", "arg3"}, cb.args)
+}
+
+func TestCommandBuilderWithNamespace(t *testing.T) {
+ cb := NewCommandBuilder("test").WithNamespace("default")
+
+ assert.Equal(t, "default", cb.namespace)
+
+ // Test invalid namespace - should not set the namespace
+ cb.WithNamespace("invalid..namespace")
+ assert.Equal(t, "default", cb.namespace) // Should remain unchanged
+}
+
+func TestCommandBuilderWithContext(t *testing.T) {
+ cb := NewCommandBuilder("test").WithContext("minikube")
+
+ assert.Equal(t, "minikube", cb.context)
+}
+
+func TestCommandBuilderWithKubeconfig(t *testing.T) {
+ cb := NewCommandBuilder("test").WithKubeconfig("/path/to/config")
+
+ assert.Equal(t, "/path/to/config", cb.kubeconfig)
+}
+
+func TestCommandBuilderWithOutput(t *testing.T) {
+ validOutputs := []string{"json", "yaml", "wide", "name"}
+
+ for _, output := range validOutputs {
+ cb := NewCommandBuilder("test").WithOutput(output)
+ assert.Equal(t, output, cb.output)
+ }
+
+ // Test invalid output
+ cb := NewCommandBuilder("test").WithOutput("invalid")
+ assert.Empty(t, cb.output)
+}
+
+func TestCommandBuilderWithLabel(t *testing.T) {
+ cb := NewCommandBuilder("test").WithLabel("app", "web")
+
+ assert.Equal(t, "web", cb.labels["app"])
+}
+
+func TestCommandBuilderWithLabels(t *testing.T) {
+ labels := map[string]string{
+ "app": "web",
+ "version": "v1.0.0",
+ }
+
+ cb := NewCommandBuilder("test").WithLabels(labels)
+
+ assert.Equal(t, labels["app"], cb.labels["app"])
+ assert.Equal(t, labels["version"], cb.labels["version"])
+}
+
+func TestCommandBuilderWithAnnotation(t *testing.T) {
+ cb := NewCommandBuilder("test").WithAnnotation("simple-key", "value")
+
+ // The annotation should be accepted if it's a valid format
+ assert.Equal(t, "value", cb.annotations["simple-key"])
+
+ // Test with invalid annotation - still gets added but logs an error
+ cb2 := NewCommandBuilder("test").WithAnnotation("invalid..key", "value")
+ assert.Equal(t, "value", cb2.annotations["invalid..key"]) // Invalid annotations are still added but logged
+}
+
+func TestCommandBuilderWithTimeout(t *testing.T) {
+ timeout := 60 * time.Second
+ cb := NewCommandBuilder("test").WithTimeout(timeout)
+
+ assert.Equal(t, timeout, cb.timeout)
+}
+
+func TestCommandBuilderWithFlags(t *testing.T) {
+ cb := NewCommandBuilder("test").
+ WithDryRun(true).
+ WithForce(true).
+ WithWait(true).
+ WithValidation(false)
+
+ assert.True(t, cb.dryRun)
+ assert.True(t, cb.force)
+ assert.True(t, cb.wait)
+ assert.False(t, cb.validate)
+}
+
+func TestCommandBuilderWithCache(t *testing.T) {
+ cb := NewCommandBuilder("test").WithCache(true)
+
+ assert.True(t, cb.cached)
+}
+
+func TestCommandBuilderWithCacheTTL(t *testing.T) {
+ ttl := 10 * time.Minute
+ cb := NewCommandBuilder("test").WithCacheTTL(ttl)
+
+ assert.Equal(t, ttl, cb.cacheTTL)
+}
+
+func TestCommandBuilderWithCacheKey(t *testing.T) {
+ cb := NewCommandBuilder("test").WithCacheKey("custom-key")
+
+ assert.Equal(t, "custom-key", cb.cacheKey)
+}
+
+func TestCommandBuilderBuild(t *testing.T) {
+ cb := NewCommandBuilder("kubectl").
+ WithArgs("get", "pods").
+ WithNamespace("default").
+ WithContext("minikube").
+ WithKubeconfig("/path/to/config").
+ WithOutput("json").
+ WithLabel("app", "web").
+ WithDryRun(true).
+ WithForce(true).
+ WithWait(true).
+ WithValidation(false)
+
+ command, args, err := cb.Build()
+ require.NoError(t, err)
+
+ assert.Equal(t, "kubectl", command)
+ assert.Contains(t, args, "get")
+ assert.Contains(t, args, "pods")
+ assert.Contains(t, args, "--namespace")
+ assert.Contains(t, args, "default")
+ assert.Contains(t, args, "--context")
+ assert.Contains(t, args, "minikube")
+ assert.Contains(t, args, "--kubeconfig")
+ assert.Contains(t, args, "/path/to/config")
+ assert.Contains(t, args, "--output")
+ assert.Contains(t, args, "json")
+ assert.Contains(t, args, "--selector")
+ assert.Contains(t, args, "app=web")
+ assert.Contains(t, args, "--dry-run=client")
+ assert.Contains(t, args, "--force")
+ assert.Contains(t, args, "--wait")
+ assert.Contains(t, args, "--validate=false")
+}
+
+func TestCommandBuilderBuildWithTimeout(t *testing.T) {
+ cb := NewCommandBuilder("kubectl").
+ WithArgs("delete", "pod", "test-pod").
+ WithTimeout(45 * time.Second)
+
+ command, args, err := cb.Build()
+ require.NoError(t, err)
+
+ assert.Equal(t, "kubectl", command)
+ assert.Contains(t, args, "--timeout")
+ assert.Contains(t, args, "45s")
+}
+
+func TestCommandBuilderBuildWithMultipleLabels(t *testing.T) {
+ cb := NewCommandBuilder("kubectl").
+ WithArgs("get", "pods").
+ WithLabel("app", "web").
+ WithLabel("version", "v1.0.0")
+
+ command, args, err := cb.Build()
+ require.NoError(t, err)
+
+ assert.Equal(t, "kubectl", command)
+ assert.Contains(t, args, "--selector")
+
+ // Find the selector argument
+ var selectorValue string
+ for i, arg := range args {
+ if arg == "--selector" && i+1 < len(args) {
+ selectorValue = args[i+1]
+ break
+ }
+ }
+
+ assert.Contains(t, selectorValue, "app=web")
+ assert.Contains(t, selectorValue, "version=v1.0.0")
+}
+
+func TestGetPods(t *testing.T) {
+ namespace := "default"
+ labels := map[string]string{"app": "web"}
+
+ cb := GetPods(namespace, labels)
+
+ assert.Equal(t, "kubectl", cb.command)
+ assert.Contains(t, cb.args, "get")
+ assert.Contains(t, cb.args, "pods")
+ assert.Equal(t, namespace, cb.namespace)
+ assert.Equal(t, labels, cb.labels)
+ assert.True(t, cb.cached)
+ assert.Equal(t, "json", cb.output)
+}
+
+func TestGetServices(t *testing.T) {
+ namespace := "default"
+ labels := map[string]string{"app": "web"}
+
+ cb := GetServices(namespace, labels)
+
+ assert.Equal(t, "kubectl", cb.command)
+ assert.Contains(t, cb.args, "get")
+ assert.Contains(t, cb.args, "services")
+ assert.Equal(t, namespace, cb.namespace)
+ assert.Equal(t, labels, cb.labels)
+ assert.True(t, cb.cached)
+ assert.Equal(t, "json", cb.output)
+}
+
+func TestGetDeployments(t *testing.T) {
+ namespace := "default"
+ labels := map[string]string{"app": "web"}
+
+ cb := GetDeployments(namespace, labels)
+
+ assert.Equal(t, "kubectl", cb.command)
+ assert.Contains(t, cb.args, "get")
+ assert.Contains(t, cb.args, "deployments")
+ assert.Equal(t, namespace, cb.namespace)
+ assert.Equal(t, labels, cb.labels)
+ assert.True(t, cb.cached)
+ assert.Equal(t, "json", cb.output)
+}
+
+func TestDescribeResource(t *testing.T) {
+ resourceType := "pod"
+ resourceName := "test-pod"
+ namespace := "default"
+
+ cb := DescribeResource(resourceType, resourceName, namespace)
+
+ assert.Equal(t, "kubectl", cb.command)
+ assert.Contains(t, cb.args, "describe")
+ assert.Contains(t, cb.args, resourceType)
+ assert.Contains(t, cb.args, resourceName)
+ assert.Equal(t, namespace, cb.namespace)
+ assert.True(t, cb.cached)
+ assert.Equal(t, 2*time.Minute, cb.cacheTTL)
+}
+
+func TestGetLogs(t *testing.T) {
+ podName := "test-pod"
+ namespace := "default"
+ options := LogOptions{
+ Container: "app",
+ Follow: true,
+ Previous: false,
+ Timestamps: true,
+ TailLines: 100,
+ SinceTime: "2023-01-01T00:00:00Z",
+ SinceDuration: "1h",
+ }
+
+ cb := GetLogs(podName, namespace, options)
+
+ assert.Equal(t, "kubectl", cb.command)
+ assert.Contains(t, cb.args, "logs")
+ assert.Contains(t, cb.args, podName)
+ assert.Equal(t, namespace, cb.namespace)
+ assert.Contains(t, cb.args, "--container")
+ assert.Contains(t, cb.args, "app")
+ assert.Contains(t, cb.args, "--follow")
+ assert.Contains(t, cb.args, "--timestamps")
+ assert.Contains(t, cb.args, "--tail")
+ assert.Contains(t, cb.args, "100")
+ assert.Contains(t, cb.args, "--since-time")
+ assert.Contains(t, cb.args, "2023-01-01T00:00:00Z")
+ assert.Contains(t, cb.args, "--since")
+ assert.Contains(t, cb.args, "1h")
+ assert.False(t, cb.cached)
+}
+
+func TestGetLogsWithPrevious(t *testing.T) {
+ podName := "test-pod"
+ namespace := "default"
+ options := LogOptions{
+ Previous: true,
+ }
+
+ cb := GetLogs(podName, namespace, options)
+
+ assert.Contains(t, cb.args, "--previous")
+}
+
+func TestApplyResource(t *testing.T) {
+ filename := "/path/to/resource.yaml"
+ namespace := "default"
+ options := ApplyOptions{
+ DryRun: true,
+ Force: true,
+ Wait: true,
+ Validate: false,
+ }
+
+ cb := ApplyResource(filename, namespace, options)
+
+ assert.Equal(t, "kubectl", cb.command)
+ assert.Contains(t, cb.args, "apply")
+ assert.Contains(t, cb.args, "-f")
+ assert.Contains(t, cb.args, filename)
+ assert.Equal(t, namespace, cb.namespace)
+ assert.True(t, cb.dryRun)
+ assert.True(t, cb.force)
+ assert.True(t, cb.wait)
+ assert.False(t, cb.validate)
+ assert.False(t, cb.cached)
+}
+
+func TestDeleteResource(t *testing.T) {
+ resourceType := "pod"
+ resourceName := "test-pod"
+ namespace := "default"
+ options := DeleteOptions{
+ Force: true,
+ GracePeriod: 30,
+ Wait: true,
+ }
+
+ cb := DeleteResource(resourceType, resourceName, namespace, options)
+
+ assert.Equal(t, "kubectl", cb.command)
+ assert.Contains(t, cb.args, "delete")
+ assert.Contains(t, cb.args, resourceType)
+ assert.Contains(t, cb.args, resourceName)
+ assert.Equal(t, namespace, cb.namespace)
+ assert.True(t, cb.force)
+ assert.True(t, cb.wait)
+ assert.False(t, cb.cached)
+}
+
+func TestHelmInstall(t *testing.T) {
+ releaseName := "test-release"
+ chart := "bitnami/nginx"
+ namespace := "default"
+ options := HelmInstallOptions{
+ CreateNamespace: true,
+ DryRun: true,
+ Wait: true,
+ ValuesFile: "/path/to/values.yaml",
+ SetValues: map[string]string{"image.tag": "1.20"},
+ }
+
+ cb := HelmInstall(releaseName, chart, namespace, options)
+
+ assert.Equal(t, "helm", cb.command)
+ assert.Contains(t, cb.args, "install")
+ assert.Contains(t, cb.args, releaseName)
+ assert.Contains(t, cb.args, chart)
+ assert.Equal(t, namespace, cb.namespace)
+ assert.True(t, cb.dryRun)
+ assert.True(t, cb.wait)
+ assert.False(t, cb.cached)
+}
+
+func TestHelmList(t *testing.T) {
+ namespace := "default"
+ options := HelmListOptions{
+ AllNamespaces: true,
+ Output: "json",
+ }
+
+ cb := HelmList(namespace, options)
+
+ assert.Equal(t, "helm", cb.command)
+ assert.Contains(t, cb.args, "list")
+ assert.Equal(t, namespace, cb.namespace)
+ assert.Equal(t, "json", cb.output)
+ assert.True(t, cb.cached)
+}
+
+func TestIstioProxyStatus(t *testing.T) {
+ podName := "test-pod"
+ namespace := "default"
+
+ cb := IstioProxyStatus(podName, namespace)
+
+ assert.Equal(t, "istioctl", cb.command)
+ assert.Contains(t, cb.args, "proxy-status")
+ assert.Contains(t, cb.args, podName)
+ assert.Equal(t, namespace, cb.namespace)
+ assert.True(t, cb.cached)
+}
+
+func TestCiliumStatus(t *testing.T) {
+ cb := CiliumStatus()
+
+ assert.Equal(t, "cilium", cb.command)
+ assert.Contains(t, cb.args, "status")
+ assert.Empty(t, cb.output) // CiliumStatus doesn't set output format
+ assert.True(t, cb.cached)
+}
+
+func TestArgoRolloutsGet(t *testing.T) {
+ rolloutName := "test-rollout"
+ namespace := "default"
+
+ cb := ArgoRolloutsGet(rolloutName, namespace)
+
+ assert.Equal(t, "kubectl", cb.command)
+ assert.Contains(t, cb.args, "argo")
+ assert.Contains(t, cb.args, "rollouts")
+ assert.Contains(t, cb.args, "get")
+ assert.Contains(t, cb.args, "rollout")
+ assert.Contains(t, cb.args, rolloutName)
+ assert.Equal(t, namespace, cb.namespace)
+ assert.Empty(t, cb.output) // ArgoRolloutsGet doesn't set output format
+ assert.True(t, cb.cached)
+}
+
+func TestCommandBuilderChaining(t *testing.T) {
+ cb := NewCommandBuilder("kubectl").
+ WithArgs("get", "pods").
+ WithNamespace("default").
+ WithOutput("json").
+ WithLabel("app", "web").
+ WithTimeout(60 * time.Second).
+ WithCache(true).
+ WithCacheTTL(10 * time.Minute)
+
+ assert.Equal(t, "kubectl", cb.command)
+ assert.Equal(t, []string{"get", "pods"}, cb.args)
+ assert.Equal(t, "default", cb.namespace)
+ assert.Equal(t, "json", cb.output)
+ assert.Equal(t, "web", cb.labels["app"])
+ assert.Equal(t, 60*time.Second, cb.timeout)
+ assert.True(t, cb.cached)
+ assert.Equal(t, 10*time.Minute, cb.cacheTTL)
+}
+
+func TestCommandBuilderEmptyNamespace(t *testing.T) {
+ cb := GetPods("", nil)
+
+ assert.Empty(t, cb.namespace)
+}
+
+func TestCommandBuilderEmptyLabels(t *testing.T) {
+ cb := GetPods("default", nil)
+
+ assert.Empty(t, cb.labels)
+}
+
+func TestLogOptionsDefaults(t *testing.T) {
+ options := LogOptions{}
+
+ assert.False(t, options.Follow)
+ assert.False(t, options.Previous)
+ assert.False(t, options.Timestamps)
+ assert.Equal(t, 0, options.TailLines)
+ assert.Empty(t, options.SinceTime)
+ assert.Empty(t, options.SinceDuration)
+}
+
+func TestApplyOptionsDefaults(t *testing.T) {
+ options := ApplyOptions{}
+
+ assert.False(t, options.DryRun)
+ assert.False(t, options.Force)
+ assert.False(t, options.Wait)
+ assert.False(t, options.Validate)
+}
+
+func TestDeleteOptionsDefaults(t *testing.T) {
+ options := DeleteOptions{}
+
+ assert.False(t, options.Force)
+ assert.Equal(t, 0, options.GracePeriod)
+ assert.False(t, options.Wait)
+}
+
+func TestHelmInstallOptionsDefaults(t *testing.T) {
+ options := HelmInstallOptions{}
+
+ assert.False(t, options.CreateNamespace)
+ assert.False(t, options.DryRun)
+ assert.False(t, options.Wait)
+ assert.Empty(t, options.ValuesFile)
+ assert.Nil(t, options.SetValues)
+}
+
+func TestHelmListOptionsDefaults(t *testing.T) {
+ options := HelmListOptions{}
+
+ assert.False(t, options.AllNamespaces)
+ assert.Empty(t, options.Output)
+}
+
+// Mock tests for Execute method - these would need a mock for utils.RunCommandWithContext
+func TestCommandBuilderExecuteWithoutCache(t *testing.T) {
+ cb := NewCommandBuilder("echo").
+ WithArgs("hello", "world").
+ WithCache(false)
+
+ // This test would need mocking to work properly
+ // For now, we'll just verify the command building part
+ command, args, err := cb.Build()
+ require.NoError(t, err)
+
+ assert.Equal(t, "echo", command)
+ assert.Contains(t, args, "hello")
+ assert.Contains(t, args, "world")
+ assert.Contains(t, args, "--timeout")
+ assert.Contains(t, args, "30s")
+}
+
+func TestCommandBuilderExecuteWithCache(t *testing.T) {
+ cb := NewCommandBuilder("echo").
+ WithArgs("hello", "world").
+ WithCache(true)
+
+ // This test would need mocking to work properly
+ // For now, we'll just verify the command building part
+ command, args, err := cb.Build()
+ require.NoError(t, err)
+
+ assert.Equal(t, "echo", command)
+ assert.Contains(t, args, "hello")
+ assert.Contains(t, args, "world")
+ assert.Contains(t, args, "--timeout")
+ assert.Contains(t, args, "30s")
+ assert.True(t, cb.cached)
+}
diff --git a/internal/errors/tool_errors.go b/internal/errors/tool_errors.go
new file mode 100644
index 0000000..2677164
--- /dev/null
+++ b/internal/errors/tool_errors.go
@@ -0,0 +1,352 @@
+package errors
+
+import (
+ "fmt"
+ "strings"
+ "time"
+
+ "github.com/mark3labs/mcp-go/mcp"
+)
+
+// ToolError represents a structured error with context and recovery suggestions
+type ToolError struct {
+ Operation string `json:"operation"`
+ Cause error `json:"cause"`
+ Suggestions []string `json:"suggestions"`
+ IsRetryable bool `json:"is_retryable"`
+ Timestamp time.Time `json:"timestamp"`
+ ErrorCode string `json:"error_code"`
+ Component string `json:"component"`
+ ResourceType string `json:"resource_type,omitempty"`
+ ResourceName string `json:"resource_name,omitempty"`
+ Context map[string]interface{} `json:"context,omitempty"`
+}
+
+// Error implements the error interface
+func (e *ToolError) Error() string {
+ return fmt.Sprintf("[%s] %s failed: %v", e.Component, e.Operation, e.Cause)
+}
+
+// ToMCPResult converts the error to an MCP result with rich context
+func (e *ToolError) ToMCPResult() *mcp.CallToolResult {
+ var message strings.Builder
+
+ // Format the error message with context
+ message.WriteString(fmt.Sprintf("❌ **%s Error**\n\n", e.Component))
+ message.WriteString(fmt.Sprintf("**Operation**: %s\n", e.Operation))
+ message.WriteString(fmt.Sprintf("**Error**: %s\n", e.Cause.Error()))
+
+ if e.ResourceType != "" {
+ message.WriteString(fmt.Sprintf("**Resource Type**: %s\n", e.ResourceType))
+ }
+
+ if e.ResourceName != "" {
+ message.WriteString(fmt.Sprintf("**Resource Name**: %s\n", e.ResourceName))
+ }
+
+ message.WriteString(fmt.Sprintf("**Error Code**: %s\n", e.ErrorCode))
+ message.WriteString(fmt.Sprintf("**Timestamp**: %s\n", e.Timestamp.Format(time.RFC3339)))
+
+ if e.IsRetryable {
+ message.WriteString("**Retryable**: Yes\n")
+ } else {
+ message.WriteString("**Retryable**: No\n")
+ }
+
+ if len(e.Suggestions) > 0 {
+ message.WriteString("\n**💡 Suggestions**:\n")
+ for i, suggestion := range e.Suggestions {
+ message.WriteString(fmt.Sprintf("%d. %s\n", i+1, suggestion))
+ }
+ }
+
+ if len(e.Context) > 0 {
+ message.WriteString("\n**📋 Context**:\n")
+ for key, value := range e.Context {
+ message.WriteString(fmt.Sprintf("- %s: %v\n", key, value))
+ }
+ }
+
+ return mcp.NewToolResultError(message.String())
+}
+
+// NewToolError creates a new structured tool error
+func NewToolError(component, operation string, cause error) *ToolError {
+ return &ToolError{
+ Operation: operation,
+ Cause: cause,
+ Suggestions: []string{},
+ IsRetryable: false,
+ Timestamp: time.Now(),
+ ErrorCode: "UNKNOWN",
+ Component: component,
+ Context: make(map[string]interface{}),
+ }
+}
+
+// WithSuggestions adds recovery suggestions to the error
+func (e *ToolError) WithSuggestions(suggestions ...string) *ToolError {
+ e.Suggestions = append(e.Suggestions, suggestions...)
+ return e
+}
+
+// WithRetryable sets whether the error is retryable
+func (e *ToolError) WithRetryable(retryable bool) *ToolError {
+ e.IsRetryable = retryable
+ return e
+}
+
+// WithErrorCode sets the error code
+func (e *ToolError) WithErrorCode(code string) *ToolError {
+ e.ErrorCode = code
+ return e
+}
+
+// WithResource adds resource information to the error
+func (e *ToolError) WithResource(resourceType, resourceName string) *ToolError {
+ e.ResourceType = resourceType
+ e.ResourceName = resourceName
+ return e
+}
+
+// WithContext adds contextual information to the error
+func (e *ToolError) WithContext(key string, value interface{}) *ToolError {
+ e.Context[key] = value
+ return e
+}
+
+// Common error creators for different components
+
+// NewKubernetesError creates a Kubernetes-specific error
+func NewKubernetesError(operation string, cause error) *ToolError {
+ err := NewToolError("Kubernetes", operation, cause)
+
+ // Add Kubernetes-specific suggestions based on common errors
+ if strings.Contains(cause.Error(), "connection refused") {
+ err = err.WithSuggestions(
+ "Check if the Kubernetes cluster is running",
+ "Verify your kubeconfig is correct",
+ "Ensure network connectivity to the cluster",
+ ).WithRetryable(true).WithErrorCode("K8S_CONNECTION_ERROR")
+ } else if strings.Contains(cause.Error(), "forbidden") {
+ err = err.WithSuggestions(
+ "Check your RBAC permissions",
+ "Verify your service account has the required permissions",
+ "Contact your cluster administrator",
+ ).WithRetryable(false).WithErrorCode("K8S_PERMISSION_ERROR")
+ } else if strings.Contains(cause.Error(), "not found") {
+ err = err.WithSuggestions(
+ "Check if the resource exists",
+ "Verify the resource name and namespace",
+ "List available resources to confirm",
+ ).WithRetryable(false).WithErrorCode("K8S_RESOURCE_NOT_FOUND")
+ } else if strings.Contains(cause.Error(), "already exists") {
+ err = err.WithSuggestions(
+ "Use a different name for the resource",
+ "Delete the existing resource first",
+ "Use 'kubectl apply' instead of 'kubectl create'",
+ ).WithRetryable(false).WithErrorCode("K8S_RESOURCE_EXISTS")
+ } else {
+ err = err.WithSuggestions(
+ "Check the kubectl command syntax",
+ "Verify your kubeconfig is valid",
+ "Check cluster connectivity",
+ ).WithRetryable(true).WithErrorCode("K8S_GENERIC_ERROR")
+ }
+
+ return err
+}
+
+// NewHelmError creates a Helm-specific error
+func NewHelmError(operation string, cause error) *ToolError {
+ err := NewToolError("Helm", operation, cause)
+
+ if strings.Contains(cause.Error(), "not found") {
+ err = err.WithSuggestions(
+ "Check if the Helm release exists",
+ "Verify the release name and namespace",
+ "Use 'helm list' to see available releases",
+ ).WithRetryable(false).WithErrorCode("HELM_RELEASE_NOT_FOUND")
+ } else if strings.Contains(cause.Error(), "already exists") {
+ err = err.WithSuggestions(
+ "Use a different release name",
+ "Upgrade the existing release instead",
+ "Uninstall the existing release first",
+ ).WithRetryable(false).WithErrorCode("HELM_RELEASE_EXISTS")
+ } else if strings.Contains(cause.Error(), "repository") {
+ err = err.WithSuggestions(
+ "Add the required Helm repository",
+ "Update your Helm repositories",
+ "Check repository URL and credentials",
+ ).WithRetryable(true).WithErrorCode("HELM_REPOSITORY_ERROR")
+ } else {
+ err = err.WithSuggestions(
+ "Check the Helm command syntax",
+ "Verify your kubeconfig is valid",
+ "Ensure Helm is properly installed",
+ ).WithRetryable(true).WithErrorCode("HELM_GENERIC_ERROR")
+ }
+
+ return err
+}
+
+// NewIstioError creates an Istio-specific error
+func NewIstioError(operation string, cause error) *ToolError {
+ err := NewToolError("Istio", operation, cause)
+
+ if strings.Contains(cause.Error(), "not found") {
+ err = err.WithSuggestions(
+ "Check if Istio is installed in the cluster",
+ "Verify the pod/service name and namespace",
+ "Ensure Istio sidecar is injected",
+ ).WithRetryable(false).WithErrorCode("ISTIO_RESOURCE_NOT_FOUND")
+ } else if strings.Contains(cause.Error(), "connection refused") {
+ err = err.WithSuggestions(
+ "Check if Istio control plane is running",
+ "Verify Istio proxy is healthy",
+ "Check network policies",
+ ).WithRetryable(true).WithErrorCode("ISTIO_CONNECTION_ERROR")
+ } else {
+ err = err.WithSuggestions(
+ "Check istioctl command syntax",
+ "Verify Istio installation",
+ "Check Istio proxy status",
+ ).WithRetryable(true).WithErrorCode("ISTIO_GENERIC_ERROR")
+ }
+
+ return err
+}
+
+// NewPrometheusError creates a Prometheus-specific error
+func NewPrometheusError(operation string, cause error) *ToolError {
+ err := NewToolError("Prometheus", operation, cause)
+
+ if strings.Contains(cause.Error(), "connection refused") {
+ err = err.WithSuggestions(
+ "Check if Prometheus server is running",
+ "Verify the Prometheus URL",
+ "Check network connectivity",
+ ).WithRetryable(true).WithErrorCode("PROMETHEUS_CONNECTION_ERROR")
+ } else if strings.Contains(cause.Error(), "parse error") {
+ err = err.WithSuggestions(
+ "Check your PromQL query syntax",
+ "Verify metric names and labels",
+ "Test the query in Prometheus UI",
+ ).WithRetryable(false).WithErrorCode("PROMETHEUS_QUERY_ERROR")
+ } else {
+ err = err.WithSuggestions(
+ "Check Prometheus server status",
+ "Verify the query format",
+ "Check authentication if required",
+ ).WithRetryable(true).WithErrorCode("PROMETHEUS_GENERIC_ERROR")
+ }
+
+ return err
+}
+
+// NewArgoError creates an Argo-specific error
+func NewArgoError(operation string, cause error) *ToolError {
+ err := NewToolError("Argo Rollouts", operation, cause)
+
+ if strings.Contains(cause.Error(), "not found") {
+ err = err.WithSuggestions(
+ "Check if Argo Rollouts is installed",
+ "Verify the rollout name and namespace",
+ "Use 'kubectl get rollouts' to list available rollouts",
+ ).WithRetryable(false).WithErrorCode("ARGO_ROLLOUT_NOT_FOUND")
+ } else if strings.Contains(cause.Error(), "plugin") {
+ err = err.WithSuggestions(
+ "Install the kubectl argo rollouts plugin",
+ "Check plugin version compatibility",
+ "Verify plugin installation path",
+ ).WithRetryable(true).WithErrorCode("ARGO_PLUGIN_ERROR")
+ } else {
+ err = err.WithSuggestions(
+ "Check Argo Rollouts installation",
+ "Verify the command syntax",
+ "Check RBAC permissions",
+ ).WithRetryable(true).WithErrorCode("ARGO_GENERIC_ERROR")
+ }
+
+ return err
+}
+
+// NewCiliumError creates a Cilium-specific error
+func NewCiliumError(operation string, cause error) *ToolError {
+ err := NewToolError("Cilium", operation, cause)
+
+ if strings.Contains(cause.Error(), "not found") {
+ err = err.WithSuggestions(
+ "Check if Cilium is installed",
+ "Verify the cilium CLI is installed",
+ "Check Cilium agent status",
+ ).WithRetryable(false).WithErrorCode("CILIUM_NOT_FOUND")
+ } else if strings.Contains(cause.Error(), "connection") {
+ err = err.WithSuggestions(
+ "Check Cilium agent connectivity",
+ "Verify cluster mesh configuration",
+ "Check Cilium operator status",
+ ).WithRetryable(true).WithErrorCode("CILIUM_CONNECTION_ERROR")
+ } else {
+ err = err.WithSuggestions(
+ "Check Cilium installation",
+ "Verify cilium CLI version",
+ "Check Cilium system pods",
+ ).WithRetryable(true).WithErrorCode("CILIUM_GENERIC_ERROR")
+ }
+
+ return err
+}
+
+// NewValidationError creates a validation error
+func NewValidationError(field, message string) *ToolError {
+ err := NewToolError("Validation", fmt.Sprintf("validate %s", field), fmt.Errorf("%s", message))
+
+ err = err.WithSuggestions(
+ "Check the input format",
+ "Refer to the documentation for valid values",
+ "Verify the parameter requirements",
+ ).WithRetryable(false).WithErrorCode("VALIDATION_ERROR")
+
+ return err
+}
+
+// NewSecurityError creates a security-related error
+func NewSecurityError(operation string, cause error) *ToolError {
+ err := NewToolError("Security", operation, cause)
+
+ err = err.WithSuggestions(
+ "Review the input for potentially dangerous content",
+ "Use only trusted input sources",
+ "Contact security team if needed",
+ ).WithRetryable(false).WithErrorCode("SECURITY_ERROR")
+
+ return err
+}
+
+// NewTimeoutError creates a timeout error
+func NewTimeoutError(operation string, timeout time.Duration) *ToolError {
+ cause := fmt.Errorf("operation timed out after %v", timeout)
+ err := NewToolError("Timeout", operation, cause)
+
+ err = err.WithSuggestions(
+ "Try the operation again",
+ "Check network connectivity",
+ "Increase timeout if possible",
+ ).WithRetryable(true).WithErrorCode("TIMEOUT_ERROR")
+
+ return err
+}
+
+// NewCommandError creates a command execution error
+func NewCommandError(command string, cause error) *ToolError {
+ err := NewToolError("Command", fmt.Sprintf("execute %s", command), cause)
+
+ err = err.WithSuggestions(
+ "Check if the command exists in PATH",
+ "Verify command syntax and arguments",
+ "Check system permissions",
+ ).WithRetryable(true).WithErrorCode("COMMAND_ERROR")
+
+ return err
+}
diff --git a/internal/errors/tool_errors_test.go b/internal/errors/tool_errors_test.go
new file mode 100644
index 0000000..bfa2f24
--- /dev/null
+++ b/internal/errors/tool_errors_test.go
@@ -0,0 +1,366 @@
+package errors
+
+import (
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestNewToolError(t *testing.T) {
+ cause := errors.New("test error")
+ err := NewToolError("TestComponent", "test operation", cause)
+
+ assert.Equal(t, "test operation", err.Operation)
+ assert.Equal(t, cause, err.Cause)
+ assert.Equal(t, "TestComponent", err.Component)
+ assert.Equal(t, "UNKNOWN", err.ErrorCode)
+ assert.False(t, err.IsRetryable)
+ assert.Empty(t, err.Suggestions)
+ assert.NotNil(t, err.Context)
+ assert.WithinDuration(t, time.Now(), err.Timestamp, time.Second)
+}
+
+func TestToolErrorError(t *testing.T) {
+ cause := errors.New("test error")
+ err := NewToolError("TestComponent", "test operation", cause)
+
+ result := err.Error()
+ expected := "[TestComponent] test operation failed: test error"
+ assert.Equal(t, expected, result)
+}
+
+func TestToolErrorWithSuggestions(t *testing.T) {
+ cause := errors.New("test error")
+ err := NewToolError("TestComponent", "test operation", cause)
+
+ err = err.WithSuggestions("suggestion 1", "suggestion 2")
+
+ assert.Equal(t, []string{"suggestion 1", "suggestion 2"}, err.Suggestions)
+
+ // Test chaining
+ err = err.WithSuggestions("suggestion 3")
+ assert.Equal(t, []string{"suggestion 1", "suggestion 2", "suggestion 3"}, err.Suggestions)
+}
+
+func TestToolErrorWithRetryable(t *testing.T) {
+ cause := errors.New("test error")
+ err := NewToolError("TestComponent", "test operation", cause)
+
+ err = err.WithRetryable(true)
+ assert.True(t, err.IsRetryable)
+
+ err = err.WithRetryable(false)
+ assert.False(t, err.IsRetryable)
+}
+
+func TestToolErrorWithErrorCode(t *testing.T) {
+ cause := errors.New("test error")
+ err := NewToolError("TestComponent", "test operation", cause)
+
+ err = err.WithErrorCode("TEST_ERROR")
+ assert.Equal(t, "TEST_ERROR", err.ErrorCode)
+}
+
+func TestToolErrorWithResource(t *testing.T) {
+ cause := errors.New("test error")
+ err := NewToolError("TestComponent", "test operation", cause)
+
+ err = err.WithResource("Pod", "test-pod")
+ assert.Equal(t, "Pod", err.ResourceType)
+ assert.Equal(t, "test-pod", err.ResourceName)
+}
+
+func TestToolErrorWithContext(t *testing.T) {
+ cause := errors.New("test error")
+ err := NewToolError("TestComponent", "test operation", cause)
+
+ err = err.WithContext("key1", "value1")
+ err = err.WithContext("key2", 42)
+
+ assert.Equal(t, "value1", err.Context["key1"])
+ assert.Equal(t, 42, err.Context["key2"])
+}
+
+func TestToolErrorToMCPResult(t *testing.T) {
+ cause := errors.New("test error")
+ err := NewToolError("TestComponent", "test operation", cause).
+ WithErrorCode("TEST_ERROR").
+ WithResource("Pod", "test-pod").
+ WithSuggestions("suggestion 1", "suggestion 2").
+ WithContext("key1", "value1").
+ WithRetryable(true)
+
+ result := err.ToMCPResult()
+
+ assert.NotNil(t, result)
+ assert.True(t, result.IsError)
+ assert.NotEmpty(t, result.Content)
+
+ // Check content (assuming it's text content)
+ if len(result.Content) > 0 {
+ content := result.Content[0]
+ // This depends on the actual MCP implementation
+ // We'll just check that it's not empty
+ assert.NotNil(t, content)
+ }
+}
+
+func TestNewKubernetesError(t *testing.T) {
+ tests := []struct {
+ name string
+ causeError string
+ expectedCode string
+ expectedRetry bool
+ expectedSuggs int
+ }{
+ {
+ name: "connection refused",
+ causeError: "connection refused",
+ expectedCode: "K8S_CONNECTION_ERROR",
+ expectedRetry: true,
+ expectedSuggs: 3,
+ },
+ {
+ name: "forbidden",
+ causeError: "forbidden",
+ expectedCode: "K8S_PERMISSION_ERROR",
+ expectedRetry: false,
+ expectedSuggs: 3,
+ },
+ {
+ name: "not found",
+ causeError: "not found",
+ expectedCode: "K8S_RESOURCE_NOT_FOUND",
+ expectedRetry: false,
+ expectedSuggs: 3,
+ },
+ {
+ name: "already exists",
+ causeError: "already exists",
+ expectedCode: "K8S_RESOURCE_EXISTS",
+ expectedRetry: false,
+ expectedSuggs: 3,
+ },
+ {
+ name: "generic error",
+ causeError: "some other error",
+ expectedCode: "K8S_GENERIC_ERROR",
+ expectedRetry: true,
+ expectedSuggs: 3,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cause := errors.New(tt.causeError)
+ err := NewKubernetesError("test operation", cause)
+
+ assert.Equal(t, "Kubernetes", err.Component)
+ assert.Equal(t, tt.expectedCode, err.ErrorCode)
+ assert.Equal(t, tt.expectedRetry, err.IsRetryable)
+ assert.Len(t, err.Suggestions, tt.expectedSuggs)
+ })
+ }
+}
+
+func TestNewHelmError(t *testing.T) {
+ tests := []struct {
+ name string
+ causeError string
+ expectedCode string
+ expectedRetry bool
+ expectedSuggs int
+ }{
+ {
+ name: "not found",
+ causeError: "not found",
+ expectedCode: "HELM_RELEASE_NOT_FOUND",
+ expectedRetry: false,
+ expectedSuggs: 3,
+ },
+ {
+ name: "already exists",
+ causeError: "already exists",
+ expectedCode: "HELM_RELEASE_EXISTS",
+ expectedRetry: false,
+ expectedSuggs: 3,
+ },
+ {
+ name: "repository error",
+ causeError: "repository error",
+ expectedCode: "HELM_REPOSITORY_ERROR",
+ expectedRetry: true,
+ expectedSuggs: 3,
+ },
+ {
+ name: "generic error",
+ causeError: "some other error",
+ expectedCode: "HELM_GENERIC_ERROR",
+ expectedRetry: true,
+ expectedSuggs: 3,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cause := errors.New(tt.causeError)
+ err := NewHelmError("test operation", cause)
+
+ assert.Equal(t, "Helm", err.Component)
+ assert.Equal(t, tt.expectedCode, err.ErrorCode)
+ assert.Equal(t, tt.expectedRetry, err.IsRetryable)
+ assert.Len(t, err.Suggestions, tt.expectedSuggs)
+ })
+ }
+}
+
+func TestNewIstioError(t *testing.T) {
+ cause := errors.New("test error")
+ err := NewIstioError("test operation", cause)
+
+ assert.Equal(t, "Istio", err.Component)
+ assert.Equal(t, "test operation", err.Operation)
+ assert.Equal(t, cause, err.Cause)
+}
+
+func TestNewPrometheusError(t *testing.T) {
+ cause := errors.New("test error")
+ err := NewPrometheusError("test operation", cause)
+
+ assert.Equal(t, "Prometheus", err.Component)
+ assert.Equal(t, "test operation", err.Operation)
+ assert.Equal(t, cause, err.Cause)
+}
+
+func TestNewArgoError(t *testing.T) {
+ cause := errors.New("test error")
+ err := NewArgoError("test operation", cause)
+
+ assert.Equal(t, "Argo Rollouts", err.Component)
+ assert.Equal(t, "test operation", err.Operation)
+ assert.Equal(t, cause, err.Cause)
+}
+
+func TestNewCiliumError(t *testing.T) {
+ cause := errors.New("test error")
+ err := NewCiliumError("test operation", cause)
+
+ assert.Equal(t, "Cilium", err.Component)
+ assert.Equal(t, "test operation", err.Operation)
+ assert.Equal(t, cause, err.Cause)
+}
+
+func TestNewValidationError(t *testing.T) {
+ err := NewValidationError("test-field", "validation failed")
+
+ assert.Equal(t, "Validation", err.Component)
+ assert.Equal(t, "validate test-field", err.Operation)
+ assert.Equal(t, "VALIDATION_ERROR", err.ErrorCode)
+ assert.False(t, err.IsRetryable)
+ assert.Contains(t, err.Cause.Error(), "validation failed")
+}
+
+func TestNewSecurityError(t *testing.T) {
+ cause := errors.New("security violation")
+ err := NewSecurityError("test operation", cause)
+
+ assert.Equal(t, "Security", err.Component)
+ assert.Equal(t, "test operation", err.Operation)
+ assert.Equal(t, cause, err.Cause)
+ assert.Equal(t, "SECURITY_ERROR", err.ErrorCode)
+ assert.False(t, err.IsRetryable)
+}
+
+func TestNewTimeoutError(t *testing.T) {
+ timeout := 30 * time.Second
+ err := NewTimeoutError("test operation", timeout)
+
+ assert.Equal(t, "Timeout", err.Component)
+ assert.Equal(t, "test operation", err.Operation)
+ assert.Equal(t, "TIMEOUT_ERROR", err.ErrorCode)
+ assert.True(t, err.IsRetryable)
+ assert.Contains(t, err.Cause.Error(), "30s")
+}
+
+func TestNewCommandError(t *testing.T) {
+ cause := errors.New("command failed")
+ err := NewCommandError("test-command", cause)
+
+ assert.Equal(t, "Command", err.Component)
+ assert.Equal(t, "execute test-command", err.Operation)
+ assert.Equal(t, cause, err.Cause)
+ assert.Equal(t, "COMMAND_ERROR", err.ErrorCode)
+ assert.True(t, err.IsRetryable)
+}
+
+func TestToolErrorChaining(t *testing.T) {
+ cause := errors.New("test error")
+ err := NewToolError("TestComponent", "test operation", cause).
+ WithErrorCode("TEST_ERROR").
+ WithResource("Pod", "test-pod").
+ WithSuggestions("suggestion 1").
+ WithContext("key1", "value1").
+ WithRetryable(true)
+
+ // Test that all methods return the same instance for chaining
+ assert.Equal(t, "TEST_ERROR", err.ErrorCode)
+ assert.Equal(t, "Pod", err.ResourceType)
+ assert.Equal(t, "test-pod", err.ResourceName)
+ assert.Equal(t, []string{"suggestion 1"}, err.Suggestions)
+ assert.Equal(t, "value1", err.Context["key1"])
+ assert.True(t, err.IsRetryable)
+}
+
+func TestToolErrorStringRepresentation(t *testing.T) {
+ cause := errors.New("test error")
+ err := NewToolError("TestComponent", "test operation", cause)
+
+ errorStr := err.Error()
+ assert.Contains(t, errorStr, "TestComponent")
+ assert.Contains(t, errorStr, "test operation")
+ assert.Contains(t, errorStr, "test error")
+ assert.Contains(t, errorStr, "failed")
+}
+
+func TestToolErrorTimestamp(t *testing.T) {
+ before := time.Now()
+ cause := errors.New("test error")
+ err := NewToolError("TestComponent", "test operation", cause)
+ after := time.Now()
+
+ assert.True(t, err.Timestamp.After(before) || err.Timestamp.Equal(before))
+ assert.True(t, err.Timestamp.Before(after) || err.Timestamp.Equal(after))
+}
+
+func TestToolErrorContextInitialization(t *testing.T) {
+ cause := errors.New("test error")
+ err := NewToolError("TestComponent", "test operation", cause)
+
+ // Context should be initialized but empty
+ assert.NotNil(t, err.Context)
+ assert.Empty(t, err.Context)
+
+ // Should be able to add to context
+ err = err.WithContext("test", "value")
+ assert.Equal(t, "value", err.Context["test"])
+}
+
+func TestMCPResultContainsExpectedFields(t *testing.T) {
+ cause := errors.New("test error")
+ err := NewToolError("TestComponent", "test operation", cause).
+ WithErrorCode("TEST_ERROR").
+ WithResource("Pod", "test-pod").
+ WithSuggestions("suggestion 1").
+ WithContext("key1", "value1").
+ WithRetryable(true)
+
+ result := err.ToMCPResult()
+
+ // The result should be an error result
+ assert.True(t, result.IsError)
+
+ // Should have content
+ assert.NotEmpty(t, result.Content)
+}
diff --git a/internal/logger/logger.go b/internal/logger/logger.go
new file mode 100644
index 0000000..041d499
--- /dev/null
+++ b/internal/logger/logger.go
@@ -0,0 +1,76 @@
+package logger
+
+import (
+ "context"
+ "log/slog"
+ "os"
+
+ "go.opentelemetry.io/otel/trace"
+)
+
+var globalLogger *slog.Logger
+
+func Init() {
+ opts := &slog.HandlerOptions{
+ Level: slog.LevelInfo,
+ }
+
+ if os.Getenv("KAGENT_LOG_FORMAT") == "json" {
+ globalLogger = slog.New(slog.NewJSONHandler(os.Stdout, opts))
+ } else {
+ globalLogger = slog.New(slog.NewTextHandler(os.Stdout, opts))
+ }
+
+ slog.SetDefault(globalLogger)
+}
+
+func Get() *slog.Logger {
+ if globalLogger == nil {
+ Init()
+ }
+ return globalLogger
+}
+
+func WithContext(ctx context.Context) *slog.Logger {
+ logger := Get()
+ span := trace.SpanFromContext(ctx)
+ if span.SpanContext().IsValid() {
+ logger = logger.With(
+ "trace_id", span.SpanContext().TraceID().String(),
+ "span_id", span.SpanContext().SpanID().String(),
+ )
+ }
+ return logger
+}
+
+func LogExecCommand(ctx context.Context, logger *slog.Logger, command string, args []string, caller string) {
+ logger.Info("executing command",
+ "command", command,
+ "args", args,
+ "caller", caller,
+ )
+}
+
+func LogExecCommandResult(ctx context.Context, logger *slog.Logger, command string, args []string, output string, err error, duration float64, caller string) {
+ if err != nil {
+ logger.Error("command execution failed",
+ "command", command,
+ "args", args,
+ "error", err.Error(),
+ "duration_seconds", duration,
+ "caller", caller,
+ )
+ } else {
+ logger.Info("command execution successful",
+ "command", command,
+ "args", args,
+ "output", output,
+ "duration_seconds", duration,
+ "caller", caller,
+ )
+ }
+}
+
+func Sync() {
+ // No-op for slog, but kept for compatibility
+}
diff --git a/internal/logger/logger_test.go b/internal/logger/logger_test.go
new file mode 100644
index 0000000..ad5c988
--- /dev/null
+++ b/internal/logger/logger_test.go
@@ -0,0 +1,72 @@
+package logger
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "log/slog"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.opentelemetry.io/otel/trace/noop"
+)
+
+func TestLogExecCommand(t *testing.T) {
+ var buf bytes.Buffer
+ logger := slog.New(slog.NewTextHandler(&buf, nil))
+
+ ctx := context.Background()
+ LogExecCommand(ctx, logger, "test-command", []string{"arg1", "arg2"}, "test.go:123")
+
+ output := buf.String()
+ assert.Contains(t, output, "executing command")
+ assert.Contains(t, output, "test-command")
+ assert.Contains(t, output, "arg1")
+ assert.Contains(t, output, "arg2")
+}
+
+func TestLogExecCommandResult(t *testing.T) {
+ var buf bytes.Buffer
+ logger := slog.New(slog.NewTextHandler(&buf, nil))
+
+ ctx := context.Background()
+ LogExecCommandResult(ctx, logger, "test-command", []string{"arg1"}, "success output", nil, 1.5, "test.go:123")
+ assert.Contains(t, buf.String(), "command execution successful")
+
+ buf.Reset()
+ LogExecCommandResult(ctx, logger, "test-command", []string{"arg1"}, "error output", assert.AnError, 0.5, "test.go:123")
+ assert.Contains(t, buf.String(), "command execution failed")
+}
+
+func TestWithContextAddsTraceID(t *testing.T) {
+ var buf bytes.Buffer
+ logger := slog.New(slog.NewJSONHandler(&buf, nil))
+
+ // Create a context with a mock span
+ tp := noop.NewTracerProvider()
+ ctx, span := tp.Tracer("test").Start(context.Background(), "test-span")
+ defer span.End()
+
+ loggerWithTrace := logger.With("trace_id", span.SpanContext().TraceID().String())
+ loggerWithTrace.InfoContext(ctx, "test message")
+
+ var logOutput map[string]interface{}
+ err := json.Unmarshal(buf.Bytes(), &logOutput)
+ require.NoError(t, err)
+
+ traceID := span.SpanContext().TraceID().String()
+ assert.Equal(t, traceID, logOutput["trace_id"])
+}
+
+func TestGet(t *testing.T) {
+ assert.NotNil(t, Get())
+}
+
+func TestInit(t *testing.T) {
+ assert.NotPanics(t, Init)
+}
+
+func TestSync(t *testing.T) {
+ assert.NotPanics(t, Sync)
+}
diff --git a/internal/security/validation.go b/internal/security/validation.go
new file mode 100644
index 0000000..6aadc38
--- /dev/null
+++ b/internal/security/validation.go
@@ -0,0 +1,291 @@
+package security
+
+import (
+ "fmt"
+ "regexp"
+ "strings"
+)
+
+// ValidationError represents a validation error
+type ValidationError struct {
+ Field string
+ Message string
+}
+
+func (e ValidationError) Error() string {
+ return fmt.Sprintf("validation error in field '%s': %s", e.Field, e.Message)
+}
+
+// Common validation patterns
+var (
+ // K8s resource name pattern (RFC 1123)
+ k8sNamePattern = regexp.MustCompile(`^[a-z0-9]([-a-z0-9]*[a-z0-9])?$`)
+
+ // Namespace pattern
+ namespacePattern = regexp.MustCompile(`^[a-z0-9]([-a-z0-9]*[a-z0-9])?$`)
+
+ // Container image pattern
+ imagePattern = regexp.MustCompile(`^[a-z0-9]+(([._-][a-z0-9]+)*(/[a-z0-9]+(([._-][a-z0-9]+)*)?)*)?(:([a-zA-Z0-9]([a-zA-Z0-9._-]*[a-zA-Z0-9])?))$`)
+
+ // Path pattern (no directory traversal)
+ pathPattern = regexp.MustCompile(`^[a-zA-Z0-9._/-]+$`)
+
+ // Command injection patterns to reject
+ commandInjectionPatterns = []*regexp.Regexp{
+ regexp.MustCompile(`[;&|` + "`" + `$(){}[\]\\<>*?~!#\n\r\t]`),
+ regexp.MustCompile(`\.\./`),
+ regexp.MustCompile(`\$\{`),
+ regexp.MustCompile(`\$\(`),
+ regexp.MustCompile(`\|\|`),
+ regexp.MustCompile(`&&`),
+ }
+)
+
+// ValidateK8sResourceName validates a Kubernetes resource name
+func ValidateK8sResourceName(name string) error {
+ if name == "" {
+ return ValidationError{Field: "name", Message: "cannot be empty"}
+ }
+
+ if len(name) > 63 {
+ return ValidationError{Field: "name", Message: "cannot exceed 63 characters"}
+ }
+
+ if !k8sNamePattern.MatchString(name) {
+ return ValidationError{Field: "name", Message: "must follow RFC 1123 naming convention"}
+ }
+
+ return nil
+}
+
+// ValidateNamespace validates a Kubernetes namespace
+func ValidateNamespace(namespace string) error {
+ if namespace == "" {
+ return nil // Empty namespace is allowed (defaults to 'default')
+ }
+
+ if len(namespace) > 63 {
+ return ValidationError{Field: "namespace", Message: "cannot exceed 63 characters"}
+ }
+
+ if !namespacePattern.MatchString(namespace) {
+ return ValidationError{Field: "namespace", Message: "must follow RFC 1123 naming convention"}
+ }
+
+ // Reserved namespaces
+ reserved := []string{"kube-system", "kube-public", "kube-node-lease"}
+ for _, res := range reserved {
+ if namespace == res {
+ return ValidationError{Field: "namespace", Message: fmt.Sprintf("'%s' is a reserved namespace", namespace)}
+ }
+ }
+
+ return nil
+}
+
+// ValidateContainerImage validates a container image reference
+func ValidateContainerImage(image string) error {
+ if image == "" {
+ return ValidationError{Field: "image", Message: "cannot be empty"}
+ }
+
+ if len(image) > 255 {
+ return ValidationError{Field: "image", Message: "cannot exceed 255 characters"}
+ }
+
+ if !imagePattern.MatchString(image) {
+ return ValidationError{Field: "image", Message: "invalid image format"}
+ }
+
+ return nil
+}
+
+// ValidateFilePath validates a file path for security
+func ValidateFilePath(path string) error {
+ if path == "" {
+ return ValidationError{Field: "path", Message: "cannot be empty"}
+ }
+
+ if len(path) > 4096 {
+ return ValidationError{Field: "path", Message: "path too long"}
+ }
+
+ if strings.Contains(path, "..") {
+ return ValidationError{Field: "path", Message: "path traversal not allowed"}
+ }
+
+ if !pathPattern.MatchString(path) {
+ return ValidationError{Field: "path", Message: "contains invalid characters"}
+ }
+
+ return nil
+}
+
+// ValidateCommandInput validates command inputs for injection attacks
+func ValidateCommandInput(input string) error {
+ if input == "" {
+ return ValidationError{Field: "input", Message: "cannot be empty"}
+ }
+
+ if len(input) > 1024 {
+ return ValidationError{Field: "input", Message: "input too long"}
+ }
+
+ for _, pattern := range commandInjectionPatterns {
+ if pattern.MatchString(input) {
+ return ValidationError{Field: "input", Message: "potentially dangerous characters detected"}
+ }
+ }
+
+ return nil
+}
+
+// SanitizeInput sanitizes input strings by replacing potentially dangerous characters
+func SanitizeInput(input string) string {
+ // Replace dangerous characters with safe alternatives
+ sanitized := strings.ReplaceAll(input, "\n", " ")
+ sanitized = strings.ReplaceAll(sanitized, "\r", " ")
+ sanitized = strings.ReplaceAll(sanitized, "\t", " ")
+
+ // Replace multiple spaces with single space
+ spacePattern := regexp.MustCompile(`\s+`)
+ sanitized = spacePattern.ReplaceAllString(sanitized, " ")
+
+ sanitized = strings.TrimSpace(sanitized)
+
+ return sanitized
+}
+
+// ValidateK8sLabel validates a Kubernetes label key and value
+func ValidateK8sLabel(key, value string) error {
+ if key == "" {
+ return ValidationError{Field: "label_key", Message: "cannot be empty"}
+ }
+
+ if len(key) > 63 {
+ return ValidationError{Field: "label_key", Message: "cannot exceed 63 characters"}
+ }
+
+ if len(value) > 63 {
+ return ValidationError{Field: "label_value", Message: "cannot exceed 63 characters"}
+ }
+
+ // Label key validation
+ labelKeyPattern := regexp.MustCompile(`^[a-z0-9A-Z]([a-z0-9A-Z._-]*[a-z0-9A-Z])?$`)
+ if !labelKeyPattern.MatchString(key) {
+ return ValidationError{Field: "label_key", Message: "invalid label key format"}
+ }
+
+ // Label value validation (can be empty)
+ if value != "" {
+ labelValuePattern := regexp.MustCompile(`^[a-z0-9A-Z]([a-z0-9A-Z._-]*[a-z0-9A-Z])?$`)
+ if !labelValuePattern.MatchString(value) {
+ return ValidationError{Field: "label_value", Message: "invalid label value format"}
+ }
+ }
+
+ return nil
+}
+
+// ValidatePromQLQuery validates a PromQL query for basic security
+func ValidatePromQLQuery(query string) error {
+ if query == "" {
+ return ValidationError{Field: "query", Message: "cannot be empty"}
+ }
+
+ if len(query) > 8192 {
+ return ValidationError{Field: "query", Message: "query too long"}
+ }
+
+ // Basic PromQL validation - no shell commands
+ dangerousPatterns := []string{
+ "`", "$", "$(", "${", "&&", "||", ";", "|", ">", "<", "&",
+ }
+
+ for _, pattern := range dangerousPatterns {
+ if strings.Contains(query, pattern) {
+ return ValidationError{Field: "query", Message: "potentially dangerous characters in query"}
+ }
+ }
+
+ return nil
+}
+
+// ValidateYAMLContent validates YAML content for basic security
+func ValidateYAMLContent(content string) error {
+ if content == "" {
+ return ValidationError{Field: "content", Message: "cannot be empty"}
+ }
+
+ if len(content) > 1024*1024 { // 1MB limit
+ return ValidationError{Field: "content", Message: "content too large"}
+ }
+
+ // Check for potentially dangerous YAML content
+ dangerousPatterns := []string{
+ "!!python/object/apply",
+ "!!python/object/new",
+ "!!python/object",
+ "__import__",
+ "eval(",
+ "exec(",
+ }
+
+ for _, pattern := range dangerousPatterns {
+ if strings.Contains(content, pattern) {
+ return ValidationError{Field: "content", Message: "potentially dangerous YAML content detected"}
+ }
+ }
+
+ return nil
+}
+
+// ValidateHelmReleaseName validates a Helm release name
+func ValidateHelmReleaseName(name string) error {
+ if name == "" {
+ return ValidationError{Field: "release_name", Message: "cannot be empty"}
+ }
+
+ if len(name) > 53 {
+ return ValidationError{Field: "release_name", Message: "cannot exceed 53 characters"}
+ }
+
+ // Helm release name pattern
+ helmNamePattern := regexp.MustCompile(`^[a-z0-9]([-a-z0-9]*[a-z0-9])?$`)
+ if !helmNamePattern.MatchString(name) {
+ return ValidationError{Field: "release_name", Message: "must follow DNS naming convention"}
+ }
+
+ return nil
+}
+
+// ValidateURL validates a URL for basic security
+func ValidateURL(url string) error {
+ if url == "" {
+ return ValidationError{Field: "url", Message: "cannot be empty"}
+ }
+
+ if len(url) > 2048 {
+ return ValidationError{Field: "url", Message: "URL too long"}
+ }
+
+ // Basic URL validation
+ if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") {
+ return ValidationError{Field: "url", Message: "must start with http:// or https://"}
+ }
+
+ // Check for dangerous URL patterns
+ dangerousPatterns := []string{
+ "javascript:", "data:", "file:", "ftp:",
+ "", true},
+ {"too long URL", "https://" + string(make([]byte, 3000)), true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := ValidateURL(tt.input)
+ if tt.expectError && err == nil {
+ t.Errorf("Expected error for input %q, but got none", tt.input)
+ }
+ if !tt.expectError && err != nil {
+ t.Errorf("Unexpected error for input %q: %v", tt.input, err)
+ }
+ })
+ }
+}
+
+func TestValidationError(t *testing.T) {
+ err := ValidationError{
+ Field: "test_field",
+ Message: "test message",
+ }
+
+ expected := "validation error in field 'test_field': test message"
+ if err.Error() != expected {
+ t.Errorf("Expected error message %q, got %q", expected, err.Error())
+ }
+}
diff --git a/internal/telemetry/config.go b/internal/telemetry/config.go
new file mode 100644
index 0000000..56b266e
--- /dev/null
+++ b/internal/telemetry/config.go
@@ -0,0 +1,74 @@
+package telemetry
+
+import (
+ "os"
+ "strconv"
+ "strings"
+ "sync"
+)
+
+// Telemetry holds all telemetry-related configuration.
+type Telemetry struct {
+ ServiceName string
+ ServiceVersion string
+ Environment string
+ Endpoint string
+ Protocol string
+ SamplingRatio float64
+ Insecure bool
+ Disabled bool
+}
+
+// Config holds all application configuration.
+type Config struct {
+ Telemetry Telemetry
+}
+
+var (
+ once sync.Once
+ config *Config
+)
+
+// LoadOtelCfg initializes and returns the application configuration.
+func LoadOtelCfg() *Config {
+ once.Do(func() {
+ config = &Config{
+ Telemetry: Telemetry{
+ ServiceName: getEnv("OTEL_SERVICE_NAME", "kagent-tools"),
+ ServiceVersion: getEnv("OTEL_SERVICE_VERSION", "dev"),
+ Environment: getEnv("OTEL_ENVIRONMENT", "development"),
+ Endpoint: getEnv("OTEL_EXPORTER_OTLP_ENDPOINT", ""),
+ Protocol: getEnv("OTEL_EXPORTER_OTLP_PROTOCOL", "auto"),
+ SamplingRatio: getEnvFloat("OTEL_TRACES_SAMPLER_ARG", 1.0),
+ Insecure: getEnvBool("OTEL_EXPORTER_OTLP_TRACES_INSECURE", false),
+ Disabled: getEnvBool("OTEL_SDK_DISABLED", false),
+ },
+ }
+ })
+ return config
+}
+
+func getEnv(key, fallback string) string {
+ if value, ok := os.LookupEnv(key); ok {
+ return value
+ }
+ return fallback
+}
+
+func getEnvFloat(key string, fallback float64) float64 {
+ if valueStr, ok := os.LookupEnv(key); ok {
+ if value, err := strconv.ParseFloat(valueStr, 64); err == nil {
+ return value
+ }
+ }
+ return fallback
+}
+
+func getEnvBool(key string, fallback bool) bool {
+ if valueStr, ok := os.LookupEnv(key); ok {
+ if value, err := strconv.ParseBool(strings.ToLower(valueStr)); err == nil {
+ return value
+ }
+ }
+ return fallback
+}
diff --git a/internal/telemetry/config_test.go b/internal/telemetry/config_test.go
new file mode 100644
index 0000000..fe6454b
--- /dev/null
+++ b/internal/telemetry/config_test.go
@@ -0,0 +1,49 @@
+package telemetry
+
+import (
+ "os"
+ "sync"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestLoad(t *testing.T) {
+ // Reset singleton for testing
+ once = sync.Once{}
+ config = nil
+
+ os.Setenv("OTEL_SERVICE_NAME", "test-service")
+ os.Setenv("OTEL_EXPORTER_OTLP_TRACES_INSECURE", "true")
+ defer func() {
+ os.Unsetenv("OTEL_SERVICE_NAME")
+ os.Unsetenv("OTEL_EXPORTER_OTLP_TRACES_INSECURE")
+ }()
+
+ cfg := LoadOtelCfg()
+ assert.Equal(t, "test-service", cfg.Telemetry.ServiceName)
+ assert.True(t, cfg.Telemetry.Insecure)
+}
+
+func TestLoadDefaults(t *testing.T) {
+ // Reset singleton for testing
+ once = sync.Once{}
+ config = nil
+
+ cfg := LoadOtelCfg()
+ assert.Equal(t, "kagent-tools", cfg.Telemetry.ServiceName)
+ assert.False(t, cfg.Telemetry.Insecure)
+ assert.Equal(t, 1.0, cfg.Telemetry.SamplingRatio)
+}
+
+func TestLoadDevelopmentSampling(t *testing.T) {
+ // Reset singleton for testing
+ once = sync.Once{}
+ config = nil
+
+ os.Setenv("OTEL_ENVIRONMENT", "development")
+ defer os.Unsetenv("OTEL_ENVIRONMENT")
+
+ cfg := LoadOtelCfg()
+ assert.Equal(t, 1.0, cfg.Telemetry.SamplingRatio)
+}
diff --git a/internal/telemetry/middleware.go b/internal/telemetry/middleware.go
new file mode 100644
index 0000000..720a99b
--- /dev/null
+++ b/internal/telemetry/middleware.go
@@ -0,0 +1,179 @@
+package telemetry
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "time"
+
+ "github.com/mark3labs/mcp-go/mcp"
+ "github.com/mark3labs/mcp-go/server"
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/attribute"
+ "go.opentelemetry.io/otel/codes"
+ "go.opentelemetry.io/otel/propagation"
+ "go.opentelemetry.io/otel/trace"
+)
+
+type ToolHandler func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
+
+// contextKey is used for storing HTTP context in the request context
+type contextKey string
+
+const (
+ HTTPHeadersKey contextKey = "http_headers"
+ TraceIDKey contextKey = "trace_id"
+ SpanIDKey contextKey = "span_id"
+)
+
+// HTTPMiddleware wraps an HTTP handler to extract headers and propagate context
+func HTTPMiddleware(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ ctx := r.Context()
+
+ // Extract OpenTelemetry context from HTTP headers
+ propagator := otel.GetTextMapPropagator()
+ ctx = propagator.Extract(ctx, propagation.HeaderCarrier(r.Header))
+
+ // Store relevant HTTP headers in context for tool handlers
+ headers := make(map[string]string)
+ for name, values := range r.Header {
+ if len(values) > 0 {
+ // Store important headers for debugging/tracing
+ switch name {
+ case "X-Request-ID", "X-Correlation-ID", "X-Trace-ID",
+ "User-Agent", "Authorization", "X-Forwarded-For":
+ headers[name] = values[0]
+ }
+ }
+ }
+
+ // Add headers to context
+ ctx = context.WithValue(ctx, HTTPHeadersKey, headers)
+
+ // Extract trace information if available
+ span := trace.SpanFromContext(ctx)
+ if span.SpanContext().HasTraceID() {
+ ctx = context.WithValue(ctx, TraceIDKey, span.SpanContext().TraceID().String())
+ ctx = context.WithValue(ctx, SpanIDKey, span.SpanContext().SpanID().String())
+ }
+
+ // Call next handler with enhanced context
+ next.ServeHTTP(w, r.WithContext(ctx))
+ })
+}
+
+// ExtractHTTPHeaders retrieves HTTP headers from context
+func ExtractHTTPHeaders(ctx context.Context) map[string]string {
+ if headers, ok := ctx.Value(HTTPHeadersKey).(map[string]string); ok {
+ return headers
+ }
+ return make(map[string]string)
+}
+
+// ExtractTraceInfo retrieves trace information from context
+func ExtractTraceInfo(ctx context.Context) (traceID, spanID string) {
+ if tid, ok := ctx.Value(TraceIDKey).(string); ok {
+ traceID = tid
+ }
+ if sid, ok := ctx.Value(SpanIDKey).(string); ok {
+ spanID = sid
+ }
+ return traceID, spanID
+}
+
+func WithTracing(toolName string, handler ToolHandler) ToolHandler {
+ return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ tracer := otel.Tracer("kagent-tools/mcp")
+
+ spanName := fmt.Sprintf("mcp.tool.%s", toolName)
+ ctx, span := tracer.Start(ctx, spanName)
+ defer span.End()
+
+ // Extract HTTP headers from context and add as span attributes
+ headers := ExtractHTTPHeaders(ctx)
+ for key, value := range headers {
+ span.SetAttributes(attribute.String(fmt.Sprintf("http.header.%s", key), value))
+ }
+
+ // Extract parent trace information
+ parentTraceID, parentSpanID := ExtractTraceInfo(ctx)
+ if parentTraceID != "" {
+ span.SetAttributes(
+ attribute.String("http.parent_trace_id", parentTraceID),
+ attribute.String("http.parent_span_id", parentSpanID),
+ )
+ }
+
+ span.SetAttributes(
+ attribute.String("mcp.tool.name", toolName),
+ attribute.String("mcp.request.id", request.Params.Name),
+ )
+
+ if request.Params.Arguments != nil {
+ if argsJSON, err := json.Marshal(request.Params.Arguments); err == nil {
+ span.SetAttributes(attribute.String("mcp.request.arguments", string(argsJSON)))
+ }
+ }
+
+ span.AddEvent("tool.execution.start")
+ startTime := time.Now()
+
+ result, err := handler(ctx, request)
+
+ duration := time.Since(startTime)
+ span.SetAttributes(attribute.Float64("mcp.tool.duration_seconds", duration.Seconds()))
+
+ if err != nil {
+ span.RecordError(err)
+ span.SetStatus(codes.Error, err.Error())
+ span.AddEvent("tool.execution.error", trace.WithAttributes(
+ attribute.String("error.message", err.Error()),
+ ))
+ } else {
+ span.SetStatus(codes.Ok, "tool execution completed successfully")
+ span.AddEvent("tool.execution.success")
+
+ if result != nil {
+ span.SetAttributes(attribute.Bool("mcp.result.is_error", result.IsError))
+ if result.Content != nil {
+ span.SetAttributes(attribute.Int("mcp.result.content_count", len(result.Content)))
+ }
+ }
+ }
+
+ return result, err
+ }
+}
+
+func StartSpan(ctx context.Context, operationName string, attrs ...attribute.KeyValue) (context.Context, trace.Span) {
+ tracer := otel.Tracer("kagent-tools")
+ ctx, span := tracer.Start(ctx, operationName)
+
+ if len(attrs) > 0 {
+ span.SetAttributes(attrs...)
+ }
+
+ return ctx, span
+}
+
+func RecordError(span trace.Span, err error, message string) {
+ span.RecordError(err)
+ span.SetStatus(codes.Error, message)
+}
+
+func RecordSuccess(span trace.Span, message string) {
+ span.SetStatus(codes.Ok, message)
+}
+
+func AddEvent(span trace.Span, name string, attrs ...attribute.KeyValue) {
+ span.AddEvent(name, trace.WithAttributes(attrs...))
+}
+
+// AdaptToolHandler adapts a telemetry.ToolHandler to a server.ToolHandlerFunc.
+func AdaptToolHandler(th ToolHandler) server.ToolHandlerFunc {
+ return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ return th(ctx, req)
+ }
+}
diff --git a/internal/telemetry/middleware_test.go b/internal/telemetry/middleware_test.go
new file mode 100644
index 0000000..bcbf494
--- /dev/null
+++ b/internal/telemetry/middleware_test.go
@@ -0,0 +1,801 @@
+package telemetry
+
+import (
+ "context"
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/mark3labs/mcp-go/mcp"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/attribute"
+ "go.opentelemetry.io/otel/codes"
+ "go.opentelemetry.io/otel/sdk/trace"
+ "go.opentelemetry.io/otel/trace/noop"
+)
+
+// InMemoryExporter is a simple in-memory exporter for testing
+type InMemoryExporter struct {
+ spans []trace.ReadOnlySpan
+}
+
+func (e *InMemoryExporter) ExportSpans(ctx context.Context, spans []trace.ReadOnlySpan) error {
+ e.spans = append(e.spans, spans...)
+ return nil
+}
+
+func (e *InMemoryExporter) Shutdown(ctx context.Context) error {
+ return nil
+}
+
+func (e *InMemoryExporter) GetSpans() []trace.ReadOnlySpan {
+ return e.spans
+}
+
+// setupTracing initializes OpenTelemetry with in-memory exporter for testing
+func setupTracing() (*trace.TracerProvider, *InMemoryExporter) {
+ exporter := &InMemoryExporter{}
+ provider := trace.NewTracerProvider(
+ trace.WithSampler(trace.AlwaysSample()),
+ trace.WithSpanProcessor(trace.NewSimpleSpanProcessor(exporter)),
+ )
+ otel.SetTracerProvider(provider)
+ return provider, exporter
+}
+
+func TestWithTracing(t *testing.T) {
+ // Initialize OpenTelemetry
+ provider, exporter := setupTracing()
+ defer func() {
+ if err := provider.Shutdown(context.Background()); err != nil {
+ t.Errorf("Failed to shutdown provider: %v", err)
+ }
+ }()
+
+ // Create a test handler
+ testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ textContent := mcp.NewTextContent("test response")
+ return &mcp.CallToolResult{
+ IsError: false,
+ Content: []mcp.Content{textContent},
+ }, nil
+ }
+
+ // Wrap with tracing
+ tracedHandler := WithTracing("test-tool", testHandler)
+
+ // Create test request
+ request := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "test-tool",
+ Arguments: map[string]interface{}{
+ "param1": "value1",
+ "param2": 42,
+ },
+ },
+ }
+
+ // Execute the handler
+ result, err := tracedHandler(context.Background(), request)
+
+ // Force flush to ensure spans are exported
+ if err := provider.ForceFlush(context.Background()); err != nil {
+ t.Errorf("Failed to flush provider: %v", err)
+ }
+
+ // Verify result
+ require.NoError(t, err)
+ assert.NotNil(t, result)
+ assert.False(t, result.IsError)
+ assert.Len(t, result.Content, 1)
+ textContent, ok := mcp.AsTextContent(result.Content[0])
+ require.True(t, ok)
+ assert.Equal(t, "test response", textContent.Text)
+
+ // Verify span was created
+ spans := exporter.GetSpans()
+ assert.Len(t, spans, 1)
+
+ span := spans[0]
+ assert.Equal(t, "mcp.tool.test-tool", span.Name())
+ assert.Equal(t, codes.Ok, span.Status().Code)
+ // Note: SDK may not preserve description in test environment
+ // assert.Equal(t, "tool execution completed successfully", span.Status().Description)
+
+ // Verify attributes
+ attributes := span.Attributes()
+ hasToolName := false
+ hasRequestID := false
+ hasIsError := false
+ hasContentCount := false
+
+ for _, attr := range attributes {
+ if attr.Key == "mcp.tool.name" && attr.Value.AsString() == "test-tool" {
+ hasToolName = true
+ }
+ if attr.Key == "mcp.request.id" && attr.Value.AsString() == "test-tool" {
+ hasRequestID = true
+ }
+ if attr.Key == "mcp.result.is_error" && attr.Value.AsBool() == false {
+ hasIsError = true
+ }
+ if attr.Key == "mcp.result.content_count" && attr.Value.AsInt64() == 1 {
+ hasContentCount = true
+ }
+ }
+
+ assert.True(t, hasToolName)
+ assert.True(t, hasRequestID)
+ assert.True(t, hasIsError)
+ assert.True(t, hasContentCount)
+
+ // Verify events
+ events := span.Events()
+ assert.Len(t, events, 2)
+ assert.Equal(t, "tool.execution.start", events[0].Name)
+ assert.Equal(t, "tool.execution.success", events[1].Name)
+}
+
+func TestWithTracingError(t *testing.T) {
+ // Initialize OpenTelemetry
+ provider, exporter := setupTracing()
+ defer func() {
+ if err := provider.Shutdown(context.Background()); err != nil {
+ t.Errorf("Failed to shutdown provider: %v", err)
+ }
+ }()
+
+ // Create a test handler that returns an error
+ testError := errors.New("test error")
+ testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ return nil, testError
+ }
+
+ // Wrap with tracing
+ tracedHandler := WithTracing("test-tool", testHandler)
+
+ // Create test request
+ request := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "test-tool",
+ },
+ }
+
+ // Execute the handler
+ result, err := tracedHandler(context.Background(), request)
+
+ // Force flush to ensure spans are exported
+ if err := provider.ForceFlush(context.Background()); err != nil {
+ t.Errorf("Failed to flush provider: %v", err)
+ }
+
+ // Verify result
+ assert.Error(t, err)
+ assert.Equal(t, testError, err)
+ assert.Nil(t, result)
+
+ // Verify span was created with error
+ spans := exporter.GetSpans()
+ assert.Len(t, spans, 1)
+
+ span := spans[0]
+ assert.Equal(t, "mcp.tool.test-tool", span.Name())
+ assert.Equal(t, codes.Error, span.Status().Code)
+ // Note: SDK may not preserve description in test environment
+ // assert.Equal(t, "test error", span.Status().Description)
+
+ // Verify events - span.RecordError() adds an "exception" event, plus our custom events
+ events := span.Events()
+ assert.Len(t, events, 3)
+ assert.Equal(t, "tool.execution.start", events[0].Name)
+ assert.Equal(t, "exception", events[1].Name) // Added by span.RecordError()
+ assert.Equal(t, "tool.execution.error", events[2].Name)
+}
+
+func TestWithTracingErrorResult(t *testing.T) {
+ // Initialize OpenTelemetry
+ provider, exporter := setupTracing()
+ defer func() {
+ if err := provider.Shutdown(context.Background()); err != nil {
+ t.Errorf("Failed to shutdown provider: %v", err)
+ }
+ }()
+
+ // Create a test handler that returns an error result
+ testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ textContent := mcp.NewTextContent("error occurred")
+ return &mcp.CallToolResult{
+ IsError: true,
+ Content: []mcp.Content{textContent},
+ }, nil
+ }
+
+ // Wrap with tracing
+ tracedHandler := WithTracing("test-tool", testHandler)
+
+ // Create test request
+ request := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "test-tool",
+ },
+ }
+
+ // Execute the handler
+ result, err := tracedHandler(context.Background(), request)
+
+ // Force flush to ensure spans are exported
+ if err := provider.ForceFlush(context.Background()); err != nil {
+ t.Errorf("Failed to flush provider: %v", err)
+ }
+
+ // Verify result
+ require.NoError(t, err)
+ assert.NotNil(t, result)
+ assert.True(t, result.IsError)
+
+ // Verify span was created successfully (no error from handler)
+ spans := exporter.GetSpans()
+ assert.Len(t, spans, 1)
+
+ span := spans[0]
+ assert.Equal(t, "mcp.tool.test-tool", span.Name())
+ assert.Equal(t, codes.Ok, span.Status().Code)
+
+ // Verify attributes
+ attributes := span.Attributes()
+ hasIsError := false
+ hasContentCount := false
+
+ for _, attr := range attributes {
+ if attr.Key == "mcp.result.is_error" && attr.Value.AsBool() == true {
+ hasIsError = true
+ }
+ if attr.Key == "mcp.result.content_count" && attr.Value.AsInt64() == 1 {
+ hasContentCount = true
+ }
+ }
+
+ assert.True(t, hasIsError)
+ assert.True(t, hasContentCount)
+}
+
+func TestWithTracingWithArguments(t *testing.T) {
+ // Initialize OpenTelemetry
+ provider, exporter := setupTracing()
+ defer func() {
+ if err := provider.Shutdown(context.Background()); err != nil {
+ t.Errorf("Failed to shutdown provider: %v", err)
+ }
+ }()
+
+ // Create a test handler
+ testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ textContent := mcp.NewTextContent("test response")
+ return &mcp.CallToolResult{
+ IsError: false,
+ Content: []mcp.Content{textContent},
+ }, nil
+ }
+
+ // Wrap with tracing
+ tracedHandler := WithTracing("test-tool", testHandler)
+
+ // Create test request with arguments
+ request := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "test-tool",
+ Arguments: map[string]interface{}{
+ "string_param": "hello",
+ "number_param": 42,
+ "bool_param": true,
+ "array_param": []interface{}{"a", "b", "c"},
+ "object_param": map[string]interface{}{
+ "nested": "value",
+ },
+ },
+ },
+ }
+
+ // Execute the handler
+ result, err := tracedHandler(context.Background(), request)
+
+ // Force flush to ensure spans are exported
+ if err := provider.ForceFlush(context.Background()); err != nil {
+ t.Errorf("Failed to flush provider: %v", err)
+ }
+
+ // Verify result
+ require.NoError(t, err)
+ assert.NotNil(t, result)
+ assert.False(t, result.IsError)
+
+ // Verify span was created
+ spans := exporter.GetSpans()
+ assert.Len(t, spans, 1)
+
+ span := spans[0]
+ assert.Equal(t, "mcp.tool.test-tool", span.Name())
+
+ // Verify that arguments were added as an attribute (they are JSON-encoded)
+ attributes := span.Attributes()
+ hasArguments := false
+
+ for _, attr := range attributes {
+ if attr.Key == "mcp.request.arguments" {
+ hasArguments = true
+ // Arguments should be JSON-encoded
+ assert.NotEmpty(t, attr.Value.AsString())
+ }
+ }
+
+ assert.True(t, hasArguments)
+}
+
+func TestWithTracingNilArguments(t *testing.T) {
+ // Initialize OpenTelemetry
+ provider, exporter := setupTracing()
+ defer func() {
+ if err := provider.Shutdown(context.Background()); err != nil {
+ t.Errorf("Failed to shutdown provider: %v", err)
+ }
+ }()
+
+ // Create a test handler
+ testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ textContent := mcp.NewTextContent("test response")
+ return &mcp.CallToolResult{
+ IsError: false,
+ Content: []mcp.Content{textContent},
+ }, nil
+ }
+
+ // Wrap with tracing
+ tracedHandler := WithTracing("test-tool", testHandler)
+
+ // Create test request without arguments
+ request := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "test-tool",
+ },
+ }
+
+ // Execute the handler
+ result, err := tracedHandler(context.Background(), request)
+
+ // Force flush to ensure spans are exported
+ if err := provider.ForceFlush(context.Background()); err != nil {
+ t.Errorf("Failed to flush provider: %v", err)
+ }
+
+ // Verify result
+ require.NoError(t, err)
+ assert.NotNil(t, result)
+ assert.False(t, result.IsError)
+
+ // Verify span was created
+ spans := exporter.GetSpans()
+ assert.Len(t, spans, 1)
+
+ span := spans[0]
+ assert.Equal(t, "mcp.tool.test-tool", span.Name())
+}
+
+func TestStartSpan(t *testing.T) {
+ // Initialize OpenTelemetry
+ provider, exporter := setupTracing()
+ defer func() {
+ if err := provider.Shutdown(context.Background()); err != nil {
+ t.Errorf("Failed to shutdown provider: %v", err)
+ }
+ }()
+
+ // Start a span
+ _, span := StartSpan(context.Background(), "test-span",
+ attribute.String("key1", "value1"),
+ attribute.Int("key2", 42),
+ )
+
+ // End the span
+ span.End()
+
+ // Force flush to ensure spans are exported
+ if err := provider.ForceFlush(context.Background()); err != nil {
+ t.Errorf("Failed to flush provider: %v", err)
+ }
+
+ // Verify span was created
+ spans := exporter.GetSpans()
+ assert.Len(t, spans, 1)
+
+ resultSpan := spans[0]
+ assert.Equal(t, "test-span", resultSpan.Name())
+}
+
+func TestStartSpanNoAttributes(t *testing.T) {
+ // Initialize OpenTelemetry
+ provider, exporter := setupTracing()
+ defer func() {
+ if err := provider.Shutdown(context.Background()); err != nil {
+ t.Errorf("Failed to shutdown provider: %v", err)
+ }
+ }()
+
+ // Start a span without attributes
+ _, span := StartSpan(context.Background(), "test-span")
+
+ // End the span
+ span.End()
+
+ // Force flush to ensure spans are exported
+ if err := provider.ForceFlush(context.Background()); err != nil {
+ t.Errorf("Failed to flush provider: %v", err)
+ }
+
+ // Verify span was created
+ spans := exporter.GetSpans()
+ assert.Len(t, spans, 1)
+
+ resultSpan := spans[0]
+ assert.Equal(t, "test-span", resultSpan.Name())
+}
+
+func TestRecordError(t *testing.T) {
+ // Initialize OpenTelemetry
+ provider, exporter := setupTracing()
+ defer func() {
+ if err := provider.Shutdown(context.Background()); err != nil {
+ t.Errorf("Failed to shutdown provider: %v", err)
+ }
+ }()
+
+ // Start a span
+ _, span := StartSpan(context.Background(), "test-span")
+
+ // Record an error
+ testError := errors.New("test error")
+ RecordError(span, testError, "test error")
+
+ // End the span
+ span.End()
+
+ // Force flush to ensure spans are exported
+ if err := provider.ForceFlush(context.Background()); err != nil {
+ t.Errorf("Failed to flush provider: %v", err)
+ }
+
+ // Verify span was created with error
+ spans := exporter.GetSpans()
+ assert.Len(t, spans, 1)
+
+ resultSpan := spans[0]
+ assert.Equal(t, "test-span", resultSpan.Name())
+ assert.Equal(t, codes.Error, resultSpan.Status().Code)
+ assert.Equal(t, "test error", resultSpan.Status().Description)
+}
+
+func TestRecordSuccess(t *testing.T) {
+ // Initialize OpenTelemetry
+ provider, exporter := setupTracing()
+ defer func() {
+ if err := provider.Shutdown(context.Background()); err != nil {
+ t.Errorf("Failed to shutdown provider: %v", err)
+ }
+ }()
+
+ // Start a span
+ _, span := StartSpan(context.Background(), "test-span")
+
+ // Record success
+ RecordSuccess(span, "operation completed successfully")
+
+ // End the span
+ span.End()
+
+ // Force flush to ensure spans are exported
+ if err := provider.ForceFlush(context.Background()); err != nil {
+ t.Errorf("Failed to flush provider: %v", err)
+ }
+
+ // Verify span was created with success
+ spans := exporter.GetSpans()
+ assert.Len(t, spans, 1)
+
+ resultSpan := spans[0]
+ assert.Equal(t, "test-span", resultSpan.Name())
+ assert.Equal(t, codes.Ok, resultSpan.Status().Code)
+ // Note: SDK may not preserve description in test environment
+ // assert.Equal(t, "operation completed successfully", resultSpan.Status().Description)
+}
+
+func TestAddEvent(t *testing.T) {
+ // Initialize OpenTelemetry
+ provider, exporter := setupTracing()
+ defer func() {
+ if err := provider.Shutdown(context.Background()); err != nil {
+ t.Errorf("Failed to shutdown provider: %v", err)
+ }
+ }()
+
+ // Start a span
+ _, span := StartSpan(context.Background(), "test-span")
+
+ // Add an event
+ AddEvent(span, "test-event",
+ attribute.String("event_key", "event_value"),
+ attribute.Int("event_num", 123),
+ )
+
+ // End the span
+ span.End()
+
+ // Force flush to ensure spans are exported
+ if err := provider.ForceFlush(context.Background()); err != nil {
+ t.Errorf("Failed to flush provider: %v", err)
+ }
+
+ // Verify span was created with event
+ spans := exporter.GetSpans()
+ assert.Len(t, spans, 1)
+
+ resultSpan := spans[0]
+ assert.Equal(t, "test-span", resultSpan.Name())
+
+ // Verify event
+ events := resultSpan.Events()
+ assert.Len(t, events, 1)
+ assert.Equal(t, "test-event", events[0].Name)
+}
+
+func TestAddEventNoAttributes(t *testing.T) {
+ // Initialize OpenTelemetry
+ provider, exporter := setupTracing()
+ defer func() {
+ if err := provider.Shutdown(context.Background()); err != nil {
+ t.Errorf("Failed to shutdown provider: %v", err)
+ }
+ }()
+
+ // Start a span
+ _, span := StartSpan(context.Background(), "test-span")
+
+ // Add an event without attributes
+ AddEvent(span, "test-event")
+
+ // End the span
+ span.End()
+
+ // Force flush to ensure spans are exported
+ if err := provider.ForceFlush(context.Background()); err != nil {
+ t.Errorf("Failed to flush provider: %v", err)
+ }
+
+ // Verify span was created with event
+ spans := exporter.GetSpans()
+ assert.Len(t, spans, 1)
+
+ resultSpan := spans[0]
+ assert.Equal(t, "test-span", resultSpan.Name())
+
+ // Verify event
+ events := resultSpan.Events()
+ assert.Len(t, events, 1)
+ assert.Equal(t, "test-event", events[0].Name)
+}
+
+func TestAdaptToolHandler(t *testing.T) {
+ // Create a test handler
+ testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ textContent := mcp.NewTextContent("test response")
+ return &mcp.CallToolResult{
+ IsError: false,
+ Content: []mcp.Content{textContent},
+ }, nil
+ }
+
+ // Adapt the handler
+ adapted := AdaptToolHandler(testHandler)
+
+ // Create test request
+ request := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "test-tool",
+ },
+ }
+
+ // Execute the adapted handler
+ result, err := adapted(context.Background(), request)
+
+ // Verify result
+ require.NoError(t, err)
+ assert.NotNil(t, result)
+ assert.False(t, result.IsError)
+ assert.Len(t, result.Content, 1)
+ textContent, ok := mcp.AsTextContent(result.Content[0])
+ require.True(t, ok)
+ assert.Equal(t, "test response", textContent.Text)
+}
+
+func TestWithTracingNilResult(t *testing.T) {
+ // Initialize OpenTelemetry
+ provider, exporter := setupTracing()
+ defer func() {
+ if err := provider.Shutdown(context.Background()); err != nil {
+ t.Errorf("Failed to shutdown provider: %v", err)
+ }
+ }()
+
+ // Create a test handler that returns nil result
+ testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ return nil, nil
+ }
+
+ // Wrap with tracing
+ tracedHandler := WithTracing("test-tool", testHandler)
+
+ // Create test request
+ request := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "test-tool",
+ },
+ }
+
+ // Execute the handler
+ result, err := tracedHandler(context.Background(), request)
+
+ // Force flush to ensure spans are exported
+ if err := provider.ForceFlush(context.Background()); err != nil {
+ t.Errorf("Failed to flush provider: %v", err)
+ }
+
+ // Verify result
+ require.NoError(t, err)
+ assert.Nil(t, result)
+
+ // Verify span was created
+ spans := exporter.GetSpans()
+ assert.Len(t, spans, 1)
+
+ span := spans[0]
+ assert.Equal(t, "mcp.tool.test-tool", span.Name())
+ assert.Equal(t, codes.Ok, span.Status().Code)
+}
+
+func TestWithTracingNoContent(t *testing.T) {
+ // Initialize OpenTelemetry
+ provider, exporter := setupTracing()
+ defer func() {
+ if err := provider.Shutdown(context.Background()); err != nil {
+ t.Errorf("Failed to shutdown provider: %v", err)
+ }
+ }()
+
+ // Create a test handler that returns result with no content
+ testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ return &mcp.CallToolResult{
+ IsError: false,
+ Content: []mcp.Content{},
+ }, nil
+ }
+
+ // Wrap with tracing
+ tracedHandler := WithTracing("test-tool", testHandler)
+
+ // Create test request
+ request := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "test-tool",
+ },
+ }
+
+ // Execute the handler
+ result, err := tracedHandler(context.Background(), request)
+
+ // Force flush to ensure spans are exported
+ if err := provider.ForceFlush(context.Background()); err != nil {
+ t.Errorf("Failed to flush provider: %v", err)
+ }
+
+ // Verify result
+ require.NoError(t, err)
+ assert.NotNil(t, result)
+ assert.False(t, result.IsError)
+ assert.Len(t, result.Content, 0)
+
+ // Verify span was created
+ spans := exporter.GetSpans()
+ assert.Len(t, spans, 1)
+
+ span := spans[0]
+ assert.Equal(t, "mcp.tool.test-tool", span.Name())
+ assert.Equal(t, codes.Ok, span.Status().Code)
+
+ // Verify attributes
+ attributes := span.Attributes()
+ hasContentCount := false
+
+ for _, attr := range attributes {
+ if attr.Key == "mcp.result.content_count" && attr.Value.AsInt64() == 0 {
+ hasContentCount = true
+ }
+ }
+
+ assert.True(t, hasContentCount)
+}
+
+func TestWithTracingNoopTracer(t *testing.T) {
+ // Set up noop tracer provider
+ otel.SetTracerProvider(noop.NewTracerProvider())
+
+ // Create a test handler
+ testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ textContent := mcp.NewTextContent("test response")
+ return &mcp.CallToolResult{
+ IsError: false,
+ Content: []mcp.Content{textContent},
+ }, nil
+ }
+
+ // Wrap with tracing
+ tracedHandler := WithTracing("test-tool", testHandler)
+
+ // Create test request
+ request := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "test-tool",
+ },
+ }
+
+ // Execute the handler
+ result, err := tracedHandler(context.Background(), request)
+
+ // Verify result (should work normally with noop tracer)
+ require.NoError(t, err)
+ assert.NotNil(t, result)
+ assert.False(t, result.IsError)
+ assert.Len(t, result.Content, 1)
+ textContent, ok := mcp.AsTextContent(result.Content[0])
+ require.True(t, ok)
+ assert.Equal(t, "test response", textContent.Text)
+}
+
+func TestWithTracingPerformance(t *testing.T) {
+ // Initialize OpenTelemetry
+ provider, _ := setupTracing()
+ defer func() {
+ if err := provider.Shutdown(context.Background()); err != nil {
+ t.Errorf("Failed to shutdown provider: %v", err)
+ }
+ }()
+
+ // Create a test handler
+ testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ textContent := mcp.NewTextContent("test response")
+ return &mcp.CallToolResult{
+ IsError: false,
+ Content: []mcp.Content{textContent},
+ }, nil
+ }
+
+ // Wrap with tracing
+ tracedHandler := WithTracing("test-tool", testHandler)
+
+ // Create test request
+ request := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "test-tool",
+ },
+ }
+
+ // Time execution
+ start := time.Now()
+ for i := 0; i < 100; i++ {
+ _, err := tracedHandler(context.Background(), request)
+ require.NoError(t, err)
+ }
+ duration := time.Since(start)
+
+ // Verify performance is reasonable (should complete in less than 1 second)
+ assert.Less(t, duration, time.Second)
+}
diff --git a/internal/telemetry/tracing.go b/internal/telemetry/tracing.go
new file mode 100644
index 0000000..6b6f720
--- /dev/null
+++ b/internal/telemetry/tracing.go
@@ -0,0 +1,282 @@
+package telemetry
+
+import (
+ "context"
+ "fmt"
+ "net/url"
+ "os"
+ "strings"
+ "time"
+
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/attribute"
+ "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
+ "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
+ "go.opentelemetry.io/otel/exporters/stdout/stdouttrace"
+ "go.opentelemetry.io/otel/propagation"
+ "go.opentelemetry.io/otel/sdk/resource"
+ sdktrace "go.opentelemetry.io/otel/sdk/trace"
+ semconv "go.opentelemetry.io/otel/semconv/v1.32.0"
+ "go.opentelemetry.io/otel/trace/noop"
+
+ "github.com/kagent-dev/tools/internal/logger"
+)
+
+// Standard OpenTelemetry environment variable names
+// These follow the official OTLP specification
+const (
+ // Service identification
+ OtelServiceName = "OTEL_SERVICE_NAME"
+ OtelServiceVersion = "OTEL_SERVICE_VERSION"
+ OtelEnvironment = "OTEL_ENVIRONMENT" // Custom extension, not in official spec
+
+ // OTLP Exporter configuration
+ OtelExporterOtlpEndpoint = "OTEL_EXPORTER_OTLP_ENDPOINT"
+ OtelExporterOtlpProtocol = "OTEL_EXPORTER_OTLP_PROTOCOL"
+ OtelExporterOtlpHeaders = "OTEL_EXPORTER_OTLP_HEADERS"
+
+ // Trace-specific OTLP configuration
+ OtelExporterOtlpTracesInsecure = "OTEL_EXPORTER_OTLP_TRACES_INSECURE"
+
+ // Sampling configuration
+ OtelTracesSamplerArg = "OTEL_TRACES_SAMPLER_ARG"
+
+ // SDK control
+ OtelSdkDisabled = "OTEL_SDK_DISABLED"
+)
+
+// OTLP Protocol constants
+const (
+ ProtocolGRPC = "grpc"
+ ProtocolHTTP = "http/protobuf"
+ ProtocolAuto = "auto" // Custom extension for automatic protocol detection
+)
+
+// Standard OTLP port numbers
+// These are the official OTLP default ports as per OpenTelemetry specification
+const (
+ DefaultOtlpGrpcPort = "4317" // Standard OTLP/gRPC port
+ DefaultOtlpHttpPort = "4318" // Standard OTLP/HTTP port
+)
+
+// Default endpoint paths
+const (
+ DefaultHttpTracesPath = "/v1/traces"
+)
+
+// SetupOTelSDK initializes the OpenTelemetry SDK
+func SetupOTelSDK(ctx context.Context) error {
+ log := logger.WithContext(ctx)
+ cfg := LoadOtelCfg()
+ telemetryConfig := cfg.Telemetry
+
+ // If tracing is disabled, set a no-op tracer provider and return.
+ // This prevents further initialization and ensures no traces are exported.
+ if cfg.Telemetry.Disabled {
+ otel.SetTracerProvider(noop.NewTracerProvider())
+ return nil
+ }
+
+ res, err := resource.New(ctx,
+ resource.WithDetectors(), // Detectors for cloud provider, k8s, etc.
+ resource.WithAttributes(
+ semconv.ServiceNameKey.String(telemetryConfig.ServiceName),
+ semconv.ServiceVersionKey.String(telemetryConfig.ServiceVersion),
+ attribute.String("deployment.environment", telemetryConfig.Environment),
+ ),
+ )
+ if err != nil {
+ log.Error("failed to create resource", "error", err)
+ return fmt.Errorf("failed to create resource: %w", err)
+ }
+
+ // Set up propagator
+ prop := propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{})
+ otel.SetTextMapPropagator(prop)
+
+ exporter, err := createExporter(ctx, &telemetryConfig)
+ if err != nil {
+ log.Error("failed to create exporter", "error", err)
+ return fmt.Errorf("failed to create exporter: %w", err)
+ }
+
+ // Set up trace provider
+ tracerProvider, err := newTracerProvider(ctx, &telemetryConfig, exporter, res)
+ if err != nil {
+ log.Error("failed to create tracer provider", "error", err)
+ return fmt.Errorf("failed to create tracer provider: %w", err)
+ }
+ otel.SetTracerProvider(tracerProvider)
+
+ log.Info("OpenTelemetry SDK successfully initialized")
+ //start goroutine and wait for ctx cancellation
+ go func() {
+ <-ctx.Done()
+ if err := tracerProvider.Shutdown(ctx); err != nil {
+ log.Error("failed to shutdown tracer provider", "error", err)
+ } else {
+ log.Info("OpenTelemetry SDK shutdown successfully")
+ }
+ }()
+ return nil
+}
+
+// newTracerProvider creates a new trace provider
+func newTracerProvider(ctx context.Context, cfg *Telemetry, exporter sdktrace.SpanExporter, res *resource.Resource) (*sdktrace.TracerProvider, error) {
+ if err := ctx.Err(); err != nil {
+ return nil, err
+ }
+
+ sampler := sdktrace.AlwaysSample()
+
+ tp := sdktrace.NewTracerProvider(
+ sdktrace.WithSampler(sampler),
+ sdktrace.WithBatcher(exporter),
+ sdktrace.WithResource(res),
+ )
+ return tp, nil
+}
+
+// createExporter creates a OTLP exporter
+func createExporter(ctx context.Context, cfg *Telemetry) (sdktrace.SpanExporter, error) {
+ if err := ctx.Err(); err != nil {
+ return nil, err
+ }
+
+ if cfg.Endpoint == "" {
+ return stdouttrace.New(stdouttrace.WithPrettyPrint())
+ }
+
+ // Determine protocol
+ protocol := cfg.Protocol
+ if protocol == ProtocolAuto || protocol == "" {
+ protocol = detectProtocol(cfg.Endpoint)
+ }
+
+ switch strings.ToLower(protocol) {
+ case ProtocolGRPC:
+ return createGRPCExporter(ctx, cfg)
+ case ProtocolHTTP:
+ return createHTTPExporter(ctx, cfg)
+ default:
+ return nil, fmt.Errorf("unsupported protocol: %s (supported: %s, %s)", protocol, ProtocolGRPC, ProtocolHTTP)
+ }
+}
+
+// detectProtocol determines the protocol based on the endpoint URL
+func detectProtocol(endpoint string) string {
+ // Parse URL to extract port
+ if parsedURL, err := url.Parse(endpoint); err == nil {
+ port := parsedURL.Port()
+ if port == "" {
+ // Check for default ports in hostname
+ if strings.Contains(parsedURL.Host, ":"+DefaultOtlpGrpcPort) {
+ return ProtocolGRPC
+ }
+ if strings.Contains(parsedURL.Host, ":"+DefaultOtlpHttpPort) {
+ return ProtocolHTTP
+ }
+ } else {
+ switch port {
+ case DefaultOtlpGrpcPort:
+ return ProtocolGRPC
+ case DefaultOtlpHttpPort:
+ return ProtocolHTTP
+ }
+ }
+ }
+
+ // Check if endpoint contains port info directly
+ if strings.Contains(endpoint, ":"+DefaultOtlpGrpcPort) {
+ return ProtocolGRPC
+ }
+ if strings.Contains(endpoint, ":"+DefaultOtlpHttpPort) {
+ return ProtocolHTTP
+ }
+
+ // Default to HTTP for backward compatibility
+ return ProtocolHTTP
+}
+
+// createGRPCExporter creates a gRPC OTLP exporter
+func createGRPCExporter(ctx context.Context, cfg *Telemetry) (sdktrace.SpanExporter, error) {
+ opts := []otlptracegrpc.Option{
+ otlptracegrpc.WithEndpoint(normalizeGRPCEndpoint(cfg.Endpoint)),
+ otlptracegrpc.WithTimeout(30 * time.Second),
+ }
+
+ // Use insecure connection if explicitly configured
+ if cfg.Insecure {
+ opts = append(opts, otlptracegrpc.WithInsecure())
+ }
+
+ if authToken := os.Getenv(OtelExporterOtlpHeaders); authToken != "" {
+ opts = append(opts, otlptracegrpc.WithHeaders(parseHeaders(authToken)))
+ }
+
+ return otlptracegrpc.New(ctx, opts...)
+}
+
+// createHTTPExporter creates an HTTP OTLP exporter
+func createHTTPExporter(ctx context.Context, cfg *Telemetry) (sdktrace.SpanExporter, error) {
+ opts := []otlptracehttp.Option{
+ otlptracehttp.WithEndpointURL(normalizeHTTPEndpoint(cfg.Endpoint, cfg.Insecure)),
+ otlptracehttp.WithTimeout(30 * time.Second),
+ }
+
+ // Use insecure connection if explicitly configured
+ if cfg.Insecure {
+ opts = append(opts, otlptracehttp.WithInsecure())
+ }
+
+ if authToken := os.Getenv(OtelExporterOtlpHeaders); authToken != "" {
+ opts = append(opts, otlptracehttp.WithHeaders(parseHeaders(authToken)))
+ }
+
+ return otlptracehttp.New(ctx, opts...)
+}
+
+// normalizeGRPCEndpoint normalizes the endpoint for gRPC usage
+func normalizeGRPCEndpoint(endpoint string) string {
+ if !strings.HasPrefix(endpoint, "http://") && !strings.HasPrefix(endpoint, "https://") {
+ return endpoint
+ }
+
+ u, err := url.Parse(endpoint)
+ if err != nil {
+ return endpoint // Should not happen with the check above, but as a safeguard
+ }
+
+ return u.Host + u.Path
+}
+
+// normalizeHTTPEndpoint normalizes the endpoint for HTTP usage
+func normalizeHTTPEndpoint(endpoint string, insecure bool) string {
+ // Ensure we have a proper HTTP URL
+ if !strings.HasPrefix(endpoint, "http://") && !strings.HasPrefix(endpoint, "https://") {
+ // Use HTTP if insecure is true or if endpoint contains localhost/127.0.0.1/docker.internal
+ if insecure || strings.Contains(endpoint, "localhost") || strings.Contains(endpoint, "127.0.0.1") || strings.Contains(endpoint, "docker.internal") {
+ endpoint = "http://" + endpoint
+ } else {
+ endpoint = "https://" + endpoint
+ }
+ }
+
+ // Add /v1/traces suffix if not present
+ if !strings.HasSuffix(endpoint, DefaultHttpTracesPath) {
+ endpoint = strings.TrimSuffix(endpoint, "/") + DefaultHttpTracesPath
+ }
+
+ return endpoint
+}
+
+// parseHeaders parses a comma-separated string of headers into a map
+func parseHeaders(headers string) map[string]string {
+ headerMap := make(map[string]string)
+ for _, h := range strings.Split(headers, ",") {
+ if parts := strings.SplitN(h, "=", 2); len(parts) == 2 {
+ headerMap[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1])
+ }
+ }
+ return headerMap
+}
diff --git a/internal/telemetry/tracing_test.go b/internal/telemetry/tracing_test.go
new file mode 100644
index 0000000..f26f3bd
--- /dev/null
+++ b/internal/telemetry/tracing_test.go
@@ -0,0 +1,374 @@
+package telemetry
+
+import (
+ "context"
+ "os"
+ "sync"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/exporters/stdout/stdouttrace"
+ "go.opentelemetry.io/otel/sdk/resource"
+ "go.opentelemetry.io/otel/trace/noop"
+)
+
+// Test protocol constants for additional test scenarios
+const (
+ ProtocolInvalid = "invalid"
+)
+
+// resetConfig is a helper to reset the singleton config for tests
+func resetConfig() {
+ once = sync.Once{}
+ config = nil
+}
+
+func TestSetupOTelSDK_Disabled(t *testing.T) {
+ resetConfig()
+ ctx := context.Background()
+ err := os.Setenv("OTEL_SDK_DISABLED", "true")
+ require.NoError(t, err)
+ defer func() {
+ _ = os.Unsetenv("OTEL_SDK_DISABLED")
+ }()
+ resetConfig()
+
+ err = SetupOTelSDK(ctx)
+ require.NoError(t, err)
+
+ // In a disabled state, the tracer provider should be a no-op provider
+ tp := otel.GetTracerProvider()
+ assert.IsType(t, noop.NewTracerProvider(), tp)
+
+ // Shutdown should be a no-op function
+ assert.NoError(t, err)
+}
+
+func TestSetupOTelSDKEnabled(t *testing.T) {
+ resetConfig()
+ ctx := context.Background()
+ err := os.Setenv(OtelSdkDisabled, "false")
+ require.NoError(t, err)
+ defer func() {
+ _ = os.Unsetenv(OtelSdkDisabled)
+ }()
+
+ err = SetupOTelSDK(ctx)
+ require.NoError(t, err)
+}
+
+func TestNewTracerProviderDevelopment(t *testing.T) {
+ resetConfig()
+ ctx := context.Background()
+ res := resource.NewSchemaless()
+ cfg := &Telemetry{
+ Environment: "development",
+ }
+ exporter, _ := stdouttrace.New()
+
+ tp, err := newTracerProvider(ctx, cfg, exporter, res)
+ require.NoError(t, err)
+ assert.NotNil(t, tp)
+}
+
+func TestNewTracerProviderProduction(t *testing.T) {
+ resetConfig()
+ ctx := context.Background()
+ res := resource.NewSchemaless()
+ cfg := &Telemetry{
+ Environment: "production",
+ SamplingRatio: 0.5,
+ }
+ exporter, _ := stdouttrace.New()
+
+ tp, err := newTracerProvider(ctx, cfg, exporter, res)
+ require.NoError(t, err)
+ assert.NotNil(t, tp)
+}
+
+func TestCreateExporterDevelopment(t *testing.T) {
+ resetConfig()
+ ctx := context.Background()
+ cfg := &Telemetry{
+ Environment: "development",
+ }
+
+ exporter, err := createExporter(ctx, cfg)
+ require.NoError(t, err)
+ assert.NotNil(t, exporter)
+ assert.IsType(t, &stdouttrace.Exporter{}, exporter)
+}
+
+func TestCreateExporterNoEndpoint(t *testing.T) {
+ resetConfig()
+ ctx := context.Background()
+ cfg := &Telemetry{
+ Environment: "production",
+ }
+
+ exporter, err := createExporter(ctx, cfg)
+ require.NoError(t, err)
+ assert.NotNil(t, exporter)
+ assert.IsType(t, &stdouttrace.Exporter{}, exporter)
+}
+
+func TestCreateExporterWithEndpoint(t *testing.T) {
+ resetConfig()
+ ctx := context.Background()
+ cfg := &Telemetry{
+ Environment: "production",
+ Endpoint: "http://localhost:4317",
+ Protocol: ProtocolAuto,
+ }
+
+ exporter, err := createExporter(ctx, cfg)
+ require.NoError(t, err)
+ assert.NotNil(t, exporter)
+}
+
+func TestCreateExporterWithInsecure(t *testing.T) {
+ resetConfig()
+ ctx := context.Background()
+ cfg := &Telemetry{
+ Environment: "production",
+ Endpoint: "localhost:4317",
+ Insecure: true,
+ }
+
+ // This should not fail, as insecure is handled by the exporters
+ _, err := createExporter(ctx, cfg)
+ require.NoError(t, err)
+}
+
+func TestCreateExporterWithAuthHeaders(t *testing.T) {
+ resetConfig()
+ ctx := context.Background()
+ cfg := &Telemetry{
+ Environment: "production",
+ Endpoint: "http://localhost:4317",
+ Protocol: ProtocolAuto,
+ }
+
+ // Set auth header
+ err := os.Setenv(OtelExporterOtlpHeaders, "Authorization=Bearer token123")
+ require.NoError(t, err)
+ defer func() {
+ _ = os.Unsetenv(OtelExporterOtlpHeaders)
+ }()
+
+ exporter, err := createExporter(ctx, cfg)
+ require.NoError(t, err)
+ assert.NotNil(t, exporter)
+
+ // Clean up
+ err = exporter.Shutdown(ctx)
+ assert.NoError(t, err)
+}
+
+func TestSetupOTelSDKWithCancellation(t *testing.T) {
+ resetConfig()
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel() // Cancel context immediately
+
+ err := SetupOTelSDK(ctx)
+ require.Error(t, err) // Expect an error due to context cancellation
+}
+
+func TestProtocolDetection(t *testing.T) {
+ tests := []struct {
+ name string
+ endpoint string
+ expected string
+ }{
+ {"gRPC port 4317", "localhost:4317", ProtocolGRPC},
+ {"HTTP port 4318", "localhost:4318", ProtocolHTTP},
+ {"gRPC port 4317 without scheme", "localhost:4317", ProtocolGRPC},
+ {"HTTP port 4318 without scheme", "localhost:4318", ProtocolHTTP},
+ {"gRPC with docker internal", "host.docker.internal:4317", ProtocolGRPC},
+ {"HTTP with docker internal", "host.docker.internal:4318", ProtocolHTTP},
+ {"No port specified", "localhost", ProtocolHTTP},
+ {"Unknown port", "localhost:1234", ProtocolHTTP},
+ {"HTTPS with gRPC port", "https://localhost:4317", ProtocolGRPC},
+ {"HTTPS with HTTP port", "https://localhost:4318", ProtocolHTTP},
+ {"gRPC with path", "localhost:4317/v1/traces", ProtocolGRPC},
+ {"HTTP with path", "localhost:4318/v1/traces", ProtocolHTTP},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := detectProtocol(tt.endpoint)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestEndpointNormalization(t *testing.T) {
+ tests := []struct {
+ name string
+ endpoint string
+ expected string
+ }{
+ {"Basic gRPC endpoint", "localhost:4317", "localhost:4317"},
+ {"gRPC with path", "localhost:4317/v1/traces", "localhost:4317/v1/traces"},
+ {"gRPC without scheme", "localhost:4317", "localhost:4317"},
+ {"gRPC with HTTPS", "https://localhost:4317", "localhost:4317"},
+ {"Docker internal gRPC", "host.docker.internal:4317", "host.docker.internal:4317"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := normalizeGRPCEndpoint(tt.endpoint)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestHTTPEndpointNormalization(t *testing.T) {
+ tests := []struct {
+ name string
+ endpoint string
+ insecure bool
+ expected string
+ }{
+ {"Basic HTTP endpoint", "http://localhost:4318", false, "http://localhost:4318/v1/traces"},
+ {"HTTP with path", "http://localhost:4318/v1/traces", false, "http://localhost:4318/v1/traces"},
+ {"HTTP without scheme - secure localhost", "localhost:4318", false, "http://localhost:4318/v1/traces"},
+ {"HTTP without scheme - insecure localhost", "localhost:4318", true, "http://localhost:4318/v1/traces"},
+ {"HTTP with trailing slash", "http://localhost:4318/", false, "http://localhost:4318/v1/traces"},
+ {"Docker internal HTTP - secure", "host.docker.internal:4318", false, "http://host.docker.internal:4318/v1/traces"},
+ {"Docker internal HTTP - insecure", "host.docker.internal:4318", true, "http://host.docker.internal:4318/v1/traces"},
+ {"Remote endpoint - secure", "collector.example.com:4318", false, "https://collector.example.com:4318/v1/traces"},
+ {"Remote endpoint - insecure", "collector.example.com:4318", true, "http://collector.example.com:4318/v1/traces"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := normalizeHTTPEndpoint(tt.endpoint, tt.insecure)
+ assert.Equal(t, tt.expected, result, "HTTP endpoint normalization failed for: %s", tt.endpoint)
+ })
+ }
+}
+
+func TestParseHeaders(t *testing.T) {
+ tests := []struct {
+ name string
+ headers string
+ want map[string]string
+ }{
+ {"Empty string", "", map[string]string{}},
+ {"Single header", "key=value", map[string]string{"key": "value"}},
+ {"Multiple headers", "key1=value1,key2=value2", map[string]string{"key1": "value1", "key2": "value2"}},
+ {"Headers with spaces", " key1 = value1 , key2 = value2 ", map[string]string{"key1": "value1", "key2": "value2"}},
+ {"Invalid header format", "key-value,key2", map[string]string{}},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := parseHeaders(tt.headers)
+ assert.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func TestCreateExporterWithProtocol(t *testing.T) {
+
+ ctx := context.Background()
+
+ tests := []struct {
+ name string
+ config *Telemetry
+ shouldError bool
+ description string
+ }{
+ {
+ "gRPC protocol",
+ &Telemetry{
+ Environment: "development",
+ Endpoint: "localhost:4317",
+ Protocol: ProtocolGRPC,
+ },
+ false,
+ "Should create gRPC exporter",
+ },
+ {
+ "HTTP protocol",
+ &Telemetry{
+ Environment: "development",
+ Endpoint: "localhost:4318",
+ Protocol: ProtocolHTTP,
+ },
+ false,
+ "Should create HTTP exporter",
+ },
+ {
+ "Auto protocol with gRPC port",
+ &Telemetry{
+ Environment: "development",
+ Endpoint: "localhost:4317",
+ Protocol: ProtocolAuto,
+ },
+ false,
+ "Should auto-detect gRPC",
+ },
+ {
+ "Auto protocol with HTTP port",
+ &Telemetry{
+ Environment: "development",
+ Endpoint: "localhost:4318",
+ Protocol: ProtocolAuto,
+ },
+ false,
+ "Should auto-detect HTTP",
+ },
+ {
+ "gRPC protocol with insecure",
+ &Telemetry{
+ Environment: "production",
+ Endpoint: "localhost:4317",
+ Protocol: ProtocolGRPC,
+ Insecure: true,
+ },
+ false,
+ "Should create gRPC exporter with insecure",
+ },
+ {
+ "HTTP protocol with insecure",
+ &Telemetry{
+ Environment: "production",
+ Endpoint: "localhost:4318",
+ Protocol: ProtocolHTTP,
+ Insecure: true,
+ },
+ false,
+ "Should create HTTP exporter with insecure",
+ },
+ {
+ "Invalid protocol",
+ &Telemetry{
+ Environment: "development",
+ Endpoint: "localhost:1234",
+ Protocol: ProtocolInvalid,
+ },
+ true,
+ "Should return error for invalid protocol",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ resetConfig()
+ exporter, err := createExporter(ctx, tt.config)
+ if tt.shouldError {
+ require.Error(t, err, tt.description)
+ assert.Nil(t, exporter, tt.description)
+ } else {
+ require.NoError(t, err, tt.description)
+ assert.NotNil(t, exporter, tt.description)
+ err = exporter.Shutdown(ctx)
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
diff --git a/pkg/argo/argo.go b/pkg/argo/argo.go
index 566a4a0..a24e978 100644
--- a/pkg/argo/argo.go
+++ b/pkg/argo/argo.go
@@ -13,6 +13,8 @@ import (
"strings"
"time"
+ "github.com/kagent-dev/tools/internal/commands"
+ "github.com/kagent-dev/tools/internal/telemetry"
"github.com/kagent-dev/tools/pkg/utils"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
@@ -20,8 +22,6 @@ import (
// Argo Rollouts tools
-var kubeConfig = ""
-
func handleVerifyArgoRolloutsControllerInstall(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
ns := mcp.ParseString(request, "namespace", "argo-rollouts")
label := mcp.ParseString(request, "label", "app.kubernetes.io/component=rollouts-controller")
@@ -76,10 +76,11 @@ func handleVerifyKubectlPluginInstall(ctx context.Context, request mcp.CallToolR
}
func runArgoRolloutCommand(ctx context.Context, args []string) (string, error) {
- if kubeConfig != "" {
- args = append(args, "--kubeconfig", kubeConfig)
- }
- return utils.RunCommandWithContext(ctx, "kubectl", args)
+ kubeconfigPath := utils.GetKubeconfig()
+ return commands.NewCommandBuilder("kubectl").
+ WithArgs(args...).
+ WithKubeconfig(kubeconfigPath).
+ Execute(ctx)
}
func handlePromoteRollout(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
@@ -198,9 +199,13 @@ func getSystemArchitecture() (string, error) {
}
}
-func getLatestVersion() string {
+func getLatestVersion(ctx context.Context) string {
client := &http.Client{Timeout: 10 * time.Second}
- resp, err := client.Get("https://api.github.com/repos/argoproj-labs/rollouts-plugin-trafficrouter-gatewayapi/releases/latest")
+ req, err := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/repos/argoproj-labs/rollouts-plugin-trafficrouter-gatewayapi/releases/latest", nil)
+ if err != nil {
+ return "0.5.0" // Default version
+ }
+ resp, err := client.Do(req)
if err != nil {
return "0.5.0" // Default version
}
@@ -220,7 +225,7 @@ func getLatestVersion() string {
return "0.5.0"
}
-func configureGatewayPlugin(version, namespace string) GatewayPluginStatus {
+func configureGatewayPlugin(ctx context.Context, version, namespace string) GatewayPluginStatus {
arch, err := getSystemArchitecture()
if err != nil {
return GatewayPluginStatus{
@@ -230,7 +235,7 @@ func configureGatewayPlugin(version, namespace string) GatewayPluginStatus {
}
if version == "" {
- version = getLatestVersion()
+ version = getLatestVersion(ctx)
}
configMap := fmt.Sprintf(`apiVersion: v1
@@ -263,11 +268,12 @@ data:
tmpFile.Close()
// Apply the ConfigMap
- _, err = utils.RunCommandWithContext(context.Background(), "kubectl", []string{"apply", "-f", tmpFile.Name()})
+ cmdArgs := []string{"apply", "-f", tmpFile.Name()}
+ output, err := runArgoRolloutCommand(ctx, cmdArgs)
if err != nil {
return GatewayPluginStatus{
Installed: false,
- ErrorMessage: fmt.Sprintf("Failed to configure Gateway API plugin: %s", err.Error()),
+ ErrorMessage: fmt.Sprintf("Error applying Gateway API plugin config: %s. Output: %s", err.Error(), output),
}
}
@@ -304,7 +310,7 @@ func handleVerifyGatewayPlugin(ctx context.Context, request mcp.CallToolRequest)
}
// Configure plugin
- status := configureGatewayPlugin(version, namespace)
+ status := configureGatewayPlugin(ctx, version, namespace)
return mcp.NewToolResultText(status.String()), nil
}
@@ -348,48 +354,47 @@ func handleCheckPluginLogs(ctx context.Context, request mcp.CallToolRequest) (*m
return mcp.NewToolResultText(status.String()), nil
}
-func RegisterArgoTools(s *server.MCPServer, kubeconfig string) {
- kubeConfig = kubeconfig
+func RegisterTools(s *server.MCPServer) {
s.AddTool(mcp.NewTool("argo_verify_argo_rollouts_controller_install",
mcp.WithDescription("Verify that the Argo Rollouts controller is installed and running"),
mcp.WithString("namespace", mcp.Description("The namespace where Argo Rollouts is installed")),
mcp.WithString("label", mcp.Description("The label of the Argo Rollouts controller pods")),
- ), handleVerifyArgoRolloutsControllerInstall)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("argo_verify_argo_rollouts_controller_install", handleVerifyArgoRolloutsControllerInstall)))
s.AddTool(mcp.NewTool("argo_verify_kubectl_plugin_install",
mcp.WithDescription("Verify that the kubectl Argo Rollouts plugin is installed"),
- ), handleVerifyKubectlPluginInstall)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("argo_verify_kubectl_plugin_install", handleVerifyKubectlPluginInstall)))
s.AddTool(mcp.NewTool("argo_promote_rollout",
mcp.WithDescription("Promote a paused rollout to the next step"),
mcp.WithString("rollout_name", mcp.Description("The name of the rollout to promote"), mcp.Required()),
mcp.WithString("namespace", mcp.Description("The namespace of the rollout")),
mcp.WithString("full", mcp.Description("Promote the rollout to the final step")),
- ), handlePromoteRollout)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("argo_promote_rollout", handlePromoteRollout)))
s.AddTool(mcp.NewTool("argo_pause_rollout",
mcp.WithDescription("Pause a rollout"),
mcp.WithString("rollout_name", mcp.Description("The name of the rollout to pause"), mcp.Required()),
mcp.WithString("namespace", mcp.Description("The namespace of the rollout")),
- ), handlePauseRollout)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("argo_pause_rollout", handlePauseRollout)))
s.AddTool(mcp.NewTool("argo_set_rollout_image",
mcp.WithDescription("Set the image of a rollout"),
mcp.WithString("rollout_name", mcp.Description("The name of the rollout to set the image for"), mcp.Required()),
mcp.WithString("container_image", mcp.Description("The container image to set for the rollout"), mcp.Required()),
mcp.WithString("namespace", mcp.Description("The namespace of the rollout")),
- ), handleSetRolloutImage)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("argo_set_rollout_image", handleSetRolloutImage)))
s.AddTool(mcp.NewTool("argo_verify_gateway_plugin",
mcp.WithDescription("Verify the installation status of the Argo Rollouts Gateway API plugin"),
mcp.WithString("version", mcp.Description("The version of the plugin to check")),
mcp.WithString("namespace", mcp.Description("The namespace for the plugin resources")),
mcp.WithString("should_install", mcp.Description("Whether to install the plugin if not found")),
- ), handleVerifyGatewayPlugin)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("argo_verify_gateway_plugin", handleVerifyGatewayPlugin)))
s.AddTool(mcp.NewTool("argo_check_plugin_logs",
mcp.WithDescription("Check the logs of the Argo Rollouts Gateway API plugin"),
mcp.WithString("namespace", mcp.Description("The namespace of the plugin resources")),
mcp.WithString("timeout", mcp.Description("Timeout for log collection in seconds")),
- ), handleCheckPluginLogs)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("argo_check_plugin_logs", handleCheckPluginLogs)))
}
diff --git a/pkg/argo/argo_test.go b/pkg/argo/argo_test.go
index 4a80823..0f90c39 100644
--- a/pkg/argo/argo_test.go
+++ b/pkg/argo/argo_test.go
@@ -5,7 +5,7 @@ import (
"strings"
"testing"
- "github.com/kagent-dev/tools/pkg/utils"
+ "github.com/kagent-dev/tools/internal/cmd"
"github.com/mark3labs/mcp-go/mcp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -27,11 +27,11 @@ func getResultText(result *mcp.CallToolResult) string {
// Test Argo Rollouts Promote
func TestHandlePromoteRollout(t *testing.T) {
t.Run("promote rollout basic", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `rollout "myapp" promoted`
- mock.AddCommandString("kubectl", []string{"argo", "rollouts", "promote", "myapp"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock.AddCommandString("kubectl", []string{"argo", "rollouts", "promote", "myapp", "--timeout", "30s"}, expectedOutput, nil)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
@@ -52,15 +52,15 @@ func TestHandlePromoteRollout(t *testing.T) {
callLog := mock.GetCallLog()
require.Len(t, callLog, 1)
assert.Equal(t, "kubectl", callLog[0].Command)
- assert.Equal(t, []string{"argo", "rollouts", "promote", "myapp"}, callLog[0].Args)
+ assert.Equal(t, []string{"argo", "rollouts", "promote", "myapp", "--timeout", "30s"}, callLog[0].Args)
})
t.Run("promote rollout with namespace", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `rollout "myapp" promoted`
- mock.AddCommandString("kubectl", []string{"argo", "rollouts", "promote", "-n", "production", "myapp"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock.AddCommandString("kubectl", []string{"argo", "rollouts", "promote", "-n", "production", "myapp", "--timeout", "30s"}, expectedOutput, nil)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
@@ -77,15 +77,15 @@ func TestHandlePromoteRollout(t *testing.T) {
callLog := mock.GetCallLog()
require.Len(t, callLog, 1)
assert.Equal(t, "kubectl", callLog[0].Command)
- assert.Equal(t, []string{"argo", "rollouts", "promote", "-n", "production", "myapp"}, callLog[0].Args)
+ assert.Equal(t, []string{"argo", "rollouts", "promote", "-n", "production", "myapp", "--timeout", "30s"}, callLog[0].Args)
})
t.Run("promote rollout with full flag", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `rollout "myapp" fully promoted`
- mock.AddCommandString("kubectl", []string{"argo", "rollouts", "promote", "myapp", "--full"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock.AddCommandString("kubectl", []string{"argo", "rollouts", "promote", "myapp", "--full", "--timeout", "30s"}, expectedOutput, nil)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
@@ -102,12 +102,12 @@ func TestHandlePromoteRollout(t *testing.T) {
callLog := mock.GetCallLog()
require.Len(t, callLog, 1)
assert.Equal(t, "kubectl", callLog[0].Command)
- assert.Equal(t, []string{"argo", "rollouts", "promote", "myapp", "--full"}, callLog[0].Args)
+ assert.Equal(t, []string{"argo", "rollouts", "promote", "myapp", "--full", "--timeout", "30s"}, callLog[0].Args)
})
t.Run("missing required parameters", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock := cmd.NewMockShellExecutor()
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
@@ -125,9 +125,9 @@ func TestHandlePromoteRollout(t *testing.T) {
})
t.Run("kubectl command failure", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- mock.AddCommandString("kubectl", []string{"argo", "rollouts", "promote", "myapp"}, "", assert.AnError)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("kubectl", []string{"argo", "rollouts", "promote", "myapp", "--timeout", "30s"}, "", assert.AnError)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
@@ -145,11 +145,11 @@ func TestHandlePromoteRollout(t *testing.T) {
// Test Argo Rollouts Pause
func TestHandlePauseRollout(t *testing.T) {
t.Run("pause rollout basic", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `rollout "myapp" paused`
- mock.AddCommandString("kubectl", []string{"argo", "rollouts", "pause", "myapp"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock.AddCommandString("kubectl", []string{"argo", "rollouts", "pause", "myapp", "--timeout", "30s"}, expectedOutput, nil)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
@@ -170,15 +170,15 @@ func TestHandlePauseRollout(t *testing.T) {
callLog := mock.GetCallLog()
require.Len(t, callLog, 1)
assert.Equal(t, "kubectl", callLog[0].Command)
- assert.Equal(t, []string{"argo", "rollouts", "pause", "myapp"}, callLog[0].Args)
+ assert.Equal(t, []string{"argo", "rollouts", "pause", "myapp", "--timeout", "30s"}, callLog[0].Args)
})
t.Run("pause rollout with namespace", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `rollout "myapp" paused`
- mock.AddCommandString("kubectl", []string{"argo", "rollouts", "pause", "-n", "production", "myapp"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock.AddCommandString("kubectl", []string{"argo", "rollouts", "pause", "-n", "production", "myapp", "--timeout", "30s"}, expectedOutput, nil)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
@@ -195,12 +195,12 @@ func TestHandlePauseRollout(t *testing.T) {
callLog := mock.GetCallLog()
require.Len(t, callLog, 1)
assert.Equal(t, "kubectl", callLog[0].Command)
- assert.Equal(t, []string{"argo", "rollouts", "pause", "-n", "production", "myapp"}, callLog[0].Args)
+ assert.Equal(t, []string{"argo", "rollouts", "pause", "-n", "production", "myapp", "--timeout", "30s"}, callLog[0].Args)
})
t.Run("missing required parameters", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock := cmd.NewMockShellExecutor()
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
@@ -221,11 +221,11 @@ func TestHandlePauseRollout(t *testing.T) {
// Test Argo Rollouts Set Image
func TestHandleSetRolloutImage(t *testing.T) {
t.Run("set rollout image basic", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `rollout "myapp" image updated`
- mock.AddCommandString("kubectl", []string{"argo", "rollouts", "set", "image", "myapp", "nginx:latest"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock.AddCommandString("kubectl", []string{"argo", "rollouts", "set", "image", "myapp", "nginx:latest", "--timeout", "30s"}, expectedOutput, nil)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
@@ -247,15 +247,15 @@ func TestHandleSetRolloutImage(t *testing.T) {
callLog := mock.GetCallLog()
require.Len(t, callLog, 1)
assert.Equal(t, "kubectl", callLog[0].Command)
- assert.Equal(t, []string{"argo", "rollouts", "set", "image", "myapp", "nginx:latest"}, callLog[0].Args)
+ assert.Equal(t, []string{"argo", "rollouts", "set", "image", "myapp", "nginx:latest", "--timeout", "30s"}, callLog[0].Args)
})
t.Run("set rollout image with namespace", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `rollout "myapp" image updated`
- mock.AddCommandString("kubectl", []string{"argo", "rollouts", "set", "image", "myapp", "nginx:1.20", "-n", "production"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock.AddCommandString("kubectl", []string{"argo", "rollouts", "set", "image", "myapp", "nginx:1.20", "-n", "production", "--timeout", "30s"}, expectedOutput, nil)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
@@ -273,12 +273,12 @@ func TestHandleSetRolloutImage(t *testing.T) {
callLog := mock.GetCallLog()
require.Len(t, callLog, 1)
assert.Equal(t, "kubectl", callLog[0].Command)
- assert.Equal(t, []string{"argo", "rollouts", "set", "image", "myapp", "nginx:1.20", "-n", "production"}, callLog[0].Args)
+ assert.Equal(t, []string{"argo", "rollouts", "set", "image", "myapp", "nginx:1.20", "-n", "production", "--timeout", "30s"}, callLog[0].Args)
})
t.Run("missing rollout_name parameter", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock := cmd.NewMockShellExecutor()
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
@@ -297,8 +297,8 @@ func TestHandleSetRolloutImage(t *testing.T) {
})
t.Run("missing container_image parameter", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock := cmd.NewMockShellExecutor()
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
@@ -334,7 +334,7 @@ func TestGetSystemArchitecture(t *testing.T) {
}
func TestGetLatestVersion(t *testing.T) {
- version := getLatestVersion()
+ version := getLatestVersion(context.Background())
if version == "" {
t.Error("Expected non-empty version")
}
@@ -367,11 +367,11 @@ func TestGatewayPluginStatus(t *testing.T) {
// Test Verify Gateway Plugin
func TestHandleVerifyGatewayPlugin(t *testing.T) {
t.Run("verify gateway plugin without install", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `gateway-api-plugin not found`
- mock.AddCommandString("kubectl", []string{"get", "configmap", "argo-rollouts-config", "-n", "argo-rollouts", "-o", "yaml"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock.AddCommandString("kubectl", []string{"get", "configmap", "argo-rollouts-config", "-n", "argo-rollouts", "-o", "yaml", "--timeout", "30s"}, expectedOutput, nil)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
@@ -394,11 +394,11 @@ func TestHandleVerifyGatewayPlugin(t *testing.T) {
})
t.Run("verify gateway plugin with custom namespace", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `gateway-api-plugin-abc123`
- mock.AddCommandString("kubectl", []string{"get", "configmap", "argo-rollouts-config", "-n", "custom-namespace", "-o", "yaml"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock.AddCommandString("kubectl", []string{"get", "configmap", "argo-rollouts-config", "-n", "custom-namespace", "-o", "yaml", "--timeout", "30s"}, expectedOutput, nil)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
@@ -423,11 +423,11 @@ func TestHandleVerifyGatewayPlugin(t *testing.T) {
// Test Verify Argo Rollouts Controller Install
func TestHandleVerifyArgoRolloutsControllerInstall(t *testing.T) {
t.Run("verify controller install", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `argo-rollouts-controller-manager-abc123`
- mock.AddCommandString("kubectl", []string{"get", "pods", "-l", "app.kubernetes.io/name=argo-rollouts", "-n", "argo-rollouts", "-o", "jsonpath={.items[*].metadata.name}"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock.AddCommandString("kubectl", []string{"get", "pods", "-l", "app.kubernetes.io/name=argo-rollouts", "-n", "argo-rollouts", "-o", "jsonpath={.items[*].metadata.name}", "--timeout", "30s"}, expectedOutput, nil)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
result, err := handleVerifyArgoRolloutsControllerInstall(ctx, request)
@@ -444,11 +444,11 @@ func TestHandleVerifyArgoRolloutsControllerInstall(t *testing.T) {
})
t.Run("verify controller install with custom namespace", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `argo-rollouts-controller-manager-abc123`
- mock.AddCommandString("kubectl", []string{"get", "pods", "-l", "app.kubernetes.io/name=argo-rollouts", "-n", "custom-argo", "-o", "jsonpath={.items[*].metadata.name}"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock.AddCommandString("kubectl", []string{"get", "pods", "-l", "app.kubernetes.io/name=argo-rollouts", "-n", "custom-argo", "-o", "jsonpath={.items[*].metadata.name}", "--timeout", "30s"}, expectedOutput, nil)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
@@ -469,11 +469,11 @@ func TestHandleVerifyArgoRolloutsControllerInstall(t *testing.T) {
})
t.Run("verify controller install with custom label", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `argo-rollouts-controller-manager-abc123`
- mock.AddCommandString("kubectl", []string{"get", "pods", "-l", "app=custom-rollouts", "-n", "argo-rollouts", "-o", "jsonpath={.items[*].metadata.name}"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock.AddCommandString("kubectl", []string{"get", "pods", "-l", "app=custom-rollouts", "-n", "argo-rollouts", "-o", "jsonpath={.items[*].metadata.name}", "--timeout", "30s"}, expectedOutput, nil)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
@@ -497,11 +497,11 @@ func TestHandleVerifyArgoRolloutsControllerInstall(t *testing.T) {
// Test Verify Kubectl Plugin Install
func TestHandleVerifyKubectlPluginInstall(t *testing.T) {
t.Run("verify kubectl plugin install", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `kubectl-argo-rollouts`
- mock.AddCommandString("kubectl", []string{"argo", "rollouts", "version"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock.AddCommandString("kubectl", []string{"argo", "rollouts", "version", "--timeout", "30s"}, expectedOutput, nil)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
result, err := handleVerifyKubectlPluginInstall(ctx, request)
@@ -513,13 +513,13 @@ func TestHandleVerifyKubectlPluginInstall(t *testing.T) {
callLog := mock.GetCallLog()
require.Len(t, callLog, 1)
assert.Equal(t, "kubectl", callLog[0].Command)
- assert.Equal(t, []string{"argo", "rollouts", "version"}, callLog[0].Args)
+ assert.Equal(t, []string{"argo", "rollouts", "version", "--timeout", "30s"}, callLog[0].Args)
})
t.Run("kubectl plugin command failure", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- mock.AddCommandString("kubectl", []string{"plugin", "list"}, "", assert.AnError)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("kubectl", []string{"plugin", "list", "--timeout", "30s"}, "", assert.AnError)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
result, err := handleVerifyKubectlPluginInstall(ctx, request)
diff --git a/pkg/cilium/cilium.go b/pkg/cilium/cilium.go
index a84cae3..6ad576c 100644
--- a/pkg/cilium/cilium.go
+++ b/pkg/cilium/cilium.go
@@ -3,21 +3,21 @@ package cilium
import (
"context"
"fmt"
- "strings"
+ "github.com/kagent-dev/tools/internal/commands"
+ "github.com/kagent-dev/tools/internal/telemetry"
"github.com/kagent-dev/tools/pkg/utils"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)
-var kubeConfig = ""
-
func runCiliumCliWithContext(ctx context.Context, args ...string) (string, error) {
- if kubeConfig != "" {
- args = append([]string{"--kubeconfig", kubeConfig}, args...)
- }
- return utils.RunCommandWithContext(ctx, "cilium", args)
+ kubeconfigPath := utils.GetKubeconfig()
+ return commands.NewCommandBuilder("cilium").
+ WithArgs(args...).
+ WithKubeconfig(kubeconfigPath).
+ Execute(ctx)
}
func handleCiliumStatusAndVersion(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
@@ -200,70 +200,66 @@ func handleToggleClusterMesh(ctx context.Context, request mcp.CallToolRequest) (
return mcp.NewToolResultText(output), nil
}
-func RegisterCiliumTools(s *server.MCPServer, kubeconfig string) {
- kubeConfig = kubeconfig
+func RegisterTools(s *server.MCPServer) {
- // Register debug tools
- RegisterCiliumDbgTools(s)
-
- // Register main Cilium tools
+ // Register all Cilium tools (main and debug)
s.AddTool(mcp.NewTool("cilium_status_and_version",
mcp.WithDescription("Get the status and version of Cilium installation"),
- ), handleCiliumStatusAndVersion)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_status_and_version", handleCiliumStatusAndVersion)))
s.AddTool(mcp.NewTool("cilium_upgrade_cilium",
mcp.WithDescription("Upgrade Cilium on the cluster"),
mcp.WithString("cluster_name", mcp.Description("The name of the cluster to upgrade Cilium on")),
mcp.WithString("datapath_mode", mcp.Description("The datapath mode to use for Cilium (tunnel, native, aws-eni, gke, azure, aks-byocni)")),
- ), handleUpgradeCilium)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_upgrade_cilium", handleUpgradeCilium)))
s.AddTool(mcp.NewTool("cilium_install_cilium",
mcp.WithDescription("Install Cilium on the cluster"),
mcp.WithString("cluster_name", mcp.Description("The name of the cluster to install Cilium on")),
mcp.WithString("cluster_id", mcp.Description("The ID of the cluster to install Cilium on")),
mcp.WithString("datapath_mode", mcp.Description("The datapath mode to use for Cilium (tunnel, native, aws-eni, gke, azure, aks-byocni)")),
- ), handleInstallCilium)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_install_cilium", handleInstallCilium)))
s.AddTool(mcp.NewTool("cilium_uninstall_cilium",
mcp.WithDescription("Uninstall Cilium from the cluster"),
- ), handleUninstallCilium)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_uninstall_cilium", handleUninstallCilium)))
s.AddTool(mcp.NewTool("cilium_connect_to_remote_cluster",
mcp.WithDescription("Connect to a remote cluster for cluster mesh"),
mcp.WithString("cluster_name", mcp.Description("The name of the destination cluster"), mcp.Required()),
mcp.WithString("context", mcp.Description("The kubectl context for the destination cluster")),
- ), handleConnectToRemoteCluster)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_connect_to_remote_cluster", handleConnectToRemoteCluster)))
s.AddTool(mcp.NewTool("cilium_disconnect_remote_cluster",
mcp.WithDescription("Disconnect from a remote cluster"),
mcp.WithString("cluster_name", mcp.Description("The name of the destination cluster"), mcp.Required()),
- ), handleDisconnectRemoteCluster)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_disconnect_remote_cluster", handleDisconnectRemoteCluster)))
s.AddTool(mcp.NewTool("cilium_list_bgp_peers",
mcp.WithDescription("List BGP peers"),
- ), handleListBGPPeers)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_bgp_peers", handleListBGPPeers)))
s.AddTool(mcp.NewTool("cilium_list_bgp_routes",
mcp.WithDescription("List BGP routes"),
- ), handleListBGPRoutes)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_bgp_routes", handleListBGPRoutes)))
s.AddTool(mcp.NewTool("cilium_show_cluster_mesh_status",
mcp.WithDescription("Show cluster mesh status"),
- ), handleShowClusterMeshStatus)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_show_cluster_mesh_status", handleShowClusterMeshStatus)))
s.AddTool(mcp.NewTool("cilium_show_features_status",
mcp.WithDescription("Show Cilium features status"),
- ), handleShowFeaturesStatus)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_show_features_status", handleShowFeaturesStatus)))
s.AddTool(mcp.NewTool("cilium_toggle_hubble",
mcp.WithDescription("Enable or disable Hubble"),
mcp.WithString("enable", mcp.Description("Set to 'true' to enable, 'false' to disable")),
- ), handleToggleHubble)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_toggle_hubble", handleToggleHubble)))
s.AddTool(mcp.NewTool("cilium_toggle_cluster_mesh",
mcp.WithDescription("Enable or disable cluster mesh"),
mcp.WithString("enable", mcp.Description("Set to 'true' to enable, 'false' to disable")),
- ), handleToggleClusterMesh)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_toggle_cluster_mesh", handleToggleClusterMesh)))
// Add tools that are also needed by cilium-manager agent
s.AddTool(mcp.NewTool("cilium_get_daemon_status",
@@ -276,12 +272,12 @@ func RegisterCiliumTools(s *server.MCPServer, kubeconfig string) {
mcp.WithString("show_all_redirects", mcp.Description("Whether to show all redirects")),
mcp.WithString("brief", mcp.Description("Whether to show a brief status")),
mcp.WithString("node_name", mcp.Description("The name of the node to get the daemon status for")),
- ), handleGetDaemonStatus)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_daemon_status", handleGetDaemonStatus)))
s.AddTool(mcp.NewTool("cilium_get_endpoints_list",
mcp.WithDescription("Get the list of all endpoints in the cluster"),
mcp.WithString("node_name", mcp.Description("The name of the node to get the endpoints list for")),
- ), handleGetEndpointsList)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_endpoints_list", handleGetEndpointsList)))
s.AddTool(mcp.NewTool("cilium_get_endpoint_details",
mcp.WithDescription("List the details of an endpoint in the cluster"),
@@ -289,7 +285,7 @@ func RegisterCiliumTools(s *server.MCPServer, kubeconfig string) {
mcp.WithString("labels", mcp.Description("The labels of the endpoint to get details for")),
mcp.WithString("output_format", mcp.Description("The output format of the endpoint details (json, yaml, jsonpath)")),
mcp.WithString("node_name", mcp.Description("The name of the node to get the endpoint details for")),
- ), handleGetEndpointDetails)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_endpoint_details", handleGetEndpointDetails)))
s.AddTool(mcp.NewTool("cilium_show_configuration_options",
mcp.WithDescription("Show Cilium configuration options"),
@@ -297,26 +293,26 @@ func RegisterCiliumTools(s *server.MCPServer, kubeconfig string) {
mcp.WithString("list_read_only", mcp.Description("Whether to list read-only configuration options")),
mcp.WithString("list_options", mcp.Description("Whether to list options")),
mcp.WithString("node_name", mcp.Description("The name of the node to show the configuration options for")),
- ), handleShowConfigurationOptions)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_show_configuration_options", handleShowConfigurationOptions)))
s.AddTool(mcp.NewTool("cilium_toggle_configuration_option",
mcp.WithDescription("Toggle a Cilium configuration option"),
mcp.WithString("option", mcp.Description("The option to toggle"), mcp.Required()),
mcp.WithString("value", mcp.Description("The value to set the option to (true/false)"), mcp.Required()),
mcp.WithString("node_name", mcp.Description("The name of the node to toggle the configuration option for")),
- ), handleToggleConfigurationOption)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_toggle_configuration_option", handleToggleConfigurationOption)))
s.AddTool(mcp.NewTool("cilium_list_services",
mcp.WithDescription("List services for the cluster"),
mcp.WithString("show_cluster_mesh_affinity", mcp.Description("Whether to show cluster mesh affinity")),
mcp.WithString("node_name", mcp.Description("The name of the node to get the services for")),
- ), handleListServices)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_services", handleListServices)))
s.AddTool(mcp.NewTool("cilium_get_service_information",
mcp.WithDescription("Get information about a service in the cluster"),
mcp.WithString("service_id", mcp.Description("The ID of the service to get information about"), mcp.Required()),
mcp.WithString("node_name", mcp.Description("The name of the node to get the service information for")),
- ), handleGetServiceInformation)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_service_information", handleGetServiceInformation)))
s.AddTool(mcp.NewTool("cilium_update_service",
mcp.WithDescription("Update a service in the cluster"),
@@ -335,35 +331,258 @@ func RegisterCiliumTools(s *server.MCPServer, kubeconfig string) {
mcp.WithString("protocol", mcp.Description("The protocol to update the service with")),
mcp.WithString("states", mcp.Description("The states to update the service with")),
mcp.WithString("node_name", mcp.Description("The name of the node to update the service on")),
- ), handleUpdateService)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_update_service", handleUpdateService)))
s.AddTool(mcp.NewTool("cilium_delete_service",
mcp.WithDescription("Delete a service from the cluster"),
mcp.WithString("service_id", mcp.Description("The ID of the service to delete")),
mcp.WithString("all", mcp.Description("Whether to delete all services (true/false)")),
mcp.WithString("node_name", mcp.Description("The name of the node to delete the service from")),
- ), handleDeleteService)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_delete_service", handleDeleteService)))
+
+ // Debug tools (previously in RegisterCiliumDbgTools)
+ s.AddTool(mcp.NewTool("cilium_get_endpoint_details",
+ mcp.WithDescription("List the details of an endpoint in the cluster"),
+ mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to get details for")),
+ mcp.WithString("labels", mcp.Description("The labels of the endpoint to get details for")),
+ mcp.WithString("output_format", mcp.Description("The output format of the endpoint details (json, yaml, jsonpath)")),
+ mcp.WithString("node_name", mcp.Description("The name of the node to get the endpoint details for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_endpoint_details", handleGetEndpointDetails)))
+
+ s.AddTool(mcp.NewTool("cilium_get_endpoint_logs",
+ mcp.WithDescription("Get the logs of an endpoint in the cluster"),
+ mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to get logs for"), mcp.Required()),
+ mcp.WithString("node_name", mcp.Description("The name of the node to get the endpoint logs for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_endpoint_logs", handleGetEndpointLogs)))
+
+ s.AddTool(mcp.NewTool("cilium_get_endpoint_health",
+ mcp.WithDescription("Get the health of an endpoint in the cluster"),
+ mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to get health for"), mcp.Required()),
+ mcp.WithString("node_name", mcp.Description("The name of the node to get the endpoint health for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_endpoint_health", handleGetEndpointHealth)))
+
+ s.AddTool(mcp.NewTool("cilium_manage_endpoint_labels",
+ mcp.WithDescription("Manage the labels (add or delete) of an endpoint in the cluster"),
+ mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to manage labels for"), mcp.Required()),
+ mcp.WithString("labels", mcp.Description("Space-separated labels to manage (e.g., 'key1=value1 key2=value2')"), mcp.Required()),
+ mcp.WithString("action", mcp.Description("The action to perform on the labels (add or delete)"), mcp.Required()),
+ mcp.WithString("node_name", mcp.Description("The name of the node to manage the endpoint labels on")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_manage_endpoint_labels", handleManageEndpointLabels)))
+
+ s.AddTool(mcp.NewTool("cilium_manage_endpoint_config",
+ mcp.WithDescription("Manage the configuration of an endpoint in the cluster"),
+ mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to manage configuration for"), mcp.Required()),
+ mcp.WithString("config", mcp.Description("The configuration to manage for the endpoint provided as a space-separated list of key-value pairs (e.g. 'DropNotification=false TraceNotification=false')"), mcp.Required()),
+ mcp.WithString("node_name", mcp.Description("The name of the node to manage the endpoint configuration on")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_manage_endpoint_config", handleManageEndpointConfiguration)))
+
+ s.AddTool(mcp.NewTool("cilium_disconnect_endpoint",
+ mcp.WithDescription("Disconnect an endpoint from the network"),
+ mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to disconnect"), mcp.Required()),
+ mcp.WithString("node_name", mcp.Description("The name of the node to disconnect the endpoint from")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_disconnect_endpoint", handleDisconnectEndpoint)))
+
+ s.AddTool(mcp.NewTool("cilium_list_identities",
+ mcp.WithDescription("List all identities in the cluster"),
+ mcp.WithString("node_name", mcp.Description("The name of the node to list the identities for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_identities", handleListIdentities)))
+
+ s.AddTool(mcp.NewTool("cilium_get_identity_details",
+ mcp.WithDescription("Get the details of an identity in the cluster"),
+ mcp.WithString("identity_id", mcp.Description("The ID of the identity to get details for"), mcp.Required()),
+ mcp.WithString("node_name", mcp.Description("The name of the node to get the identity details for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_identity_details", handleGetIdentityDetails)))
+
+ s.AddTool(mcp.NewTool("cilium_request_debugging_information",
+ mcp.WithDescription("Request debugging information for the cluster"),
+ mcp.WithString("node_name", mcp.Description("The name of the node to get the debugging information for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_request_debugging_information", handleRequestDebuggingInformation)))
+
+ s.AddTool(mcp.NewTool("cilium_display_encryption_state",
+ mcp.WithDescription("Display the encryption state for the cluster"),
+ mcp.WithString("node_name", mcp.Description("The name of the node to get the encryption state for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_display_encryption_state", handleDisplayEncryptionState)))
+
+ s.AddTool(mcp.NewTool("cilium_flush_ipsec_state",
+ mcp.WithDescription("Flush the IPsec state for the cluster"),
+ mcp.WithString("node_name", mcp.Description("The name of the node to flush the IPsec state for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_flush_ipsec_state", handleFlushIPsecState)))
+
+ s.AddTool(mcp.NewTool("cilium_list_envoy_config",
+ mcp.WithDescription("List the Envoy configuration for a resource in the cluster"),
+ mcp.WithString("resource_name", mcp.Description("The name of the resource to get the Envoy configuration for"), mcp.Required()),
+ mcp.WithString("node_name", mcp.Description("The name of the node to get the Envoy configuration for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_envoy_config", handleListEnvoyConfig)))
+
+ s.AddTool(mcp.NewTool("cilium_fqdn_cache",
+ mcp.WithDescription("Manage the FQDN cache for the cluster"),
+ mcp.WithString("command", mcp.Description("The command to perform on the FQDN cache (list, clean, or a specific command)"), mcp.Required()),
+ mcp.WithString("node_name", mcp.Description("The name of the node to manage the FQDN cache for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_fqdn_cache", handleFQDNCache)))
+
+ s.AddTool(mcp.NewTool("cilium_show_dns_names",
+ mcp.WithDescription("Show the DNS names for the cluster"),
+ mcp.WithString("node_name", mcp.Description("The name of the node to get the DNS names for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_show_dns_names", handleShowDNSNames)))
+
+ s.AddTool(mcp.NewTool("cilium_list_ip_addresses",
+ mcp.WithDescription("List the IP addresses for the cluster"),
+ mcp.WithString("node_name", mcp.Description("The name of the node to get the IP addresses for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_ip_addresses", handleListIPAddresses)))
+
+ s.AddTool(mcp.NewTool("cilium_show_ip_cache_information",
+ mcp.WithDescription("Show the IP cache information for the cluster"),
+ mcp.WithString("cidr", mcp.Description("The CIDR of the IP to get cache information for")),
+ mcp.WithString("labels", mcp.Description("The labels of the IP to get cache information for")),
+ mcp.WithString("node_name", mcp.Description("The name of the node to get the IP cache information for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_show_ip_cache_information", handleShowIPCacheInformation)))
+
+ s.AddTool(mcp.NewTool("cilium_delete_key_from_kv_store",
+ mcp.WithDescription("Delete a key from the kvstore for the cluster"),
+ mcp.WithString("key", mcp.Description("The key to delete from the kvstore"), mcp.Required()),
+ mcp.WithString("node_name", mcp.Description("The name of the node to delete the key from")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_delete_key_from_kv_store", handleDeleteKeyFromKVStore)))
+
+ s.AddTool(mcp.NewTool("cilium_get_kv_store_key",
+ mcp.WithDescription("Get a key from the kvstore for the cluster"),
+ mcp.WithString("key", mcp.Description("The key to get from the kvstore"), mcp.Required()),
+ mcp.WithString("node_name", mcp.Description("The name of the node to get the key from")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_kv_store_key", handleGetKVStoreKey)))
+
+ s.AddTool(mcp.NewTool("cilium_set_kv_store_key",
+ mcp.WithDescription("Set a key in the kvstore for the cluster"),
+ mcp.WithString("key", mcp.Description("The key to set in the kvstore"), mcp.Required()),
+ mcp.WithString("value", mcp.Description("The value to set in the kvstore"), mcp.Required()),
+ mcp.WithString("node_name", mcp.Description("The name of the node to set the key in")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_set_kv_store_key", handleSetKVStoreKey)))
+
+ s.AddTool(mcp.NewTool("cilium_show_load_information",
+ mcp.WithDescription("Show load information for the cluster"),
+ mcp.WithString("node_name", mcp.Description("The name of the node to get the load information for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_show_load_information", handleShowLoadInformation)))
+
+ s.AddTool(mcp.NewTool("cilium_list_local_redirect_policies",
+ mcp.WithDescription("List local redirect policies for the cluster"),
+ mcp.WithString("node_name", mcp.Description("The name of the node to get the local redirect policies for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_local_redirect_policies", handleListLocalRedirectPolicies)))
+
+ s.AddTool(mcp.NewTool("cilium_list_bpf_map_events",
+ mcp.WithDescription("List BPF map events for the cluster"),
+ mcp.WithString("map_name", mcp.Description("The name of the BPF map to get events for"), mcp.Required()),
+ mcp.WithString("node_name", mcp.Description("The name of the node to get the BPF map events for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_bpf_map_events", handleListBPFMapEvents)))
+
+ s.AddTool(mcp.NewTool("cilium_get_bpf_map",
+ mcp.WithDescription("Get BPF map for the cluster"),
+ mcp.WithString("map_name", mcp.Description("The name of the BPF map to get"), mcp.Required()),
+ mcp.WithString("node_name", mcp.Description("The name of the node to get the BPF map for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_bpf_map", handleGetBPFMap)))
+
+ s.AddTool(mcp.NewTool("cilium_list_bpf_maps",
+ mcp.WithDescription("List BPF maps for the cluster"),
+ mcp.WithString("node_name", mcp.Description("The name of the node to get the BPF maps for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_bpf_maps", handleListBPFMaps)))
+
+ s.AddTool(mcp.NewTool("cilium_list_metrics",
+ mcp.WithDescription("List metrics for the cluster"),
+ mcp.WithString("match_pattern", mcp.Description("The match pattern to filter metrics by")),
+ mcp.WithString("node_name", mcp.Description("The name of the node to get the metrics for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_metrics", handleListMetrics)))
+
+ s.AddTool(mcp.NewTool("cilium_list_cluster_nodes",
+ mcp.WithDescription("List cluster nodes for the cluster"),
+ mcp.WithString("node_name", mcp.Description("The name of the node to get the cluster nodes for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_cluster_nodes", handleListClusterNodes)))
+
+ s.AddTool(mcp.NewTool("cilium_list_node_ids",
+ mcp.WithDescription("List node IDs for the cluster"),
+ mcp.WithString("node_name", mcp.Description("The name of the node to get the node IDs for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_node_ids", handleListNodeIds)))
+
+ s.AddTool(mcp.NewTool("cilium_display_policy_node_information",
+ mcp.WithDescription("Display policy node information for the cluster"),
+ mcp.WithString("labels", mcp.Description("The labels to get policy node information for")),
+ mcp.WithString("node_name", mcp.Description("The name of the node to get policy node information for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_display_policy_node_information", handleDisplayPolicyNodeInformation)))
+
+ s.AddTool(mcp.NewTool("cilium_delete_policy_rules",
+ mcp.WithDescription("Delete policy rules for the cluster"),
+ mcp.WithString("labels", mcp.Description("The labels to delete policy rules for")),
+ mcp.WithString("all", mcp.Description("Whether to delete all policy rules")),
+ mcp.WithString("node_name", mcp.Description("The name of the node to delete policy rules for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_delete_policy_rules", handleDeletePolicyRules)))
+
+ s.AddTool(mcp.NewTool("cilium_display_selectors",
+ mcp.WithDescription("Display selectors for the cluster"),
+ mcp.WithString("node_name", mcp.Description("The name of the node to get selectors for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_display_selectors", handleDisplaySelectors)))
+
+ s.AddTool(mcp.NewTool("cilium_list_xdp_cidr_filters",
+ mcp.WithDescription("List XDP CIDR filters for the cluster"),
+ mcp.WithString("node_name", mcp.Description("The name of the node to get the XDP CIDR filters for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_xdp_cidr_filters", handleListXDPCIDRFilters)))
+
+ s.AddTool(mcp.NewTool("cilium_update_xdp_cidr_filters",
+ mcp.WithDescription("Update XDP CIDR filters for the cluster"),
+ mcp.WithString("cidr_prefixes", mcp.Description("The CIDR prefixes to update the XDP filters for"), mcp.Required()),
+ mcp.WithString("revision", mcp.Description("The revision of the XDP filters to update")),
+ mcp.WithString("node_name", mcp.Description("The name of the node to update the XDP filters for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_update_xdp_cidr_filters", handleUpdateXDPCIDRFilters)))
+
+ s.AddTool(mcp.NewTool("cilium_delete_xdp_cidr_filters",
+ mcp.WithDescription("Delete XDP CIDR filters for the cluster"),
+ mcp.WithString("cidr_prefixes", mcp.Description("The CIDR prefixes to delete the XDP filters for"), mcp.Required()),
+ mcp.WithString("revision", mcp.Description("The revision of the XDP filters to delete")),
+ mcp.WithString("node_name", mcp.Description("The name of the node to delete the XDP filters for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_delete_xdp_cidr_filters", handleDeleteXDPCIDRFilters)))
+
+ s.AddTool(mcp.NewTool("cilium_validate_cilium_network_policies",
+ mcp.WithDescription("Validate Cilium network policies for the cluster"),
+ mcp.WithString("enable_k8s", mcp.Description("Whether to enable k8s API discovery")),
+ mcp.WithString("enable_k8s_api_discovery", mcp.Description("Whether to enable k8s API discovery")),
+ mcp.WithString("node_name", mcp.Description("The name of the node to validate the Cilium network policies for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_validate_cilium_network_policies", handleValidateCiliumNetworkPolicies)))
+
+ s.AddTool(mcp.NewTool("cilium_list_pcap_recorders",
+ mcp.WithDescription("List PCAP recorders for the cluster"),
+ mcp.WithString("node_name", mcp.Description("The name of the node to get the PCAP recorders for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_pcap_recorders", handleListPCAPRecorders)))
+
+ s.AddTool(mcp.NewTool("cilium_get_pcap_recorder",
+ mcp.WithDescription("Get a PCAP recorder for the cluster"),
+ mcp.WithString("recorder_id", mcp.Description("The ID of the PCAP recorder to get"), mcp.Required()),
+ mcp.WithString("node_name", mcp.Description("The name of the node to get the PCAP recorder for")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_pcap_recorder", handleGetPCAPRecorder)))
+
+ s.AddTool(mcp.NewTool("cilium_delete_pcap_recorder",
+ mcp.WithDescription("Delete a PCAP recorder for the cluster"),
+ mcp.WithString("recorder_id", mcp.Description("The ID of the PCAP recorder to delete"), mcp.Required()),
+ mcp.WithString("node_name", mcp.Description("The name of the node to delete the PCAP recorder from")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_delete_pcap_recorder", handleDeletePCAPRecorder)))
+
+ s.AddTool(mcp.NewTool("cilium_update_pcap_recorder",
+ mcp.WithDescription("Update a PCAP recorder for the cluster"),
+ mcp.WithString("recorder_id", mcp.Description("The ID of the PCAP recorder to update"), mcp.Required()),
+ mcp.WithString("filters", mcp.Description("The filters to update the PCAP recorder with"), mcp.Required()),
+ mcp.WithString("caplen", mcp.Description("The caplen to update the PCAP recorder with")),
+ mcp.WithString("id", mcp.Description("The id to update the PCAP recorder with")),
+ mcp.WithString("node_name", mcp.Description("The name of the node to update the PCAP recorder on")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_update_pcap_recorder", handleUpdatePCAPRecorder)))
}
// -- Debug Tools --
func getCiliumPodNameWithContext(ctx context.Context, nodeName string) (string, error) {
- args := []string{"get", "pod", "-l", "k8s-app=cilium", "-o", "name", "-n", "kube-system"}
- if nodeName != "" {
- args = append(args, "--field-selector", "spec.nodeName="+nodeName)
- }
- podName, err := utils.RunCommandWithContext(ctx, "kubectl", args)
- if err != nil {
- return "", fmt.Errorf("failed to get cilium pod name: %v", err)
- }
- if podName == "" {
- return "", fmt.Errorf("no cilium pod found")
- }
- return strings.TrimSpace(podName), nil
+ args := []string{"get", "pods", "-n", "kube-system", "--selector=k8s-app=cilium", fmt.Sprintf("--field-selector=spec.nodeName=%s", nodeName), "-o", "jsonpath={.items[0].metadata.name}"}
+ kubeconfigPath := utils.GetKubeconfig()
+ return commands.NewCommandBuilder("kubectl").
+ WithArgs(args...).
+ WithKubeconfig(kubeconfigPath).
+ Execute(ctx)
}
-func runCiliumDbgCommand(command, nodeName string) (string, error) {
- return runCiliumDbgCommandWithContext(context.Background(), command, nodeName)
+func runCiliumDbgCommand(ctx context.Context, command, nodeName string) (string, error) {
+ return runCiliumDbgCommandWithContext(ctx, command, nodeName)
}
func runCiliumDbgCommandWithContext(ctx context.Context, command, nodeName string) (string, error) {
@@ -371,10 +590,12 @@ func runCiliumDbgCommandWithContext(ctx context.Context, command, nodeName strin
if err != nil {
return "", err
}
- cmdParts := strings.Fields(command)
- args := []string{"exec", "-it", podName, "-n", "kube-system", "--", "cilium-dbg"}
- args = append(args, cmdParts...)
- return utils.RunCommandWithContext(ctx, "kubectl", args)
+ args := []string{"exec", "-it", podName, "--", "cilium-dbg", command}
+ kubeconfigPath := utils.GetKubeconfig()
+ return commands.NewCommandBuilder("kubectl").
+ WithArgs(args...).
+ WithKubeconfig(kubeconfigPath).
+ Execute(ctx)
}
func handleGetEndpointDetails(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
@@ -392,7 +613,7 @@ func handleGetEndpointDetails(ctx context.Context, request mcp.CallToolRequest)
return mcp.NewToolResultError("either endpoint_id or labels must be provided"), nil
}
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to get endpoint details: %v", err)), nil
}
@@ -408,7 +629,7 @@ func handleGetEndpointLogs(ctx context.Context, request mcp.CallToolRequest) (*m
}
cmd := fmt.Sprintf("endpoint logs %s", endpointID)
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to get endpoint logs: %v", err)), nil
}
@@ -424,7 +645,7 @@ func handleGetEndpointHealth(ctx context.Context, request mcp.CallToolRequest) (
}
cmd := fmt.Sprintf("endpoint health %s", endpointID)
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to get endpoint health: %v", err)), nil
}
@@ -442,7 +663,7 @@ func handleManageEndpointLabels(ctx context.Context, request mcp.CallToolRequest
}
cmd := fmt.Sprintf("endpoint labels %s --%s %s", endpointID, action, labels)
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to manage endpoint labels: %v", err)), nil
}
@@ -462,7 +683,7 @@ func handleManageEndpointConfiguration(ctx context.Context, request mcp.CallTool
}
command := fmt.Sprintf("endpoint config %s %s", endpointID, config)
- output, err := runCiliumDbgCommand(command, nodeName)
+ output, err := runCiliumDbgCommand(ctx, command, nodeName)
if err != nil {
return mcp.NewToolResultError("Error managing endpoint configuration: " + err.Error()), nil
}
@@ -479,7 +700,7 @@ func handleDisconnectEndpoint(ctx context.Context, request mcp.CallToolRequest)
}
cmd := fmt.Sprintf("endpoint disconnect %s", endpointID)
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to disconnect endpoint: %v", err)), nil
}
@@ -489,7 +710,7 @@ func handleDisconnectEndpoint(ctx context.Context, request mcp.CallToolRequest)
func handleGetEndpointsList(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
nodeName := mcp.ParseString(request, "node_name", "")
- output, err := runCiliumDbgCommand("endpoint list", nodeName)
+ output, err := runCiliumDbgCommand(ctx, "endpoint list", nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to get endpoints list: %v", err)), nil
}
@@ -499,7 +720,7 @@ func handleGetEndpointsList(ctx context.Context, request mcp.CallToolRequest) (*
func handleListIdentities(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
nodeName := mcp.ParseString(request, "node_name", "")
- output, err := runCiliumDbgCommand("identity list", nodeName)
+ output, err := runCiliumDbgCommand(ctx, "identity list", nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to list identities: %v", err)), nil
}
@@ -515,7 +736,7 @@ func handleGetIdentityDetails(ctx context.Context, request mcp.CallToolRequest)
}
cmd := fmt.Sprintf("identity get %s", identityID)
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to get identity details: %v", err)), nil
}
@@ -539,7 +760,7 @@ func handleShowConfigurationOptions(ctx context.Context, request mcp.CallToolReq
cmd = "endpoint config"
}
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to show configuration options: %v", err)), nil
}
@@ -561,7 +782,7 @@ func handleToggleConfigurationOption(ctx context.Context, request mcp.CallToolRe
}
cmd := fmt.Sprintf("endpoint config %s=%s", option, valueStr)
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to toggle configuration option: %v", err)), nil
}
@@ -571,7 +792,7 @@ func handleToggleConfigurationOption(ctx context.Context, request mcp.CallToolRe
func handleRequestDebuggingInformation(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
nodeName := mcp.ParseString(request, "node_name", "")
- output, err := runCiliumDbgCommand("debuginfo", nodeName)
+ output, err := runCiliumDbgCommand(ctx, "debuginfo", nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to request debugging information: %v", err)), nil
}
@@ -581,7 +802,7 @@ func handleRequestDebuggingInformation(ctx context.Context, request mcp.CallTool
func handleDisplayEncryptionState(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
nodeName := mcp.ParseString(request, "node_name", "")
- output, err := runCiliumDbgCommand("encrypt status", nodeName)
+ output, err := runCiliumDbgCommand(ctx, "encrypt status", nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to display encryption state: %v", err)), nil
}
@@ -591,7 +812,7 @@ func handleDisplayEncryptionState(ctx context.Context, request mcp.CallToolReque
func handleFlushIPsecState(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
nodeName := mcp.ParseString(request, "node_name", "")
- output, err := runCiliumDbgCommand("encrypt flush -f", nodeName)
+ output, err := runCiliumDbgCommand(ctx, "encrypt flush -f", nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to flush IPsec state: %v", err)), nil
}
@@ -607,7 +828,7 @@ func handleListEnvoyConfig(ctx context.Context, request mcp.CallToolRequest) (*m
}
cmd := fmt.Sprintf("envoy admin %s", resourceName)
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to list Envoy config: %v", err)), nil
}
@@ -620,12 +841,12 @@ func handleFQDNCache(ctx context.Context, request mcp.CallToolRequest) (*mcp.Cal
var cmd string
if command == "clean" {
- cmd = "fqdn cache clean -f"
+ cmd = "fqdn cache clean"
} else {
- cmd = fmt.Sprintf("fqdn cache %s", command)
+ cmd = "fqdn cache list"
}
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to manage FQDN cache: %v", err)), nil
}
@@ -635,7 +856,7 @@ func handleFQDNCache(ctx context.Context, request mcp.CallToolRequest) (*mcp.Cal
func handleShowDNSNames(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
nodeName := mcp.ParseString(request, "node_name", "")
- output, err := runCiliumDbgCommand("dns names", nodeName)
+ output, err := runCiliumDbgCommand(ctx, "dns names", nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to show DNS names: %v", err)), nil
}
@@ -645,7 +866,7 @@ func handleShowDNSNames(ctx context.Context, request mcp.CallToolRequest) (*mcp.
func handleListIPAddresses(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
nodeName := mcp.ParseString(request, "node_name", "")
- output, err := runCiliumDbgCommand("ip list", nodeName)
+ output, err := runCiliumDbgCommand(ctx, "ip list", nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to list IP addresses: %v", err)), nil
}
@@ -666,7 +887,7 @@ func handleShowIPCacheInformation(ctx context.Context, request mcp.CallToolReque
return mcp.NewToolResultError("either cidr or labels must be provided"), nil
}
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to show IP cache information: %v", err)), nil
}
@@ -682,7 +903,7 @@ func handleDeleteKeyFromKVStore(ctx context.Context, request mcp.CallToolRequest
}
cmd := fmt.Sprintf("kvstore delete %s", key)
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to delete key from kvstore: %v", err)), nil
}
@@ -698,7 +919,7 @@ func handleGetKVStoreKey(ctx context.Context, request mcp.CallToolRequest) (*mcp
}
cmd := fmt.Sprintf("kvstore get %s", key)
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to get key from kvstore: %v", err)), nil
}
@@ -715,7 +936,7 @@ func handleSetKVStoreKey(ctx context.Context, request mcp.CallToolRequest) (*mcp
}
cmd := fmt.Sprintf("kvstore set %s=%s", key, value)
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to set key in kvstore: %v", err)), nil
}
@@ -725,7 +946,7 @@ func handleSetKVStoreKey(ctx context.Context, request mcp.CallToolRequest) (*mcp
func handleShowLoadInformation(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
nodeName := mcp.ParseString(request, "node_name", "")
- output, err := runCiliumDbgCommand("loadinfo", nodeName)
+ output, err := runCiliumDbgCommand(ctx, "loadinfo", nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to show load information: %v", err)), nil
}
@@ -735,7 +956,7 @@ func handleShowLoadInformation(ctx context.Context, request mcp.CallToolRequest)
func handleListLocalRedirectPolicies(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
nodeName := mcp.ParseString(request, "node_name", "")
- output, err := runCiliumDbgCommand("lrp list", nodeName)
+ output, err := runCiliumDbgCommand(ctx, "lrp list", nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to list local redirect policies: %v", err)), nil
}
@@ -751,7 +972,7 @@ func handleListBPFMapEvents(ctx context.Context, request mcp.CallToolRequest) (*
}
cmd := fmt.Sprintf("bpf map events %s", mapName)
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to list BPF map events: %v", err)), nil
}
@@ -767,7 +988,7 @@ func handleGetBPFMap(ctx context.Context, request mcp.CallToolRequest) (*mcp.Cal
}
cmd := fmt.Sprintf("bpf map get %s", mapName)
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to get BPF map: %v", err)), nil
}
@@ -777,7 +998,7 @@ func handleGetBPFMap(ctx context.Context, request mcp.CallToolRequest) (*mcp.Cal
func handleListBPFMaps(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
nodeName := mcp.ParseString(request, "node_name", "")
- output, err := runCiliumDbgCommand("bpf map list", nodeName)
+ output, err := runCiliumDbgCommand(ctx, "bpf map list", nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to list BPF maps: %v", err)), nil
}
@@ -795,7 +1016,7 @@ func handleListMetrics(ctx context.Context, request mcp.CallToolRequest) (*mcp.C
cmd = "metrics list"
}
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to list metrics: %v", err)), nil
}
@@ -805,7 +1026,7 @@ func handleListMetrics(ctx context.Context, request mcp.CallToolRequest) (*mcp.C
func handleListClusterNodes(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
nodeName := mcp.ParseString(request, "node_name", "")
- output, err := runCiliumDbgCommand("nodes list", nodeName)
+ output, err := runCiliumDbgCommand(ctx, "nodes list", nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to list cluster nodes: %v", err)), nil
}
@@ -815,7 +1036,7 @@ func handleListClusterNodes(ctx context.Context, request mcp.CallToolRequest) (*
func handleListNodeIds(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
nodeName := mcp.ParseString(request, "node_name", "")
- output, err := runCiliumDbgCommand("nodeid list", nodeName)
+ output, err := runCiliumDbgCommand(ctx, "nodeid list", nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to list node IDs: %v", err)), nil
}
@@ -833,7 +1054,7 @@ func handleDisplayPolicyNodeInformation(ctx context.Context, request mcp.CallToo
cmd = "policy get"
}
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to display policy node information: %v", err)), nil
}
@@ -854,7 +1075,7 @@ func handleDeletePolicyRules(ctx context.Context, request mcp.CallToolRequest) (
return mcp.NewToolResultError("either labels or all=true must be provided"), nil
}
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to delete policy rules: %v", err)), nil
}
@@ -864,7 +1085,7 @@ func handleDeletePolicyRules(ctx context.Context, request mcp.CallToolRequest) (
func handleDisplaySelectors(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
nodeName := mcp.ParseString(request, "node_name", "")
- output, err := runCiliumDbgCommand("policy selectors", nodeName)
+ output, err := runCiliumDbgCommand(ctx, "policy selectors", nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to display selectors: %v", err)), nil
}
@@ -874,7 +1095,7 @@ func handleDisplaySelectors(ctx context.Context, request mcp.CallToolRequest) (*
func handleListXDPCIDRFilters(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
nodeName := mcp.ParseString(request, "node_name", "")
- output, err := runCiliumDbgCommand("prefilter list", nodeName)
+ output, err := runCiliumDbgCommand(ctx, "prefilter list", nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to list XDP CIDR filters: %v", err)), nil
}
@@ -897,7 +1118,7 @@ func handleUpdateXDPCIDRFilters(ctx context.Context, request mcp.CallToolRequest
cmd = fmt.Sprintf("prefilter update --cidr %s", cidrPrefixes)
}
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to update XDP CIDR filters: %v", err)), nil
}
@@ -920,7 +1141,7 @@ func handleDeleteXDPCIDRFilters(ctx context.Context, request mcp.CallToolRequest
cmd = fmt.Sprintf("prefilter delete --cidr %s", cidrPrefixes)
}
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to delete XDP CIDR filters: %v", err)), nil
}
@@ -940,7 +1161,7 @@ func handleValidateCiliumNetworkPolicies(ctx context.Context, request mcp.CallTo
cmd += " --enable-k8s-api-discovery"
}
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to validate Cilium network policies: %v", err)), nil
}
@@ -950,7 +1171,7 @@ func handleValidateCiliumNetworkPolicies(ctx context.Context, request mcp.CallTo
func handleListPCAPRecorders(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
nodeName := mcp.ParseString(request, "node_name", "")
- output, err := runCiliumDbgCommand("recorder list", nodeName)
+ output, err := runCiliumDbgCommand(ctx, "recorder list", nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to list PCAP recorders: %v", err)), nil
}
@@ -966,7 +1187,7 @@ func handleGetPCAPRecorder(ctx context.Context, request mcp.CallToolRequest) (*m
}
cmd := fmt.Sprintf("recorder get %s", recorderID)
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to get PCAP recorder: %v", err)), nil
}
@@ -982,7 +1203,7 @@ func handleDeletePCAPRecorder(ctx context.Context, request mcp.CallToolRequest)
}
cmd := fmt.Sprintf("recorder delete %s", recorderID)
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to delete PCAP recorder: %v", err)), nil
}
@@ -1001,7 +1222,7 @@ func handleUpdatePCAPRecorder(ctx context.Context, request mcp.CallToolRequest)
}
cmd := fmt.Sprintf("recorder update %s --filters %s --caplen %s --id %s", recorderID, filters, caplen, id)
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to update PCAP recorder: %v", err)), nil
}
@@ -1019,7 +1240,7 @@ func handleListServices(ctx context.Context, request mcp.CallToolRequest) (*mcp.
cmd = "service list"
}
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to list services: %v", err)), nil
}
@@ -1035,7 +1256,7 @@ func handleGetServiceInformation(ctx context.Context, request mcp.CallToolReques
}
cmd := fmt.Sprintf("service get %s", serviceID)
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to get service information: %v", err)), nil
}
@@ -1056,7 +1277,7 @@ func handleDeleteService(ctx context.Context, request mcp.CallToolRequest) (*mcp
return mcp.NewToolResultError("either service_id or all=true must be provided"), nil
}
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to delete service: %v", err)), nil
}
@@ -1115,7 +1336,7 @@ func handleUpdateService(ctx context.Context, request mcp.CallToolRequest) (*mcp
cmd += " --local-redirect"
}
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to update service: %v", err)), nil
}
@@ -1155,239 +1376,9 @@ func handleGetDaemonStatus(ctx context.Context, request mcp.CallToolRequest) (*m
cmd += " --brief"
}
- output, err := runCiliumDbgCommand(cmd, nodeName)
+ output, err := runCiliumDbgCommand(ctx, cmd, nodeName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to get daemon status: %v", err)), nil
}
return mcp.NewToolResultText(output), nil
}
-
-func RegisterCiliumDbgTools(s *server.MCPServer) {
- s.AddTool(mcp.NewTool("cilium_get_endpoint_details",
- mcp.WithDescription("List the details of an endpoint in the cluster"),
- mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to get details for")),
- mcp.WithString("labels", mcp.Description("The labels of the endpoint to get details for")),
- mcp.WithString("output_format", mcp.Description("The output format of the endpoint details (json, yaml, jsonpath)")),
- mcp.WithString("node_name", mcp.Description("The name of the node to get the endpoint details for")),
- ), handleGetEndpointDetails)
-
- s.AddTool(mcp.NewTool("cilium_get_endpoint_logs",
- mcp.WithDescription("Get the logs of an endpoint in the cluster"),
- mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to get logs for"), mcp.Required()),
- mcp.WithString("node_name", mcp.Description("The name of the node to get the endpoint logs for")),
- ), handleGetEndpointLogs)
-
- s.AddTool(mcp.NewTool("cilium_get_endpoint_health",
- mcp.WithDescription("Get the health of an endpoint in the cluster"),
- mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to get health for"), mcp.Required()),
- mcp.WithString("node_name", mcp.Description("The name of the node to get the endpoint health for")),
- ), handleGetEndpointHealth)
-
- s.AddTool(mcp.NewTool("cilium_manage_endpoint_labels",
- mcp.WithDescription("Manage the labels (add or delete) of an endpoint in the cluster"),
- mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to manage labels for"), mcp.Required()),
- mcp.WithString("labels", mcp.Description("Space-separated labels to manage (e.g., 'key1=value1 key2=value2')"), mcp.Required()),
- mcp.WithString("action", mcp.Description("The action to perform on the labels (add or delete)"), mcp.Required()),
- mcp.WithString("node_name", mcp.Description("The name of the node to manage the endpoint labels on")),
- ), handleManageEndpointLabels)
-
- s.AddTool(mcp.NewTool("cilium_manage_endpoint_config",
- mcp.WithDescription("Manage the configuration of an endpoint in the cluster"),
- mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to manage configuration for"), mcp.Required()),
- mcp.WithString("config", mcp.Description("The configuration to manage for the endpoint provided as a space-separated list of key-value pairs (e.g. 'DropNotification=false TraceNotification=false')"), mcp.Required()),
- mcp.WithString("node_name", mcp.Description("The name of the node to manage the endpoint configuration on")),
- ), handleManageEndpointConfiguration)
-
- s.AddTool(mcp.NewTool("cilium_disconnect_endpoint",
- mcp.WithDescription("Disconnect an endpoint from the network"),
- mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to disconnect"), mcp.Required()),
- mcp.WithString("node_name", mcp.Description("The name of the node to disconnect the endpoint from")),
- ), handleDisconnectEndpoint)
-
- s.AddTool(mcp.NewTool("cilium_list_identities",
- mcp.WithDescription("List all identities in the cluster"),
- mcp.WithString("node_name", mcp.Description("The name of the node to list the identities for")),
- ), handleListIdentities)
-
- s.AddTool(mcp.NewTool("cilium_get_identity_details",
- mcp.WithDescription("Get the details of an identity in the cluster"),
- mcp.WithString("identity_id", mcp.Description("The ID of the identity to get details for"), mcp.Required()),
- mcp.WithString("node_name", mcp.Description("The name of the node to get the identity details for")),
- ), handleGetIdentityDetails)
-
- s.AddTool(mcp.NewTool("cilium_request_debugging_information",
- mcp.WithDescription("Request debugging information for the cluster"),
- mcp.WithString("node_name", mcp.Description("The name of the node to get the debugging information for")),
- ), handleRequestDebuggingInformation)
-
- s.AddTool(mcp.NewTool("cilium_display_encryption_state",
- mcp.WithDescription("Display the encryption state for the cluster"),
- mcp.WithString("node_name", mcp.Description("The name of the node to get the encryption state for")),
- ), handleDisplayEncryptionState)
-
- s.AddTool(mcp.NewTool("cilium_flush_ipsec_state",
- mcp.WithDescription("Flush the IPsec state for the cluster"),
- mcp.WithString("node_name", mcp.Description("The name of the node to flush the IPsec state for")),
- ), handleFlushIPsecState)
-
- s.AddTool(mcp.NewTool("cilium_list_envoy_config",
- mcp.WithDescription("List the Envoy configuration for a resource in the cluster"),
- mcp.WithString("resource_name", mcp.Description("The name of the resource to get the Envoy configuration for"), mcp.Required()),
- mcp.WithString("node_name", mcp.Description("The name of the node to get the Envoy configuration for")),
- ), handleListEnvoyConfig)
-
- s.AddTool(mcp.NewTool("cilium_fqdn_cache",
- mcp.WithDescription("Manage the FQDN cache for the cluster"),
- mcp.WithString("command", mcp.Description("The command to perform on the FQDN cache (list, clean, or a specific command)"), mcp.Required()),
- mcp.WithString("node_name", mcp.Description("The name of the node to manage the FQDN cache for")),
- ), handleFQDNCache)
-
- s.AddTool(mcp.NewTool("cilium_show_dns_names",
- mcp.WithDescription("Show the DNS names for the cluster"),
- mcp.WithString("node_name", mcp.Description("The name of the node to get the DNS names for")),
- ), handleShowDNSNames)
-
- s.AddTool(mcp.NewTool("cilium_list_ip_addresses",
- mcp.WithDescription("List the IP addresses for the cluster"),
- mcp.WithString("node_name", mcp.Description("The name of the node to get the IP addresses for")),
- ), handleListIPAddresses)
-
- s.AddTool(mcp.NewTool("cilium_show_ip_cache_information",
- mcp.WithDescription("Show the IP cache information for the cluster"),
- mcp.WithString("cidr", mcp.Description("The CIDR of the IP to get cache information for")),
- mcp.WithString("labels", mcp.Description("The labels of the IP to get cache information for")),
- mcp.WithString("node_name", mcp.Description("The name of the node to get the IP cache information for")),
- ), handleShowIPCacheInformation)
-
- s.AddTool(mcp.NewTool("cilium_delete_key_from_kv_store",
- mcp.WithDescription("Delete a key from the kvstore for the cluster"),
- mcp.WithString("key", mcp.Description("The key to delete from the kvstore"), mcp.Required()),
- mcp.WithString("node_name", mcp.Description("The name of the node to delete the key from")),
- ), handleDeleteKeyFromKVStore)
-
- s.AddTool(mcp.NewTool("cilium_get_kv_store_key",
- mcp.WithDescription("Get a key from the kvstore for the cluster"),
- mcp.WithString("key", mcp.Description("The key to get from the kvstore"), mcp.Required()),
- mcp.WithString("node_name", mcp.Description("The name of the node to get the key from")),
- ), handleGetKVStoreKey)
-
- s.AddTool(mcp.NewTool("cilium_set_kv_store_key",
- mcp.WithDescription("Set a key in the kvstore for the cluster"),
- mcp.WithString("key", mcp.Description("The key to set in the kvstore"), mcp.Required()),
- mcp.WithString("value", mcp.Description("The value to set in the kvstore"), mcp.Required()),
- mcp.WithString("node_name", mcp.Description("The name of the node to set the key in")),
- ), handleSetKVStoreKey)
-
- s.AddTool(mcp.NewTool("cilium_show_load_information",
- mcp.WithDescription("Show load information for the cluster"),
- mcp.WithString("node_name", mcp.Description("The name of the node to get the load information for")),
- ), handleShowLoadInformation)
-
- s.AddTool(mcp.NewTool("cilium_list_local_redirect_policies",
- mcp.WithDescription("List local redirect policies for the cluster"),
- mcp.WithString("node_name", mcp.Description("The name of the node to get the local redirect policies for")),
- ), handleListLocalRedirectPolicies)
-
- s.AddTool(mcp.NewTool("cilium_list_bpf_map_events",
- mcp.WithDescription("List BPF map events for the cluster"),
- mcp.WithString("map_name", mcp.Description("The name of the BPF map to get events for"), mcp.Required()),
- mcp.WithString("node_name", mcp.Description("The name of the node to get the BPF map events for")),
- ), handleListBPFMapEvents)
-
- s.AddTool(mcp.NewTool("cilium_get_bpf_map",
- mcp.WithDescription("Get BPF map for the cluster"),
- mcp.WithString("map_name", mcp.Description("The name of the BPF map to get"), mcp.Required()),
- mcp.WithString("node_name", mcp.Description("The name of the node to get the BPF map for")),
- ), handleGetBPFMap)
-
- s.AddTool(mcp.NewTool("cilium_list_bpf_maps",
- mcp.WithDescription("List BPF maps for the cluster"),
- mcp.WithString("node_name", mcp.Description("The name of the node to get the BPF maps for")),
- ), handleListBPFMaps)
-
- s.AddTool(mcp.NewTool("cilium_list_metrics",
- mcp.WithDescription("List metrics for the cluster"),
- mcp.WithString("match_pattern", mcp.Description("The match pattern to filter metrics by")),
- mcp.WithString("node_name", mcp.Description("The name of the node to get the metrics for")),
- ), handleListMetrics)
-
- s.AddTool(mcp.NewTool("cilium_list_cluster_nodes",
- mcp.WithDescription("List cluster nodes for the cluster"),
- mcp.WithString("node_name", mcp.Description("The name of the node to get the cluster nodes for")),
- ), handleListClusterNodes)
-
- s.AddTool(mcp.NewTool("cilium_list_node_ids",
- mcp.WithDescription("List node IDs for the cluster"),
- mcp.WithString("node_name", mcp.Description("The name of the node to get the node IDs for")),
- ), handleListNodeIds)
-
- s.AddTool(mcp.NewTool("cilium_display_policy_node_information",
- mcp.WithDescription("Display policy node information for the cluster"),
- mcp.WithString("labels", mcp.Description("The labels to get policy node information for")),
- mcp.WithString("node_name", mcp.Description("The name of the node to get policy node information for")),
- ), handleDisplayPolicyNodeInformation)
-
- s.AddTool(mcp.NewTool("cilium_delete_policy_rules",
- mcp.WithDescription("Delete policy rules for the cluster"),
- mcp.WithString("labels", mcp.Description("The labels to delete policy rules for")),
- mcp.WithString("all", mcp.Description("Whether to delete all policy rules")),
- mcp.WithString("node_name", mcp.Description("The name of the node to delete policy rules for")),
- ), handleDeletePolicyRules)
-
- s.AddTool(mcp.NewTool("cilium_display_selectors",
- mcp.WithDescription("Display selectors for the cluster"),
- mcp.WithString("node_name", mcp.Description("The name of the node to get selectors for")),
- ), handleDisplaySelectors)
-
- s.AddTool(mcp.NewTool("cilium_list_xdp_cidr_filters",
- mcp.WithDescription("List XDP CIDR filters for the cluster"),
- mcp.WithString("node_name", mcp.Description("The name of the node to get the XDP CIDR filters for")),
- ), handleListXDPCIDRFilters)
-
- s.AddTool(mcp.NewTool("cilium_update_xdp_cidr_filters",
- mcp.WithDescription("Update XDP CIDR filters for the cluster"),
- mcp.WithString("cidr_prefixes", mcp.Description("The CIDR prefixes to update the XDP filters for"), mcp.Required()),
- mcp.WithString("revision", mcp.Description("The revision of the XDP filters to update")),
- mcp.WithString("node_name", mcp.Description("The name of the node to update the XDP filters for")),
- ), handleUpdateXDPCIDRFilters)
-
- s.AddTool(mcp.NewTool("cilium_delete_xdp_cidr_filters",
- mcp.WithDescription("Delete XDP CIDR filters for the cluster"),
- mcp.WithString("cidr_prefixes", mcp.Description("The CIDR prefixes to delete the XDP filters for"), mcp.Required()),
- mcp.WithString("revision", mcp.Description("The revision of the XDP filters to delete")),
- mcp.WithString("node_name", mcp.Description("The name of the node to delete the XDP filters for")),
- ), handleDeleteXDPCIDRFilters)
-
- s.AddTool(mcp.NewTool("cilium_validate_cilium_network_policies",
- mcp.WithDescription("Validate Cilium network policies for the cluster"),
- mcp.WithString("enable_k8s", mcp.Description("Whether to enable k8s API discovery")),
- mcp.WithString("enable_k8s_api_discovery", mcp.Description("Whether to enable k8s API discovery")),
- mcp.WithString("node_name", mcp.Description("The name of the node to validate the Cilium network policies for")),
- ), handleValidateCiliumNetworkPolicies)
-
- s.AddTool(mcp.NewTool("cilium_list_pcap_recorders",
- mcp.WithDescription("List PCAP recorders for the cluster"),
- mcp.WithString("node_name", mcp.Description("The name of the node to get the PCAP recorders for")),
- ), handleListPCAPRecorders)
-
- s.AddTool(mcp.NewTool("cilium_get_pcap_recorder",
- mcp.WithDescription("Get a PCAP recorder for the cluster"),
- mcp.WithString("recorder_id", mcp.Description("The ID of the PCAP recorder to get"), mcp.Required()),
- mcp.WithString("node_name", mcp.Description("The name of the node to get the PCAP recorder for")),
- ), handleGetPCAPRecorder)
-
- s.AddTool(mcp.NewTool("cilium_delete_pcap_recorder",
- mcp.WithDescription("Delete a PCAP recorder for the cluster"),
- mcp.WithString("recorder_id", mcp.Description("The ID of the PCAP recorder to delete"), mcp.Required()),
- mcp.WithString("node_name", mcp.Description("The name of the node to delete the PCAP recorder from")),
- ), handleDeletePCAPRecorder)
-
- s.AddTool(mcp.NewTool("cilium_update_pcap_recorder",
- mcp.WithDescription("Update a PCAP recorder for the cluster"),
- mcp.WithString("recorder_id", mcp.Description("The ID of the PCAP recorder to update"), mcp.Required()),
- mcp.WithString("filters", mcp.Description("The filters to update the PCAP recorder with"), mcp.Required()),
- mcp.WithString("caplen", mcp.Description("The caplen to update the PCAP recorder with")),
- mcp.WithString("id", mcp.Description("The id to update the PCAP recorder with")),
- mcp.WithString("node_name", mcp.Description("The name of the node to update the PCAP recorder on")),
- ), handleUpdatePCAPRecorder)
-}
diff --git a/pkg/cilium/cilium_test.go b/pkg/cilium/cilium_test.go
index 5b01846..866de5d 100644
--- a/pkg/cilium/cilium_test.go
+++ b/pkg/cilium/cilium_test.go
@@ -1,99 +1,269 @@
package cilium
import (
+ "context"
+ "errors"
+ "fmt"
+ "strings"
"testing"
+ "github.com/kagent-dev/tools/internal/cmd"
+ "github.com/mark3labs/mcp-go/mcp"
+ "github.com/mark3labs/mcp-go/server"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
-// Basic command construction tests for Cilium CLI commands
-// Note: MCP handler tests are in cilium_mcp_test.go
+func TestRegisterCiliumTools(t *testing.T) {
+ s := server.NewMCPServer("test-server", "v0.0.1")
+ RegisterTools(s)
+ // We can't directly check the tools, but we can ensure the call doesn't panic
+}
-func TestCiliumCommandConstruction(t *testing.T) {
- t.Run("basic command construction patterns", func(t *testing.T) {
- // Test that we can construct basic cilium commands
- args := []string{"status"}
- assert.Equal(t, "status", args[0])
+func TestHandleCiliumStatusAndVersion(t *testing.T) {
+ ctx := context.Background()
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("cilium", []string{"status", "--timeout", "30s"}, "Cilium status: OK", nil)
+ mock.AddCommandString("cilium", []string{"version", "--timeout", "30s"}, "cilium version 1.14.0", nil)
- // Test upgrade command with parameters
- upgradeArgs := []string{"upgrade"}
- if clusterName := "test-cluster"; clusterName != "" {
- upgradeArgs = append(upgradeArgs, "--cluster-name", clusterName)
- }
- if datapathMode := "tunnel"; datapathMode != "" {
- upgradeArgs = append(upgradeArgs, "--datapath-mode", datapathMode)
- }
+ ctx = cmd.WithShellExecutor(ctx, mock)
- expected := []string{"upgrade", "--cluster-name", "test-cluster", "--datapath-mode", "tunnel"}
- assert.Equal(t, expected, upgradeArgs)
- })
+ result, err := handleCiliumStatusAndVersion(ctx, mcp.CallToolRequest{})
+ require.NoError(t, err)
+ assert.NotNil(t, result)
+ assert.False(t, result.IsError)
- t.Run("install command with parameters", func(t *testing.T) {
- args := []string{"install"}
- if clusterName := "test-cluster"; clusterName != "" {
- args = append(args, "--set", "cluster.name="+clusterName)
- }
- if clusterID := "123"; clusterID != "" {
- args = append(args, "--set", "cluster.id="+clusterID)
- }
- if datapathMode := "tunnel"; datapathMode != "" {
- args = append(args, "--datapath-mode", datapathMode)
+ var textContent mcp.TextContent
+ var ok bool
+ for _, content := range result.Content {
+ if textContent, ok = content.(mcp.TextContent); ok {
+ break
}
+ }
+ require.True(t, ok, "no text content in result")
- expected := []string{"install", "--set", "cluster.name=test-cluster", "--set", "cluster.id=123", "--datapath-mode", "tunnel"}
- assert.Equal(t, expected, args)
- })
+ assert.Contains(t, textContent.Text, "Cilium status: OK")
+ assert.Contains(t, textContent.Text, "cilium version 1.14.0")
+}
+
+func TestHandleCiliumStatusAndVersionError(t *testing.T) {
+ ctx := context.Background()
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("cilium", []string{"status", "--timeout", "30s"}, "", errors.New("command failed"))
+ mock.AddCommandString("cilium", []string{"version", "--timeout", "30s"}, "cilium version 1.14.0", nil)
+
+ ctx = cmd.WithShellExecutor(ctx, mock)
+
+ result, err := handleCiliumStatusAndVersion(ctx, mcp.CallToolRequest{})
+ require.NoError(t, err)
+ assert.NotNil(t, result)
+ assert.True(t, result.IsError)
+ assert.Contains(t, getResultText(result), "Error getting Cilium status")
+}
+
+func TestHandleInstallCilium(t *testing.T) {
+ ctx := context.Background()
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("cilium", []string{"install", "--timeout", "30s"}, "✓ Cilium was successfully installed!", nil)
+
+ ctx = cmd.WithShellExecutor(ctx, mock)
+
+ result, err := handleInstallCilium(ctx, mcp.CallToolRequest{})
+ require.NoError(t, err)
+ assert.NotNil(t, result)
+ assert.False(t, result.IsError)
+ assert.Contains(t, getResultText(result), "✓ Cilium was successfully installed!")
+}
+
+func TestHandleUninstallCilium(t *testing.T) {
+ ctx := context.Background()
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("cilium", []string{"uninstall", "--timeout", "30s"}, "✓ Cilium was successfully uninstalled!", nil)
+
+ ctx = cmd.WithShellExecutor(ctx, mock)
+
+ result, err := handleUninstallCilium(ctx, mcp.CallToolRequest{})
+ require.NoError(t, err)
+ assert.NotNil(t, result)
+ assert.False(t, result.IsError)
+ assert.Contains(t, getResultText(result), "✓ Cilium was successfully uninstalled!")
+}
+
+func TestHandleUpgradeCilium(t *testing.T) {
+ ctx := context.Background()
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("cilium", []string{"upgrade", "--timeout", "30s"}, "✓ Cilium was successfully upgraded!", nil)
- t.Run("clustermesh connect command", func(t *testing.T) {
- clusterName := "remote-cluster"
- context := "remote-context"
+ ctx = cmd.WithShellExecutor(ctx, mock)
- args := []string{"clustermesh", "connect", "--destination-cluster", clusterName}
- if context != "" {
- args = append(args, "--destination-context", context)
+ result, err := handleUpgradeCilium(ctx, mcp.CallToolRequest{})
+ require.NoError(t, err)
+ assert.NotNil(t, result)
+ assert.False(t, result.IsError)
+ assert.Contains(t, getResultText(result), "✓ Cilium was successfully upgraded!")
+}
+
+func TestHandleConnectToRemoteCluster(t *testing.T) {
+ ctx := context.Background()
+
+ t.Run("success", func(t *testing.T) {
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("cilium", []string{"clustermesh", "connect", "--destination-cluster", "my-cluster", "--timeout", "30s"}, "✓ Connected to cluster my-cluster!", nil)
+ ctx = cmd.WithShellExecutor(ctx, mock)
+ req := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Arguments: map[string]any{
+ "cluster_name": "my-cluster",
+ },
+ },
}
- expected := []string{"clustermesh", "connect", "--destination-cluster", "remote-cluster", "--destination-context", "remote-context"}
- assert.Equal(t, expected, args)
+ result, err := handleConnectToRemoteCluster(ctx, req)
+ require.NoError(t, err)
+ assert.NotNil(t, result)
+ assert.False(t, result.IsError)
+ assert.Contains(t, getResultText(result), "✓ Connected to cluster my-cluster!")
})
- t.Run("bgp commands", func(t *testing.T) {
- peersArgs := []string{"bgp", "peers"}
- routesArgs := []string{"bgp", "routes"}
-
- assert.Equal(t, []string{"bgp", "peers"}, peersArgs)
- assert.Equal(t, []string{"bgp", "routes"}, routesArgs)
+ t.Run("missing cluster_name", func(t *testing.T) {
+ req := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Arguments: map[string]any{},
+ },
+ }
+ result, err := handleConnectToRemoteCluster(ctx, req)
+ require.NoError(t, err)
+ assert.NotNil(t, result)
+ assert.True(t, result.IsError)
+ assert.Contains(t, getResultText(result), "cluster_name parameter is required")
})
}
-func TestCiliumParameterValidation(t *testing.T) {
- t.Run("cluster name validation", func(t *testing.T) {
- clusterName := ""
- if clusterName == "" {
- assert.True(t, true, "cluster_name parameter should be required for connect operations")
+func TestHandleDisconnectFromRemoteCluster(t *testing.T) {
+ ctx := context.Background()
+
+ t.Run("success", func(t *testing.T) {
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("cilium", []string{"clustermesh", "disconnect", "--destination-cluster", "my-cluster", "--timeout", "30s"}, "✓ Disconnected from cluster my-cluster!", nil)
+ ctx = cmd.WithShellExecutor(ctx, mock)
+ req := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Arguments: map[string]any{
+ "cluster_name": "my-cluster",
+ },
+ },
}
- clusterName = "valid-cluster"
- if clusterName != "" {
- assert.True(t, true, "valid cluster name should be accepted")
+ result, err := handleDisconnectRemoteCluster(ctx, req)
+ require.NoError(t, err)
+ assert.NotNil(t, result)
+ assert.False(t, result.IsError)
+ assert.Contains(t, getResultText(result), "✓ Disconnected from cluster my-cluster!")
+ })
+
+ t.Run("missing cluster_name", func(t *testing.T) {
+ req := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Arguments: map[string]any{},
+ },
}
+ result, err := handleDisconnectRemoteCluster(ctx, req)
+ require.NoError(t, err)
+ assert.NotNil(t, result)
+ assert.True(t, result.IsError)
+ assert.Contains(t, getResultText(result), "cluster_name parameter is required")
})
+}
+
+func TestHandleEnableHubble(t *testing.T) {
+ ctx := context.Background()
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("cilium", []string{"hubble", "enable", "--timeout", "30s"}, "✓ Hubble was successfully enabled!", nil)
+ ctx = cmd.WithShellExecutor(ctx, mock)
+ req := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Arguments: map[string]any{
+ "enable": true,
+ },
+ },
+ }
- t.Run("boolean parameter handling", func(t *testing.T) {
- enableStr := "true"
- enable := enableStr == "true"
- assert.True(t, enable)
+ result, err := handleToggleHubble(ctx, req)
+ require.NoError(t, err)
+ assert.NotNil(t, result)
+ assert.False(t, result.IsError)
+ assert.Contains(t, getResultText(result), "✓ Hubble was successfully enabled!")
+}
- enableStr = "false"
- enable = enableStr == "true"
- assert.False(t, enable)
+func TestHandleDisableHubble(t *testing.T) {
+ ctx := context.Background()
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("cilium", []string{"hubble", "disable", "--timeout", "30s"}, "✓ Hubble was successfully disabled!", nil)
+ ctx = cmd.WithShellExecutor(ctx, mock)
+ req := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Arguments: map[string]any{
+ "enable": false,
+ },
+ },
+ }
+ result, err := handleToggleHubble(ctx, req)
+ require.NoError(t, err)
+ assert.NotNil(t, result)
+ assert.False(t, result.IsError)
+ assert.Contains(t, getResultText(result), "✓ Hubble was successfully disabled!")
+}
- // Default value handling
- enableStr = ""
- if enableStr == "" {
- enableStr = "true" // default
- }
- enable = enableStr == "true"
- assert.True(t, enable)
+func TestHandleListBGPPeers(t *testing.T) {
+ ctx := context.Background()
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("cilium", []string{"bgp", "peers", "--timeout", "30s"}, "listing BGP peers", nil)
+ ctx = cmd.WithShellExecutor(ctx, mock)
+ result, err := handleListBGPPeers(ctx, mcp.CallToolRequest{})
+ require.NoError(t, err)
+ assert.NotNil(t, result)
+ assert.False(t, result.IsError)
+ assert.Contains(t, getResultText(result), "listing BGP peers")
+}
+
+func TestHandleListBGPRoutes(t *testing.T) {
+ ctx := context.Background()
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("cilium", []string{"bgp", "routes", "--timeout", "30s"}, "listing BGP routes", nil)
+ ctx = cmd.WithShellExecutor(ctx, mock)
+ result, err := handleListBGPRoutes(ctx, mcp.CallToolRequest{})
+ require.NoError(t, err)
+ assert.NotNil(t, result)
+ assert.False(t, result.IsError)
+ assert.Contains(t, getResultText(result), "listing BGP routes")
+}
+
+func TestRunCiliumCliWithContext(t *testing.T) {
+ ctx := context.Background()
+ t.Run("success", func(t *testing.T) {
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("cilium", []string{"test", "--timeout", "30s"}, "success", nil)
+ ctx = cmd.WithShellExecutor(ctx, mock)
+ result, err := runCiliumCliWithContext(ctx, "test")
+ require.NoError(t, err)
+ assert.Equal(t, "success", result)
+ })
+ t.Run("error", func(t *testing.T) {
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("cilium", []string{"test", "--timeout", "30s"}, "", fmt.Errorf("test error"))
+ ctx = cmd.WithShellExecutor(ctx, mock)
+ _, err := runCiliumCliWithContext(ctx, "test")
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "test error")
})
}
+
+func getResultText(r *mcp.CallToolResult) string {
+ if r == nil || len(r.Content) == 0 {
+ return ""
+ }
+ if textContent, ok := r.Content[0].(mcp.TextContent); ok {
+ return strings.TrimSpace(textContent.Text)
+ }
+ return ""
+}
diff --git a/pkg/helm/helm.go b/pkg/helm/helm.go
index b3a65e3..06a9ac8 100644
--- a/pkg/helm/helm.go
+++ b/pkg/helm/helm.go
@@ -5,13 +5,15 @@ import (
"fmt"
"strings"
+ "github.com/kagent-dev/tools/internal/commands"
+ "github.com/kagent-dev/tools/internal/errors"
+ "github.com/kagent-dev/tools/internal/security"
+ "github.com/kagent-dev/tools/internal/telemetry"
"github.com/kagent-dev/tools/pkg/utils"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)
-var kubeConfig = "" // Global variable to hold kubeconfig path
-
// Helm list releases
func handleHelmListReleases(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
namespace := mcp.ParseString(request, "namespace", "")
@@ -69,6 +71,15 @@ func handleHelmListReleases(ctx context.Context, request mcp.CallToolRequest) (*
result, err := runHelmCommand(ctx, args)
if err != nil {
+ // Check if it's a structured error
+ if toolErr, ok := err.(*errors.ToolError); ok {
+ // Add namespace context if provided
+ if namespace != "" {
+ toolErr = toolErr.WithContext("namespace", namespace)
+ }
+ return toolErr.ToMCPResult(), nil
+ }
+ // Fallback for non-structured errors
return mcp.NewToolResultError(fmt.Sprintf("Helm list command failed: %v", err)), nil
}
@@ -76,10 +87,24 @@ func handleHelmListReleases(ctx context.Context, request mcp.CallToolRequest) (*
}
func runHelmCommand(ctx context.Context, args []string) (string, error) {
- if kubeConfig != "" {
- args = append(args, "--kubeconfig", kubeConfig)
+ kubeconfigPath := utils.GetKubeconfig()
+ result, err := commands.NewCommandBuilder("helm").
+ WithArgs(args...).
+ WithKubeconfig(kubeconfigPath).
+ Execute(ctx)
+
+ if err != nil {
+ if toolErr, ok := err.(*errors.ToolError); ok {
+ if len(args) > 0 {
+ toolErr = toolErr.WithContext("helm_operation", args[0])
+ }
+ toolErr = toolErr.WithContext("helm_args", args)
+ return "", toolErr
+ }
+ return "", err
}
- return utils.RunCommandWithContext(ctx, "helm", args)
+
+ return result, nil
}
// Helm get release
@@ -98,7 +123,7 @@ func handleHelmGetRelease(ctx context.Context, request mcp.CallToolRequest) (*mc
args := []string{"get", resource, name, "-n", namespace}
- result, err := utils.RunCommandWithContext(ctx, "helm", args)
+ result, err := runHelmCommand(ctx, args)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Helm get command failed: %v", err)), nil
}
@@ -122,6 +147,25 @@ func handleHelmUpgradeRelease(ctx context.Context, request mcp.CallToolRequest)
return mcp.NewToolResultError("name and chart parameters are required"), nil
}
+ // Validate release name
+ if err := security.ValidateHelmReleaseName(name); err != nil {
+ return mcp.NewToolResultError(fmt.Sprintf("Invalid release name: %v", err)), nil
+ }
+
+ // Validate namespace if provided
+ if namespace != "" {
+ if err := security.ValidateNamespace(namespace); err != nil {
+ return mcp.NewToolResultError(fmt.Sprintf("Invalid namespace: %v", err)), nil
+ }
+ }
+
+ // Validate values file path if provided
+ if values != "" {
+ if err := security.ValidateFilePath(values); err != nil {
+ return mcp.NewToolResultError(fmt.Sprintf("Invalid values file path: %v", err)), nil
+ }
+ }
+
args := []string{"upgrade", name, chart}
if namespace != "" {
@@ -156,7 +200,7 @@ func handleHelmUpgradeRelease(ctx context.Context, request mcp.CallToolRequest)
args = append(args, "--wait")
}
- result, err := utils.RunCommandWithContext(ctx, "helm", args)
+ result, err := runHelmCommand(ctx, args)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Helm upgrade command failed: %v", err)), nil
}
@@ -185,7 +229,7 @@ func handleHelmUninstall(ctx context.Context, request mcp.CallToolRequest) (*mcp
args = append(args, "--wait")
}
- result, err := utils.RunCommandWithContext(ctx, "helm", args)
+ result, err := runHelmCommand(ctx, args)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Helm uninstall command failed: %v", err)), nil
}
@@ -202,9 +246,19 @@ func handleHelmRepoAdd(ctx context.Context, request mcp.CallToolRequest) (*mcp.C
return mcp.NewToolResultError("name and url parameters are required"), nil
}
+ // Validate repository name
+ if err := security.ValidateHelmReleaseName(name); err != nil {
+ return mcp.NewToolResultError(fmt.Sprintf("Invalid repository name: %v", err)), nil
+ }
+
+ // Validate repository URL
+ if err := security.ValidateURL(url); err != nil {
+ return mcp.NewToolResultError(fmt.Sprintf("Invalid repository URL: %v", err)), nil
+ }
+
args := []string{"repo", "add", name, url}
- result, err := utils.RunCommandWithContext(ctx, "helm", args)
+ result, err := runHelmCommand(ctx, args)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Helm repo add command failed: %v", err)), nil
}
@@ -216,7 +270,7 @@ func handleHelmRepoAdd(ctx context.Context, request mcp.CallToolRequest) (*mcp.C
func handleHelmRepoUpdate(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := []string{"repo", "update"}
- result, err := utils.RunCommandWithContext(ctx, "helm", args)
+ result, err := runHelmCommand(ctx, args)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Helm repo update command failed: %v", err)), nil
}
@@ -225,8 +279,7 @@ func handleHelmRepoUpdate(ctx context.Context, request mcp.CallToolRequest) (*mc
}
// Register Helm tools
-func RegisterHelmTools(s *server.MCPServer, kubeconfig string) {
- kubeConfig = kubeconfig
+func RegisterTools(s *server.MCPServer) {
s.AddTool(mcp.NewTool("helm_list_releases",
mcp.WithDescription("List Helm releases in a namespace"),
@@ -240,14 +293,14 @@ func RegisterHelmTools(s *server.MCPServer, kubeconfig string) {
mcp.WithString("pending", mcp.Description("List pending releases")),
mcp.WithString("filter", mcp.Description("A regular expression to filter releases by")),
mcp.WithString("output", mcp.Description("The output format (e.g., 'json', 'yaml', 'table')")),
- ), handleHelmListReleases)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("helm_list_releases", handleHelmListReleases)))
s.AddTool(mcp.NewTool("helm_get_release",
mcp.WithDescription("Get extended information about a Helm release"),
mcp.WithString("name", mcp.Description("The name of the release"), mcp.Required()),
mcp.WithString("namespace", mcp.Description("The namespace of the release"), mcp.Required()),
mcp.WithString("resource", mcp.Description("The resource to get (all, hooks, manifest, notes, values)")),
- ), handleHelmGetRelease)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("helm_get_release", handleHelmGetRelease)))
s.AddTool(mcp.NewTool("helm_upgrade",
mcp.WithDescription("Upgrade or install a Helm release"),
@@ -260,7 +313,7 @@ func RegisterHelmTools(s *server.MCPServer, kubeconfig string) {
mcp.WithString("install", mcp.Description("Run an install if the release is not present")),
mcp.WithString("dry_run", mcp.Description("Simulate an upgrade")),
mcp.WithString("wait", mcp.Description("Wait for the upgrade to complete")),
- ), handleHelmUpgradeRelease)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("helm_upgrade", handleHelmUpgradeRelease)))
s.AddTool(mcp.NewTool("helm_uninstall",
mcp.WithDescription("Uninstall a Helm release"),
@@ -268,15 +321,15 @@ func RegisterHelmTools(s *server.MCPServer, kubeconfig string) {
mcp.WithString("namespace", mcp.Description("The namespace of the release"), mcp.Required()),
mcp.WithString("dry_run", mcp.Description("Simulate an uninstall")),
mcp.WithString("wait", mcp.Description("Wait for the uninstall to complete")),
- ), handleHelmUninstall)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("helm_uninstall", handleHelmUninstall)))
s.AddTool(mcp.NewTool("helm_repo_add",
mcp.WithDescription("Add a Helm repository"),
mcp.WithString("name", mcp.Description("The name of the repository"), mcp.Required()),
mcp.WithString("url", mcp.Description("The URL of the repository"), mcp.Required()),
- ), handleHelmRepoAdd)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("helm_repo_add", handleHelmRepoAdd)))
s.AddTool(mcp.NewTool("helm_repo_update",
mcp.WithDescription("Update information of available charts locally from chart repositories"),
- ), handleHelmRepoUpdate)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("helm_repo_update", handleHelmRepoUpdate)))
}
diff --git a/pkg/helm/helm_test.go b/pkg/helm/helm_test.go
index 4a99165..3848de2 100644
--- a/pkg/helm/helm_test.go
+++ b/pkg/helm/helm_test.go
@@ -4,22 +4,28 @@ import (
"context"
"testing"
- "github.com/kagent-dev/tools/pkg/utils"
+ "github.com/kagent-dev/tools/internal/cmd"
"github.com/mark3labs/mcp-go/mcp"
+ "github.com/mark3labs/mcp-go/server"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
+func TestRegisterTools(t *testing.T) {
+ s := server.NewMCPServer("test-server", "v0.0.1")
+ RegisterTools(s)
+}
+
// Test Helm List Releases
func TestHandleHelmListReleases(t *testing.T) {
t.Run("basic list releases", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `NAME NAMESPACE REVISION UPDATED STATUS CHART APP VERSION
app1 default 1 2023-01-01 12:00:00.000000000 +0000 UTC deployed myapp-1.0.0 1.0.0
app2 kube-system 2 2023-01-02 12:00:00.000000000 +0000 UTC deployed system-2.0.0 2.0.0`
- mock.AddCommandString("helm", []string{"list"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock.AddCommandString("helm", []string{"list", "--timeout", "30s"}, expectedOutput, nil)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
result, err := handleHelmListReleases(ctx, request)
@@ -37,13 +43,13 @@ app2 kube-system 2 2023-01-02 12:00:00.000000000 +0000 UTC deplo
callLog := mock.GetCallLog()
require.Len(t, callLog, 1)
assert.Equal(t, "helm", callLog[0].Command)
- assert.Equal(t, []string{"list"}, callLog[0].Args)
+ assert.Equal(t, []string{"list", "--timeout", "30s"}, callLog[0].Args)
})
t.Run("list releases with namespace", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- mock.AddCommandString("helm", []string{"list", "-n", "production"}, "production releases", nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("helm", []string{"list", "-n", "production", "--timeout", "30s"}, "production releases", nil)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
@@ -59,13 +65,13 @@ app2 kube-system 2 2023-01-02 12:00:00.000000000 +0000 UTC deplo
callLog := mock.GetCallLog()
require.Len(t, callLog, 1)
assert.Equal(t, "helm", callLog[0].Command)
- assert.Equal(t, []string{"list", "-n", "production"}, callLog[0].Args)
+ assert.Equal(t, []string{"list", "-n", "production", "--timeout", "30s"}, callLog[0].Args)
})
t.Run("list releases with all namespaces", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- mock.AddCommandString("helm", []string{"list", "-A"}, "all namespaces releases", nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("helm", []string{"list", "-A", "--timeout", "30s"}, "all namespaces releases", nil)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
@@ -81,13 +87,13 @@ app2 kube-system 2 2023-01-02 12:00:00.000000000 +0000 UTC deplo
callLog := mock.GetCallLog()
require.Len(t, callLog, 1)
assert.Equal(t, "helm", callLog[0].Command)
- assert.Equal(t, []string{"list", "-A"}, callLog[0].Args)
+ assert.Equal(t, []string{"list", "-A", "--timeout", "30s"}, callLog[0].Args)
})
t.Run("list releases with multiple flags", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- mock.AddCommandString("helm", []string{"list", "-A", "-a", "--failed", "-o", "json"}, `[{"name":"failed-app","status":"failed"}]`, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("helm", []string{"list", "-A", "-a", "--failed", "-o", "json", "--timeout", "30s"}, `[{"name":"failed-app","status":"failed"}]`, nil)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
@@ -106,35 +112,35 @@ app2 kube-system 2 2023-01-02 12:00:00.000000000 +0000 UTC deplo
callLog := mock.GetCallLog()
require.Len(t, callLog, 1)
assert.Equal(t, "helm", callLog[0].Command)
- assert.Equal(t, []string{"list", "-A", "-a", "--failed", "-o", "json"}, callLog[0].Args)
+ assert.Equal(t, []string{"list", "-A", "-a", "--failed", "-o", "json", "--timeout", "30s"}, callLog[0].Args)
})
t.Run("helm command failure", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- mock.AddCommandString("helm", []string{"list"}, "", assert.AnError)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("helm", []string{"list", "--timeout", "30s"}, "", assert.AnError)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
result, err := handleHelmListReleases(ctx, request)
assert.NoError(t, err) // MCP handlers should not return Go errors
assert.True(t, result.IsError)
- assert.Contains(t, getResultText(result), "Helm list command failed")
+ assert.Contains(t, getResultText(result), "**Helm Error**")
})
}
// Test Helm Get Release
func TestHandleHelmGetRelease(t *testing.T) {
t.Run("get release all resources", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `REVISION: 1
RELEASED: Mon Jan 01 12:00:00 UTC 2023
CHART: myapp-1.0.0
VALUES:
replicaCount: 3`
- mock.AddCommandString("helm", []string{"get", "all", "myapp", "-n", "default"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock.AddCommandString("helm", []string{"get", "all", "myapp", "-n", "default", "--timeout", "30s"}, expectedOutput, nil)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
@@ -152,13 +158,13 @@ replicaCount: 3`
callLog := mock.GetCallLog()
require.Len(t, callLog, 1)
assert.Equal(t, "helm", callLog[0].Command)
- assert.Equal(t, []string{"get", "all", "myapp", "-n", "default"}, callLog[0].Args)
+ assert.Equal(t, []string{"get", "all", "myapp", "-n", "default", "--timeout", "30s"}, callLog[0].Args)
})
t.Run("get release values only", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- mock.AddCommandString("helm", []string{"get", "values", "myapp", "-n", "default"}, "replicaCount: 3", nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("helm", []string{"get", "values", "myapp", "-n", "default", "--timeout", "30s"}, "replicaCount: 3", nil)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
@@ -176,12 +182,12 @@ replicaCount: 3`
callLog := mock.GetCallLog()
require.Len(t, callLog, 1)
assert.Equal(t, "helm", callLog[0].Command)
- assert.Equal(t, []string{"get", "values", "myapp", "-n", "default"}, callLog[0].Args)
+ assert.Equal(t, []string{"get", "values", "myapp", "-n", "default", "--timeout", "30s"}, callLog[0].Args)
})
t.Run("missing required parameters", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock := cmd.NewMockShellExecutor()
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
// Test missing name
request := mcp.CallToolRequest{}
@@ -213,7 +219,7 @@ replicaCount: 3`
// Test Helm Upgrade Release
func TestHandleHelmUpgradeRelease(t *testing.T) {
t.Run("basic upgrade", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `Release "myapp" has been upgraded. Happy Helming!
NAME: myapp
LAST DEPLOYED: Mon Jan 01 12:00:00 UTC 2023
@@ -221,8 +227,8 @@ NAMESPACE: default
STATUS: deployed
REVISION: 2`
- mock.AddCommandString("helm", []string{"upgrade", "myapp", "stable/myapp"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock.AddCommandString("helm", []string{"upgrade", "myapp", "stable/myapp", "--timeout", "30s"}, expectedOutput, nil)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
@@ -240,11 +246,11 @@ REVISION: 2`
callLog := mock.GetCallLog()
require.Len(t, callLog, 1)
assert.Equal(t, "helm", callLog[0].Command)
- assert.Equal(t, []string{"upgrade", "myapp", "stable/myapp"}, callLog[0].Args)
+ assert.Equal(t, []string{"upgrade", "myapp", "stable/myapp", "--timeout", "30s"}, callLog[0].Args)
})
t.Run("upgrade with all options", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedArgs := []string{
"upgrade", "myapp", "stable/myapp",
"-n", "production",
@@ -255,9 +261,10 @@ REVISION: 2`
"--install",
"--dry-run",
"--wait",
+ "--timeout", "30s",
}
- mock.AddCommandString("helm", expectedArgs, "dry run output", nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock.AddCommandString("helm", expectedArgs, "Upgraded with options", nil)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
@@ -284,14 +291,14 @@ REVISION: 2`
assert.Equal(t, expectedArgs, callLog[0].Args)
})
- t.Run("missing required parameters", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ t.Run("missing required parameters for upgrade", func(t *testing.T) {
+ mock := cmd.NewMockShellExecutor()
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
+ // Test missing chart
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
"name": "myapp",
- // Missing chart
}
result, err := handleHelmUpgradeRelease(ctx, request)
@@ -308,11 +315,11 @@ REVISION: 2`
// Test Helm Uninstall
func TestHandleHelmUninstall(t *testing.T) {
t.Run("basic uninstall", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `release "myapp" uninstalled`
- mock.AddCommandString("helm", []string{"uninstall", "myapp", "-n", "default"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock.AddCommandString("helm", []string{"uninstall", "myapp", "-n", "default", "--timeout", "30s"}, expectedOutput, nil)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
@@ -330,14 +337,14 @@ func TestHandleHelmUninstall(t *testing.T) {
callLog := mock.GetCallLog()
require.Len(t, callLog, 1)
assert.Equal(t, "helm", callLog[0].Command)
- assert.Equal(t, []string{"uninstall", "myapp", "-n", "default"}, callLog[0].Args)
+ assert.Equal(t, []string{"uninstall", "myapp", "-n", "default", "--timeout", "30s"}, callLog[0].Args)
})
t.Run("uninstall with options", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedArgs := []string{"uninstall", "myapp", "-n", "production", "--dry-run", "--wait"}
+ mock := cmd.NewMockShellExecutor()
+ expectedArgs := []string{"uninstall", "myapp", "-n", "production", "--dry-run", "--wait", "--timeout", "30s"}
mock.AddCommandString("helm", expectedArgs, "dry run uninstall", nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
@@ -358,21 +365,51 @@ func TestHandleHelmUninstall(t *testing.T) {
assert.Equal(t, "helm", callLog[0].Command)
assert.Equal(t, expectedArgs, callLog[0].Args)
})
+
+ t.Run("missing required parameters for uninstall", func(t *testing.T) {
+ mock := cmd.NewMockShellExecutor()
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
+
+ // Test missing name
+ request := mcp.CallToolRequest{}
+ request.Params.Arguments = map[string]interface{}{
+ "namespace": "default",
+ }
+
+ result, err := handleHelmUninstall(ctx, request)
+ assert.NoError(t, err)
+ assert.True(t, result.IsError)
+ assert.Contains(t, getResultText(result), "name and namespace parameters are required")
+
+ // Test missing namespace
+ request.Params.Arguments = map[string]interface{}{
+ "name": "myapp",
+ }
+
+ result, err = handleHelmUninstall(ctx, request)
+ assert.NoError(t, err)
+ assert.True(t, result.IsError)
+ assert.Contains(t, getResultText(result), "name and namespace parameters are required")
+
+ // Verify no commands were executed
+ callLog := mock.GetCallLog()
+ assert.Len(t, callLog, 0)
+ })
}
// Test Helm Repo Add
func TestHandleHelmRepoAdd(t *testing.T) {
- t.Run("add repository", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedOutput := `"stable" has been added to your repositories`
+ t.Run("basic repo add", func(t *testing.T) {
+ mock := cmd.NewMockShellExecutor()
+ expectedOutput := `"my-repo" has been added to your repositories`
- mock.AddCommandString("helm", []string{"repo", "add", "stable", "https://charts.helm.sh/stable"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock.AddCommandString("helm", []string{"repo", "add", "my-repo", "https://charts.example.com/", "--timeout", "30s"}, expectedOutput, nil)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
- "name": "stable",
- "url": "https://charts.helm.sh/stable",
+ "name": "my-repo",
+ "url": "https://charts.example.com/",
}
result, err := handleHelmRepoAdd(ctx, request)
@@ -385,17 +422,17 @@ func TestHandleHelmRepoAdd(t *testing.T) {
callLog := mock.GetCallLog()
require.Len(t, callLog, 1)
assert.Equal(t, "helm", callLog[0].Command)
- assert.Equal(t, []string{"repo", "add", "stable", "https://charts.helm.sh/stable"}, callLog[0].Args)
+ assert.Equal(t, []string{"repo", "add", "my-repo", "https://charts.example.com/", "--timeout", "30s"}, callLog[0].Args)
})
- t.Run("missing required parameters", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ t.Run("missing required parameters for repo add", func(t *testing.T) {
+ mock := cmd.NewMockShellExecutor()
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
+ // Test missing name
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
- "name": "stable",
- // Missing url
+ "url": "https://charts.example.com/",
}
result, err := handleHelmRepoAdd(ctx, request)
@@ -411,27 +448,26 @@ func TestHandleHelmRepoAdd(t *testing.T) {
// Test Helm Repo Update
func TestHandleHelmRepoUpdate(t *testing.T) {
- t.Run("update repositories", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ t.Run("basic repo update", func(t *testing.T) {
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `Hang tight while we grab the latest from your chart repositories...
-...Successfully got an update from the "stable" chart repository
-Update Complete. ⎈Happy Helming!⎈`
+...Successfully got an update from the "my-repo" chart repository`
- mock.AddCommandString("helm", []string{"repo", "update"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock.AddCommandString("helm", []string{"repo", "update", "--timeout", "30s"}, expectedOutput, nil)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
request := mcp.CallToolRequest{}
result, err := handleHelmRepoUpdate(ctx, request)
assert.NoError(t, err)
assert.False(t, result.IsError)
- assert.Contains(t, getResultText(result), "Update Complete")
+ assert.Contains(t, getResultText(result), "Successfully got an update")
// Verify the correct command was called
callLog := mock.GetCallLog()
require.Len(t, callLog, 1)
assert.Equal(t, "helm", callLog[0].Command)
- assert.Equal(t, []string{"repo", "update"}, callLog[0].Args)
+ assert.Equal(t, []string{"repo", "update", "--timeout", "30s"}, callLog[0].Args)
})
}
diff --git a/pkg/istio/istio.go b/pkg/istio/istio.go
index 2f198aa..680d83c 100644
--- a/pkg/istio/istio.go
+++ b/pkg/istio/istio.go
@@ -5,13 +5,13 @@ import (
"fmt"
"strings"
+ "github.com/kagent-dev/tools/internal/commands"
+ "github.com/kagent-dev/tools/internal/telemetry"
"github.com/kagent-dev/tools/pkg/utils"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)
-var kubeConfig = "" // Global variable to hold kubeconfig path
-
// Istio proxy status
func handleIstioProxyStatus(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
podName := mcp.ParseString(request, "pod_name", "")
@@ -36,11 +36,11 @@ func handleIstioProxyStatus(ctx context.Context, request mcp.CallToolRequest) (*
}
func runIstioCtl(ctx context.Context, args []string) (string, error) {
- if kubeConfig != "" {
- args = append(args, "--kubeconfig", kubeConfig)
- }
- result, err := utils.RunCommandWithContext(ctx, "istioctl", args)
- return result, err
+ kubeconfigPath := utils.GetKubeconfig()
+ return commands.NewCommandBuilder("istioctl").
+ WithArgs(args...).
+ WithKubeconfig(kubeconfigPath).
+ Execute(ctx)
}
// Istio proxy config
@@ -61,7 +61,7 @@ func handleIstioProxyConfig(ctx context.Context, request mcp.CallToolRequest) (*
args = append(args, podName)
}
- result, err := utils.RunCommandWithContext(ctx, "istioctl", args)
+ result, err := runIstioCtl(ctx, args)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("istioctl proxy-config failed: %v", err)), nil
}
@@ -75,7 +75,7 @@ func handleIstioInstall(ctx context.Context, request mcp.CallToolRequest) (*mcp.
args := []string{"install", "--set", fmt.Sprintf("profile=%s", profile), "-y"}
- result, err := utils.RunCommandWithContext(ctx, "istioctl", args)
+ result, err := runIstioCtl(ctx, args)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("istioctl install failed: %v", err)), nil
}
@@ -89,7 +89,7 @@ func handleIstioGenerateManifest(ctx context.Context, request mcp.CallToolReques
args := []string{"manifest", "generate", "--set", fmt.Sprintf("profile=%s", profile)}
- result, err := utils.RunCommandWithContext(ctx, "istioctl", args)
+ result, err := runIstioCtl(ctx, args)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("istioctl manifest generate failed: %v", err)), nil
}
@@ -110,7 +110,7 @@ func handleIstioAnalyzeClusterConfiguration(ctx context.Context, request mcp.Cal
args = append(args, "-n", namespace)
}
- result, err := utils.RunCommandWithContext(ctx, "istioctl", args)
+ result, err := runIstioCtl(ctx, args)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("istioctl analyze failed: %v", err)), nil
}
@@ -128,7 +128,7 @@ func handleIstioVersion(ctx context.Context, request mcp.CallToolRequest) (*mcp.
args = append(args, "--short")
}
- result, err := utils.RunCommandWithContext(ctx, "istioctl", args)
+ result, err := runIstioCtl(ctx, args)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("istioctl version failed: %v", err)), nil
}
@@ -140,7 +140,7 @@ func handleIstioVersion(ctx context.Context, request mcp.CallToolRequest) (*mcp.
func handleIstioRemoteClusters(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := []string{"remote-clusters"}
- result, err := utils.RunCommandWithContext(ctx, "istioctl", args)
+ result, err := runIstioCtl(ctx, args)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("istioctl remote-clusters failed: %v", err)), nil
}
@@ -161,7 +161,7 @@ func handleWaypointList(ctx context.Context, request mcp.CallToolRequest) (*mcp.
args = append(args, "-n", namespace)
}
- result, err := utils.RunCommandWithContext(ctx, "istioctl", args)
+ result, err := runIstioCtl(ctx, args)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("istioctl waypoint list failed: %v", err)), nil
}
@@ -191,7 +191,7 @@ func handleWaypointGenerate(ctx context.Context, request mcp.CallToolRequest) (*
args = append(args, "--for", trafficType)
}
- result, err := utils.RunCommandWithContext(ctx, "istioctl", args)
+ result, err := runIstioCtl(ctx, args)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("istioctl waypoint generate failed: %v", err)), nil
}
@@ -214,7 +214,7 @@ func handleWaypointApply(ctx context.Context, request mcp.CallToolRequest) (*mcp
args = append(args, "--enroll-namespace")
}
- result, err := utils.RunCommandWithContext(ctx, "istioctl", args)
+ result, err := runIstioCtl(ctx, args)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("istioctl waypoint apply failed: %v", err)), nil
}
@@ -245,7 +245,7 @@ func handleWaypointDelete(ctx context.Context, request mcp.CallToolRequest) (*mc
args = append(args, "-n", namespace)
- result, err := utils.RunCommandWithContext(ctx, "istioctl", args)
+ result, err := runIstioCtl(ctx, args)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("istioctl waypoint delete failed: %v", err)), nil
}
@@ -270,7 +270,7 @@ func handleWaypointStatus(ctx context.Context, request mcp.CallToolRequest) (*mc
args = append(args, "-n", namespace)
- result, err := utils.RunCommandWithContext(ctx, "istioctl", args)
+ result, err := runIstioCtl(ctx, args)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("istioctl waypoint status failed: %v", err)), nil
}
@@ -283,30 +283,29 @@ func handleZtunnelConfig(ctx context.Context, request mcp.CallToolRequest) (*mcp
namespace := mcp.ParseString(request, "namespace", "")
configType := mcp.ParseString(request, "config_type", "all")
- args := []string{"ztunnel-config", configType}
+ args := []string{"ztunnel", "config", configType}
if namespace != "" {
args = append(args, "-n", namespace)
}
- result, err := utils.RunCommandWithContext(ctx, "istioctl", args)
+ result, err := runIstioCtl(ctx, args)
if err != nil {
- return mcp.NewToolResultError(fmt.Sprintf("istioctl ztunnel-config failed: %v", err)), nil
+ return mcp.NewToolResultError(fmt.Sprintf("istioctl ztunnel config failed: %v", err)), nil
}
return mcp.NewToolResultText(result), nil
}
// Register Istio tools
-func RegisterIstioTools(s *server.MCPServer, kubeconfig string) {
- kubeConfig = kubeconfig
+func RegisterTools(s *server.MCPServer) {
// Istio proxy status
s.AddTool(mcp.NewTool("istio_proxy_status",
mcp.WithDescription("Get Envoy proxy status for pods, retrieves last sent and acknowledged xDS sync from Istiod to each Envoy in the mesh"),
mcp.WithString("pod_name", mcp.Description("Name of the pod to get proxy status for")),
mcp.WithString("namespace", mcp.Description("Namespace of the pod")),
- ), handleIstioProxyStatus)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_proxy_status", handleIstioProxyStatus)))
// Istio proxy config
s.AddTool(mcp.NewTool("istio_proxy_config",
@@ -314,79 +313,62 @@ func RegisterIstioTools(s *server.MCPServer, kubeconfig string) {
mcp.WithString("pod_name", mcp.Description("Name of the pod to get proxy configuration for"), mcp.Required()),
mcp.WithString("namespace", mcp.Description("Namespace of the pod")),
mcp.WithString("config_type", mcp.Description("Type of configuration (all, bootstrap, cluster, ecds, listener, log, route, secret)")),
- ), handleIstioProxyConfig)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_proxy_config", handleIstioProxyConfig)))
// Istio install
s.AddTool(mcp.NewTool("istio_install_istio",
mcp.WithDescription("Install Istio with a specified configuration profile"),
mcp.WithString("profile", mcp.Description("Istio configuration profile (ambient, default, demo, minimal, empty)")),
- ), handleIstioInstall)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_install_istio", handleIstioInstall)))
// Istio generate manifest
s.AddTool(mcp.NewTool("istio_generate_manifest",
- mcp.WithDescription("Generate an Istio install manifest"),
+ mcp.WithDescription("Generate Istio manifest for a given profile"),
mcp.WithString("profile", mcp.Description("Istio configuration profile (ambient, default, demo, minimal, empty)")),
- ), handleIstioGenerateManifest)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_generate_manifest", handleIstioGenerateManifest)))
// Istio analyze
s.AddTool(mcp.NewTool("istio_analyze_cluster_configuration",
- mcp.WithDescription("Analyze live cluster configuration for potential issues"),
- mcp.WithString("namespace", mcp.Description("Namespace to analyze")),
- mcp.WithString("all_namespaces", mcp.Description("Analyze all namespaces (true/false)")),
- ), handleIstioAnalyzeClusterConfiguration)
+ mcp.WithDescription("Analyze Istio cluster configuration for issues"),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_analyze_cluster_configuration", handleIstioAnalyzeClusterConfiguration)))
// Istio version
s.AddTool(mcp.NewTool("istio_version",
- mcp.WithDescription("Get Istio CLI client version, control plane and data plane versions"),
- mcp.WithString("short", mcp.Description("Show short version format (true/false)")),
- ), handleIstioVersion)
+ mcp.WithDescription("Get Istio version information"),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_version", handleIstioVersion)))
// Istio remote clusters
s.AddTool(mcp.NewTool("istio_remote_clusters",
- mcp.WithDescription("List remote clusters each istiod instance is connected to"),
- ), handleIstioRemoteClusters)
+ mcp.WithDescription("List remote clusters registered with Istio"),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_remote_clusters", handleIstioRemoteClusters)))
// Waypoint list
s.AddTool(mcp.NewTool("istio_list_waypoints",
- mcp.WithDescription("List managed waypoint configurations in the cluster"),
- mcp.WithString("namespace", mcp.Description("Namespace to list waypoints for")),
- mcp.WithString("all_namespaces", mcp.Description("List waypoints for all namespaces (true/false)")),
- ), handleWaypointList)
+ mcp.WithDescription("List all waypoints in the mesh"),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_list_waypoints", handleWaypointList)))
// Waypoint generate
s.AddTool(mcp.NewTool("istio_generate_waypoint",
- mcp.WithDescription("Generate a waypoint configuration as YAML"),
- mcp.WithString("namespace", mcp.Description("Namespace to generate the waypoint for"), mcp.Required()),
- mcp.WithString("name", mcp.Description("Name of the waypoint to generate")),
- mcp.WithString("traffic_type", mcp.Description("Traffic type for the waypoint (all, inbound, outbound)")),
- ), handleWaypointGenerate)
+ mcp.WithDescription("Generate a waypoint resource YAML"),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_generate_waypoint", handleWaypointGenerate)))
// Waypoint apply
s.AddTool(mcp.NewTool("istio_apply_waypoint",
- mcp.WithDescription("Apply a waypoint configuration to a cluster"),
- mcp.WithString("namespace", mcp.Description("Namespace to apply the waypoint to"), mcp.Required()),
- mcp.WithString("enroll_namespace", mcp.Description("Label the namespace with the waypoint name (true/false)")),
- ), handleWaypointApply)
+ mcp.WithDescription("Apply a waypoint resource to the cluster"),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_apply_waypoint", handleWaypointApply)))
// Waypoint delete
s.AddTool(mcp.NewTool("istio_delete_waypoint",
- mcp.WithDescription("Delete waypoint configurations from a cluster"),
- mcp.WithString("namespace", mcp.Description("Namespace to delete waypoints from"), mcp.Required()),
- mcp.WithString("names", mcp.Description("Comma-separated list of waypoint names to delete")),
- mcp.WithString("all", mcp.Description("Delete all waypoints in the namespace (true/false)")),
- ), handleWaypointDelete)
+ mcp.WithDescription("Delete a waypoint resource from the cluster"),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_delete_waypoint", handleWaypointDelete)))
// Waypoint status
s.AddTool(mcp.NewTool("istio_waypoint_status",
- mcp.WithDescription("Get status of a waypoint"),
- mcp.WithString("namespace", mcp.Description("Namespace of the waypoint"), mcp.Required()),
- mcp.WithString("name", mcp.Description("Name of the waypoint to get status for")),
- ), handleWaypointStatus)
+ mcp.WithDescription("Get the status of a waypoint resource"),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_waypoint_status", handleWaypointStatus)))
// Ztunnel config
s.AddTool(mcp.NewTool("istio_ztunnel_config",
- mcp.WithDescription("Get ztunnel configuration"),
- mcp.WithString("namespace", mcp.Description("Namespace of the pod")),
- mcp.WithString("config_type", mcp.Description("Type of configuration (all, bootstrap, cluster, ecds, listener, log, route, secret)")),
- ), handleZtunnelConfig)
+ mcp.WithDescription("Get the ztunnel configuration for a namespace"),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_ztunnel_config", handleZtunnelConfig)))
}
diff --git a/pkg/istio/istio_test.go b/pkg/istio/istio_test.go
index 2adaf99..02efbee 100644
--- a/pkg/istio/istio_test.go
+++ b/pkg/istio/istio_test.go
@@ -4,299 +4,185 @@ import (
"context"
"testing"
- "github.com/kagent-dev/tools/pkg/utils"
+ "github.com/kagent-dev/tools/internal/cmd"
"github.com/mark3labs/mcp-go/mcp"
+ "github.com/mark3labs/mcp-go/server"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
-// Helper function to extract text content from MCP result
-func getResultText(result *mcp.CallToolResult) string {
- if result == nil || len(result.Content) == 0 {
- return ""
- }
- if textContent, ok := result.Content[0].(mcp.TextContent); ok {
- return textContent.Text
- }
- return ""
+func TestRegisterTools(t *testing.T) {
+ s := server.NewMCPServer("test-server", "v0.0.1")
+ RegisterTools(s)
}
-// Test Istio Proxy Status
func TestHandleIstioProxyStatus(t *testing.T) {
+ ctx := context.Background()
+
t.Run("basic proxy status", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedOutput := `NAME CDS LDS EDS RDS ISTIOD VERSION
-app-1 SYNCED SYNCED SYNCED SYNCED istiod-68d5d5b5fc-7vf6n 1.18.0
-app-2 SYNCED SYNCED SYNCED SYNCED istiod-68d5d5b5fc-7vf6n 1.18.0`
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("istioctl", []string{"proxy-status", "--timeout", "30s"}, "Proxy status output", nil)
- mock.AddCommandString("istioctl", []string{"proxy-status"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ ctx = cmd.WithShellExecutor(ctx, mock)
- request := mcp.CallToolRequest{}
- result, err := handleIstioProxyStatus(ctx, request)
+ result, err := handleIstioProxyStatus(ctx, mcp.CallToolRequest{})
- assert.NoError(t, err)
+ require.NoError(t, err)
assert.NotNil(t, result)
assert.False(t, result.IsError)
-
- // Verify the expected output
- content := getResultText(result)
- assert.Contains(t, content, "app-1")
- assert.Contains(t, content, "SYNCED")
-
- // Verify the correct command was called
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "istioctl", callLog[0].Command)
- assert.Equal(t, []string{"proxy-status"}, callLog[0].Args)
})
t.Run("proxy status with namespace", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedOutput := `NAME CDS LDS EDS RDS ISTIOD VERSION
-app-1 SYNCED SYNCED SYNCED SYNCED istiod-68d5d5b5fc-7vf6n 1.18.0`
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("istioctl", []string{"proxy-status", "-n", "istio-system", "--timeout", "30s"}, "Proxy status output", nil)
- mock.AddCommandString("istioctl", []string{"proxy-status", "-n", "production"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ ctx = cmd.WithShellExecutor(ctx, mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
- "namespace": "production",
+ "namespace": "istio-system",
}
result, err := handleIstioProxyStatus(ctx, request)
- assert.NoError(t, err)
+ require.NoError(t, err)
+ assert.NotNil(t, result)
assert.False(t, result.IsError)
-
- // Verify the correct command was called with namespace
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "istioctl", callLog[0].Command)
- assert.Equal(t, []string{"proxy-status", "-n", "production"}, callLog[0].Args)
})
- t.Run("proxy status with pod name and namespace", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedOutput := `NAME CDS LDS EDS RDS ISTIOD VERSION
-app-1 SYNCED SYNCED SYNCED SYNCED istiod-68d5d5b5fc-7vf6n 1.18.0`
+ t.Run("proxy status with pod name", func(t *testing.T) {
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("istioctl", []string{"proxy-status", "-n", "default", "test-pod", "--timeout", "30s"}, "Proxy status output", nil)
- mock.AddCommandString("istioctl", []string{"proxy-status", "-n", "production", "app-1"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ ctx = cmd.WithShellExecutor(ctx, mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
- "namespace": "production",
- "pod_name": "app-1",
+ "pod_name": "test-pod",
+ "namespace": "default",
}
result, err := handleIstioProxyStatus(ctx, request)
- assert.NoError(t, err)
+ require.NoError(t, err)
+ assert.NotNil(t, result)
assert.False(t, result.IsError)
-
- // Verify the correct command was called
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "istioctl", callLog[0].Command)
- assert.Equal(t, []string{"proxy-status", "-n", "production", "app-1"}, callLog[0].Args)
})
+}
- t.Run("istioctl command failure", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- mock.AddCommandString("istioctl", []string{"proxy-status"}, "", assert.AnError)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+func TestHandleIstioProxyConfig(t *testing.T) {
+ ctx := context.Background()
- request := mcp.CallToolRequest{}
- result, err := handleIstioProxyStatus(ctx, request)
+ t.Run("missing pod_name parameter", func(t *testing.T) {
+ result, err := handleIstioProxyConfig(ctx, mcp.CallToolRequest{})
- assert.NoError(t, err) // MCP handlers should not return Go errors
+ require.NoError(t, err)
+ assert.NotNil(t, result)
assert.True(t, result.IsError)
- assert.Contains(t, getResultText(result), "istioctl proxy-status failed")
})
-}
-// Test Istio Proxy Config
-func TestHandleIstioProxyConfig(t *testing.T) {
- t.Run("proxy config all", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedOutput := `CLUSTER NAME DIRECTION TYPE DESTINATION RULE
-outbound|80||kubernetes.default.svc.cluster.local outbound EDS
-inbound|80|| inbound EDS`
+ t.Run("proxy config with pod name", func(t *testing.T) {
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("istioctl", []string{"proxy-config", "all", "test-pod", "--timeout", "30s"}, "Proxy config output", nil)
- mock.AddCommandString("istioctl", []string{"proxy-config", "all", "app-1"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ ctx = cmd.WithShellExecutor(ctx, mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
- "pod_name": "app-1",
+ "pod_name": "test-pod",
}
result, err := handleIstioProxyConfig(ctx, request)
- assert.NoError(t, err)
+ require.NoError(t, err)
+ assert.NotNil(t, result)
assert.False(t, result.IsError)
- assert.Contains(t, getResultText(result), "CLUSTER NAME")
-
- // Verify the correct command was called
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "istioctl", callLog[0].Command)
- assert.Equal(t, []string{"proxy-config", "all", "app-1"}, callLog[0].Args)
})
- t.Run("proxy config with namespace and config type", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedOutput := `CLUSTER NAME DIRECTION TYPE DESTINATION RULE
-outbound|80||kubernetes.default.svc.cluster.local outbound EDS`
+ t.Run("proxy config with namespace", func(t *testing.T) {
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("istioctl", []string{"proxy-config", "cluster", "test-pod.default", "--timeout", "30s"}, "Proxy config output", nil)
- mock.AddCommandString("istioctl", []string{"proxy-config", "cluster", "app-1.production"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ ctx = cmd.WithShellExecutor(ctx, mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
- "pod_name": "app-1",
- "namespace": "production",
+ "pod_name": "test-pod",
+ "namespace": "default",
"config_type": "cluster",
}
result, err := handleIstioProxyConfig(ctx, request)
- assert.NoError(t, err)
+ require.NoError(t, err)
+ assert.NotNil(t, result)
assert.False(t, result.IsError)
-
- // Verify the correct command was called
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "istioctl", callLog[0].Command)
- assert.Equal(t, []string{"proxy-config", "cluster", "app-1.production"}, callLog[0].Args)
- })
-
- t.Run("missing required parameters", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- ctx := utils.WithShellExecutor(context.Background(), mock)
-
- request := mcp.CallToolRequest{}
- request.Params.Arguments = map[string]interface{}{
- // Missing pod_name
- }
-
- result, err := handleIstioProxyConfig(ctx, request)
- assert.NoError(t, err)
- assert.True(t, result.IsError)
- assert.Contains(t, getResultText(result), "pod_name parameter is required")
-
- // Verify no commands were executed
- callLog := mock.GetCallLog()
- assert.Len(t, callLog, 0)
})
}
-// Test Istio Install
func TestHandleIstioInstall(t *testing.T) {
- t.Run("basic install", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedOutput := `✔ Istio core installed
-✔ Istiod installed
-✔ Ingress gateways installed
-✔ Installation complete`
+ ctx := context.Background()
- mock.AddCommandString("istioctl", []string{"install", "--set", "profile=default", "-y"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ t.Run("install with default profile", func(t *testing.T) {
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("istioctl", []string{"install", "--set", "profile=default", "-y", "--timeout", "30s"}, "Install completed", nil)
- request := mcp.CallToolRequest{}
- result, err := handleIstioInstall(ctx, request)
+ ctx = cmd.WithShellExecutor(ctx, mock)
- assert.NoError(t, err)
- assert.False(t, result.IsError)
- assert.Contains(t, getResultText(result), "Installation complete")
+ result, err := handleIstioInstall(ctx, mcp.CallToolRequest{})
- // Verify the correct command was called
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "istioctl", callLog[0].Command)
- assert.Equal(t, []string{"install", "--set", "profile=default", "-y"}, callLog[0].Args)
+ require.NoError(t, err)
+ assert.NotNil(t, result)
+ assert.False(t, result.IsError)
})
- t.Run("install with profile", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedOutput := `✔ Istio core installed
-✔ Installation complete`
+ t.Run("install with custom profile", func(t *testing.T) {
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("istioctl", []string{"install", "--set", "profile=demo", "-y", "--timeout", "30s"}, "Install completed", nil)
- mock.AddCommandString("istioctl", []string{"install", "--set", "profile=minimal", "-y"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ ctx = cmd.WithShellExecutor(ctx, mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
- "profile": "minimal",
+ "profile": "demo",
}
result, err := handleIstioInstall(ctx, request)
- assert.NoError(t, err)
+ require.NoError(t, err)
+ assert.NotNil(t, result)
assert.False(t, result.IsError)
-
- // Verify the correct command was called with profile
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "istioctl", callLog[0].Command)
- assert.Equal(t, []string{"install", "--set", "profile=minimal", "-y"}, callLog[0].Args)
})
}
-// Test Istio Analyze
-func TestHandleIstioAnalyzeClusterConfiguration(t *testing.T) {
- t.Run("basic analyze", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedOutput := `✔ No validation issues found when analyzing namespace: default.`
-
- mock.AddCommandString("istioctl", []string{"analyze"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
-
- request := mcp.CallToolRequest{}
- result, err := handleIstioAnalyzeClusterConfiguration(ctx, request)
+func TestHandleIstioGenerateManifest(t *testing.T) {
+ ctx := context.Background()
+ mock := cmd.NewMockShellExecutor()
- assert.NoError(t, err)
- assert.False(t, result.IsError)
- assert.Contains(t, getResultText(result), "No validation issues found")
+ mock.AddCommandString("istioctl", []string{"manifest", "generate", "--set", "profile=minimal", "--timeout", "30s"}, "Generated manifest", nil)
- // Verify the correct command was called
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "istioctl", callLog[0].Command)
- assert.Equal(t, []string{"analyze"}, callLog[0].Args)
- })
+ ctx = cmd.WithShellExecutor(ctx, mock)
- t.Run("analyze with namespace", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedOutput := `✔ No validation issues found when analyzing namespace: production.`
-
- mock.AddCommandString("istioctl", []string{"analyze", "-n", "production"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
-
- request := mcp.CallToolRequest{}
- request.Params.Arguments = map[string]interface{}{
- "namespace": "production",
- }
+ request := mcp.CallToolRequest{}
+ request.Params.Arguments = map[string]interface{}{
+ "profile": "minimal",
+ }
- result, err := handleIstioAnalyzeClusterConfiguration(ctx, request)
+ result, err := handleIstioGenerateManifest(ctx, request)
- assert.NoError(t, err)
- assert.False(t, result.IsError)
+ require.NoError(t, err)
+ assert.NotNil(t, result)
+ assert.False(t, result.IsError)
+}
- // Verify the correct command was called with namespace
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "istioctl", callLog[0].Command)
- assert.Equal(t, []string{"analyze", "-n", "production"}, callLog[0].Args)
- })
+func TestHandleIstioAnalyzeClusterConfiguration(t *testing.T) {
+ ctx := context.Background()
t.Run("analyze all namespaces", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedOutput := `✔ No validation issues found when analyzing all namespaces.`
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("istioctl", []string{"analyze", "-A", "--timeout", "30s"}, "Analysis output", nil)
- mock.AddCommandString("istioctl", []string{"analyze", "-A"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ ctx = cmd.WithShellExecutor(ctx, mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
@@ -305,565 +191,167 @@ func TestHandleIstioAnalyzeClusterConfiguration(t *testing.T) {
result, err := handleIstioAnalyzeClusterConfiguration(ctx, request)
- assert.NoError(t, err)
- assert.False(t, result.IsError)
-
- // Verify the correct command was called with -A flag
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "istioctl", callLog[0].Command)
- assert.Equal(t, []string{"analyze", "-A"}, callLog[0].Args)
- })
-}
-
-// Test Istio Version
-func TestHandleIstioVersion(t *testing.T) {
- t.Run("version detailed output", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedOutput := `client version: 1.18.0
-control plane version: 1.18.0
-data plane version: 1.18.0 (2 proxies)`
-
- mock.AddCommandString("istioctl", []string{"version"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
-
- request := mcp.CallToolRequest{}
- result, err := handleIstioVersion(ctx, request)
-
- assert.NoError(t, err)
- assert.False(t, result.IsError)
- assert.Contains(t, getResultText(result), "client version: 1.18.0")
-
- // Verify the correct command was called
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "istioctl", callLog[0].Command)
- assert.Equal(t, []string{"version"}, callLog[0].Args)
- })
-
- t.Run("version short output", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedOutput := `1.18.0`
-
- mock.AddCommandString("istioctl", []string{"version", "--short"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
-
- request := mcp.CallToolRequest{}
- request.Params.Arguments = map[string]interface{}{
- "short": "true",
- }
-
- result, err := handleIstioVersion(ctx, request)
-
- assert.NoError(t, err)
- assert.False(t, result.IsError)
- assert.Contains(t, getResultText(result), "1.18.0")
-
- // Verify the correct command was called with --short flag
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "istioctl", callLog[0].Command)
- assert.Equal(t, []string{"version", "--short"}, callLog[0].Args)
- })
-}
-
-// Test Waypoint List
-func TestHandleWaypointList(t *testing.T) {
- t.Run("list waypoints", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedOutput := `NAMESPACE NAME TRAFFIC TYPE
-default waypoint ALL
-production waypoint INBOUND`
-
- mock.AddCommandString("istioctl", []string{"waypoint", "list"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
-
- request := mcp.CallToolRequest{}
- result, err := handleWaypointList(ctx, request)
-
- assert.NoError(t, err)
- assert.False(t, result.IsError)
- assert.Contains(t, getResultText(result), "NAMESPACE")
- assert.Contains(t, getResultText(result), "waypoint")
-
- // Verify the correct command was called
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "istioctl", callLog[0].Command)
- assert.Equal(t, []string{"waypoint", "list"}, callLog[0].Args)
- })
-
- t.Run("list waypoints in namespace", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedOutput := `NAMESPACE NAME TRAFFIC TYPE
-production waypoint INBOUND`
-
- mock.AddCommandString("istioctl", []string{"waypoint", "list", "-n", "production"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
-
- request := mcp.CallToolRequest{}
- request.Params.Arguments = map[string]interface{}{
- "namespace": "production",
- }
-
- result, err := handleWaypointList(ctx, request)
-
- assert.NoError(t, err)
- assert.False(t, result.IsError)
-
- // Verify the correct command was called with namespace
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "istioctl", callLog[0].Command)
- assert.Equal(t, []string{"waypoint", "list", "-n", "production"}, callLog[0].Args)
- })
-}
-
-// Test Waypoint Generate
-func TestHandleWaypointGenerate(t *testing.T) {
- t.Run("generate waypoint", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedOutput := `apiVersion: gateway.networking.k8s.io/v1beta1
-kind: Gateway
-metadata:
- name: waypoint
- namespace: production
-spec:
- gatewayClassName: istio-waypoint`
-
- mock.AddCommandString("istioctl", []string{"waypoint", "generate", "waypoint", "-n", "production", "--for", "all"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
-
- request := mcp.CallToolRequest{}
- request.Params.Arguments = map[string]interface{}{
- "namespace": "production",
- }
-
- result, err := handleWaypointGenerate(ctx, request)
-
- assert.NoError(t, err)
+ require.NoError(t, err)
+ assert.NotNil(t, result)
assert.False(t, result.IsError)
- assert.Contains(t, getResultText(result), "apiVersion: gateway.networking.k8s.io/v1beta1")
-
- // Verify the correct command was called
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "istioctl", callLog[0].Command)
- assert.Equal(t, []string{"waypoint", "generate", "waypoint", "-n", "production", "--for", "all"}, callLog[0].Args)
- })
-
- t.Run("missing required parameters", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- ctx := utils.WithShellExecutor(context.Background(), mock)
-
- request := mcp.CallToolRequest{}
- request.Params.Arguments = map[string]interface{}{
- // Missing namespace
- }
-
- result, err := handleWaypointGenerate(ctx, request)
- assert.NoError(t, err)
- assert.True(t, result.IsError)
- assert.Contains(t, getResultText(result), "namespace parameter is required")
-
- // Verify no commands were executed
- callLog := mock.GetCallLog()
- assert.Len(t, callLog, 0)
})
-}
-// Test Waypoint Apply
-func TestHandleWaypointApply(t *testing.T) {
- t.Run("basic waypoint apply", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedOutput := `waypoint/waypoint applied`
+ t.Run("analyze specific namespace", func(t *testing.T) {
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("istioctl", []string{"analyze", "-n", "default", "--timeout", "30s"}, "Analysis output", nil)
- mock.AddCommandString("istioctl", []string{"waypoint", "apply", "-n", "default"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ ctx = cmd.WithShellExecutor(ctx, mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
"namespace": "default",
}
- result, err := handleWaypointApply(ctx, request)
+ result, err := handleIstioAnalyzeClusterConfiguration(ctx, request)
- assert.NoError(t, err)
+ require.NoError(t, err)
assert.NotNil(t, result)
assert.False(t, result.IsError)
- assert.Contains(t, getResultText(result), "applied")
-
- // Verify the correct command was called
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "istioctl", callLog[0].Command)
- assert.Equal(t, []string{"waypoint", "apply", "-n", "default"}, callLog[0].Args)
- })
-
- t.Run("waypoint apply with enroll namespace", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedOutput := `waypoint/waypoint applied
-namespace/default labeled with istio.io/use-waypoint=waypoint`
-
- mock.AddCommandString("istioctl", []string{"waypoint", "apply", "-n", "default", "--enroll-namespace"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
-
- request := mcp.CallToolRequest{}
- request.Params.Arguments = map[string]interface{}{
- "namespace": "default",
- "enroll_namespace": "true",
- }
-
- result, err := handleWaypointApply(ctx, request)
-
- assert.NoError(t, err)
- assert.False(t, result.IsError)
- assert.Contains(t, getResultText(result), "applied")
-
- // Verify the correct command was called with --enroll-namespace flag
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "istioctl", callLog[0].Command)
- assert.Equal(t, []string{"waypoint", "apply", "-n", "default", "--enroll-namespace"}, callLog[0].Args)
- })
-
- t.Run("missing namespace parameter", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- ctx := utils.WithShellExecutor(context.Background(), mock)
-
- request := mcp.CallToolRequest{}
- request.Params.Arguments = map[string]interface{}{
- // Missing namespace
- }
-
- result, err := handleWaypointApply(ctx, request)
- assert.NoError(t, err)
- assert.NotNil(t, result)
- assert.True(t, result.IsError)
- assert.Contains(t, getResultText(result), "namespace parameter is required")
-
- // Verify no commands were executed
- callLog := mock.GetCallLog()
- assert.Len(t, callLog, 0)
- })
-
- t.Run("istioctl command failure", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- mock.AddCommandString("istioctl", []string{"waypoint", "apply", "-n", "default"}, "", assert.AnError)
- ctx := utils.WithShellExecutor(context.Background(), mock)
-
- request := mcp.CallToolRequest{}
- request.Params.Arguments = map[string]interface{}{
- "namespace": "default",
- }
-
- result, err := handleWaypointApply(ctx, request)
-
- assert.NoError(t, err) // MCP handlers should not return Go errors
- assert.True(t, result.IsError)
- assert.Contains(t, getResultText(result), "istioctl waypoint apply failed")
})
}
-// Test Waypoint Delete
-func TestHandleWaypointDelete(t *testing.T) {
- t.Run("delete all waypoints", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedOutput := `waypoint/waypoint deleted`
+func TestHandleIstioVersion(t *testing.T) {
+ ctx := context.Background()
- mock.AddCommandString("istioctl", []string{"waypoint", "delete", "--all", "-n", "default"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ t.Run("version full", func(t *testing.T) {
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("istioctl", []string{"version", "--timeout", "30s"}, "Version output", nil)
- request := mcp.CallToolRequest{}
- request.Params.Arguments = map[string]interface{}{
- "namespace": "default",
- "all": "true",
- }
+ ctx = cmd.WithShellExecutor(ctx, mock)
- result, err := handleWaypointDelete(ctx, request)
+ result, err := handleIstioVersion(ctx, mcp.CallToolRequest{})
- assert.NoError(t, err)
+ require.NoError(t, err)
assert.NotNil(t, result)
assert.False(t, result.IsError)
- assert.Contains(t, getResultText(result), "deleted")
-
- // Verify the correct command was called
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "istioctl", callLog[0].Command)
- assert.Equal(t, []string{"waypoint", "delete", "--all", "-n", "default"}, callLog[0].Args)
})
- t.Run("delete specific waypoints", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedOutput := `waypoint/waypoint1 deleted
-waypoint/waypoint2 deleted`
+ t.Run("version short", func(t *testing.T) {
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("istioctl", []string{"version", "--short", "--timeout", "30s"}, "1.18.0", nil)
- mock.AddCommandString("istioctl", []string{"waypoint", "delete", "waypoint1", "waypoint2", "-n", "default"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ ctx = cmd.WithShellExecutor(ctx, mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
- "namespace": "default",
- "names": "waypoint1,waypoint2",
+ "short": "true",
}
- result, err := handleWaypointDelete(ctx, request)
-
- assert.NoError(t, err)
- assert.False(t, result.IsError)
- assert.Contains(t, getResultText(result), "deleted")
-
- // Verify the correct command was called with specific names
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "istioctl", callLog[0].Command)
- assert.Equal(t, []string{"waypoint", "delete", "waypoint1", "waypoint2", "-n", "default"}, callLog[0].Args)
- })
-
- t.Run("missing namespace parameter", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- ctx := utils.WithShellExecutor(context.Background(), mock)
-
- request := mcp.CallToolRequest{}
- request.Params.Arguments = map[string]interface{}{
- // Missing namespace
- }
+ result, err := handleIstioVersion(ctx, request)
- result, err := handleWaypointDelete(ctx, request)
- assert.NoError(t, err)
+ require.NoError(t, err)
assert.NotNil(t, result)
- assert.True(t, result.IsError)
- assert.Contains(t, getResultText(result), "namespace parameter is required")
-
- // Verify no commands were executed
- callLog := mock.GetCallLog()
- assert.Len(t, callLog, 0)
- })
-
- t.Run("istioctl command failure", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- mock.AddCommandString("istioctl", []string{"waypoint", "delete", "--all", "-n", "default"}, "", assert.AnError)
- ctx := utils.WithShellExecutor(context.Background(), mock)
-
- request := mcp.CallToolRequest{}
- request.Params.Arguments = map[string]interface{}{
- "namespace": "default",
- "all": "true",
- }
-
- result, err := handleWaypointDelete(ctx, request)
-
- assert.NoError(t, err) // MCP handlers should not return Go errors
- assert.True(t, result.IsError)
- assert.Contains(t, getResultText(result), "istioctl waypoint delete failed")
+ assert.False(t, result.IsError)
})
}
-// Test Waypoint Status
-func TestHandleWaypointStatus(t *testing.T) {
- t.Run("waypoint status", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedOutput := `waypoint/waypoint is deployed and ready`
+func TestHandleIstioRemoteClusters(t *testing.T) {
+ ctx := context.Background()
+ mock := cmd.NewMockShellExecutor()
- mock.AddCommandString("istioctl", []string{"waypoint", "status", "-n", "default"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock.AddCommandString("istioctl", []string{"remote-clusters", "--timeout", "30s"}, "Remote clusters output", nil)
- request := mcp.CallToolRequest{}
- request.Params.Arguments = map[string]interface{}{
- "namespace": "default",
- }
+ ctx = cmd.WithShellExecutor(ctx, mock)
- result, err := handleWaypointStatus(ctx, request)
+ result, err := handleIstioRemoteClusters(ctx, mcp.CallToolRequest{})
- assert.NoError(t, err)
- assert.NotNil(t, result)
- assert.False(t, result.IsError)
- assert.Contains(t, getResultText(result), "waypoint")
+ require.NoError(t, err)
+ assert.NotNil(t, result)
+ assert.False(t, result.IsError)
+}
- // Verify the correct command was called
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "istioctl", callLog[0].Command)
- assert.Equal(t, []string{"waypoint", "status", "-n", "default"}, callLog[0].Args)
- })
+func TestHandleWaypointList(t *testing.T) {
+ ctx := context.Background()
- t.Run("waypoint status with specific name", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedOutput := `waypoint/test-waypoint is deployed and ready`
+ t.Run("list waypoints in all namespaces", func(t *testing.T) {
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("istioctl", []string{"waypoint", "list", "-A", "--timeout", "30s"}, "Waypoints list", nil)
- mock.AddCommandString("istioctl", []string{"waypoint", "status", "test-waypoint", "-n", "default"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ ctx = cmd.WithShellExecutor(ctx, mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
- "namespace": "default",
- "name": "test-waypoint",
+ "all_namespaces": "true",
}
- result, err := handleWaypointStatus(ctx, request)
+ result, err := handleWaypointList(ctx, request)
- assert.NoError(t, err)
+ require.NoError(t, err)
+ assert.NotNil(t, result)
assert.False(t, result.IsError)
- assert.Contains(t, getResultText(result), "test-waypoint")
-
- // Verify the correct command was called with specific name
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "istioctl", callLog[0].Command)
- assert.Equal(t, []string{"waypoint", "status", "test-waypoint", "-n", "default"}, callLog[0].Args)
})
- t.Run("missing namespace parameter", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ t.Run("list waypoints in a specific namespace", func(t *testing.T) {
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("istioctl", []string{"waypoint", "list", "-n", "default", "--timeout", "30s"}, "Waypoints list", nil)
- request := mcp.CallToolRequest{}
- request.Params.Arguments = map[string]interface{}{
- // Missing namespace
- }
-
- result, err := handleWaypointStatus(ctx, request)
- assert.NoError(t, err)
- assert.NotNil(t, result)
- assert.True(t, result.IsError)
- assert.Contains(t, getResultText(result), "namespace parameter is required")
-
- // Verify no commands were executed
- callLog := mock.GetCallLog()
- assert.Len(t, callLog, 0)
- })
-
- t.Run("istioctl command failure", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- mock.AddCommandString("istioctl", []string{"waypoint", "status", "-n", "default"}, "", assert.AnError)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ ctx = cmd.WithShellExecutor(ctx, mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
"namespace": "default",
}
- result, err := handleWaypointStatus(ctx, request)
-
- assert.NoError(t, err) // MCP handlers should not return Go errors
- assert.True(t, result.IsError)
- assert.Contains(t, getResultText(result), "istioctl waypoint status failed")
- })
-}
-
-// Test Ztunnel Config
-func TestHandleZtunnelConfig(t *testing.T) {
- t.Run("default ztunnel config", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedOutput := `CLUSTER_NAME CLUSTER_TYPE ENDPOINTS
-cluster1 EDS 10.0.0.1:15010
-cluster2 STATIC 10.0.0.2:15010`
-
- mock.AddCommandString("istioctl", []string{"ztunnel-config", "all"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
-
- request := mcp.CallToolRequest{}
- result, err := handleZtunnelConfig(ctx, request)
+ result, err := handleWaypointList(ctx, request)
- assert.NoError(t, err)
+ require.NoError(t, err)
assert.NotNil(t, result)
assert.False(t, result.IsError)
- assert.Contains(t, getResultText(result), "CLUSTER_NAME")
-
- // Verify the correct command was called
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "istioctl", callLog[0].Command)
- assert.Equal(t, []string{"ztunnel-config", "all"}, callLog[0].Args)
})
+}
- t.Run("ztunnel config with namespace", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedOutput := `CLUSTER_NAME CLUSTER_TYPE ENDPOINTS
-cluster1 EDS 10.0.0.1:15010`
-
- mock.AddCommandString("istioctl", []string{"ztunnel-config", "all", "-n", "istio-system"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
-
- request := mcp.CallToolRequest{}
- request.Params.Arguments = map[string]interface{}{
- "namespace": "istio-system",
- }
-
- result, err := handleZtunnelConfig(ctx, request)
-
- assert.NoError(t, err)
- assert.False(t, result.IsError)
-
- // Verify the correct command was called with namespace
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "istioctl", callLog[0].Command)
- assert.Equal(t, []string{"ztunnel-config", "all", "-n", "istio-system"}, callLog[0].Args)
- })
+func TestHandleWaypointGenerate(t *testing.T) {
+ ctx := context.Background()
- t.Run("ztunnel config with specific type", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedOutput := `CLUSTER_NAME CLUSTER_TYPE ENDPOINTS
-cluster1 EDS 10.0.0.1:15010`
+ t.Run("generate waypoint with namespace", func(t *testing.T) {
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("istioctl", []string{"waypoint", "generate", "waypoint", "-n", "default", "--for", "all", "--timeout", "30s"}, "Generated waypoint", nil)
- mock.AddCommandString("istioctl", []string{"ztunnel-config", "cluster"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ ctx = cmd.WithShellExecutor(ctx, mock)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
- "config_type": "cluster",
+ "namespace": "default",
+ "name": "waypoint",
+ "traffic_type": "all",
}
- result, err := handleZtunnelConfig(ctx, request)
+ result, err := handleWaypointGenerate(ctx, request)
- assert.NoError(t, err)
+ require.NoError(t, err)
+ assert.NotNil(t, result)
assert.False(t, result.IsError)
-
- // Verify the correct command was called with specific config type
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "istioctl", callLog[0].Command)
- assert.Equal(t, []string{"ztunnel-config", "cluster"}, callLog[0].Args)
})
+}
- t.Run("ztunnel config with namespace and config type", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedOutput := `LISTENER_NAME ADDRESS PORT TYPE
-listener1 0.0.0.0 15006 TCP`
-
- mock.AddCommandString("istioctl", []string{"ztunnel-config", "listener", "-n", "istio-system"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
-
- request := mcp.CallToolRequest{}
- request.Params.Arguments = map[string]interface{}{
- "namespace": "istio-system",
- "config_type": "listener",
- }
-
- result, err := handleZtunnelConfig(ctx, request)
+func TestRunIstioCtl(t *testing.T) {
+ t.Run("run istioctl with context", func(t *testing.T) {
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("istioctl", []string{"version", "--timeout", "30s"}, "1.18.0", nil)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
- assert.NoError(t, err)
- assert.False(t, result.IsError)
+ result, err := runIstioCtl(ctx, []string{"version"})
- // Verify the correct command was called with both namespace and config type
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "istioctl", callLog[0].Command)
- assert.Equal(t, []string{"ztunnel-config", "listener", "-n", "istio-system"}, callLog[0].Args)
+ require.NoError(t, err)
+ assert.Equal(t, "1.18.0", result)
})
+}
+func TestIstioErrorHandling(t *testing.T) {
t.Run("istioctl command failure", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- mock.AddCommandString("istioctl", []string{"ztunnel-config", "all"}, "", assert.AnError)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock := cmd.NewMockShellExecutor()
+ mock.AddCommandString("istioctl", []string{"proxy-status"}, "", assert.AnError)
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
- request := mcp.CallToolRequest{}
- result, err := handleZtunnelConfig(ctx, request)
+ result, err := handleIstioProxyStatus(ctx, mcp.CallToolRequest{})
- assert.NoError(t, err) // MCP handlers should not return Go errors
+ require.NoError(t, err)
+ assert.NotNil(t, result)
assert.True(t, result.IsError)
- assert.Contains(t, getResultText(result), "istioctl ztunnel-config failed")
})
}
diff --git a/pkg/k8s/k8s.go b/pkg/k8s/k8s.go
index 6bd73ec..6c29c9c 100644
--- a/pkg/k8s/k8s.go
+++ b/pkg/k8s/k8s.go
@@ -10,12 +10,15 @@ import (
"slices"
"strings"
- "github.com/kagent-dev/tools/pkg/logger"
- "github.com/kagent-dev/tools/pkg/utils"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
"github.com/tmc/langchaingo/llms"
- "github.com/tmc/langchaingo/llms/openai"
+
+ "github.com/kagent-dev/tools/internal/cache"
+ "github.com/kagent-dev/tools/internal/commands"
+ "github.com/kagent-dev/tools/internal/logger"
+ "github.com/kagent-dev/tools/internal/security"
+ "github.com/kagent-dev/tools/internal/telemetry"
)
// K8sTool struct to hold the LLM model
@@ -32,6 +35,22 @@ func NewK8sToolWithConfig(kubeconfig string, llmModel llms.Model) *K8sTool {
return &K8sTool{kubeconfig: kubeconfig, llmModel: llmModel}
}
+// runKubectlCommandWithCacheInvalidation runs a kubectl command and invalidates cache if it's a modification operation
+func (k *K8sTool) runKubectlCommandWithCacheInvalidation(ctx context.Context, args ...string) (*mcp.CallToolResult, error) {
+ result, err := k.runKubectlCommand(ctx, args...)
+
+ // If command succeeded and it's a modification command, invalidate cache
+ if err == nil && len(args) > 0 {
+ subcommand := args[0]
+ switch subcommand {
+ case "apply", "delete", "patch", "scale", "annotate", "label", "create", "run", "rollout":
+ cache.InvalidateKubernetesCache()
+ }
+ }
+
+ return result, err
+}
+
// Enhanced kubectl get
func (k *K8sTool) handleKubectlGetEnhanced(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
resourceType := mcp.ParseString(request, "resource_type", "")
@@ -62,7 +81,7 @@ func (k *K8sTool) handleKubectlGetEnhanced(ctx context.Context, request mcp.Call
args = append(args, "-o", "json")
}
- return k.runKubectlCommand(ctx, args)
+ return k.runKubectlCommand(ctx, args...)
}
// Get pod logs
@@ -86,7 +105,7 @@ func (k *K8sTool) handleKubectlLogsEnhanced(ctx context.Context, request mcp.Cal
args = append(args, "--tail", fmt.Sprintf("%d", tailLines))
}
- return k.runKubectlCommand(ctx, args)
+ return k.runKubectlCommand(ctx, args...)
}
// Scale deployment
@@ -101,7 +120,7 @@ func (k *K8sTool) handleScaleDeployment(ctx context.Context, request mcp.CallToo
args := []string{"scale", "deployment", deploymentName, "--replicas", fmt.Sprintf("%d", replicas), "-n", namespace}
- return k.runKubectlCommand(ctx, args)
+ return k.runKubectlCommandWithCacheInvalidation(ctx, args...)
}
// Patch resource
@@ -115,9 +134,24 @@ func (k *K8sTool) handlePatchResource(ctx context.Context, request mcp.CallToolR
return mcp.NewToolResultError("resource_type, resource_name, and patch parameters are required"), nil
}
+ // Validate resource name for security
+ if err := security.ValidateK8sResourceName(resourceName); err != nil {
+ return mcp.NewToolResultError(fmt.Sprintf("Invalid resource name: %v", err)), nil
+ }
+
+ // Validate namespace for security
+ if err := security.ValidateNamespace(namespace); err != nil {
+ return mcp.NewToolResultError(fmt.Sprintf("Invalid namespace: %v", err)), nil
+ }
+
+ // Validate patch content as JSON/YAML
+ if err := security.ValidateYAMLContent(patch); err != nil {
+ return mcp.NewToolResultError(fmt.Sprintf("Invalid patch content: %v", err)), nil
+ }
+
args := []string{"patch", resourceType, resourceName, "-p", patch, "-n", namespace}
- return k.runKubectlCommand(ctx, args)
+ return k.runKubectlCommandWithCacheInvalidation(ctx, args...)
}
// Apply manifest from content
@@ -128,18 +162,41 @@ func (k *K8sTool) handleApplyManifest(ctx context.Context, request mcp.CallToolR
return mcp.NewToolResultError("manifest parameter is required"), nil
}
- tmpFile, err := os.CreateTemp("", "manifest-*.yaml")
+ // Validate YAML content for security
+ if err := security.ValidateYAMLContent(manifest); err != nil {
+ return mcp.NewToolResultError(fmt.Sprintf("Invalid manifest content: %v", err)), nil
+ }
+
+ // Create temporary file with secure permissions
+ tmpFile, err := os.CreateTemp("", "k8s-manifest-*.yaml")
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to create temp file: %v", err)), nil
}
- defer os.Remove(tmpFile.Name())
+ // Ensure file is removed regardless of execution path
+ defer func() {
+ if removeErr := os.Remove(tmpFile.Name()); removeErr != nil {
+ logger.Get().Error("Failed to remove temporary file", "error", removeErr, "file", tmpFile.Name())
+ }
+ }()
+
+ // Set secure file permissions (readable/writable by owner only)
+ if err := os.Chmod(tmpFile.Name(), 0600); err != nil {
+ return mcp.NewToolResultError(fmt.Sprintf("Failed to set file permissions: %v", err)), nil
+ }
+
+ // Write manifest content to temporary file
if _, err := tmpFile.WriteString(manifest); err != nil {
+ tmpFile.Close()
return mcp.NewToolResultError(fmt.Sprintf("Failed to write to temp file: %v", err)), nil
}
- tmpFile.Close()
- return k.runKubectlCommand(ctx, []string{"apply", "-f", tmpFile.Name()})
+ // Close the file before passing to kubectl
+ if err := tmpFile.Close(); err != nil {
+ return mcp.NewToolResultError(fmt.Sprintf("Failed to close temp file: %v", err)), nil
+ }
+
+ return k.runKubectlCommandWithCacheInvalidation(ctx, "apply", "-f", tmpFile.Name())
}
// Delete resource
@@ -154,7 +211,7 @@ func (k *K8sTool) handleDeleteResource(ctx context.Context, request mcp.CallTool
args := []string{"delete", resourceType, resourceName, "-n", namespace}
- return k.runKubectlCommand(ctx, args)
+ return k.runKubectlCommandWithCacheInvalidation(ctx, args...)
}
// Check service connectivity
@@ -169,23 +226,23 @@ func (k *K8sTool) handleCheckServiceConnectivity(ctx context.Context, request mc
// Create a temporary curl pod for connectivity check
podName := fmt.Sprintf("curl-test-%d", rand.Intn(10000))
defer func() {
- _, _ = k.runKubectlCommand(ctx, []string{"delete", "pod", podName, "-n", namespace, "--ignore-not-found"})
+ _, _ = k.runKubectlCommand(ctx, "delete", "pod", podName, "-n", namespace, "--ignore-not-found")
}()
// Create the curl pod
- _, err := k.runKubectlCommand(ctx, []string{"run", podName, "--image=curlimages/curl", "-n", namespace, "--restart=Never", "--", "sleep", "3600"})
+ _, err := k.runKubectlCommand(ctx, "run", podName, "--image=curlimages/curl", "-n", namespace, "--restart=Never", "--", "sleep", "3600")
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to create curl pod: %v", err)), nil
}
// Wait for pod to be ready
- _, err = k.runKubectlCommand(ctx, []string{"wait", "--for=condition=ready", "pod/" + podName, "-n", namespace, "--timeout=60s"})
+ _, err = k.runKubectlCommand(ctx, "wait", "--for=condition=ready", "pod/"+podName, "-n", namespace, "--timeout=60s")
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to wait for curl pod: %v", err)), nil
}
- // Execute curl command
- return k.runKubectlCommand(ctx, []string{"exec", podName, "-n", namespace, "--", "curl", "-s", serviceName})
+ // Execute kubectl command
+ return k.runKubectlCommand(ctx, "exec", podName, "-n", namespace, "--", "curl", "-s", serviceName)
}
// Get cluster events
@@ -199,7 +256,7 @@ func (k *K8sTool) handleGetEvents(ctx context.Context, request mcp.CallToolReque
args = append(args, "--all-namespaces")
}
- return k.runKubectlCommand(ctx, args)
+ return k.runKubectlCommand(ctx, args...)
}
// Execute command in pod
@@ -212,14 +269,29 @@ func (k *K8sTool) handleExecCommand(ctx context.Context, request mcp.CallToolReq
return mcp.NewToolResultError("pod_name and command parameters are required"), nil
}
+ // Validate pod name for security
+ if err := security.ValidateK8sResourceName(podName); err != nil {
+ return mcp.NewToolResultError(fmt.Sprintf("Invalid pod name: %v", err)), nil
+ }
+
+ // Validate namespace for security
+ if err := security.ValidateNamespace(namespace); err != nil {
+ return mcp.NewToolResultError(fmt.Sprintf("Invalid namespace: %v", err)), nil
+ }
+
+ // Validate command input for security
+ if err := security.ValidateCommandInput(command); err != nil {
+ return mcp.NewToolResultError(fmt.Sprintf("Invalid command: %v", err)), nil
+ }
+
args := []string{"exec", podName, "-n", namespace, "--", command}
- return k.runKubectlCommand(ctx, args)
+ return k.runKubectlCommand(ctx, args...)
}
// Get available API resources
func (k *K8sTool) handleGetAvailableAPIResources(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- return k.runKubectlCommand(ctx, []string{"api-resources", "-o", "json"})
+ return k.runKubectlCommand(ctx, "api-resources", "-o", "json")
}
// Kubectl describe tool
@@ -237,7 +309,7 @@ func (k *K8sTool) handleKubectlDescribeTool(ctx context.Context, request mcp.Cal
args = append(args, "-n", namespace)
}
- return k.runKubectlCommand(ctx, args)
+ return k.runKubectlCommand(ctx, args...)
}
// Rollout operations
@@ -256,12 +328,12 @@ func (k *K8sTool) handleRollout(ctx context.Context, request mcp.CallToolRequest
args = append(args, "-n", namespace)
}
- return k.runKubectlCommand(ctx, args)
+ return k.runKubectlCommand(ctx, args...)
}
// Get cluster configuration
func (k *K8sTool) handleGetClusterConfiguration(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- return k.runKubectlCommand(ctx, []string{"config", "view"})
+ return k.runKubectlCommand(ctx, "config", "view", "-o", "json")
}
// Remove annotation
@@ -280,7 +352,7 @@ func (k *K8sTool) handleRemoveAnnotation(ctx context.Context, request mcp.CallTo
args = append(args, "-n", namespace)
}
- return k.runKubectlCommand(ctx, args)
+ return k.runKubectlCommand(ctx, args...)
}
// Remove label
@@ -299,7 +371,7 @@ func (k *K8sTool) handleRemoveLabel(ctx context.Context, request mcp.CallToolReq
args = append(args, "-n", namespace)
}
- return k.runKubectlCommand(ctx, args)
+ return k.runKubectlCommand(ctx, args...)
}
// Annotate resource
@@ -320,7 +392,7 @@ func (k *K8sTool) handleAnnotateResource(ctx context.Context, request mcp.CallTo
args = append(args, "-n", namespace)
}
- return k.runKubectlCommand(ctx, args)
+ return k.runKubectlCommand(ctx, args...)
}
// Label resource
@@ -341,7 +413,7 @@ func (k *K8sTool) handleLabelResource(ctx context.Context, request mcp.CallToolR
args = append(args, "-n", namespace)
}
- return k.runKubectlCommand(ctx, args)
+ return k.runKubectlCommand(ctx, args...)
}
// Create resource from URL
@@ -358,7 +430,7 @@ func (k *K8sTool) handleCreateResourceFromURL(ctx context.Context, request mcp.C
args = append(args, "-n", namespace)
}
- return k.runKubectlCommand(ctx, args)
+ return k.runKubectlCommand(ctx, args...)
}
// Resource generation embeddings
@@ -450,30 +522,27 @@ func (k *K8sTool) handleGenerateResource(ctx context.Context, request mcp.CallTo
return mcp.NewToolResultError("empty response from model"), nil
}
c1 := choices[0]
- return mcp.NewToolResultText(c1.Content), nil
+ responseText := c1.Content
+
+ return mcp.NewToolResultText(responseText), nil
}
-// Helper function to run kubectl commands
-func (k *K8sTool) runKubectlCommand(ctx context.Context, args []string) (*mcp.CallToolResult, error) {
- if k.kubeconfig != "" {
- args = append([]string{"--kubeconfig", k.kubeconfig}, args...)
- }
- result, err := utils.RunCommandWithContext(ctx, "kubectl", args)
+// runKubectlCommand is a helper function to execute kubectl commands
+func (k *K8sTool) runKubectlCommand(ctx context.Context, args ...string) (*mcp.CallToolResult, error) {
+ output, err := commands.NewCommandBuilder("kubectl").
+ WithArgs(args...).
+ WithKubeconfig(k.kubeconfig).
+ Execute(ctx)
+
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
- return mcp.NewToolResultText(result), nil
+
+ return mcp.NewToolResultText(output), nil
}
// RegisterK8sTools registers all k8s tools with the MCP server
-func RegisterK8sTools(s *server.MCPServer, kubeconfig string) {
- var llm llms.Model
- if openAiClient, err := openai.New(); err == nil {
- llm = openAiClient
- } else {
- logger.Get().Error(err, "Failed to initialize OpenAI LLM, k8s_generate_resource tool will not be available")
- }
-
+func RegisterTools(s *server.MCPServer, llm llms.Model, kubeconfig string) {
k8sTool := NewK8sToolWithConfig(kubeconfig, llm)
s.AddTool(mcp.NewTool("k8s_get_resources",
@@ -483,7 +552,7 @@ func RegisterK8sTools(s *server.MCPServer, kubeconfig string) {
mcp.WithString("namespace", mcp.Description("Namespace to query (optional)")),
mcp.WithString("all_namespaces", mcp.Description("Query all namespaces (true/false)")),
mcp.WithString("output", mcp.Description("Output format (json, yaml, wide, etc.)")),
- ), k8sTool.handleKubectlGetEnhanced)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_get_resources", k8sTool.handleKubectlGetEnhanced)))
s.AddTool(mcp.NewTool("k8s_get_pod_logs",
mcp.WithDescription("Get logs from a Kubernetes pod"),
@@ -491,14 +560,14 @@ func RegisterK8sTools(s *server.MCPServer, kubeconfig string) {
mcp.WithString("namespace", mcp.Description("Namespace of the pod (default: default)")),
mcp.WithString("container", mcp.Description("Container name (for multi-container pods)")),
mcp.WithNumber("tail_lines", mcp.Description("Number of lines to show from the end (default: 50)")),
- ), k8sTool.handleKubectlLogsEnhanced)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_get_pod_logs", k8sTool.handleKubectlLogsEnhanced)))
s.AddTool(mcp.NewTool("k8s_scale",
mcp.WithDescription("Scale a Kubernetes deployment"),
mcp.WithString("name", mcp.Description("Name of the deployment"), mcp.Required()),
mcp.WithString("namespace", mcp.Description("Namespace of the deployment (default: default)")),
mcp.WithNumber("replicas", mcp.Description("Number of replicas"), mcp.Required()),
- ), k8sTool.handleScaleDeployment)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_scale", k8sTool.handleScaleDeployment)))
s.AddTool(mcp.NewTool("k8s_patch_resource",
mcp.WithDescription("Patch a Kubernetes resource using strategic merge patch"),
@@ -506,45 +575,46 @@ func RegisterK8sTools(s *server.MCPServer, kubeconfig string) {
mcp.WithString("resource_name", mcp.Description("Name of the resource"), mcp.Required()),
mcp.WithString("patch", mcp.Description("JSON patch to apply"), mcp.Required()),
mcp.WithString("namespace", mcp.Description("Namespace of the resource (default: default)")),
- ), k8sTool.handlePatchResource)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_patch_resource", k8sTool.handlePatchResource)))
s.AddTool(mcp.NewTool("k8s_apply_manifest",
mcp.WithDescription("Apply a YAML manifest to the Kubernetes cluster"),
mcp.WithString("manifest", mcp.Description("YAML manifest content"), mcp.Required()),
- ), k8sTool.handleApplyManifest)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_apply_manifest", k8sTool.handleApplyManifest)))
s.AddTool(mcp.NewTool("k8s_delete_resource",
mcp.WithDescription("Delete a Kubernetes resource"),
mcp.WithString("resource_type", mcp.Description("Type of resource (pod, service, deployment, etc.)"), mcp.Required()),
mcp.WithString("resource_name", mcp.Description("Name of the resource"), mcp.Required()),
mcp.WithString("namespace", mcp.Description("Namespace of the resource (default: default)")),
- ), k8sTool.handleDeleteResource)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_delete_resource", k8sTool.handleDeleteResource)))
s.AddTool(mcp.NewTool("k8s_check_service_connectivity",
mcp.WithDescription("Check connectivity to a service using a temporary curl pod"),
mcp.WithString("service_name", mcp.Description("Service name to test (e.g., my-service.my-namespace.svc.cluster.local:80)"), mcp.Required()),
mcp.WithString("namespace", mcp.Description("Namespace to run the check from (default: default)")),
- ), k8sTool.handleCheckServiceConnectivity)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_check_service_connectivity", k8sTool.handleCheckServiceConnectivity)))
s.AddTool(mcp.NewTool("k8s_get_events",
- mcp.WithDescription("Get Kubernetes cluster events"),
- mcp.WithString("namespace", mcp.Description("Namespace to query events from (optional, default: all namespaces)")),
- ), k8sTool.handleGetEvents)
+ mcp.WithDescription("Get events from a Kubernetes namespace"),
+ mcp.WithString("namespace", mcp.Description("Namespace to get events from (default: default)")),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_get_events", k8sTool.handleGetEvents)))
s.AddTool(mcp.NewTool("k8s_execute_command",
- mcp.WithDescription("Execute a command inside a Kubernetes pod"),
- mcp.WithString("pod_name", mcp.Description("Name of the pod"), mcp.Required()),
+ mcp.WithDescription("Execute a command in a Kubernetes pod"),
+ mcp.WithString("pod_name", mcp.Description("Name of the pod to execute in"), mcp.Required()),
mcp.WithString("namespace", mcp.Description("Namespace of the pod (default: default)")),
+ mcp.WithString("container", mcp.Description("Container name (for multi-container pods)")),
mcp.WithString("command", mcp.Description("Command to execute"), mcp.Required()),
- ), k8sTool.handleExecCommand)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_execute_command", k8sTool.handleExecCommand)))
s.AddTool(mcp.NewTool("k8s_get_available_api_resources",
- mcp.WithDescription("Get all available API resources from the Kubernetes cluster"),
- ), k8sTool.handleGetAvailableAPIResources)
+ mcp.WithDescription("Get available Kubernetes API resources"),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_get_available_api_resources", k8sTool.handleGetAvailableAPIResources)))
s.AddTool(mcp.NewTool("k8s_get_cluster_configuration",
- mcp.WithDescription("Get the current kubectl cluster configuration"),
- ), k8sTool.handleGetClusterConfiguration)
+ mcp.WithDescription("Get cluster configuration details"),
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_get_cluster_configuration", k8sTool.handleGetClusterConfiguration)))
s.AddTool(mcp.NewTool("k8s_rollout",
mcp.WithDescription("Perform rollout operations on Kubernetes resources (history, pause, restart, resume, status, undo)"),
@@ -552,7 +622,7 @@ func RegisterK8sTools(s *server.MCPServer, kubeconfig string) {
mcp.WithString("resource_type", mcp.Description("The type of resource to rollout (e.g., deployment)"), mcp.Required()),
mcp.WithString("resource_name", mcp.Description("The name of the resource to rollout"), mcp.Required()),
mcp.WithString("namespace", mcp.Description("The namespace of the resource")),
- ), k8sTool.handleRollout)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_rollout", k8sTool.handleRollout)))
s.AddTool(mcp.NewTool("k8s_label_resource",
mcp.WithDescription("Add or update labels on a Kubernetes resource"),
@@ -560,7 +630,7 @@ func RegisterK8sTools(s *server.MCPServer, kubeconfig string) {
mcp.WithString("resource_name", mcp.Description("The name of the resource"), mcp.Required()),
mcp.WithString("labels", mcp.Description("Space-separated key=value pairs for labels"), mcp.Required()),
mcp.WithString("namespace", mcp.Description("The namespace of the resource")),
- ), k8sTool.handleLabelResource)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_label_resource", k8sTool.handleLabelResource)))
s.AddTool(mcp.NewTool("k8s_annotate_resource",
mcp.WithDescription("Add or update annotations on a Kubernetes resource"),
@@ -568,7 +638,7 @@ func RegisterK8sTools(s *server.MCPServer, kubeconfig string) {
mcp.WithString("resource_name", mcp.Description("The name of the resource"), mcp.Required()),
mcp.WithString("annotations", mcp.Description("Space-separated key=value pairs for annotations"), mcp.Required()),
mcp.WithString("namespace", mcp.Description("The namespace of the resource")),
- ), k8sTool.handleAnnotateResource)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_annotate_resource", k8sTool.handleAnnotateResource)))
s.AddTool(mcp.NewTool("k8s_remove_annotation",
mcp.WithDescription("Remove an annotation from a Kubernetes resource"),
@@ -576,7 +646,7 @@ func RegisterK8sTools(s *server.MCPServer, kubeconfig string) {
mcp.WithString("resource_name", mcp.Description("The name of the resource"), mcp.Required()),
mcp.WithString("annotation_key", mcp.Description("The key of the annotation to remove"), mcp.Required()),
mcp.WithString("namespace", mcp.Description("The namespace of the resource")),
- ), k8sTool.handleRemoveAnnotation)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_remove_annotation", k8sTool.handleRemoveAnnotation)))
s.AddTool(mcp.NewTool("k8s_remove_label",
mcp.WithDescription("Remove a label from a Kubernetes resource"),
@@ -584,12 +654,12 @@ func RegisterK8sTools(s *server.MCPServer, kubeconfig string) {
mcp.WithString("resource_name", mcp.Description("The name of the resource"), mcp.Required()),
mcp.WithString("label_key", mcp.Description("The key of the label to remove"), mcp.Required()),
mcp.WithString("namespace", mcp.Description("The namespace of the resource")),
- ), k8sTool.handleRemoveLabel)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_remove_label", k8sTool.handleRemoveLabel)))
s.AddTool(mcp.NewTool("k8s_create_resource",
mcp.WithDescription("Create a Kubernetes resource from YAML content"),
mcp.WithString("yaml_content", mcp.Description("YAML content of the resource"), mcp.Required()),
- ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_create_resource", func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
yamlContent := mcp.ParseString(request, "yaml_content", "")
if yamlContent == "" {
@@ -608,26 +678,26 @@ func RegisterK8sTools(s *server.MCPServer, kubeconfig string) {
}
tmpFile.Close()
- result, err := utils.RunCommandWithContext(ctx, "kubectl", []string{"create", "-f", tmpFile.Name()})
+ result, err := k8sTool.runKubectlCommand(ctx, "create", "-f", tmpFile.Name())
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Create command failed: %v", err)), nil
}
- return mcp.NewToolResultText(result), nil
- })
+ return result, nil
+ })))
s.AddTool(mcp.NewTool("k8s_create_resource_from_url",
mcp.WithDescription("Create a Kubernetes resource from a URL pointing to a YAML manifest"),
mcp.WithString("url", mcp.Description("The URL of the manifest"), mcp.Required()),
mcp.WithString("namespace", mcp.Description("The namespace to create the resource in")),
- ), k8sTool.handleCreateResourceFromURL)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_create_resource_from_url", k8sTool.handleCreateResourceFromURL)))
s.AddTool(mcp.NewTool("k8s_get_resource_yaml",
mcp.WithDescription("Get the YAML representation of a Kubernetes resource"),
mcp.WithString("resource_type", mcp.Description("Type of resource"), mcp.Required()),
mcp.WithString("resource_name", mcp.Description("Name of the resource"), mcp.Required()),
mcp.WithString("namespace", mcp.Description("Namespace of the resource (optional)")),
- ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_get_resource_yaml", func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
resourceType := mcp.ParseString(request, "resource_type", "")
resourceName := mcp.ParseString(request, "resource_name", "")
namespace := mcp.ParseString(request, "namespace", "")
@@ -641,24 +711,24 @@ func RegisterK8sTools(s *server.MCPServer, kubeconfig string) {
args = append(args, "-n", namespace)
}
- result, err := utils.RunCommandWithContext(ctx, "kubectl", args)
+ result, err := k8sTool.runKubectlCommand(ctx, args...)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Get YAML command failed: %v", err)), nil
}
- return mcp.NewToolResultText(result), nil
- })
+ return result, nil
+ })))
s.AddTool(mcp.NewTool("k8s_describe_resource",
mcp.WithDescription("Describe a Kubernetes resource in detail"),
mcp.WithString("resource_type", mcp.Description("Type of resource (deployment, service, pod, node, etc.)"), mcp.Required()),
mcp.WithString("resource_name", mcp.Description("Name of the resource"), mcp.Required()),
mcp.WithString("namespace", mcp.Description("Namespace of the resource (optional)")),
- ), k8sTool.handleKubectlDescribeTool)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_describe_resource", k8sTool.handleKubectlDescribeTool)))
s.AddTool(mcp.NewTool("k8s_generate_resource",
mcp.WithDescription("Generate a Kubernetes resource YAML from a description"),
mcp.WithString("resource_description", mcp.Description("Detailed description of the resource to generate"), mcp.Required()),
mcp.WithString("resource_type", mcp.Description(fmt.Sprintf("Type of resource to generate (%s)", strings.Join(slices.Collect(resourceTypes), ", "))), mcp.Required()),
- ), k8sTool.handleGenerateResource)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_generate_resource", k8sTool.handleGenerateResource)))
}
diff --git a/pkg/k8s/k8s_test.go b/pkg/k8s/k8s_test.go
index 240e7ac..a71e10f 100644
--- a/pkg/k8s/k8s_test.go
+++ b/pkg/k8s/k8s_test.go
@@ -2,11 +2,9 @@ package k8s
import (
"context"
- "fmt"
- "os"
"testing"
- "github.com/kagent-dev/tools/pkg/utils"
+ "github.com/kagent-dev/tools/internal/cmd"
"github.com/mark3labs/mcp-go/mcp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -38,10 +36,10 @@ func TestHandleGetAvailableAPIResources(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `[{"name": "pods", "singularName": "pod", "namespaced": true, "kind": "Pod"}]`
mock.AddCommandString("kubectl", []string{"api-resources", "-o", "json"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(ctx, mock)
+ ctx := cmd.WithShellExecutor(ctx, mock)
k8sTool := newTestK8sTool()
@@ -57,9 +55,9 @@ func TestHandleGetAvailableAPIResources(t *testing.T) {
})
t.Run("kubectl command failure", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
mock.AddCommandString("kubectl", []string{"api-resources", "-o", "json"}, "", assert.AnError)
- ctx := utils.WithShellExecutor(ctx, mock)
+ ctx := cmd.WithShellExecutor(ctx, mock)
k8sTool := newTestK8sTool()
@@ -75,10 +73,10 @@ func TestHandleScaleDeployment(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `deployment.apps/test-deployment scaled`
mock.AddCommandString("kubectl", []string{"scale", "deployment", "test-deployment", "--replicas", "5", "-n", "default"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(ctx, mock)
+ ctx := cmd.WithShellExecutor(ctx, mock)
k8sTool := newTestK8sTool()
@@ -99,8 +97,8 @@ func TestHandleScaleDeployment(t *testing.T) {
})
t.Run("missing name parameter", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock := cmd.NewMockShellExecutor()
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
k8sTool := newTestK8sTool()
@@ -122,18 +120,16 @@ func TestHandleScaleDeployment(t *testing.T) {
})
t.Run("missing replicas parameter uses default", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `deployment.apps/test-deployment scaled`
- // Default replicas is 1
mock.AddCommandString("kubectl", []string{"scale", "deployment", "test-deployment", "--replicas", "1", "-n", "default"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ ctx := cmd.WithShellExecutor(ctx, mock)
k8sTool := newTestK8sTool()
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"name": "test-deployment",
- // Missing replicas parameter - should use default value of 1
}
result, err := k8sTool.handleScaleDeployment(ctx, req)
@@ -156,10 +152,10 @@ func TestHandleGetEvents(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `{"items": [{"metadata": {"name": "test-event"}, "message": "Test event message"}]}`
mock.AddCommandString("kubectl", []string{"get", "events", "-o", "json", "--all-namespaces"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(ctx, mock)
+ ctx := cmd.WithShellExecutor(ctx, mock)
k8sTool := newTestK8sTool()
@@ -174,10 +170,10 @@ func TestHandleGetEvents(t *testing.T) {
})
t.Run("with namespace", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `{"items": []}`
mock.AddCommandString("kubectl", []string{"get", "events", "-o", "json", "-n", "custom-namespace"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(ctx, mock)
+ ctx := cmd.WithShellExecutor(ctx, mock)
k8sTool := newTestK8sTool()
@@ -197,8 +193,8 @@ func TestHandlePatchResource(t *testing.T) {
ctx := context.Background()
t.Run("missing parameters", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock := cmd.NewMockShellExecutor()
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
k8sTool := newTestK8sTool()
@@ -219,10 +215,10 @@ func TestHandlePatchResource(t *testing.T) {
})
t.Run("valid parameters", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `deployment.apps/test-deployment patched`
mock.AddCommandString("kubectl", []string{"patch", "deployment", "test-deployment", "-p", `{"spec":{"replicas":5}}`, "-n", "default"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(ctx, mock)
+ ctx := cmd.WithShellExecutor(ctx, mock)
k8sTool := newTestK8sTool()
@@ -247,8 +243,8 @@ func TestHandleDeleteResource(t *testing.T) {
ctx := context.Background()
t.Run("missing parameters", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock := cmd.NewMockShellExecutor()
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
k8sTool := newTestK8sTool()
@@ -269,17 +265,17 @@ func TestHandleDeleteResource(t *testing.T) {
})
t.Run("valid parameters", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedOutput := `pod "test-pod" deleted`
- mock.AddCommandString("kubectl", []string{"delete", "pod", "test-pod", "-n", "default"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(ctx, mock)
+ mock := cmd.NewMockShellExecutor()
+ expectedOutput := `deployment.apps/test-deployment deleted`
+ mock.AddCommandString("kubectl", []string{"delete", "deployment", "test-deployment", "-n", "default", "--timeout", "30s"}, expectedOutput, nil)
+ ctx := cmd.WithShellExecutor(ctx, mock)
k8sTool := newTestK8sTool()
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
- "resource_type": "pod",
- "resource_name": "test-pod",
+ "resource_type": "deployment",
+ "resource_name": "test-deployment",
}
result, err := k8sTool.handleDeleteResource(ctx, req)
@@ -296,8 +292,8 @@ func TestHandleCheckServiceConnectivity(t *testing.T) {
ctx := context.Background()
t.Run("missing service_name", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock := cmd.NewMockShellExecutor()
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
k8sTool := newTestK8sTool()
@@ -315,15 +311,15 @@ func TestHandleCheckServiceConnectivity(t *testing.T) {
})
t.Run("valid service_name", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
// Mock the pod creation, wait, and exec commands using partial matchers
mock.AddPartialMatcherString("kubectl", []string{"run", "*", "--image=curlimages/curl", "-n", "default", "--restart=Never", "--", "sleep", "3600"}, "pod/curl-test-123 created", nil)
- mock.AddPartialMatcherString("kubectl", []string{"wait", "--for=condition=ready", "*", "-n", "default", "--timeout=60s"}, "pod/curl-test-123 condition met", nil)
+ mock.AddPartialMatcherString("kubectl", []string{"wait", "--for=condition=ready", "*", "-n", "default", "--timeout=60s", "--timeout", "30s"}, "pod/curl-test-123 condition met", nil)
mock.AddPartialMatcherString("kubectl", []string{"exec", "*", "-n", "default", "--", "curl", "-s", "test-service.default.svc.cluster.local:80"}, "Connection successful", nil)
- mock.AddPartialMatcherString("kubectl", []string{"delete", "pod", "*", "-n", "default", "--ignore-not-found"}, "pod deleted", nil)
+ mock.AddPartialMatcherString("kubectl", []string{"delete", "pod", "*", "-n", "default", "--ignore-not-found", "--timeout", "30s"}, "pod deleted", nil)
- ctx := utils.WithShellExecutor(ctx, mock)
+ ctx := cmd.WithShellExecutor(ctx, mock)
k8sTool := newTestK8sTool()
@@ -343,8 +339,8 @@ func TestHandleKubectlDescribeTool(t *testing.T) {
ctx := context.Background()
t.Run("missing parameters", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock := cmd.NewMockShellExecutor()
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
k8sTool := newTestK8sTool()
@@ -365,12 +361,12 @@ func TestHandleKubectlDescribeTool(t *testing.T) {
})
t.Run("valid parameters", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `Name: test-deployment
Namespace: default
Labels: app=test`
mock.AddCommandString("kubectl", []string{"describe", "deployment", "test-deployment", "-n", "default"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(ctx, mock)
+ ctx := cmd.WithShellExecutor(ctx, mock)
k8sTool := newTestK8sTool()
@@ -395,8 +391,8 @@ func TestHandleKubectlGetEnhanced(t *testing.T) {
ctx := context.Background()
t.Run("missing resource_type", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock := cmd.NewMockShellExecutor()
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
k8sTool := newTestK8sTool()
req := mcp.CallToolRequest{}
@@ -411,10 +407,10 @@ func TestHandleKubectlGetEnhanced(t *testing.T) {
})
t.Run("valid resource_type", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `{"items": [{"metadata": {"name": "pod1"}}]}`
mock.AddCommandString("kubectl", []string{"get", "pods", "-o", "json"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(ctx, mock)
+ ctx := cmd.WithShellExecutor(ctx, mock)
k8sTool := newTestK8sTool()
req := mcp.CallToolRequest{}
@@ -430,8 +426,8 @@ func TestHandleKubectlLogsEnhanced(t *testing.T) {
ctx := context.Background()
t.Run("missing pod_name", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock := cmd.NewMockShellExecutor()
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
k8sTool := newTestK8sTool()
req := mcp.CallToolRequest{}
@@ -446,11 +442,11 @@ func TestHandleKubectlLogsEnhanced(t *testing.T) {
})
t.Run("valid pod_name", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `log line 1
log line 2`
mock.AddCommandString("kubectl", []string{"logs", "test-pod", "-n", "default", "--tail", "50"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(ctx, mock)
+ ctx := cmd.WithShellExecutor(ctx, mock)
k8sTool := newTestK8sTool()
req := mcp.CallToolRequest{}
@@ -463,8 +459,9 @@ log line 2`
}
func TestHandleApplyManifest(t *testing.T) {
+ ctx := context.Background()
t.Run("apply manifest from string", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
manifest := `apiVersion: v1
kind: Pod
metadata:
@@ -476,8 +473,8 @@ spec:
expectedOutput := `pod/test-pod created`
// Use partial matcher to handle dynamic temp file names
- mock.AddPartialMatcherString("kubectl", []string{"apply", "-f", "*"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock.AddPartialMatcherString("kubectl", []string{"apply", "-f"}, expectedOutput, nil)
+ ctx := cmd.WithShellExecutor(ctx, mock)
k8sTool := newTestK8sTool()
@@ -507,8 +504,8 @@ spec:
})
t.Run("missing manifest parameter", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock := cmd.NewMockShellExecutor()
+ ctx := cmd.WithShellExecutor(ctx, mock)
k8sTool := newTestK8sTool()
@@ -530,15 +527,16 @@ spec:
}
func TestHandleExecCommand(t *testing.T) {
+ ctx := context.Background()
t.Run("exec command in pod", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `total 8
drwxr-xr-x 1 root root 4096 Jan 1 12:00 .
drwxr-xr-x 1 root root 4096 Jan 1 12:00 ..`
// The implementation passes the command as a single string after --
mock.AddCommandString("kubectl", []string{"exec", "mypod", "-n", "default", "--", "ls -la"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ ctx := cmd.WithShellExecutor(ctx, mock)
k8sTool := newTestK8sTool()
@@ -566,8 +564,8 @@ drwxr-xr-x 1 root root 4096 Jan 1 12:00 ..`
})
t.Run("missing required parameters", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock := cmd.NewMockShellExecutor()
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
k8sTool := newTestK8sTool()
@@ -590,12 +588,13 @@ drwxr-xr-x 1 root root 4096 Jan 1 12:00 ..`
}
func TestHandleRollout(t *testing.T) {
+ ctx := context.Background()
t.Run("rollout restart deployment", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `deployment.apps/myapp restarted`
mock.AddCommandString("kubectl", []string{"rollout", "restart", "deployment/myapp", "-n", "default"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ ctx := cmd.WithShellExecutor(ctx, mock)
k8sTool := newTestK8sTool()
@@ -624,8 +623,8 @@ func TestHandleRollout(t *testing.T) {
})
t.Run("missing required parameters", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock := cmd.NewMockShellExecutor()
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
k8sTool := newTestK8sTool()
@@ -773,10 +772,10 @@ func TestHandleAnnotateResource(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `deployment.apps/test-deployment annotated`
- mock.AddCommandString("kubectl", []string{"annotate", "deployment", "test-deployment", "key1=value1", "key2=value2", "-n", "default"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(ctx, mock)
+ mock.AddCommandString("kubectl", []string{"annotate", "deployment", "test-deployment", "key1=value1", "key2=value2", "-n", "default", "--timeout", "30s"}, expectedOutput, nil)
+ ctx := cmd.WithShellExecutor(ctx, mock)
k8sTool := newTestK8sTool()
@@ -798,8 +797,8 @@ func TestHandleAnnotateResource(t *testing.T) {
})
t.Run("missing parameters", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock := cmd.NewMockShellExecutor()
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
k8sTool := newTestK8sTool()
@@ -825,10 +824,10 @@ func TestHandleLabelResource(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `deployment.apps/test-deployment labeled`
- mock.AddCommandString("kubectl", []string{"label", "deployment", "test-deployment", "env=prod", "version=1.0", "-n", "default"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(ctx, mock)
+ mock.AddCommandString("kubectl", []string{"label", "deployment", "test-deployment", "env=prod", "version=1.0", "-n", "default", "--timeout", "30s"}, expectedOutput, nil)
+ ctx := cmd.WithShellExecutor(ctx, mock)
k8sTool := newTestK8sTool()
@@ -850,8 +849,8 @@ func TestHandleLabelResource(t *testing.T) {
})
t.Run("missing parameters", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock := cmd.NewMockShellExecutor()
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
k8sTool := newTestK8sTool()
@@ -877,10 +876,10 @@ func TestHandleRemoveAnnotation(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `deployment.apps/test-deployment annotated`
- mock.AddCommandString("kubectl", []string{"annotate", "deployment", "test-deployment", "key1-", "-n", "default"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(ctx, mock)
+ mock.AddCommandString("kubectl", []string{"annotate", "deployment", "test-deployment", "key1-", "-n", "default", "--timeout", "30s"}, expectedOutput, nil)
+ ctx := cmd.WithShellExecutor(ctx, mock)
k8sTool := newTestK8sTool()
@@ -902,8 +901,8 @@ func TestHandleRemoveAnnotation(t *testing.T) {
})
t.Run("missing parameters", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock := cmd.NewMockShellExecutor()
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
k8sTool := newTestK8sTool()
@@ -929,10 +928,10 @@ func TestHandleRemoveLabel(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `deployment.apps/test-deployment labeled`
- mock.AddCommandString("kubectl", []string{"label", "deployment", "test-deployment", "env-", "-n", "default"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(ctx, mock)
+ mock.AddCommandString("kubectl", []string{"label", "deployment", "test-deployment", "env-", "-n", "default", "--timeout", "30s"}, expectedOutput, nil)
+ ctx := cmd.WithShellExecutor(ctx, mock)
k8sTool := newTestK8sTool()
@@ -954,8 +953,8 @@ func TestHandleRemoveLabel(t *testing.T) {
})
t.Run("missing parameters", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock := cmd.NewMockShellExecutor()
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
k8sTool := newTestK8sTool()
@@ -981,10 +980,10 @@ func TestHandleCreateResourceFromURL(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `deployment.apps/test-deployment created`
- mock.AddCommandString("kubectl", []string{"create", "-f", "https://example.com/manifest.yaml", "-n", "default"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(ctx, mock)
+ mock.AddCommandString("kubectl", []string{"create", "-f", "https://example.com/manifest.yaml", "-n", "default", "--timeout", "30s"}, expectedOutput, nil)
+ ctx := cmd.WithShellExecutor(ctx, mock)
k8sTool := newTestK8sTool()
@@ -1004,8 +1003,8 @@ func TestHandleCreateResourceFromURL(t *testing.T) {
})
t.Run("missing url parameter", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- ctx := utils.WithShellExecutor(context.Background(), mock)
+ mock := cmd.NewMockShellExecutor()
+ ctx := cmd.WithShellExecutor(context.Background(), mock)
k8sTool := newTestK8sTool()
@@ -1030,7 +1029,7 @@ func TestHandleGetClusterConfiguration(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
+ mock := cmd.NewMockShellExecutor()
expectedOutput := `apiVersion: v1
clusters:
- cluster:
@@ -1046,8 +1045,8 @@ kind: Config
preferences: {}
users:
- name: default`
- mock.AddCommandString("kubectl", []string{"config", "view"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(ctx, mock)
+ mock.AddCommandString("kubectl", []string{"config", "view", "-o", "json"}, expectedOutput, nil)
+ ctx := cmd.WithShellExecutor(ctx, mock)
k8sTool := newTestK8sTool()
@@ -1062,258 +1061,3 @@ users:
assert.Contains(t, resultText, "clusters")
})
}
-
-// Test the k8s_create_resource handler (inline function in RegisterK8sTools)
-func TestHandleCreateResource(t *testing.T) {
- ctx := context.Background()
-
- t.Run("success", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- yamlContent := `apiVersion: v1
-kind: Pod
-metadata:
- name: test-pod
-spec:
- containers:
- - name: test
- image: nginx`
-
- expectedOutput := `pod/test-pod created`
- // Use partial matcher to handle dynamic temp file names
- mock.AddPartialMatcherString("kubectl", []string{"create", "-f", "*"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(ctx, mock)
-
- // We need to test the inline function from RegisterK8sTools
- // Let's create a test handler that mimics the inline function
- testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- yamlContent := mcp.ParseString(request, "yaml_content", "")
-
- if yamlContent == "" {
- return mcp.NewToolResultError("yaml_content is required"), nil
- }
-
- // Create temporary file
- tmpFile, err := os.CreateTemp("", "k8s-resource-*.yaml")
- if err != nil {
- return mcp.NewToolResultError(fmt.Sprintf("Failed to create temp file: %v", err)), nil
- }
- defer os.Remove(tmpFile.Name())
-
- if _, err := tmpFile.WriteString(yamlContent); err != nil {
- return mcp.NewToolResultError(fmt.Sprintf("Failed to write to temp file: %v", err)), nil
- }
- tmpFile.Close()
-
- result, err := utils.RunCommandWithContext(ctx, "kubectl", []string{"create", "-f", tmpFile.Name()})
- if err != nil {
- return mcp.NewToolResultError(fmt.Sprintf("Create command failed: %v", err)), nil
- }
-
- return mcp.NewToolResultText(result), nil
- }
-
- req := mcp.CallToolRequest{}
- req.Params.Arguments = map[string]interface{}{
- "yaml_content": yamlContent,
- }
-
- result, err := testHandler(ctx, req)
- assert.NoError(t, err)
- assert.NotNil(t, result)
- assert.False(t, result.IsError)
-
- // Verify the expected output
- content := getResultText(result)
- assert.Contains(t, content, "created")
-
- // Verify kubectl create was called
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "kubectl", callLog[0].Command)
- assert.Len(t, callLog[0].Args, 3) // create, -f,
- assert.Equal(t, "create", callLog[0].Args[0])
- assert.Equal(t, "-f", callLog[0].Args[1])
- // Third argument should be the temporary file path
- assert.Contains(t, callLog[0].Args[2], "k8s-resource-")
- })
-
- t.Run("missing yaml_content parameter", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- ctx := utils.WithShellExecutor(context.Background(), mock)
-
- // Test handler for missing parameter
- testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- yamlContent := mcp.ParseString(request, "yaml_content", "")
-
- if yamlContent == "" {
- return mcp.NewToolResultError("yaml_content is required"), nil
- }
- return mcp.NewToolResultText("should not reach here"), nil
- }
-
- req := mcp.CallToolRequest{}
- req.Params.Arguments = map[string]interface{}{
- // Missing yaml_content parameter
- }
-
- result, err := testHandler(ctx, req)
- assert.NoError(t, err)
- assert.NotNil(t, result)
- assert.True(t, result.IsError)
- assert.Contains(t, getResultText(result), "yaml_content is required")
-
- // Verify no commands were executed
- callLog := mock.GetCallLog()
- assert.Len(t, callLog, 0)
- })
-}
-
-// Test the k8s_get_resource_yaml handler (inline function in RegisterK8sTools)
-func TestHandleGetResourceYAML(t *testing.T) {
- ctx := context.Background()
-
- t.Run("success", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedOutput := `apiVersion: v1
-kind: Pod
-metadata:
- name: test-pod
- namespace: default
-spec:
- containers:
- - name: test
- image: nginx`
- mock.AddCommandString("kubectl", []string{"get", "pod", "test-pod", "-o", "yaml", "-n", "default"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(ctx, mock)
-
- // Test handler that mimics the inline function
- testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- resourceType := mcp.ParseString(request, "resource_type", "")
- resourceName := mcp.ParseString(request, "resource_name", "")
- namespace := mcp.ParseString(request, "namespace", "")
-
- if resourceType == "" || resourceName == "" {
- return mcp.NewToolResultError("resource_type and resource_name are required"), nil
- }
-
- args := []string{"get", resourceType, resourceName, "-o", "yaml"}
- if namespace != "" {
- args = append(args, "-n", namespace)
- }
-
- result, err := utils.RunCommandWithContext(ctx, "kubectl", args)
- if err != nil {
- return mcp.NewToolResultError(fmt.Sprintf("Get YAML command failed: %v", err)), nil
- }
-
- return mcp.NewToolResultText(result), nil
- }
-
- req := mcp.CallToolRequest{}
- req.Params.Arguments = map[string]interface{}{
- "resource_type": "pod",
- "resource_name": "test-pod",
- "namespace": "default",
- }
-
- result, err := testHandler(ctx, req)
- assert.NoError(t, err)
- assert.NotNil(t, result)
- assert.False(t, result.IsError)
-
- resultText := getResultText(result)
- assert.Contains(t, resultText, "test-pod")
- assert.Contains(t, resultText, "apiVersion")
-
- // Verify the correct kubectl command was called
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "kubectl", callLog[0].Command)
- assert.Equal(t, []string{"get", "pod", "test-pod", "-o", "yaml", "-n", "default"}, callLog[0].Args)
- })
-
- t.Run("missing parameters", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- ctx := utils.WithShellExecutor(context.Background(), mock)
-
- testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- resourceType := mcp.ParseString(request, "resource_type", "")
- resourceName := mcp.ParseString(request, "resource_name", "")
-
- if resourceType == "" || resourceName == "" {
- return mcp.NewToolResultError("resource_type and resource_name are required"), nil
- }
- return mcp.NewToolResultText("should not reach here"), nil
- }
-
- req := mcp.CallToolRequest{}
- req.Params.Arguments = map[string]interface{}{
- "resource_type": "pod",
- // Missing resource_name
- }
-
- result, err := testHandler(ctx, req)
- assert.NoError(t, err)
- assert.NotNil(t, result)
- assert.True(t, result.IsError)
- assert.Contains(t, getResultText(result), "resource_type and resource_name are required")
-
- // Verify no commands were executed
- callLog := mock.GetCallLog()
- assert.Len(t, callLog, 0)
- })
-
- t.Run("without namespace", func(t *testing.T) {
- mock := utils.NewMockShellExecutor()
- expectedOutput := `apiVersion: v1
-kind: ClusterRole
-metadata:
- name: test-cluster-role`
- mock.AddCommandString("kubectl", []string{"get", "clusterrole", "test-cluster-role", "-o", "yaml"}, expectedOutput, nil)
- ctx := utils.WithShellExecutor(ctx, mock)
-
- testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- resourceType := mcp.ParseString(request, "resource_type", "")
- resourceName := mcp.ParseString(request, "resource_name", "")
- namespace := mcp.ParseString(request, "namespace", "")
-
- if resourceType == "" || resourceName == "" {
- return mcp.NewToolResultError("resource_type and resource_name are required"), nil
- }
-
- args := []string{"get", resourceType, resourceName, "-o", "yaml"}
- if namespace != "" {
- args = append(args, "-n", namespace)
- }
-
- result, err := utils.RunCommandWithContext(ctx, "kubectl", args)
- if err != nil {
- return mcp.NewToolResultError(fmt.Sprintf("Get YAML command failed: %v", err)), nil
- }
-
- return mcp.NewToolResultText(result), nil
- }
-
- req := mcp.CallToolRequest{}
- req.Params.Arguments = map[string]interface{}{
- "resource_type": "clusterrole",
- "resource_name": "test-cluster-role",
- // No namespace for cluster-scoped resource
- }
-
- result, err := testHandler(ctx, req)
- assert.NoError(t, err)
- assert.NotNil(t, result)
- assert.False(t, result.IsError)
-
- resultText := getResultText(result)
- assert.Contains(t, resultText, "test-cluster-role")
- assert.Contains(t, resultText, "ClusterRole")
-
- // Verify the correct kubectl command was called (without namespace)
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "kubectl", callLog[0].Command)
- assert.Equal(t, []string{"get", "clusterrole", "test-cluster-role", "-o", "yaml"}, callLog[0].Args)
- })
-}
diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go
deleted file mode 100644
index 062997d..0000000
--- a/pkg/logger/logger.go
+++ /dev/null
@@ -1,58 +0,0 @@
-package logger
-
-import (
- "github.com/go-logr/logr"
- "github.com/go-logr/stdr"
-)
-
-var globalLogger logr.Logger
-
-// Init initializes the global logger with appropriate configuration
-func Init() {
- // Set log level from environment variable (not directly supported by stdr, but can be extended)
- // For now, just use stdr with default settings
- globalLogger = stdr.New(nil)
-}
-
-// Get returns the global logger instance
-func Get() logr.Logger {
- if globalLogger.GetSink() == nil {
- Init()
- }
- return globalLogger
-}
-
-// LogExecCommand logs information about an exec command being executed
-func LogExecCommand(command string, args []string, caller string) {
- logger := Get()
- logger.Info("executing command",
- "command", command,
- "args", args,
- "caller", caller,
- )
-}
-
-// LogExecCommandResult logs the result of an exec command
-func LogExecCommandResult(command string, args []string, output string, err error, duration float64, caller string) {
- logger := Get()
-
- if err != nil {
- logger.Error(err, "command execution failed",
- "command", command,
- "args", args,
- "duration_seconds", duration,
- "caller", caller,
- )
- } else {
- logger.Info("command execution successful",
- "command", command,
- "args", args,
- "output", output,
- "duration_seconds", duration,
- "caller", caller,
- )
- }
-}
-
-// Sync is a no-op for logr/stdr
-func Sync() {}
diff --git a/pkg/logger/logger_test.go b/pkg/logger/logger_test.go
deleted file mode 100644
index 24903d0..0000000
--- a/pkg/logger/logger_test.go
+++ /dev/null
@@ -1,83 +0,0 @@
-package logger
-
-import (
- "os"
- "testing"
-
- "github.com/go-logr/logr"
- "github.com/stretchr/testify/assert"
-)
-
-func TestInit(t *testing.T) {
- // Test initialization
- Init()
- assert.NotNil(t, globalLogger)
-}
-
-func TestGet(t *testing.T) {
- // Reset global logger
- globalLogger = logr.Logger{}
-
- // Test Get without Init
- logger := Get()
- assert.NotNil(t, logger)
- assert.NotNil(t, globalLogger)
-}
-
-func TestLogExecCommand(t *testing.T) {
- // Just test that it does not panic and logs
- assert.NotPanics(t, func() {
- LogExecCommand("test-command", []string{"arg1", "arg2"}, "test.go:123")
- })
-}
-
-func TestLogExecCommandResult(t *testing.T) {
- // Test successful command
- assert.NotPanics(t, func() {
- LogExecCommandResult("test-command", []string{"arg1"}, "success output", nil, 1.5, "test.go:123")
- })
- // Test failed command
- assert.NotPanics(t, func() {
- LogExecCommandResult("test-command", []string{"arg1"}, "error output", assert.AnError, 0.5, "test.go:123")
- })
-}
-
-func TestEnvironmentVariables(t *testing.T) {
- // Test log level from environment (no-op for stdr)
- os.Setenv("KAGENT_LOG_LEVEL", "debug")
- defer os.Unsetenv("KAGENT_LOG_LEVEL")
-
- // Reset global logger
- globalLogger = logr.Logger{}
-
- // Initialize with environment variable
- Init()
-
- // Just check logger is set
- assert.NotNil(t, globalLogger)
-}
-
-func TestDevelopmentMode(t *testing.T) {
- // Test development mode (no-op for stdr)
- os.Setenv("KAGENT_ENV", "development")
- defer os.Unsetenv("KAGENT_ENV")
-
- // Reset global logger
- globalLogger = logr.Logger{}
-
- // Initialize in development mode
- Init()
-
- // In development mode, the logger should be configured (no panic)
- assert.NotNil(t, globalLogger)
-}
-
-func TestSync(t *testing.T) {
- // Test Sync function
- Init()
-
- // Sync should not panic
- assert.NotPanics(t, func() {
- Sync()
- })
-}
diff --git a/pkg/prometheus/prometheus.go b/pkg/prometheus/prometheus.go
index dc4d6cb..1239305 100644
--- a/pkg/prometheus/prometheus.go
+++ b/pkg/prometheus/prometheus.go
@@ -9,6 +9,9 @@ import (
"net/url"
"time"
+ "github.com/kagent-dev/tools/internal/errors"
+ "github.com/kagent-dev/tools/internal/security"
+ "github.com/kagent-dev/tools/internal/telemetry"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)
@@ -33,6 +36,16 @@ func handlePrometheusQueryTool(ctx context.Context, request mcp.CallToolRequest)
return mcp.NewToolResultError("query parameter is required"), nil
}
+ // Validate prometheus URL
+ if err := security.ValidateURL(prometheusURL); err != nil {
+ return mcp.NewToolResultError(fmt.Sprintf("Invalid Prometheus URL: %v", err)), nil
+ }
+
+ // Validate PromQL query
+ if err := security.ValidatePromQLQuery(query); err != nil {
+ return mcp.NewToolResultError(fmt.Sprintf("Invalid PromQL query: %v", err)), nil
+ }
+
// Make request to Prometheus API
apiURL := fmt.Sprintf("%s/api/v1/query", prometheusURL)
params := url.Values{}
@@ -42,19 +55,40 @@ func handlePrometheusQueryTool(ctx context.Context, request mcp.CallToolRequest)
fullURL := fmt.Sprintf("%s?%s", apiURL, params.Encode())
client := getHTTPClient(ctx)
- resp, err := client.Get(fullURL)
+ req, err := http.NewRequestWithContext(ctx, "GET", fullURL, nil)
if err != nil {
- return mcp.NewToolResultError("failed to query Prometheus: " + err.Error()), nil
+ toolErr := errors.NewPrometheusError("create_request", err).
+ WithContext("prometheus_url", prometheusURL).
+ WithContext("query", query)
+ return toolErr.ToMCPResult(), nil
+ }
+
+ resp, err := client.Do(req)
+ if err != nil {
+ toolErr := errors.NewPrometheusError("query_execution", err).
+ WithContext("prometheus_url", prometheusURL).
+ WithContext("query", query).
+ WithContext("api_url", apiURL)
+ return toolErr.ToMCPResult(), nil
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
- return mcp.NewToolResultError("failed to read response: " + err.Error()), nil
+ toolErr := errors.NewPrometheusError("read_response", err).
+ WithContext("prometheus_url", prometheusURL).
+ WithContext("query", query).
+ WithContext("status_code", resp.StatusCode)
+ return toolErr.ToMCPResult(), nil
}
if resp.StatusCode != http.StatusOK {
- return mcp.NewToolResultError(fmt.Sprintf("Prometheus API error (%d): %s", resp.StatusCode, string(body))), nil
+ toolErr := errors.NewPrometheusError("api_error", fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))).
+ WithContext("prometheus_url", prometheusURL).
+ WithContext("query", query).
+ WithContext("status_code", resp.StatusCode).
+ WithContext("response_body", string(body))
+ return toolErr.ToMCPResult(), nil
}
// Parse the JSON response to pretty-print it
@@ -82,6 +116,33 @@ func handlePrometheusRangeQueryTool(ctx context.Context, request mcp.CallToolReq
return mcp.NewToolResultError("query parameter is required"), nil
}
+ // Validate prometheus URL
+ if err := security.ValidateURL(prometheusURL); err != nil {
+ return mcp.NewToolResultError(fmt.Sprintf("Invalid Prometheus URL: %v", err)), nil
+ }
+
+ // Validate PromQL query
+ if err := security.ValidatePromQLQuery(query); err != nil {
+ return mcp.NewToolResultError(fmt.Sprintf("Invalid PromQL query: %v", err)), nil
+ }
+
+ // Validate time parameters if provided
+ if start != "" {
+ if err := security.ValidateCommandInput(start); err != nil {
+ return mcp.NewToolResultError(fmt.Sprintf("Invalid start time: %v", err)), nil
+ }
+ }
+ if end != "" {
+ if err := security.ValidateCommandInput(end); err != nil {
+ return mcp.NewToolResultError(fmt.Sprintf("Invalid end time: %v", err)), nil
+ }
+ }
+ if step != "" {
+ if err := security.ValidateCommandInput(step); err != nil {
+ return mcp.NewToolResultError(fmt.Sprintf("Invalid step parameter: %v", err)), nil
+ }
+ }
+
// Use default time range if not specified
if start == "" {
start = fmt.Sprintf("%d", time.Now().Add(-1*time.Hour).Unix())
@@ -101,7 +162,12 @@ func handlePrometheusRangeQueryTool(ctx context.Context, request mcp.CallToolReq
fullURL := fmt.Sprintf("%s?%s", apiURL, params.Encode())
client := getHTTPClient(ctx)
- resp, err := client.Get(fullURL)
+ req, err := http.NewRequestWithContext(ctx, "GET", fullURL, nil)
+ if err != nil {
+ return mcp.NewToolResultError("failed to create request: " + err.Error()), nil
+ }
+
+ resp, err := client.Do(req)
if err != nil {
return mcp.NewToolResultError("failed to query Prometheus: " + err.Error()), nil
}
@@ -133,23 +199,48 @@ func handlePrometheusRangeQueryTool(ctx context.Context, request mcp.CallToolReq
func handlePrometheusLabelsQueryTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
prometheusURL := mcp.ParseString(request, "prometheus_url", "http://localhost:9090")
+ // Validate prometheus URL
+ if err := security.ValidateURL(prometheusURL); err != nil {
+ return mcp.NewToolResultError(fmt.Sprintf("Invalid Prometheus URL: %v", err)), nil
+ }
+
// Make request to Prometheus API for labels
apiURL := fmt.Sprintf("%s/api/v1/labels", prometheusURL)
client := getHTTPClient(ctx)
- resp, err := client.Get(apiURL)
+ req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil)
if err != nil {
- return mcp.NewToolResultError("failed to query Prometheus: " + err.Error()), nil
+ toolErr := errors.NewPrometheusError("create_request", err).
+ WithContext("prometheus_url", prometheusURL).
+ WithContext("api_url", apiURL)
+ return toolErr.ToMCPResult(), nil
+ }
+
+ resp, err := client.Do(req)
+ if err != nil {
+ toolErr := errors.NewPrometheusError("query_execution", err).
+ WithContext("prometheus_url", prometheusURL).
+ WithContext("api_url", apiURL)
+ return toolErr.ToMCPResult(), nil
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
- return mcp.NewToolResultError("failed to read response: " + err.Error()), nil
+ toolErr := errors.NewPrometheusError("read_response", err).
+ WithContext("prometheus_url", prometheusURL).
+ WithContext("api_url", apiURL).
+ WithContext("status_code", resp.StatusCode)
+ return toolErr.ToMCPResult(), nil
}
if resp.StatusCode != http.StatusOK {
- return mcp.NewToolResultError(fmt.Sprintf("Prometheus API error (%d): %s", resp.StatusCode, string(body))), nil
+ toolErr := errors.NewPrometheusError("api_error", fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))).
+ WithContext("prometheus_url", prometheusURL).
+ WithContext("api_url", apiURL).
+ WithContext("status_code", resp.StatusCode).
+ WithContext("response_body", string(body))
+ return toolErr.ToMCPResult(), nil
}
// Parse the JSON response to pretty-print it
@@ -169,11 +260,21 @@ func handlePrometheusLabelsQueryTool(ctx context.Context, request mcp.CallToolRe
func handlePrometheusTargetsQueryTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
prometheusURL := mcp.ParseString(request, "prometheus_url", "http://localhost:9090")
+ // Validate prometheus URL
+ if err := security.ValidateURL(prometheusURL); err != nil {
+ return mcp.NewToolResultError(fmt.Sprintf("Invalid Prometheus URL: %v", err)), nil
+ }
+
// Make request to Prometheus API for targets
apiURL := fmt.Sprintf("%s/api/v1/targets", prometheusURL)
client := getHTTPClient(ctx)
- resp, err := client.Get(apiURL)
+ req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil)
+ if err != nil {
+ return mcp.NewToolResultError("failed to create request: " + err.Error()), nil
+ }
+
+ resp, err := client.Do(req)
if err != nil {
return mcp.NewToolResultError("failed to query Prometheus: " + err.Error()), nil
}
@@ -202,12 +303,12 @@ func handlePrometheusTargetsQueryTool(ctx context.Context, request mcp.CallToolR
return mcp.NewToolResultText(string(prettyJSON)), nil
}
-func RegisterPrometheusTools(s *server.MCPServer, kubeconfig string) {
+func RegisterTools(s *server.MCPServer) {
s.AddTool(mcp.NewTool("prometheus_query_tool",
mcp.WithDescription("Execute a PromQL query against Prometheus"),
mcp.WithString("query", mcp.Description("PromQL query to execute"), mcp.Required()),
mcp.WithString("prometheus_url", mcp.Description("Prometheus server URL (default: http://localhost:9090)")),
- ), handlePrometheusQueryTool)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("prometheus_query_tool", handlePrometheusQueryTool)))
s.AddTool(mcp.NewTool("prometheus_query_range_tool",
mcp.WithDescription("Execute a PromQL range query against Prometheus"),
@@ -216,20 +317,20 @@ func RegisterPrometheusTools(s *server.MCPServer, kubeconfig string) {
mcp.WithString("end", mcp.Description("End time (Unix timestamp or relative time)")),
mcp.WithString("step", mcp.Description("Query resolution step (default: 15s)")),
mcp.WithString("prometheus_url", mcp.Description("Prometheus server URL (default: http://localhost:9090)")),
- ), handlePrometheusRangeQueryTool)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("prometheus_query_range_tool", handlePrometheusRangeQueryTool)))
s.AddTool(mcp.NewTool("prometheus_label_names_tool",
mcp.WithDescription("Get all available labels from Prometheus"),
mcp.WithString("prometheus_url", mcp.Description("Prometheus server URL (default: http://localhost:9090)")),
- ), handlePrometheusLabelsQueryTool)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("prometheus_label_names_tool", handlePrometheusLabelsQueryTool)))
s.AddTool(mcp.NewTool("prometheus_targets_tool",
mcp.WithDescription("Get all Prometheus targets and their status"),
mcp.WithString("prometheus_url", mcp.Description("Prometheus server URL (default: http://localhost:9090)")),
- ), handlePrometheusTargetsQueryTool)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("prometheus_targets_tool", handlePrometheusTargetsQueryTool)))
s.AddTool(mcp.NewTool("prometheus_promql_tool",
mcp.WithDescription("Generate a PromQL query"),
mcp.WithString("query_description", mcp.Description("A string describing the query to generate"), mcp.Required()),
- ), handlePromql)
+ ), telemetry.AdaptToolHandler(telemetry.WithTracing("prometheus_promql_tool", handlePromql)))
}
diff --git a/pkg/prometheus/prometheus_test.go b/pkg/prometheus/prometheus_test.go
index d51b52e..647d1f3 100644
--- a/pkg/prometheus/prometheus_test.go
+++ b/pkg/prometheus/prometheus_test.go
@@ -122,7 +122,7 @@ func TestHandlePrometheusQueryTool(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, result)
assert.True(t, result.IsError)
- assert.Contains(t, getResultText(result), "failed to query Prometheus")
+ assert.Contains(t, getResultText(result), "**Prometheus Error**")
})
t.Run("HTTP 500 error", func(t *testing.T) {
@@ -139,7 +139,7 @@ func TestHandlePrometheusQueryTool(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, result)
assert.True(t, result.IsError)
- assert.Contains(t, getResultText(result), "Prometheus API error (500)")
+ assert.Contains(t, getResultText(result), "**Prometheus Error**")
})
t.Run("malformed JSON response", func(t *testing.T) {
@@ -283,7 +283,7 @@ func TestHandlePrometheusLabelsQueryTool(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, result)
assert.True(t, result.IsError)
- assert.Contains(t, getResultText(result), "failed to query Prometheus")
+ assert.Contains(t, getResultText(result), "**Prometheus Error**")
})
t.Run("custom prometheus URL", func(t *testing.T) {
diff --git a/pkg/utils/common.go b/pkg/utils/common.go
index d8be795..ce8b73b 100644
--- a/pkg/utils/common.go
+++ b/pkg/utils/common.go
@@ -3,314 +3,49 @@ package utils
import (
"context"
"fmt"
- "os/exec"
- "runtime"
"strings"
+ "sync"
"time"
- "github.com/kagent-dev/tools/pkg/logger"
+ "github.com/kagent-dev/tools/internal/commands"
+ "github.com/kagent-dev/tools/internal/logger"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
- "go.opentelemetry.io/otel"
- "go.opentelemetry.io/otel/attribute"
- "go.opentelemetry.io/otel/codes"
- "go.opentelemetry.io/otel/metric"
)
-// ShellExecutor defines the interface for executing shell commands
-type ShellExecutor interface {
- Exec(ctx context.Context, command string, args ...string) (output []byte, err error)
+// KubeConfigManager manages kubeconfig path with thread safety
+type KubeConfigManager struct {
+ mu sync.RWMutex
+ kubeconfigPath string
}
-// DefaultShellExecutor implements ShellExecutor using os/exec
-type DefaultShellExecutor struct{}
+// globalKubeConfigManager is the singleton instance
+var globalKubeConfigManager = &KubeConfigManager{}
-// Exec executes a command using os/exec.CommandContext
-func (e *DefaultShellExecutor) Exec(ctx context.Context, command string, args ...string) ([]byte, error) {
- cmd := exec.CommandContext(ctx, command, args...)
- return cmd.CombinedOutput()
-}
-
-// MockShellExecutor implements ShellExecutor for testing
-type MockShellExecutor struct {
- // Commands maps command+args to expected output and error
- Commands map[string]MockCommandResult
- // CallLog keeps track of all executed commands for verification
- CallLog []MockCommandCall
- // PartialMatchers allows partial matching for dynamic arguments
- PartialMatchers []PartialMatcher
-}
-
-// PartialMatcher represents a partial command matcher for dynamic arguments
-type PartialMatcher struct {
- Command string
- Args []string // Use "*" for wildcard matching
- Result MockCommandResult
-}
-
-// MockCommandResult represents the expected result of a mocked command
-type MockCommandResult struct {
- Output []byte
- Error error
-}
-
-// MockCommandCall represents a logged command execution
-type MockCommandCall struct {
- Command string
- Args []string
-}
-
-// Exec executes a mocked command
-func (m *MockShellExecutor) Exec(ctx context.Context, command string, args ...string) ([]byte, error) {
- // Log the call
- m.CallLog = append(m.CallLog, MockCommandCall{
- Command: command,
- Args: args,
- })
-
- // Try exact match first
- key := m.commandKey(command, args...)
- if result, exists := m.Commands[key]; exists {
- return result.Output, result.Error
- }
-
- // Try partial matchers
- for _, matcher := range m.PartialMatchers {
- if m.matchesPartial(command, args, matcher) {
- return matcher.Result.Output, matcher.Result.Error
- }
- }
-
- // Default behavior for unmocked commands
- return []byte(""), fmt.Errorf("unmocked command: %s %v", command, args)
-}
-
-// matchesPartial checks if a command matches a partial matcher
-func (m *MockShellExecutor) matchesPartial(command string, args []string, matcher PartialMatcher) bool {
- if command != matcher.Command {
- return false
- }
-
- if len(args) != len(matcher.Args) {
- return false
- }
+// SetKubeconfig sets the global kubeconfig path in a thread-safe manner
+func SetKubeconfig(path string) {
+ globalKubeConfigManager.mu.Lock()
+ defer globalKubeConfigManager.mu.Unlock()
- for i, expectedArg := range matcher.Args {
- if expectedArg == "*" {
- continue // Wildcard match
- }
- if args[i] != expectedArg {
- return false
- }
- }
-
- return true
+ globalKubeConfigManager.kubeconfigPath = path
+ logger.Get().Info("Setting shared kubeconfig", "path", path)
}
-// AddCommand adds a command mock
-func (m *MockShellExecutor) AddCommand(command string, args []string, output []byte, err error) {
- if m.Commands == nil {
- m.Commands = make(map[string]MockCommandResult)
- }
- key := m.commandKey(command, args...)
- m.Commands[key] = MockCommandResult{
- Output: output,
- Error: err,
- }
-}
+// GetKubeconfig returns the global kubeconfig path in a thread-safe manner
+func GetKubeconfig() string {
+ globalKubeConfigManager.mu.RLock()
+ defer globalKubeConfigManager.mu.RUnlock()
-// AddCommandString is a convenience method for adding string output
-func (m *MockShellExecutor) AddCommandString(command string, args []string, output string, err error) {
- m.AddCommand(command, args, []byte(output), err)
+ return globalKubeConfigManager.kubeconfigPath
}
-// AddPartialMatcher adds a partial matcher for dynamic arguments
-func (m *MockShellExecutor) AddPartialMatcher(command string, args []string, output []byte, err error) {
- if m.PartialMatchers == nil {
- m.PartialMatchers = []PartialMatcher{}
+// AddKubeconfigArgs adds kubeconfig arguments to command args if configured
+func AddKubeconfigArgs(args []string) []string {
+ kubeconfigPath := GetKubeconfig()
+ if kubeconfigPath != "" {
+ return append([]string{"--kubeconfig", kubeconfigPath}, args...)
}
- m.PartialMatchers = append(m.PartialMatchers, PartialMatcher{
- Command: command,
- Args: args,
- Result: MockCommandResult{
- Output: output,
- Error: err,
- },
- })
-}
-
-// AddPartialMatcherString is a convenience method for adding string output with partial matching
-func (m *MockShellExecutor) AddPartialMatcherString(command string, args []string, output string, err error) {
- m.AddPartialMatcher(command, args, []byte(output), err)
-}
-
-// GetCallLog returns the log of all command calls
-func (m *MockShellExecutor) GetCallLog() []MockCommandCall {
- return m.CallLog
-}
-
-// Reset clears the mock state
-func (m *MockShellExecutor) Reset() {
- m.Commands = make(map[string]MockCommandResult)
- m.CallLog = []MockCommandCall{}
- m.PartialMatchers = []PartialMatcher{}
-}
-
-// commandKey creates a unique key for command+args combination
-func (m *MockShellExecutor) commandKey(command string, args ...string) string {
- return fmt.Sprintf("%s %s", command, strings.Join(args, " "))
-}
-
-// Context key for shell executor injection
-type contextKey string
-
-const shellExecutorKey contextKey = "shellExecutor"
-
-// WithShellExecutor returns a context with the given shell executor
-func WithShellExecutor(ctx context.Context, executor ShellExecutor) context.Context {
- return context.WithValue(ctx, shellExecutorKey, executor)
-}
-
-// GetShellExecutor retrieves the shell executor from context, or returns default
-func GetShellExecutor(ctx context.Context) ShellExecutor {
- if executor, ok := ctx.Value(shellExecutorKey).(ShellExecutor); ok {
- return executor
- }
- return &DefaultShellExecutor{}
-}
-
-// NewMockShellExecutor creates a new mock shell executor for testing
-func NewMockShellExecutor() *MockShellExecutor {
- return &MockShellExecutor{
- Commands: make(map[string]MockCommandResult),
- CallLog: []MockCommandCall{},
- PartialMatchers: []PartialMatcher{},
- }
-}
-
-var (
- tracer = otel.Tracer("kagent-tools")
- meter = otel.Meter("kagent-tools")
-
- // Metrics
- commandExecutionCounter metric.Int64Counter
- commandExecutionDuration metric.Float64Histogram
- commandExecutionErrors metric.Int64Counter
-)
-
-func init() {
- // Initialize metrics (these are safe to call even if OTEL is not configured)
- var err error
-
- commandExecutionCounter, err = meter.Int64Counter(
- "command_executions_total",
- metric.WithDescription("Total number of command executions"),
- )
- if err != nil {
- logger.Get().Error(err, "Failed to create command execution counter")
- }
-
- commandExecutionDuration, err = meter.Float64Histogram(
- "command_execution_duration_seconds",
- metric.WithDescription("Duration of command executions in seconds"),
- metric.WithUnit("s"),
- )
- if err != nil {
- logger.Get().Error(err, "Failed to create command execution duration histogram")
- }
-
- commandExecutionErrors, err = meter.Int64Counter(
- "command_execution_errors_total",
- metric.WithDescription("Total number of command execution errors"),
- )
- if err != nil {
- logger.Get().Error(err, "Failed to create command execution errors counter")
- }
-}
-
-// RunCommand executes a command and returns output or error with OTEL tracing
-func RunCommand(command string, args []string) (string, error) {
- return RunCommandWithContext(context.Background(), command, args)
-}
-
-// RunCommandWithContext executes a command with context and returns output or error with OTEL tracing
-func RunCommandWithContext(ctx context.Context, command string, args []string) (string, error) {
- // Get caller information for tracing
- _, file, line, _ := runtime.Caller(1)
- caller := fmt.Sprintf("%s:%d", file, line)
-
- // Start OpenTelemetry span
- spanName := fmt.Sprintf("exec.%s", command)
- ctx, span := tracer.Start(ctx, spanName)
- defer span.End()
-
- // Set span attributes
- span.SetAttributes(
- attribute.String("command", command),
- attribute.StringSlice("args", args),
- attribute.String("caller", caller),
- )
-
- // Record metrics
- startTime := time.Now()
-
- // Use the shell executor from context (or default)
- executor := GetShellExecutor(ctx)
- output, err := executor.Exec(ctx, command, args...)
-
- duration := time.Since(startTime)
-
- // Set additional span attributes with results
- span.SetAttributes(
- attribute.Float64("duration_seconds", duration.Seconds()),
- attribute.Int("output_size", len(output)),
- )
-
- // Record metrics
- attributes := []attribute.KeyValue{
- attribute.String("command", command),
- attribute.Bool("success", err == nil),
- }
-
- if commandExecutionCounter != nil {
- commandExecutionCounter.Add(ctx, 1, metric.WithAttributes(attributes...))
- }
-
- if commandExecutionDuration != nil {
- commandExecutionDuration.Record(ctx, duration.Seconds(), metric.WithAttributes(attributes...))
- }
-
- if err != nil {
- // Set span status and record error
- span.RecordError(err)
- span.SetStatus(codes.Error, err.Error())
- span.SetAttributes(attribute.String("error", err.Error()))
-
- if commandExecutionErrors != nil {
- commandExecutionErrors.Add(ctx, 1, metric.WithAttributes(attributes...))
- }
-
- logger.Get().Error(err, "CommandExec failed",
- "command", command,
- "args", args,
- "duration", duration,
- "caller", caller,
- )
- return "", fmt.Errorf("command %s failed: %v", command, err)
- }
-
- // Set successful span status
- span.SetStatus(codes.Ok, "CommandExec")
-
- logger.Get().Info("CommandExec",
- "command", command,
- "args", args,
- "duration", duration,
- "outputSize", len(output),
- "caller", caller,
- )
-
- return strings.TrimSpace(string(output)), nil
+ return args
}
// shellTool provides shell command execution functionality
@@ -328,10 +63,21 @@ func shellTool(ctx context.Context, params shellParams) (string, error) {
cmd := parts[0]
args := parts[1:]
- return RunCommandWithContext(ctx, cmd, args)
+ return commands.NewCommandBuilder(cmd).WithArgs(args...).Execute(ctx)
+}
+
+// handleGetCurrentDateTimeTool provides datetime functionality for both MCP and testing
+func handleGetCurrentDateTimeTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ // Returns the current date and time in ISO 8601 format (RFC3339)
+ // This matches the Python implementation: datetime.datetime.now().isoformat()
+ now := time.Now()
+ return mcp.NewToolResultText(now.Format(time.RFC3339)), nil
}
-func RegisterCommonTools(s *server.MCPServer) {
+func RegisterTools(s *server.MCPServer) {
+ logger.Get().Info("RegisterTools initialized")
+
+ // Register shell tool
s.AddTool(mcp.NewTool("shell",
mcp.WithDescription("Execute shell commands"),
mcp.WithString("command", mcp.Description("The shell command to execute"), mcp.Required()),
@@ -350,5 +96,10 @@ func RegisterCommonTools(s *server.MCPServer) {
return mcp.NewToolResultText(result), nil
})
+ // Register datetime tool
+ s.AddTool(mcp.NewTool("datetime_get_current_time",
+ mcp.WithDescription("Returns the current date and time in ISO 8601 format."),
+ ), handleGetCurrentDateTimeTool)
+
// Note: LLM Tool implementation would go here if needed
}
diff --git a/pkg/utils/common_test.go b/pkg/utils/common_test.go
deleted file mode 100644
index e21cf76..0000000
--- a/pkg/utils/common_test.go
+++ /dev/null
@@ -1,288 +0,0 @@
-package utils
-
-import (
- "context"
- "errors"
- "testing"
-
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-func TestDefaultShellExecutor(t *testing.T) {
- executor := &DefaultShellExecutor{}
-
- // Test successful command
- output, err := executor.Exec(context.Background(), "echo", "hello")
- assert.NoError(t, err)
- assert.Equal(t, "hello\n", string(output))
-
- // Test command with error
- output, err = executor.Exec(context.Background(), "nonexistent-command")
- assert.Error(t, err)
- assert.Empty(t, output)
-}
-
-func TestMockShellExecutor(t *testing.T) {
- mock := NewMockShellExecutor()
-
- t.Run("unmocked command returns error", func(t *testing.T) {
- output, err := mock.Exec(context.Background(), "unmocked", "command")
- assert.Error(t, err)
- assert.Contains(t, err.Error(), "unmocked command")
- assert.Empty(t, output)
- })
-
- t.Run("mocked command returns expected result", func(t *testing.T) {
- expectedOutput := "mocked output"
- mock.AddCommandString("kubectl", []string{"get", "pods"}, expectedOutput, nil)
-
- output, err := mock.Exec(context.Background(), "kubectl", "get", "pods")
- assert.NoError(t, err)
- assert.Equal(t, expectedOutput, string(output))
- })
-
- t.Run("mocked command with error", func(t *testing.T) {
- expectedError := errors.New("mocked error")
- mock.AddCommandString("helm", []string{"install", "app"}, "", expectedError)
-
- output, err := mock.Exec(context.Background(), "helm", "install", "app")
- assert.Error(t, err)
- assert.Equal(t, expectedError, err)
- assert.Empty(t, output)
- })
-
- t.Run("call log tracking", func(t *testing.T) {
- mock.Reset()
-
- // Execute some commands
- mock.AddCommandString("cmd1", []string{"arg1"}, "output1", nil)
- mock.AddCommandString("cmd2", []string{"arg2", "arg3"}, "output2", nil)
-
- _, _ = mock.Exec(context.Background(), "cmd1", "arg1")
- _, _ = mock.Exec(context.Background(), "cmd2", "arg2", "arg3")
- _, _ = mock.Exec(context.Background(), "unmocked", "command")
-
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 3)
-
- assert.Equal(t, "cmd1", callLog[0].Command)
- assert.Equal(t, []string{"arg1"}, callLog[0].Args)
-
- assert.Equal(t, "cmd2", callLog[1].Command)
- assert.Equal(t, []string{"arg2", "arg3"}, callLog[1].Args)
-
- assert.Equal(t, "unmocked", callLog[2].Command)
- assert.Equal(t, []string{"command"}, callLog[2].Args)
- })
-
- t.Run("reset functionality", func(t *testing.T) {
- // Create a fresh mock for this test
- freshMock := NewMockShellExecutor()
- freshMock.AddCommandString("test", []string{}, "output", nil)
- _, _ = freshMock.Exec(context.Background(), "test")
-
- assert.Len(t, freshMock.Commands, 1)
- assert.Len(t, freshMock.CallLog, 1)
-
- freshMock.Reset()
-
- assert.Len(t, freshMock.Commands, 0)
- assert.Len(t, freshMock.CallLog, 0)
- })
-}
-
-func TestContextShellExecutor(t *testing.T) {
- t.Run("default executor when no context value", func(t *testing.T) {
- ctx := context.Background()
- executor := GetShellExecutor(ctx)
-
- _, ok := executor.(*DefaultShellExecutor)
- assert.True(t, ok, "should return DefaultShellExecutor when no context value")
- })
-
- t.Run("mock executor from context", func(t *testing.T) {
- mock := NewMockShellExecutor()
- ctx := WithShellExecutor(context.Background(), mock)
-
- executor := GetShellExecutor(ctx)
- assert.Equal(t, mock, executor, "should return the mock executor from context")
- })
-
- t.Run("context propagation", func(t *testing.T) {
- mock := NewMockShellExecutor()
- mock.AddCommandString("test", []string{"arg"}, "test output", nil)
-
- ctx := WithShellExecutor(context.Background(), mock)
-
- // Test that RunCommandWithContext uses the mock
- output, err := RunCommandWithContext(ctx, "test", []string{"arg"})
- assert.NoError(t, err)
- assert.Equal(t, "test output", output)
-
- // Verify the command was logged
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "test", callLog[0].Command)
- assert.Equal(t, []string{"arg"}, callLog[0].Args)
- })
-}
-
-func TestRunCommandWithMocking(t *testing.T) {
- t.Run("successful command execution with mock", func(t *testing.T) {
- mock := NewMockShellExecutor()
- mock.AddCommandString("kubectl", []string{"get", "pods", "-n", "default"}, "pod1\npod2", nil)
-
- ctx := WithShellExecutor(context.Background(), mock)
-
- output, err := RunCommandWithContext(ctx, "kubectl", []string{"get", "pods", "-n", "default"})
- assert.NoError(t, err)
- assert.Equal(t, "pod1\npod2", output)
-
- // Verify command was called
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "kubectl", callLog[0].Command)
- assert.Equal(t, []string{"get", "pods", "-n", "default"}, callLog[0].Args)
- })
-
- t.Run("command failure with mock", func(t *testing.T) {
- mock := NewMockShellExecutor()
- expectedError := errors.New("command failed")
- mock.AddCommandString("helm", []string{"install", "app"}, "", expectedError)
-
- ctx := WithShellExecutor(context.Background(), mock)
-
- output, err := RunCommandWithContext(ctx, "helm", []string{"install", "app"})
- assert.Error(t, err)
- assert.Contains(t, err.Error(), "command helm failed")
- assert.Empty(t, output)
- })
-
- t.Run("multiple commands with mock", func(t *testing.T) {
- mock := NewMockShellExecutor()
- mock.AddCommandString("kubectl", []string{"get", "pods"}, "pod-list", nil)
- mock.AddCommandString("kubectl", []string{"get", "services"}, "service-list", nil)
- mock.AddCommandString("helm", []string{"list"}, "helm-releases", nil)
-
- ctx := WithShellExecutor(context.Background(), mock)
-
- // Execute multiple commands
- output1, err1 := RunCommandWithContext(ctx, "kubectl", []string{"get", "pods"})
- assert.NoError(t, err1)
- assert.Equal(t, "pod-list", output1)
-
- output2, err2 := RunCommandWithContext(ctx, "kubectl", []string{"get", "services"})
- assert.NoError(t, err2)
- assert.Equal(t, "service-list", output2)
-
- output3, err3 := RunCommandWithContext(ctx, "helm", []string{"list"})
- assert.NoError(t, err3)
- assert.Equal(t, "helm-releases", output3)
-
- // Verify all commands were logged
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 3)
-
- assert.Equal(t, "kubectl", callLog[0].Command)
- assert.Equal(t, []string{"get", "pods"}, callLog[0].Args)
-
- assert.Equal(t, "kubectl", callLog[1].Command)
- assert.Equal(t, []string{"get", "services"}, callLog[1].Args)
-
- assert.Equal(t, "helm", callLog[2].Command)
- assert.Equal(t, []string{"list"}, callLog[2].Args)
- })
-}
-
-func TestShellToolWithMocking(t *testing.T) {
- t.Run("shell tool uses mock executor", func(t *testing.T) {
- mock := NewMockShellExecutor()
- mock.AddCommandString("echo", []string{"hello", "world"}, "hello world", nil)
-
- ctx := WithShellExecutor(context.Background(), mock)
-
- params := shellParams{Command: "echo hello world"}
- output, err := shellTool(ctx, params)
- assert.NoError(t, err)
- assert.Equal(t, "hello world", output)
-
- // Verify command was called
- callLog := mock.GetCallLog()
- require.Len(t, callLog, 1)
- assert.Equal(t, "echo", callLog[0].Command)
- assert.Equal(t, []string{"hello", "world"}, callLog[0].Args)
- })
-
- t.Run("shell tool with empty command", func(t *testing.T) {
- mock := NewMockShellExecutor()
- ctx := WithShellExecutor(context.Background(), mock)
-
- params := shellParams{Command: ""}
- output, err := shellTool(ctx, params)
- assert.Error(t, err)
- assert.Contains(t, err.Error(), "empty command")
- assert.Empty(t, output)
-
- // No commands should be logged
- callLog := mock.GetCallLog()
- assert.Len(t, callLog, 0)
- })
-}
-
-func TestMockShellExecutorCommandKey(t *testing.T) {
- mock := NewMockShellExecutor()
-
- // Test that different argument combinations create different keys
- mock.AddCommandString("kubectl", []string{"get", "pods"}, "pods", nil)
- mock.AddCommandString("kubectl", []string{"get", "services"}, "services", nil)
- mock.AddCommandString("kubectl", []string{}, "kubectl-help", nil)
-
- // Test first command
- output, err := mock.Exec(context.Background(), "kubectl", "get", "pods")
- assert.NoError(t, err)
- assert.Equal(t, "pods", string(output))
-
- // Test second command
- output, err = mock.Exec(context.Background(), "kubectl", "get", "services")
- assert.NoError(t, err)
- assert.Equal(t, "services", string(output))
-
- // Test third command (no args)
- output, err = mock.Exec(context.Background(), "kubectl")
- assert.NoError(t, err)
- assert.Equal(t, "kubectl-help", string(output))
-}
-
-// Benchmark tests to ensure mocking doesn't add significant overhead
-func BenchmarkDefaultShellExecutor(b *testing.B) {
- executor := &DefaultShellExecutor{}
- ctx := context.Background()
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _, _ = executor.Exec(ctx, "echo", "test")
- }
-}
-
-func BenchmarkMockShellExecutor(b *testing.B) {
- mock := NewMockShellExecutor()
- mock.AddCommandString("echo", []string{"test"}, "test", nil)
- ctx := context.Background()
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _, _ = mock.Exec(ctx, "echo", "test")
- }
-}
-
-func BenchmarkRunCommandWithContext(b *testing.B) {
- mock := NewMockShellExecutor()
- mock.AddCommandString("echo", []string{"test"}, "test", nil)
- ctx := WithShellExecutor(context.Background(), mock)
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _, _ = RunCommandWithContext(ctx, "echo", []string{"test"})
- }
-}
diff --git a/pkg/utils/datetime.go b/pkg/utils/datetime.go
index 165ea68..3bca950 100644
--- a/pkg/utils/datetime.go
+++ b/pkg/utils/datetime.go
@@ -1,31 +1,5 @@
package utils
-import (
- "context"
- "github.com/kagent-dev/tools/pkg/logger"
- "time"
-
- "github.com/mark3labs/mcp-go/mcp"
- "github.com/mark3labs/mcp-go/server"
-)
-
-var kubeConfig = ""
-
-// DateTime tools using direct Go time package
-// This implementation matches the Python version exactly
-func handleGetCurrentDateTimeTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- // Returns the current date and time in ISO 8601 format (RFC3339)
- // This matches the Python implementation: datetime.datetime.now().isoformat()
- now := time.Now()
- return mcp.NewToolResultText(now.Format(time.RFC3339)), nil
-}
-
-func RegisterDateTimeTools(s *server.MCPServer, kubeconfig string) {
- kubeConfig = kubeconfig
- logger.Get().Info("kubeConfig", kubeConfig)
-
- // Register the GetCurrentDateTime tool to match Python implementation exactly
- s.AddTool(mcp.NewTool("datetime_get_current_time",
- mcp.WithDescription("Returns the current date and time in ISO 8601 format."),
- ), handleGetCurrentDateTimeTool)
-}
+// DateTime tools implementation moved to RegisterTools function in common.go
+// This file remains for backwards compatibility but the tools are now registered
+// through the unified RegisterTools function.