Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ coverage.xml
# Virtual environments
.env
.venv
.venv-*/
env/
venv/
ENV/
Expand Down
38 changes: 32 additions & 6 deletions go/cmd/prompd/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -383,19 +383,45 @@ func handleGitCommit() {
break
}
}

if message == "" {
fmt.Println("Error: git commit requires -m <message>")
os.Exit(1)
}


// SECURITY: Sanitize commit message
if err := validateGitMessage(message); err != nil {
fmt.Printf("Error: invalid commit message: %v\n", err)
os.Exit(1)
}

cmd := exec.Command("git", "commit", "-m", message)
if err := cmd.Run(); err != nil {
fmt.Printf("Error committing: %v\n", err)
os.Exit(1)
}

fmt.Printf("✓ Committed with message: %s\n", message)

fmt.Printf("Committed with message: %s\n", message)
}

// validateGitMessage sanitizes git commit messages to prevent injection
func validateGitMessage(msg string) error {
const maxMessageLength = 5000
if len(msg) > maxMessageLength {
return fmt.Errorf("message too long (%d chars, max %d)", len(msg), maxMessageLength)
}

// Reject null bytes and control characters (except newline, tab, carriage return)
for i, c := range msg {
if c == 0 {
return fmt.Errorf("message contains null byte at position %d", i)
}
if c < 32 && c != '\n' && c != '\r' && c != '\t' {
return fmt.Errorf("message contains control character (0x%02x) at position %d", c, i)
}
}

return nil
}


Expand Down Expand Up @@ -864,8 +890,8 @@ func saveConfig(config *Config) error {
return fmt.Errorf("failed to marshal config: %w", err)
}

// Write to file
if err := os.WriteFile(path, data, 0644); err != nil {
// SECURITY: Write config with restrictive permissions (owner read/write only)
if err := os.WriteFile(path, data, 0600); err != nil {
continue // Try next path
}

Expand Down
2 changes: 1 addition & 1 deletion go/cmd/prompd/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func handleCreate() {
if paramName == "" {
break
}
paramType := promptWithDefault("Parameter type [string/integer/float/boolean]", "string")
paramType := promptWithDefault("Parameter type [string/number/integer/float/boolean/array/object/json/file/base64]", "string")
paramDesc := promptWithDefault("Parameter description", "")
paramRequired := promptYesNo("Required?", false)

Expand Down
116 changes: 110 additions & 6 deletions go/cmd/prompd/package.go
Original file line number Diff line number Diff line change
Expand Up @@ -428,12 +428,28 @@ func createPackage(sourceDir, outputPath string, manifest PackageManifest, exclu
return err
}

// SECURITY: Track total size to prevent oversized packages
const maxTotalSize int64 = 200 * 1024 * 1024 // 200MB total package limit
const maxFileSize int64 = 50 * 1024 * 1024 // 50MB per file limit
var totalSize int64

// Walk source directory and add files
return filepath.Walk(sourceDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}

// SECURITY: Reject symlinks to prevent including files from outside the package directory
if info.Mode()&os.ModeSymlink != 0 {
return fmt.Errorf("security violation: symlinks not allowed in packages: %s", path)
}

// SECURITY: Double-check with Lstat to catch symlinks that Walk may follow
lstatInfo, lstatErr := os.Lstat(path)
if lstatErr == nil && lstatInfo.Mode()&os.ModeSymlink != 0 {
return fmt.Errorf("security violation: symlink detected: %s", path)
}

// Get relative path
relPath, err := filepath.Rel(sourceDir, path)
if err != nil {
Expand All @@ -445,6 +461,11 @@ func createPackage(sourceDir, outputPath string, manifest PackageManifest, exclu
return nil
}

// SECURITY: Validate path is safe (no traversal, null bytes, etc.)
if pathErr := isSecurePath(relPath, sourceDir); pathErr != nil {
return fmt.Errorf("security violation in path %s: %v", relPath, pathErr)
}

// Check exclusions
if shouldExclude(relPath, info, exclusions) {
if info.IsDir() {
Expand All @@ -458,9 +479,20 @@ func createPackage(sourceDir, outputPath string, manifest PackageManifest, exclu
return nil
}

// SECURITY: Enforce per-file size limit
if info.Size() > maxFileSize {
return fmt.Errorf("file too large: %s (%d bytes, max %d bytes)", relPath, info.Size(), maxFileSize)
}

// SECURITY: Enforce total package size limit
totalSize += info.Size()
if totalSize > maxTotalSize {
return fmt.Errorf("total package size exceeds limit (%d bytes max)", maxTotalSize)
}

// Add file to zip
zipPath := filepath.ToSlash(relPath) // Ensure forward slashes in zip

zipFileWriter, err := zipWriter.Create(zipPath)
if err != nil {
return err
Expand All @@ -472,7 +504,8 @@ func createPackage(sourceDir, outputPath string, manifest PackageManifest, exclu
}
defer fileReader.Close()

_, err = io.Copy(zipFileWriter, fileReader)
// SECURITY: Use LimitReader to enforce file size limit during copy
_, err = io.Copy(zipFileWriter, io.LimitReader(fileReader, maxFileSize+1))
return err
})
}
Expand Down Expand Up @@ -507,39 +540,97 @@ func shouldExclude(relPath string, info os.FileInfo, exclusions PDProjExclusions


func validatePdpkgFile(filePath string) error {
// SECURITY: Check package file size before opening
stat, err := os.Stat(filePath)
if err != nil {
return fmt.Errorf("failed to stat file: %v", err)
}
const maxPackageSize int64 = 200 * 1024 * 1024 // 200MB
if stat.Size() > maxPackageSize {
return fmt.Errorf("package file too large: %d bytes (max %d bytes)", stat.Size(), maxPackageSize)
}

// Open ZIP file
zipReader, err := zip.OpenReader(filePath)
if err != nil {
return fmt.Errorf("failed to open ZIP file: %v", err)
}
defer zipReader.Close()

// SECURITY: Check for ZIP slip/directory traversal attacks
// SECURITY: Check for ZIP slip/directory traversal, symlinks, and decompression bombs
const maxDecompressedSize uint64 = 500 * 1024 * 1024 // 500MB total decompressed limit
const maxCompressionRatio uint64 = 100 // 100:1 max ratio
const maxFileCount = 1000
var totalDecompressedSize uint64

if len(zipReader.File) > maxFileCount {
return fmt.Errorf("too many files in package: %d (max %d)", len(zipReader.File), maxFileCount)
}

for _, file := range zipReader.File {
// SECURITY: Check for null bytes in raw name before cleaning
if strings.Contains(file.Name, "\x00") {
return fmt.Errorf("security violation: null byte in file name: %s", file.Name)
}

// Normalize path and check for traversal
cleanPath := filepath.Clean(file.Name)
if strings.Contains(cleanPath, "..") || filepath.IsAbs(file.Name) {
if strings.Contains(cleanPath, "..") || filepath.IsAbs(file.Name) || filepath.IsAbs(cleanPath) {
return fmt.Errorf("security violation: path traversal detected in %s", file.Name)
}

// SECURITY: Check for backslash-based traversal on all platforms
if strings.Contains(file.Name, "\\..") || strings.Contains(file.Name, "..\\") {
return fmt.Errorf("security violation: path traversal detected in %s", file.Name)
}

// SECURITY: Detect symlinks in ZIP entries
if file.FileInfo().Mode()&os.ModeSymlink != 0 {
return fmt.Errorf("security violation: symlink detected in package: %s", file.Name)
}

// SECURITY: Track cumulative decompressed size for bomb detection
totalDecompressedSize += file.UncompressedSize64
if totalDecompressedSize > maxDecompressedSize {
return fmt.Errorf("security violation: total decompressed size exceeds limit (%d bytes max)", maxDecompressedSize)
}

// SECURITY: Check individual file compression ratio
if file.CompressedSize64 > 0 {
ratio := file.UncompressedSize64 / file.CompressedSize64
if ratio > maxCompressionRatio {
return fmt.Errorf("security violation: suspicious compression ratio %d:1 in %s (max %d:1)", ratio, file.Name, maxCompressionRatio)
}
}
}

// Check for manifest.json
var manifestFound bool
for _, file := range zipReader.File {
if file.Name == "manifest.json" {
manifestFound = true


// SECURITY: Enforce manifest size limit
const maxManifestSize uint64 = 1024 * 1024 // 1MB
if file.UncompressedSize64 > maxManifestSize {
return fmt.Errorf("manifest.json too large: %d bytes (max %d bytes)", file.UncompressedSize64, maxManifestSize)
}

// Read and validate manifest
reader, err := file.Open()
if err != nil {
return fmt.Errorf("failed to read manifest.json: %v", err)
}
defer reader.Close()

content, err := io.ReadAll(reader)
// SECURITY: Use LimitReader to enforce size during read
content, err := io.ReadAll(io.LimitReader(reader, int64(maxManifestSize)+1))
if err != nil {
return fmt.Errorf("failed to read manifest content: %v", err)
}
if uint64(len(content)) > maxManifestSize {
return fmt.Errorf("manifest.json content exceeds size limit")
}

var manifest PackageManifest
if err := json.Unmarshal(content, &manifest); err != nil {
Expand All @@ -557,6 +648,19 @@ func validatePdpkgFile(filePath string) error {
return fmt.Errorf("missing 'description' in manifest.json")
}

// SECURITY: Validate manifest type field
if manifest.Type != "" && manifest.Type != "package" {
return fmt.Errorf("invalid manifest type: %s (expected 'package')", manifest.Type)
}

// SECURITY: Validate manifest field formats
if err := validatePackageName(manifest.Name); err != nil {
return fmt.Errorf("invalid package name in manifest: %v", err)
}
if err := validateVersion(manifest.Version); err != nil {
return fmt.Errorf("invalid version in manifest: %v", err)
}

break
}
}
Expand Down
25 changes: 15 additions & 10 deletions go/cmd/prompd/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,15 @@ func validateFile(filename string) error {
allParams := append(prompd.Metadata.Parameters, prompd.Metadata.Variables...)
validTypes := map[string]bool{
"string": true,
"number": true,
"integer": true,
"float": true,
"boolean": true,
"array": true,
"object": true,
"json": true,
"file": true,
"base64": true,
}

for _, param := range allParams {
Expand All @@ -132,7 +135,7 @@ func validateFile(filename string) error {

// Validate parameter type
if param.Type != "" && !validTypes[param.Type] {
return fmt.Errorf("invalid parameter type '%s' for parameter '%s'. Must be one of: string, integer, float, boolean, array, object, file", param.Type, param.Name)
return fmt.Errorf("invalid parameter type '%s' for parameter '%s'. Must be one of: string, number, integer, float, boolean, array, object, json, file, base64", param.Type, param.Name)
}

// Validate pattern if present (for string types)
Expand All @@ -148,7 +151,7 @@ func validateFile(filename string) error {

// Validate min/max constraints (for numeric types)
if param.Min != nil || param.Max != nil {
if param.Type != "" && param.Type != "integer" && param.Type != "float" {
if param.Type != "" && param.Type != "integer" && param.Type != "float" && param.Type != "number" {
return fmt.Errorf("min/max constraints are only valid for numeric types, but '%s' has type '%s'", param.Name, param.Type)
}
if param.Min != nil && param.Max != nil && *param.Min > *param.Max {
Expand Down Expand Up @@ -182,10 +185,17 @@ func validateFile(filename string) error {

func validateDefaultType(paramName, paramType string, defaultValue interface{}) error {
switch paramType {
case "string":
case "string", "file", "base64":
if _, ok := defaultValue.(string); !ok {
return fmt.Errorf("default value for parameter '%s' must be a string", paramName)
}
case "number", "float":
switch defaultValue.(type) {
case float32, float64, int, int32, int64:
// Valid numeric types
default:
return fmt.Errorf("default value for parameter '%s' must be a number", paramName)
}
case "integer":
switch v := defaultValue.(type) {
case int, int32, int64:
Expand All @@ -198,13 +208,6 @@ func validateDefaultType(paramName, paramType string, defaultValue interface{})
default:
return fmt.Errorf("default value for parameter '%s' must be an integer", paramName)
}
case "float":
switch defaultValue.(type) {
case float32, float64, int, int32, int64:
// Valid numeric types
default:
return fmt.Errorf("default value for parameter '%s' must be a float", paramName)
}
case "boolean":
if _, ok := defaultValue.(bool); !ok {
return fmt.Errorf("default value for parameter '%s' must be a boolean", paramName)
Expand All @@ -220,6 +223,8 @@ func validateDefaultType(paramName, paramType string, defaultValue interface{})
if _, ok := defaultValue.(map[string]interface{}); !ok {
return fmt.Errorf("default value for parameter '%s' must be an object", paramName)
}
case "json":
// Any non-nil value is acceptable as a default for json type
}
return nil
}
Expand Down
17 changes: 12 additions & 5 deletions go/cmd/prompd/security.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,22 @@ type SecretMatch struct {

// secretPatterns defines patterns for detecting various types of secrets
var secretPatterns = map[string]*regexp.Regexp{
"OpenAI API Key": regexp.MustCompile(`sk-[a-zA-Z0-9]{48}`),
"Anthropic API Key": regexp.MustCompile(`sk-ant-api[0-9]{2}-[a-zA-Z0-9_-]{95}`),
"OpenAI API Key": regexp.MustCompile(`sk-[a-zA-Z0-9]{20,}`),
"Anthropic API Key": regexp.MustCompile(`sk-ant-[a-zA-Z0-9_-]{20,}`),
"AWS Access Key": regexp.MustCompile(`AKIA[0-9A-Z]{16}`),
"AWS Secret Key": regexp.MustCompile(`(?i)aws[_-]?secret[_-]?access[_-]?key[=:\s]+['"]?[a-zA-Z0-9/+=]{40}['"]?`),
"GitHub Token": regexp.MustCompile(`gh[ps]_[a-zA-Z0-9]{36}`),
"GitHub Fine-Grained": regexp.MustCompile(`github_pat_[a-zA-Z0-9_]{22,}`),
"Prompd Registry Token": regexp.MustCompile(`prompd_[a-zA-Z0-9]{32,}`),
"Private Key": regexp.MustCompile(`-----BEGIN (?:RSA |EC |DSA )?PRIVATE KEY-----`),
"Generic API Key": regexp.MustCompile(`(?i)api[_-]?key[_-]?[=:]\s*['"]?([a-zA-Z0-9_\-]{32,})['"]?`),
"Bearer Token": regexp.MustCompile(`[Bb]earer\s+[a-zA-Z0-9_\-\.]{32,}`),
"Private Key": regexp.MustCompile(`-----BEGIN (?:RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----`),
"Generic API Key": regexp.MustCompile(`(?i)(?:api[_-]?key|apikey|api_secret|apisecret)[_-]?[=:]\s*['"]?([a-zA-Z0-9_\-]{20,})['"]?`),
"Generic Secret": regexp.MustCompile(`(?i)(?:secret|password|passwd|token)[_-]?[=:]\s*['"]?([a-zA-Z0-9_\-!@#$%^&*]{16,})['"]?`),
"Bearer Token": regexp.MustCompile(`[Bb]earer\s+[a-zA-Z0-9_\-.]{32,256}`),
"JWT Token": regexp.MustCompile(`eyJ[a-zA-Z0-9_-]{10,}\.eyJ[a-zA-Z0-9_-]{10,}\.[a-zA-Z0-9_-]{10,}`),
"URL-Embedded Creds": regexp.MustCompile(`https?://[^:\s]+:[^@\s]+@[a-zA-Z0-9.-]+`),
"Slack Token": regexp.MustCompile(`xox[bpors]-[a-zA-Z0-9-]{10,}`),
"Google API Key": regexp.MustCompile(`AIza[0-9A-Za-z_-]{35}`),
"Stripe Key": regexp.MustCompile(`(?:sk|pk)_(?:test|live)_[a-zA-Z0-9]{20,}`),
}

// detectSecretsInContent scans content string for secrets
Expand Down
Loading
Loading