diff --git a/config.go b/config.go index 70ede9a..fb98150 100644 --- a/config.go +++ b/config.go @@ -18,14 +18,16 @@ type Config struct { type SecurityConfig struct { Enabled bool `yaml:"enabled"` - AllowedCommands []string `yaml:"allowed_commands"` - BlockedCommands []string `yaml:"blocked_commands"` - BlockedPatterns []string `yaml:"blocked_patterns"` + AllowedCommands []string `yaml:"allowed_commands"` // Deprecated: use AllowedExecutables + BlockedCommands []string `yaml:"blocked_commands"` // Deprecated: use validation instead + BlockedPatterns []string `yaml:"blocked_patterns"` // Deprecated: use validation instead + AllowedExecutables []string `yaml:"allowed_executables"` // Secure: list of allowed executable paths MaxExecutionTime time.Duration `yaml:"max_execution_time"` WorkingDirectory string `yaml:"working_directory"` RunAsUser string `yaml:"run_as_user"` MaxOutputSize int `yaml:"max_output_size"` AuditLog bool `yaml:"audit_log"` + UseShellExecution bool `yaml:"use_shell_execution"` // Legacy mode - enables shell execution (DANGEROUS) } type ServerConfig struct { @@ -83,11 +85,13 @@ func loadSecurityFromFile(config *Config, filename string) error { AllowedCommands []string `yaml:"allowed_commands"` BlockedCommands []string `yaml:"blocked_commands"` BlockedPatterns []string `yaml:"blocked_patterns"` + AllowedExecutables []string `yaml:"allowed_executables"` MaxExecutionTime string `yaml:"max_execution_time"` WorkingDirectory string `yaml:"working_directory"` RunAsUser string `yaml:"run_as_user"` MaxOutputSize int `yaml:"max_output_size"` AuditLog bool `yaml:"audit_log"` + UseShellExecution bool `yaml:"use_shell_execution"` } `yaml:"security"` } @@ -99,10 +103,12 @@ func loadSecurityFromFile(config *Config, filename string) error { config.Security.AllowedCommands = yamlConfig.Security.AllowedCommands config.Security.BlockedCommands = yamlConfig.Security.BlockedCommands config.Security.BlockedPatterns = yamlConfig.Security.BlockedPatterns + config.Security.AllowedExecutables = yamlConfig.Security.AllowedExecutables config.Security.WorkingDirectory = yamlConfig.Security.WorkingDirectory config.Security.RunAsUser = yamlConfig.Security.RunAsUser config.Security.MaxOutputSize = yamlConfig.Security.MaxOutputSize config.Security.AuditLog = yamlConfig.Security.AuditLog + config.Security.UseShellExecution = yamlConfig.Security.UseShellExecution if yamlConfig.Security.MaxExecutionTime != "" { duration, err := time.ParseDuration(yamlConfig.Security.MaxExecutionTime) diff --git a/config_test.go b/config_test.go new file mode 100644 index 0000000..81de865 --- /dev/null +++ b/config_test.go @@ -0,0 +1,356 @@ +package main + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLoadConfig_defaults(t *testing.T) { + // Clear environment variables + os.Unsetenv("MCP_SHELL_SEC_CONFIG_FILE") + os.Unsetenv("MCP_SHELL_SERVER_NAME") + os.Unsetenv("MCP_SHELL_LOG_LEVEL") + + config, err := loadConfig() + require.NoError(t, err) + + // Check defaults + assert.False(t, config.Security.Enabled) + assert.Equal(t, "mcp-shell 🐚", config.Server.Name) + assert.Equal(t, "info", config.Logging.Level) + assert.Equal(t, "console", config.Logging.Format) + assert.Equal(t, "stderr", config.Logging.Output) +} + +func TestLoadConfig_environment_variables(t *testing.T) { + // Set environment variables + os.Setenv("MCP_SHELL_SERVER_NAME", "test-server") + os.Setenv("MCP_SHELL_LOG_LEVEL", "debug") + os.Setenv("MCP_SHELL_LOG_FORMAT", "json") + os.Setenv("MCP_SHELL_LOG_OUTPUT", "stdout") + + defer func() { + os.Unsetenv("MCP_SHELL_SERVER_NAME") + os.Unsetenv("MCP_SHELL_LOG_LEVEL") + os.Unsetenv("MCP_SHELL_LOG_FORMAT") + os.Unsetenv("MCP_SHELL_LOG_OUTPUT") + }() + + config, err := loadConfig() + require.NoError(t, err) + + assert.Equal(t, "test-server", config.Server.Name) + assert.Equal(t, "debug", config.Logging.Level) + assert.Equal(t, "json", config.Logging.Format) + assert.Equal(t, "stdout", config.Logging.Output) +} + +func TestLoadSecurityFromFile(t *testing.T) { + tests := []struct { + name string + yamlContent string + expectError bool + validateConfig func(t *testing.T, config *Config) + }{ + { + name: "secure configuration", + yamlContent: ` +security: + enabled: true + use_shell_execution: false + allowed_executables: + - "ls" + - "echo" + - "/usr/bin/git" + max_execution_time: "10s" + working_directory: "/tmp" + run_as_user: "nobody" + max_output_size: 2048 + audit_log: true +`, + expectError: false, + validateConfig: func(t *testing.T, config *Config) { + assert.True(t, config.Security.Enabled) + assert.False(t, config.Security.UseShellExecution) + assert.Equal(t, []string{"ls", "echo", "/usr/bin/git"}, config.Security.AllowedExecutables) + assert.Equal(t, 10*time.Second, config.Security.MaxExecutionTime) + assert.Equal(t, "/tmp", config.Security.WorkingDirectory) + assert.Equal(t, "nobody", config.Security.RunAsUser) + assert.Equal(t, 2048, config.Security.MaxOutputSize) + assert.True(t, config.Security.AuditLog) + }, + }, + { + name: "legacy configuration", + yamlContent: ` +security: + enabled: true + use_shell_execution: true + allowed_commands: + - "echo" + - "ls" + blocked_commands: + - "rm" + - "chmod" + blocked_patterns: + - "rm\\s+-rf" + max_execution_time: "30s" + audit_log: false +`, + expectError: false, + validateConfig: func(t *testing.T, config *Config) { + assert.True(t, config.Security.Enabled) + assert.True(t, config.Security.UseShellExecution) + assert.Equal(t, []string{"echo", "ls"}, config.Security.AllowedCommands) + assert.Equal(t, []string{"rm", "chmod"}, config.Security.BlockedCommands) + assert.Equal(t, []string{"rm\\s+-rf"}, config.Security.BlockedPatterns) + assert.Equal(t, 30*time.Second, config.Security.MaxExecutionTime) + assert.False(t, config.Security.AuditLog) + }, + }, + { + name: "invalid max_execution_time", + yamlContent: ` +security: + enabled: true + max_execution_time: "invalid_duration" +`, + expectError: true, + }, + { + name: "invalid yaml", + yamlContent: ` +security: + enabled: true + invalid_yaml: [unclosed +`, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create temporary file + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "security.yaml") + + err := os.WriteFile(configFile, []byte(tt.yamlContent), 0644) + require.NoError(t, err) + + // Set environment variable + os.Setenv("MCP_SHELL_SEC_CONFIG_FILE", configFile) + defer os.Unsetenv("MCP_SHELL_SEC_CONFIG_FILE") + + config, err := loadConfig() + + if tt.expectError { + assert.Error(t, err) + } else { + require.NoError(t, err) + if tt.validateConfig != nil { + tt.validateConfig(t, config) + } + } + }) + } +} + +func TestValidateConfig(t *testing.T) { + tests := []struct { + name string + config Config + expectError bool + errorMsg string + }{ + { + name: "valid config", + config: Config{ + Security: SecurityConfig{ + MaxOutputSize: 1024, + }, + Logging: LoggingConfig{ + Level: "info", + }, + }, + expectError: false, + }, + { + name: "negative max_output_size", + config: Config{ + Security: SecurityConfig{ + MaxOutputSize: -1, + }, + Logging: LoggingConfig{ + Level: "info", + }, + }, + expectError: true, + errorMsg: "max_output_size cannot be negative", + }, + { + name: "invalid log level", + config: Config{ + Security: SecurityConfig{ + MaxOutputSize: 1024, + }, + Logging: LoggingConfig{ + Level: "invalid", + }, + }, + expectError: true, + errorMsg: "invalid log level", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateConfig(&tt.config) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestGetEnv_functions(t *testing.T) { + t.Run("getEnv", func(t *testing.T) { + // Test with existing environment variable + os.Setenv("TEST_VAR", "test_value") + defer os.Unsetenv("TEST_VAR") + + value := getEnv("TEST_VAR", "default") + assert.Equal(t, "test_value", value) + + // Test with non-existing environment variable + value = getEnv("NON_EXISTING_VAR", "default") + assert.Equal(t, "default", value) + }) + + t.Run("getBoolEnv", func(t *testing.T) { + // Test with true value + os.Setenv("TEST_BOOL", "true") + defer os.Unsetenv("TEST_BOOL") + + value := getBoolEnv("TEST_BOOL", false) + assert.True(t, value) + + // Test with false value + os.Setenv("TEST_BOOL", "false") + value = getBoolEnv("TEST_BOOL", true) + assert.False(t, value) + + // Test with invalid value (should return default) + os.Setenv("TEST_BOOL", "invalid") + value = getBoolEnv("TEST_BOOL", true) + assert.True(t, value) + + // Test with non-existing variable + value = getBoolEnv("NON_EXISTING_BOOL", false) + assert.False(t, value) + }) + + t.Run("getIntEnv", func(t *testing.T) { + // Test with valid integer + os.Setenv("TEST_INT", "42") + defer os.Unsetenv("TEST_INT") + + value := getIntEnv("TEST_INT", 0) + assert.Equal(t, 42, value) + + // Test with invalid integer (should return default) + os.Setenv("TEST_INT", "invalid") + value = getIntEnv("TEST_INT", 100) + assert.Equal(t, 100, value) + + // Test with non-existing variable + value = getIntEnv("NON_EXISTING_INT", 50) + assert.Equal(t, 50, value) + }) +} + +func TestConfig_security_model_examples(t *testing.T) { + t.Run("secure_example_config", func(t *testing.T) { + yamlContent := ` +security: + enabled: true + use_shell_execution: false + allowed_executables: + - "ls" + - "pwd" + - "echo" + - "cat" + - "/usr/bin/git" + max_execution_time: "30s" + working_directory: "/tmp" + audit_log: true +` + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "secure.yaml") + + err := os.WriteFile(configFile, []byte(yamlContent), 0644) + require.NoError(t, err) + + os.Setenv("MCP_SHELL_SEC_CONFIG_FILE", configFile) + defer os.Unsetenv("MCP_SHELL_SEC_CONFIG_FILE") + + config, err := loadConfig() + require.NoError(t, err) + + // Verify secure configuration + assert.True(t, config.Security.Enabled) + assert.False(t, config.Security.UseShellExecution) + assert.Contains(t, config.Security.AllowedExecutables, "ls") + assert.Contains(t, config.Security.AllowedExecutables, "/usr/bin/git") + assert.Equal(t, 30*time.Second, config.Security.MaxExecutionTime) + assert.Equal(t, "/tmp", config.Security.WorkingDirectory) + assert.True(t, config.Security.AuditLog) + }) + + t.Run("legacy_example_config", func(t *testing.T) { + yamlContent := ` +security: + enabled: true + use_shell_execution: true + allowed_commands: + - "ls" + - "echo" + blocked_commands: + - "rm" + - "chmod" + - "sudo" + blocked_patterns: + - "rm\\s+-rf" + - "sudo\\s+" + max_execution_time: "30s" + audit_log: true +` + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "legacy.yaml") + + err := os.WriteFile(configFile, []byte(yamlContent), 0644) + require.NoError(t, err) + + os.Setenv("MCP_SHELL_SEC_CONFIG_FILE", configFile) + defer os.Unsetenv("MCP_SHELL_SEC_CONFIG_FILE") + + config, err := loadConfig() + require.NoError(t, err) + + // Verify legacy configuration + assert.True(t, config.Security.Enabled) + assert.True(t, config.Security.UseShellExecution) + assert.Contains(t, config.Security.AllowedCommands, "ls") + assert.Contains(t, config.Security.BlockedCommands, "rm") + assert.Contains(t, config.Security.BlockedPatterns, "rm\\s+-rf") + assert.True(t, config.Security.AuditLog) + }) +} diff --git a/executor.go b/executor.go index c960de3..06c1f30 100644 --- a/executor.go +++ b/executor.go @@ -97,12 +97,71 @@ func (e *CommandExecutor) execute( return result, nil } +// parseCommand securely parses a command string into executable and arguments +// without using shell interpretation. This prevents command injection through +// shell metacharacters and substitution. +func (e *CommandExecutor) parseCommand(command string) (string, []string, error) { + command = strings.TrimSpace(command) + if command == "" { + return "", nil, fmt.Errorf("empty command") + } + + // Simple whitespace-based splitting - no shell interpretation + parts := strings.Fields(command) + if len(parts) == 0 { + return "", nil, fmt.Errorf("no command found") + } + + executable := parts[0] + args := parts[1:] + + // Validate that the executable doesn't contain shell metacharacters + if containsShellMetacharacters(executable) { + return "", nil, fmt.Errorf("executable contains shell metacharacters: %s", executable) + } + + // Validate arguments don't contain dangerous shell constructs + // In secure mode, this should be an error, not just a warning + for _, arg := range args { + if containsDangerousShellConstructs(arg) { + return "", nil, fmt.Errorf("argument contains dangerous shell constructs: %s", arg) + } + } + + return executable, args, nil +} + func (e *CommandExecutor) executeSecureCommand( ctx context.Context, command string, useBase64 bool, ) (*ExecutionResult, error) { - cmd := exec.CommandContext(ctx, "bash", "-c", command) + var cmd *exec.Cmd + + // Use secure execution unless legacy shell mode is explicitly enabled + if e.config.UseShellExecution { + e.logger.Warn(). + Str("command", command). + Msg("Using legacy shell execution mode - vulnerable to injection attacks") + cmd = exec.CommandContext(ctx, "bash", "-c", command) + } else { + // Secure execution: parse command and execute directly + executable, args, err := e.parseCommand(command) + if err != nil { + e.logger.Error(). + Err(err). + Str("command", command). + Msg("Failed to parse command securely") + return nil, fmt.Errorf("command parsing failed: %w", err) + } + + e.logger.Debug(). + Str("executable", executable). + Strs("args", args). + Msg("Executing command with direct execution") + + cmd = exec.CommandContext(ctx, executable, args...) + } if e.config.WorkingDirectory != "" { if err := os.MkdirAll(e.config.WorkingDirectory, 0755); err == nil { diff --git a/executor_test.go b/executor_test.go new file mode 100644 index 0000000..e846852 --- /dev/null +++ b/executor_test.go @@ -0,0 +1,338 @@ +package main + +import ( + "context" + "testing" + "time" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCommandExecutor_parseCommand(t *testing.T) { + tests := []struct { + name string + command string + expectExec string + expectArgs []string + expectError bool + errorContains string + }{ + { + name: "simple command", + command: "ls -la", + expectExec: "ls", + expectArgs: []string{"-la"}, + expectError: false, + }, + { + name: "command with multiple args", + command: "grep -n test file.txt", + expectExec: "grep", + expectArgs: []string{"-n", "test", "file.txt"}, + expectError: false, + }, + { + name: "command with spaces around", + command: " echo hello world ", + expectExec: "echo", + expectArgs: []string{"hello", "world"}, + expectError: false, + }, + { + name: "single command no args", + command: "pwd", + expectExec: "pwd", + expectArgs: []string{}, + expectError: false, + }, + { + name: "empty command", + command: "", + expectError: true, + errorContains: "empty command", + }, + { + name: "whitespace only", + command: " ", + expectError: true, + errorContains: "empty command", + }, + { + name: "command with pipe (shell metacharacter)", + command: "ls | grep test", + expectError: true, + errorContains: "dangerous shell constructs", + }, + { + name: "command with semicolon", + command: "echo hello; rm file", + expectError: true, + errorContains: "dangerous shell constructs", + }, + { + name: "command with command substitution", + command: "echo $(whoami)", + expectError: true, + errorContains: "dangerous shell constructs", + }, + { + name: "command with backticks", + command: "echo `whoami`", + expectError: true, + errorContains: "dangerous shell constructs", + }, + { + name: "command with redirection", + command: "echo hello > file.txt", + expectError: true, + errorContains: "dangerous shell constructs", + }, + { + name: "command with background process", + command: "sleep 10 &", + expectError: true, + errorContains: "dangerous shell constructs", + }, + } + + logger := zerolog.New(zerolog.NewTestWriter(t)) + config := SecurityConfig{} + executor := newCommandExecutor(config, logger) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + exec, args, err := executor.parseCommand(tt.command) + + if tt.expectError { + require.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectExec, exec) + assert.Equal(t, tt.expectArgs, args) + } + }) + } +} + +func TestCommandExecutor_containsShellMetacharacters(t *testing.T) { + tests := []struct { + name string + input string + expected bool + }{ + {"normal command", "ls", false}, + {"path with slash", "/usr/bin/ls", false}, + {"command with dash", "ls-extended", false}, + {"command with underscore", "my_command", false}, + {"command with dot", "node.js", false}, + {"pipe character", "ls|grep", true}, + {"ampersand", "command&", true}, + {"semicolon", "cmd;", true}, + {"less than", "cmd<", true}, + {"greater than", "cmd>", true}, + {"parentheses", "cmd()", true}, + {"braces", "cmd{}", true}, + {"brackets", "cmd[]", true}, + {"dollar sign", "cmd$", true}, + {"backtick", "cmd`", true}, + {"backslash", "cmd\\", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := containsShellMetacharacters(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestCommandExecutor_containsDangerousShellConstructs(t *testing.T) { + tests := []struct { + name string + input string + expected bool + }{ + {"normal argument", "file.txt", false}, + {"normal flag", "-la", false}, + {"command substitution", "$(whoami)", true}, + {"backtick substitution", "`whoami`", true}, + {"variable expansion", "${HOME}", true}, + {"logical AND", "cmd && cmd2", true}, + {"logical OR", "cmd || cmd2", true}, + {"command separator", "cmd; cmd2", true}, + {"pipe", "cmd | cmd2", true}, + {"redirection out", "cmd > file", true}, + {"redirection in", "cmd < file", true}, + {"append redirection", "cmd >> file", true}, + {"here document", "cmd << EOF", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := containsDangerousShellConstructs(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestCommandExecutor_executeSecureCommand_secure_vs_legacy(t *testing.T) { + logger := zerolog.New(zerolog.NewTestWriter(t)) + ctx := context.Background() + + tests := []struct { + name string + command string + useShellExecution bool + expectError bool + errorContains string + }{ + { + name: "safe command - secure mode", + command: "echo hello", + useShellExecution: false, + expectError: false, + }, + { + name: "safe command - legacy mode", + command: "echo hello", + useShellExecution: true, + expectError: false, + }, + { + name: "command with pipe - secure mode blocks", + command: "echo hello | cat", + useShellExecution: false, + expectError: true, + errorContains: "dangerous shell constructs", + }, + { + name: "command with pipe - legacy mode allows", + command: "echo hello | cat", + useShellExecution: true, + expectError: false, + }, + { + name: "command substitution - secure mode blocks", + command: "echo $(whoami)", + useShellExecution: false, + expectError: true, + errorContains: "dangerous shell constructs", + }, + { + name: "command substitution - legacy mode allows", + command: "echo $(whoami)", + useShellExecution: true, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := SecurityConfig{ + UseShellExecution: tt.useShellExecution, + MaxExecutionTime: time.Second * 5, + } + executor := newCommandExecutor(config, logger) + + result, err := executor.executeSecureCommand(ctx, tt.command, false) + + if tt.expectError { + require.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + } else { + require.NoError(t, err) + assert.NotNil(t, result) + } + }) + } +} + +func TestCommandExecutor_vulnerability_prevention(t *testing.T) { + logger := zerolog.New(zerolog.NewTestWriter(t)) + ctx := context.Background() + + // These are actual injection payloads that should be blocked + vulnerabilityTests := []struct { + name string + command string + description string + }{ + { + name: "VULN.md example - obfuscated chmod", + command: "echo $($(echo -n c; echo -n h; echo -n m; echo -n o; echo -n d))", + description: "Command substitution to reconstruct 'chmod' command", + }, + { + name: "command injection via semicolon", + command: "ls; rm -rf /", + description: "Command separator to execute dangerous command", + }, + { + name: "command injection via pipe", + command: "echo safe | rm -rf /", + description: "Pipe to execute dangerous command", + }, + { + name: "command injection via background", + command: "echo safe & rm -rf /", + description: "Background execution to hide dangerous command", + }, + { + name: "variable expansion injection", + command: "echo ${IFS}rm${IFS}-rf${IFS}/", + description: "Using IFS variable to obfuscate dangerous command", + }, + { + name: "backtick command substitution", + command: "echo `rm -rf /`", + description: "Backtick command substitution for injection", + }, + } + + // Test with secure execution (should block all) + t.Run("secure_execution_blocks_vulnerabilities", func(t *testing.T) { + config := SecurityConfig{ + UseShellExecution: false, + MaxExecutionTime: time.Second * 5, + } + executor := newCommandExecutor(config, logger) + + for _, vt := range vulnerabilityTests { + t.Run(vt.name, func(t *testing.T) { + _, err := executor.executeSecureCommand(ctx, vt.command, false) + assert.Error(t, err, "Secure execution should block: %s", vt.description) + }) + } + }) + + // Test with legacy execution (vulnerable - allows these) + t.Run("legacy_execution_allows_vulnerabilities", func(t *testing.T) { + config := SecurityConfig{ + UseShellExecution: true, + MaxExecutionTime: time.Second * 5, + } + executor := newCommandExecutor(config, logger) + + for _, vt := range vulnerabilityTests { + t.Run(vt.name, func(t *testing.T) { + // Note: We don't actually want these to succeed in tests, + // but we verify they reach the execution stage (not blocked by parsing) + _, err := executor.executeSecureCommand(ctx, vt.command, false) + // These may fail due to actual command execution, but should not fail due to parsing + if err != nil { + assert.NotContains(t, err.Error(), "shell metacharacters", + "Legacy mode should not block based on metacharacters") + assert.NotContains(t, err.Error(), "command parsing failed", + "Legacy mode should not fail at parsing stage") + } + }) + } + }) +} diff --git a/go.mod b/go.mod index ac8a983..8089551 100644 --- a/go.mod +++ b/go.mod @@ -6,19 +6,22 @@ require ( github.com/joho/godotenv v1.5.1 github.com/mark3labs/mcp-go v0.39.1 github.com/rs/zerolog v1.34.0 + github.com/stretchr/testify v1.11.1 gopkg.in/yaml.v3 v3.0.1 ) require ( github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/uuid v1.6.0 // indirect github.com/invopop/jsonschema v0.13.0 // indirect - github.com/mailru/easyjson v0.9.0 // indirect + github.com/mailru/easyjson v0.9.1 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/spf13/cast v1.9.2 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/spf13/cast v1.10.0 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect - golang.org/x/sys v0.35.0 // indirect + golang.org/x/sys v0.36.0 // indirect ) diff --git a/go.sum b/go.sum index 49a2f1a..4a40331 100644 --- a/go.sum +++ b/go.sum @@ -20,8 +20,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= -github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/mailru/easyjson v0.9.1 h1:LbtsOm5WAswyWbvTEOqhypdPeZzHavpZx96/n553mR8= +github.com/mailru/easyjson v0.9.1/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/mark3labs/mcp-go v0.39.1 h1:2oPxk7aDbQhouakkYyKl2T4hKFU1c6FDaubWyGyVE1k= github.com/mark3labs/mcp-go v0.39.1/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= @@ -39,10 +39,10 @@ github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/f github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= -github.com/spf13/cast v1.9.2 h1:SsGfm7M8QOFtEzumm7UZrZdLLquNdzFYfIbEXntcFbE= -github.com/spf13/cast v1.9.2/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= @@ -50,8 +50,8 @@ github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT0 golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= -golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= +golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/handler_test.go b/handler_test.go new file mode 100644 index 0000000..7ed62b7 --- /dev/null +++ b/handler_test.go @@ -0,0 +1,290 @@ +package main + +import ( + "context" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestShellHandler_handle_secure_mode(t *testing.T) { + logger := zerolog.New(zerolog.NewTestWriter(t)) + ctx := context.Background() + + tests := []struct { + name string + config SecurityConfig + requestArgs map[string]interface{} + expectError bool + expectErrorText string + }{ + { + name: "secure mode allows safe command", + config: SecurityConfig{ + Enabled: true, + UseShellExecution: false, + AllowedExecutables: []string{"echo", "pwd"}, + MaxExecutionTime: time.Second * 5, + }, + requestArgs: map[string]interface{}{ + "command": "echo hello world", + "base64": false, + }, + expectError: false, + }, + { + name: "secure mode blocks dangerous command", + config: SecurityConfig{ + Enabled: true, + UseShellExecution: false, + AllowedExecutables: []string{"echo", "pwd"}, + }, + requestArgs: map[string]interface{}{ + "command": "rm -rf /", + }, + expectError: true, + expectErrorText: "not in allowed list", + }, + { + name: "secure mode blocks injection attempt", + config: SecurityConfig{ + Enabled: true, + UseShellExecution: false, + AllowedExecutables: []string{"echo"}, + }, + requestArgs: map[string]interface{}{ + "command": "echo $($(echo -n c; echo -n h; echo -n m; echo -n o; echo -n d))", + }, + expectError: true, + expectErrorText: "not in allowed list", + }, + { + name: "missing command parameter", + config: SecurityConfig{ + Enabled: false, + }, + requestArgs: map[string]interface{}{ + "base64": false, + }, + expectError: true, + expectErrorText: "Missing 'command' parameter", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + validator := newSecurityValidator(tt.config, logger) + executor := newCommandExecutor(tt.config, logger) + handler := newShellHandler(validator, executor, logger) + + // Create MCP request using the arguments map + request := mcp.CallToolRequest{} + request.Params.Arguments = tt.requestArgs + request.Params.Name = "shell_exec" + + result, err := handler.handle(ctx, request) + + require.NoError(t, err, "Handler should not return error, but result should contain error") + require.NotNil(t, result) + + if tt.expectError { + // Check if result contains error + assert.True(t, result.IsError) + } else { + // Check if result is successful + assert.False(t, result.IsError) + } + }) + } +} + +func TestShellHandler_vulnerability_prevention_integration(t *testing.T) { + logger := zerolog.New(zerolog.NewTestWriter(t)) + ctx := context.Background() + + // Test the exact vulnerability from VULN.md + vulnerabilityRequest := mcp.CallToolRequest{} + vulnerabilityRequest.Params.Arguments = map[string]interface{}{ + "command": "echo $($(echo -n c; echo -n h; echo -n m; echo -n o; echo -n d))", + "base64": false, + } + + t.Run("secure_mode_blocks_vuln_md_example", func(t *testing.T) { + config := SecurityConfig{ + Enabled: true, + UseShellExecution: false, + AllowedExecutables: []string{"echo", "ls", "pwd"}, + } + + validator := newSecurityValidator(config, logger) + executor := newCommandExecutor(config, logger) + handler := newShellHandler(validator, executor, logger) + + result, err := handler.handle(ctx, vulnerabilityRequest) + require.NoError(t, err) + + // Should be blocked at validation stage + assert.True(t, result.IsError, "Secure mode should block the injection attempt") + }) + + t.Run("legacy_mode_vulnerable_without_blocks", func(t *testing.T) { + config := SecurityConfig{ + Enabled: true, + UseShellExecution: true, + // No blocks - vulnerable + MaxExecutionTime: time.Second * 1, + } + + validator := newSecurityValidator(config, logger) + executor := newCommandExecutor(config, logger) + handler := newShellHandler(validator, executor, logger) + + result, err := handler.handle(ctx, vulnerabilityRequest) + require.NoError(t, err) + + // This demonstrates the vulnerability - legacy mode allows dangerous commands + // In a real attack, this would execute the obfuscated chmod + t.Logf("Legacy mode result - IsError: %v", result.IsError) + }) + + t.Run("legacy_mode_with_proper_blocks", func(t *testing.T) { + config := SecurityConfig{ + Enabled: true, + UseShellExecution: true, + BlockedCommands: []string{"chmod"}, // This should catch the obfuscated chmod + } + + validator := newSecurityValidator(config, logger) + executor := newCommandExecutor(config, logger) + handler := newShellHandler(validator, executor, logger) + + result, err := handler.handle(ctx, vulnerabilityRequest) + require.NoError(t, err) + + // This demonstrates the vulnerability - legacy mode cannot detect obfuscated commands + // even with keyword blocking, since "chmod" doesn't appear literally + assert.False(t, result.IsError, "Legacy mode with blocks still vulnerable to obfuscation") + }) +} + +func TestShellHandler_base64_encoding(t *testing.T) { + logger := zerolog.New(zerolog.NewTestWriter(t)) + ctx := context.Background() + + config := SecurityConfig{ + Enabled: true, + UseShellExecution: false, + AllowedExecutables: []string{"echo"}, + MaxExecutionTime: time.Second * 5, + } + + validator := newSecurityValidator(config, logger) + executor := newCommandExecutor(config, logger) + handler := newShellHandler(validator, executor, logger) + + tests := []struct { + name string + base64 bool + command string + }{ + { + name: "without base64 encoding", + base64: false, + command: "echo hello world", + }, + { + name: "with base64 encoding", + base64: true, + command: "echo hello world", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + request := mcp.CallToolRequest{} + request.Params.Arguments = map[string]interface{}{ + "command": tt.command, + "base64": tt.base64, + } + + result, err := handler.handle(ctx, request) + require.NoError(t, err) + assert.False(t, result.IsError, "Base64 encoding test should succeed") + }) + } +} + +// Test direct security validation and execution without MCP wrapper +func TestShellHandler_direct_security_tests(t *testing.T) { + logger := zerolog.New(zerolog.NewTestWriter(t)) + + t.Run("secure_execution_blocks_injection", func(t *testing.T) { + config := SecurityConfig{ + Enabled: true, + UseShellExecution: false, + AllowedExecutables: []string{"echo"}, + } + + validator := newSecurityValidator(config, logger) + executor := newCommandExecutor(config, logger) + + // Test validation - should pass (this is the vulnerability) + err := validator.validateCommand("echo $(rm -rf /)") + assert.Error(t, err, "Should block command with shell metacharacters") + + // Test parsing - would also fail in executor + _, _, err = executor.parseCommand("echo $(rm -rf /)") + assert.Error(t, err, "Should fail to parse command with shell metacharacters") + }) + + t.Run("legacy_execution_allows_injection", func(t *testing.T) { + config := SecurityConfig{ + Enabled: true, + UseShellExecution: true, + // No restrictions - vulnerable + } + + validator := newSecurityValidator(config, logger) + + // Test validation - should pass (this is the vulnerability) + err := validator.validateCommand("echo $(rm -rf /)") + assert.NoError(t, err, "Legacy mode without blocks allows dangerous commands") + }) + + t.Run("VULN_MD_example_security_comparison", func(t *testing.T) { + vulnCommand := "echo $($(echo -n c; echo -n h; echo -n m; echo -n o; echo -n d))" + + // Secure mode + secureConfig := SecurityConfig{ + Enabled: true, + UseShellExecution: false, + AllowedExecutables: []string{"echo"}, + } + secureValidator := newSecurityValidator(secureConfig, logger) + err := secureValidator.validateCommand(vulnCommand) + assert.Error(t, err, "Secure mode should block VULN.md example") + + // Legacy mode without proper blocks (vulnerable) + vulnerableConfig := SecurityConfig{ + Enabled: true, + UseShellExecution: true, + } + vulnerableValidator := newSecurityValidator(vulnerableConfig, logger) + err = vulnerableValidator.validateCommand(vulnCommand) + assert.NoError(t, err, "Legacy mode without blocks is vulnerable") + + // Legacy mode with proper blocks - still vulnerable to obfuscated commands + protectedConfig := SecurityConfig{ + Enabled: true, + UseShellExecution: true, + BlockedCommands: []string{"chmod"}, + } + protectedValidator := newSecurityValidator(protectedConfig, logger) + err = protectedValidator.validateCommand(vulnCommand) + assert.NoError(t, err, "Legacy mode cannot detect obfuscated commands even with blocks") + }) +} diff --git a/security-legacy.yaml b/security-legacy.yaml new file mode 100644 index 0000000..e9b274a --- /dev/null +++ b/security-legacy.yaml @@ -0,0 +1,43 @@ +# Legacy MCP Shell Configuration (VULNERABLE) +# This configuration uses the old shell execution model and is vulnerable to injection +# Only use this for backwards compatibility and when you fully trust the input +security: + # Enable security features + enabled: true + + # Enable legacy shell execution (DANGEROUS - vulnerable to injection) + use_shell_execution: true + + # Legacy allowlist approach (vulnerable to bypass) + allowed_commands: + - "ls" + - "pwd" + - "echo" + - "cat" + + # Legacy blocklist approach (easily bypassed) + blocked_commands: + - "rm" + - "chmod" + - "chown" + - "sudo" + - "su" + + blocked_patterns: + - "rm\\s+-rf" + - "sudo\\s+" + - "passwd" + + # These settings are ignored in legacy mode + allowed_executables: [] + + # Execution limits + max_execution_time: "30s" + max_output_size: 1048576 # 1MB + + # Security context + working_directory: "/tmp" + run_as_user: "" + + # Logging + audit_log: true diff --git a/security.go b/security.go index e872475..bfabeda 100644 --- a/security.go +++ b/security.go @@ -2,6 +2,8 @@ package main import ( "fmt" + "os/exec" + "path/filepath" "regexp" "strings" @@ -28,6 +30,128 @@ func (v *SecurityValidator) validateCommand(command string) error { v.logger.Debug().Str("command", command).Msg("Validating command") + // If shell execution is disabled and we have allowed executables configured, + // use the secure validation approach + if !v.config.UseShellExecution && len(v.config.AllowedExecutables) > 0 { + return v.validateExecutableCommand(command) + } + + // Legacy validation for backwards compatibility + if v.config.UseShellExecution { + v.logger.Warn(). + Str("command", command). + Msg("Using legacy shell execution mode - this is vulnerable to injection attacks") + return v.validateLegacyCommand(command) + } + + // If no allowed executables are configured but security is enabled, + // block everything for safety + if len(v.config.AllowedExecutables) == 0 { + v.logger.Warn(). + Str("command", command). + Msg("No allowed executables configured - blocking all commands") + return fmt.Errorf("no allowed executables configured - all commands blocked for security") + } + + return v.validateExecutableCommand(command) +} + +// validateExecutableCommand validates commands using the secure executable allowlist approach +func (v *SecurityValidator) validateExecutableCommand(command string) error { + command = strings.TrimSpace(command) + if command == "" { + return fmt.Errorf("empty command") + } + + // Check for shell metacharacters first - reject commands that try to use shell features + if containsShellMetacharacters(command) { + return fmt.Errorf("command contains shell metacharacters (not allowed in secure mode): %s", command) + } + + // Check for dangerous shell constructs in the entire command + if containsDangerousShellConstructs(command) { + return fmt.Errorf("command contains dangerous shell constructs (not allowed in secure mode): %s", command) + } + + // Simple whitespace-based splitting to get the executable + parts := strings.Fields(command) + if len(parts) == 0 { + return fmt.Errorf("no command found") + } + + executable := parts[0] + + // Check if the executable is in the allowlist + for _, allowed := range v.config.AllowedExecutables { + if v.matchesExecutable(executable, allowed) { + v.logger.Debug(). + Str("executable", executable). + Str("allowed_pattern", allowed). + Msg("Command validated against allowed executable") + return nil + } + } + + v.logger.Warn(). + Str("executable", executable). + Strs("allowed_executables", v.config.AllowedExecutables). + Msg("Executable not in allowed list") + return fmt.Errorf("executable '%s' not in allowed list", executable) +} + +// matchesExecutable checks if an executable matches an allowed pattern +func (v *SecurityValidator) matchesExecutable(executable, pattern string) bool { + // Exact match + if executable == pattern { + return true + } + + // Check if it's a full path match + if filepath.IsAbs(pattern) { + if absExec, err := filepath.Abs(executable); err == nil { + return absExec == pattern + } + return false + } + + // Check if it's a basename match for simple commands (only if executable is not absolute) + if !filepath.IsAbs(executable) && filepath.Base(executable) == pattern { + // Verify the executable exists in PATH + if _, err := exec.LookPath(executable); err == nil { + return true + } + } + + return false +} + +// containsShellMetacharacters checks if a string contains shell metacharacters +// that could be used for command injection +func containsShellMetacharacters(s string) bool { + metachars := "|&;<>(){}[]$`\\" + for _, char := range s { + if strings.ContainsRune(metachars, char) { + return true + } + } + return false +} + +// containsDangerousShellConstructs checks for potentially dangerous shell constructs +func containsDangerousShellConstructs(s string) bool { + dangerous := []string{ + "$(", "`", "${", "&&", "||", ";", "|", ">", "<", ">>", "<<", "&", + } + for _, construct := range dangerous { + if strings.Contains(s, construct) { + return true + } + } + return false +} + +// validateLegacyCommand performs the old validation for backwards compatibility +func (v *SecurityValidator) validateLegacyCommand(command string) error { for _, pattern := range v.config.BlockedPatterns { if matched, err := regexp.MatchString(pattern, command); err == nil && matched { v.logger.Warn(). @@ -65,7 +189,7 @@ func (v *SecurityValidator) validateCommand(command string) error { } } - v.logger.Debug().Str("command", command).Msg("Command validation passed") + v.logger.Debug().Str("command", command).Msg("Legacy command validation passed") return nil } diff --git a/security.yaml b/security.yaml index 59ca3f5..41ae953 100644 --- a/security.yaml +++ b/security.yaml @@ -1,32 +1,46 @@ +# Secure MCP Shell Configuration +# This configuration uses the new secure execution model that prevents command injection security: + # Enable security features enabled: true - allowed_commands: - - ls - - cat - - grep - - find - - echo - - pwd - - whoami - - date - - curl - - wget - blocked_commands: - - rm -rf - - sudo - - chmod - - dd - - mkfs - - fdisk - blocked_patterns: - - 'rm\s+.*-rf.*' - - 'sudo\s+.*' - - 'chmod\s+(777|666)' - - '>/dev/' - - 'format\s+' - - 'del\s+.*\*' - max_execution_time: 30s - working_directory: /tmp/mcp-workspace - run_as_user: "" - max_output_size: 1048576 + + # Use secure execution (disable shell interpretation) + # This prevents command injection through shell metacharacters + use_shell_execution: false + + # Allowed executables - only these commands can be executed + # Use absolute paths for maximum security, or command names for PATH lookup + allowed_executables: + - "ls" + - "pwd" + - "whoami" + - "date" + - "echo" + - "cat" + - "grep" + - "find" + - "wc" + - "head" + - "tail" + - "sort" + - "uniq" + - "/usr/bin/git" + - "/usr/bin/python3" + - "/bin/bash" # Only allow if you trust the arguments + + # Legacy settings (deprecated but kept for backwards compatibility) + # These are ignored when use_shell_execution is false + allowed_commands: [] + blocked_commands: [] + blocked_patterns: [] + + # Execution limits + max_execution_time: "30s" + max_output_size: 1048576 # 1MB + + # Security context + working_directory: "/tmp" + run_as_user: "" # Leave empty to run as current user + + # Logging audit_log: true diff --git a/security_test.go b/security_test.go new file mode 100644 index 0000000..8853a9a --- /dev/null +++ b/security_test.go @@ -0,0 +1,421 @@ +package main + +import ( + "strings" + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSecurityValidator_validateCommand(t *testing.T) { + logger := zerolog.New(zerolog.NewTestWriter(t)) + + tests := []struct { + name string + config SecurityConfig + command string + expectError bool + errorContains string + }{ + { + name: "security disabled allows everything", + config: SecurityConfig{ + Enabled: false, + }, + command: "rm -rf /", + expectError: false, + }, + { + name: "secure mode with allowed executables - allows ls", + config: SecurityConfig{ + Enabled: true, + UseShellExecution: false, + AllowedExecutables: []string{"ls", "pwd", "echo"}, + }, + command: "ls -la", + expectError: false, + }, + { + name: "secure mode with allowed executables - blocks rm", + config: SecurityConfig{ + Enabled: true, + UseShellExecution: false, + AllowedExecutables: []string{"ls", "pwd", "echo"}, + }, + command: "rm -rf /", + expectError: true, + errorContains: "not in allowed list", + }, + { + name: "secure mode with no allowed executables - blocks everything", + config: SecurityConfig{ + Enabled: true, + UseShellExecution: false, + }, + command: "echo hello", + expectError: true, + errorContains: "no allowed executables configured", + }, + { + name: "legacy mode with allowed commands - allows echo", + config: SecurityConfig{ + Enabled: true, + UseShellExecution: true, + AllowedCommands: []string{"echo", "ls"}, + }, + command: "echo hello", + expectError: false, + }, + { + name: "legacy mode with allowed commands - blocks rm", + config: SecurityConfig{ + Enabled: true, + UseShellExecution: true, + AllowedCommands: []string{"echo", "ls"}, + }, + command: "rm file", + expectError: true, + errorContains: "not in allowed list", + }, + { + name: "legacy mode with blocked commands - blocks rm", + config: SecurityConfig{ + Enabled: true, + UseShellExecution: true, + BlockedCommands: []string{"rm", "chmod", "sudo"}, + }, + command: "rm file", + expectError: true, + errorContains: "blocked keyword", + }, + { + name: "legacy mode with blocked patterns - blocks rm -rf", + config: SecurityConfig{ + Enabled: true, + UseShellExecution: true, + BlockedPatterns: []string{"rm\\s+-rf", "sudo\\s+"}, + }, + command: "rm -rf /tmp", + expectError: true, + errorContains: "blocked pattern", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + validator := newSecurityValidator(tt.config, logger) + err := validator.validateCommand(tt.command) + + if tt.expectError { + require.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +func TestSecurityValidator_validateExecutableCommand(t *testing.T) { + logger := zerolog.New(zerolog.NewTestWriter(t)) + + tests := []struct { + name string + allowedExecutables []string + command string + expectError bool + errorContains string + }{ + { + name: "simple command in allowlist", + allowedExecutables: []string{"ls", "pwd", "echo"}, + command: "ls -la", + expectError: false, + }, + { + name: "command not in allowlist", + allowedExecutables: []string{"ls", "pwd", "echo"}, + command: "rm file.txt", + expectError: true, + errorContains: "not in allowed list", + }, + { + name: "absolute path exact match", + allowedExecutables: []string{"/usr/bin/git", "/bin/ls"}, + command: "/usr/bin/git status", + expectError: false, + }, + { + name: "absolute path mismatch", + allowedExecutables: []string{"/usr/bin/git"}, + command: "/bin/git status", + expectError: true, + errorContains: "not in allowed list", + }, + { + name: "empty command", + allowedExecutables: []string{"ls"}, + command: "", + expectError: true, + errorContains: "empty command", + }, + { + name: "whitespace only command", + allowedExecutables: []string{"ls"}, + command: " ", + expectError: true, + errorContains: "empty command", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := SecurityConfig{ + AllowedExecutables: tt.allowedExecutables, + } + validator := newSecurityValidator(config, logger) + err := validator.validateExecutableCommand(tt.command) + + if tt.expectError { + require.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +func TestSecurityValidator_matchesExecutable(t *testing.T) { + logger := zerolog.New(zerolog.NewTestWriter(t)) + validator := newSecurityValidator(SecurityConfig{}, logger) + + tests := []struct { + name string + executable string + pattern string + expected bool + }{ + { + name: "exact match", + executable: "ls", + pattern: "ls", + expected: true, + }, + { + name: "no match", + executable: "ls", + pattern: "rm", + expected: false, + }, + { + name: "absolute path exact match", + executable: "/usr/bin/git", + pattern: "/usr/bin/git", + expected: true, + }, + { + name: "basename match for command in PATH", + executable: "git", + pattern: "git", + expected: true, // This should work if git is in PATH + }, + { + name: "absolute path vs basename no match", + executable: "/usr/bin/git", + pattern: "git", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := validator.matchesExecutable(tt.executable, tt.pattern) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestSecurityValidator_validateLegacyCommand(t *testing.T) { + logger := zerolog.New(zerolog.NewTestWriter(t)) + + tests := []struct { + name string + config SecurityConfig + command string + expectError bool + errorContains string + }{ + { + name: "no restrictions - allows everything", + config: SecurityConfig{ + AllowedCommands: []string{}, + BlockedCommands: []string{}, + BlockedPatterns: []string{}, + }, + command: "any command here", + expectError: false, + }, + { + name: "blocked command keyword", + config: SecurityConfig{ + BlockedCommands: []string{"rm", "chmod"}, + }, + command: "rm -rf /", + expectError: true, + errorContains: "blocked keyword", + }, + { + name: "blocked pattern match", + config: SecurityConfig{ + BlockedPatterns: []string{"rm\\s+-rf"}, + }, + command: "rm -rf /tmp", + expectError: true, + errorContains: "blocked pattern", + }, + { + name: "allowed command prefix match", + config: SecurityConfig{ + AllowedCommands: []string{"echo", "ls -"}, + }, + command: "echo hello world", + expectError: false, + }, + { + name: "command not in allowed list", + config: SecurityConfig{ + AllowedCommands: []string{"echo", "ls"}, + }, + command: "rm file", + expectError: true, + errorContains: "not in allowed list", + }, + { + name: "complex injection attempt blocked by keyword", + config: SecurityConfig{ + BlockedCommands: []string{"chmod"}, + }, + command: "chmod 777 /etc/passwd", + expectError: true, + errorContains: "blocked keyword", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + validator := newSecurityValidator(tt.config, logger) + err := validator.validateLegacyCommand(tt.command) + + if tt.expectError { + require.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +func TestSecurityValidator_vulnerability_scenarios(t *testing.T) { + logger := zerolog.New(zerolog.NewTestWriter(t)) + + // Test scenarios based on the VULN.md report + vulnerabilityPayloads := []struct { + name string + command string + description string + }{ + { + name: "VULN.md example", + command: "echo $($(echo -n c; echo -n h; echo -n m; echo -n o; echo -n d))", + description: "Obfuscated chmod reconstruction", + }, + { + name: "simple command injection", + command: "ls; rm -rf /", + description: "Command separator injection", + }, + { + name: "pipe injection", + command: "echo safe | rm dangerous", + description: "Pipe-based command injection", + }, + { + name: "background injection", + command: "echo safe & rm dangerous", + description: "Background process injection", + }, + } + + t.Run("secure_mode_blocks_all_vulnerabilities", func(t *testing.T) { + config := SecurityConfig{ + Enabled: true, + UseShellExecution: false, + AllowedExecutables: []string{"echo", "ls"}, // Only safe commands + } + validator := newSecurityValidator(config, logger) + + for _, payload := range vulnerabilityPayloads { + t.Run(payload.name, func(t *testing.T) { + err := validator.validateCommand(payload.command) + if err != nil { + assert.Error(t, err, "Secure mode should block: %s", payload.description) + // Check for either error message since they both indicate blocking + errorMsg := err.Error() + shouldContainOne := strings.Contains(errorMsg, "not in allowed list") || + strings.Contains(errorMsg, "shell metacharacters") || + strings.Contains(errorMsg, "dangerous shell constructs") + assert.True(t, shouldContainOne, "Error should indicate blocking: %s", errorMsg) + } else { + t.Errorf("Secure mode should block: %s", payload.description) + } + }) + } + }) + + t.Run("legacy_mode_with_proper_blocks", func(t *testing.T) { + config := SecurityConfig{ + Enabled: true, + UseShellExecution: true, + BlockedCommands: []string{"rm", "chmod", "chown", "sudo"}, + BlockedPatterns: []string{"rm\\s+-rf", "chmod\\s+"}, + } + validator := newSecurityValidator(config, logger) + + // The VULN.md example demonstrates the vulnerability - obfuscated commands bypass keyword matching + err := validator.validateCommand("echo $($(echo -n c; echo -n h; echo -n m; echo -n o; echo -n d))") + // This should pass because "chmod" doesn't appear literally in the command + assert.NoError(t, err, "Legacy mode cannot detect obfuscated commands") + + // But a simple rm should be blocked + err = validator.validateCommand("rm file") + assert.Error(t, err) + assert.Contains(t, err.Error(), "blocked keyword") + }) + + t.Run("legacy_mode_vulnerable_without_proper_blocks", func(t *testing.T) { + config := SecurityConfig{ + Enabled: true, + UseShellExecution: true, + // No blocks configured - vulnerable + } + validator := newSecurityValidator(config, logger) + + // All payloads would pass validation (but still be dangerous) + for _, payload := range vulnerabilityPayloads { + t.Run(payload.name, func(t *testing.T) { + err := validator.validateCommand(payload.command) + assert.NoError(t, err, "Legacy mode without blocks allows: %s", payload.description) + }) + } + }) +}