diff --git a/internal/truncate/fallback.go b/internal/truncate/fallback.go new file mode 100644 index 00000000..723d2de3 --- /dev/null +++ b/internal/truncate/fallback.go @@ -0,0 +1,41 @@ +package truncate + +import ( + "strings" +) + +// FallbackLanguage provides line-based truncation for unknown or unsupported languages. +// It treats the entire file as a single block and provides no import detection. +// This ensures graceful degradation when language-specific parsing is unavailable. +type FallbackLanguage struct{} + +// DetectImportEnd returns 0 for fallback as there's no language-specific import detection. +// The fallback treats the entire file uniformly without distinguishing imports. +func (f FallbackLanguage) DetectImportEnd(lines []string) int { + return 0 +} + +// DetectBlocks returns the entire content as a single block. +// Since we don't understand the language structure, we treat everything as one unit. +// This allows line-based truncation while maintaining the Block interface contract. +func (f FallbackLanguage) DetectBlocks(content string) []Block { + if content == "" { + return []Block{} + } + + lines := strings.Split(content, "\n") + return []Block{ + { + Type: "block", + Name: "", + StartLine: 0, + EndLine: len(lines) - 1, + }, + } +} + +// CommentSyntax returns no comment syntax for fallback. +// Truncation indicators will use plaintext format without comment markers. +func (f FallbackLanguage) CommentSyntax() (single string, multiOpen string, multiClose string) { + return "", "", "" +} diff --git a/internal/truncate/fallback_test.go b/internal/truncate/fallback_test.go new file mode 100644 index 00000000..b1e7167d --- /dev/null +++ b/internal/truncate/fallback_test.go @@ -0,0 +1,143 @@ +package truncate + +import ( + "testing" +) + +func TestFallbackLanguage_DetectImportEnd(t *testing.T) { + fb := FallbackLanguage{} + + tests := []struct { + name string + lines []string + want int + }{ + { + name: "empty file", + lines: []string{}, + want: 0, + }, + { + name: "single line", + lines: []string{"line1"}, + want: 0, + }, + { + name: "multiple lines", + lines: []string{ + "import something", + "", + "code here", + }, + want: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := fb.DetectImportEnd(tt.lines) + if got != tt.want { + t.Errorf("DetectImportEnd() = %d, want %d", got, tt.want) + } + }) + } +} + +func TestFallbackLanguage_DetectBlocks(t *testing.T) { + fb := FallbackLanguage{} + + tests := []struct { + name string + content string + want []Block + }{ + { + name: "empty content", + content: "", + want: []Block{}, + }, + { + name: "single line", + content: "single line", + want: []Block{ + { + Type: "block", + Name: "", + StartLine: 0, + EndLine: 0, + }, + }, + }, + { + name: "multiple lines", + content: "line1\nline2\nline3", + want: []Block{ + { + Type: "block", + Name: "", + StartLine: 0, + EndLine: 2, + }, + }, + }, + { + name: "content with trailing newline", + content: "line1\nline2\n", + want: []Block{ + { + Type: "block", + Name: "", + StartLine: 0, + EndLine: 2, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := fb.DetectBlocks(tt.content) + + if len(got) != len(tt.want) { + t.Errorf("DetectBlocks() returned %d blocks, want %d", len(got), len(tt.want)) + return + } + + for i := range got { + if got[i].Type != tt.want[i].Type { + t.Errorf("block[%d].Type = %q, want %q", i, got[i].Type, tt.want[i].Type) + } + if got[i].Name != tt.want[i].Name { + t.Errorf("block[%d].Name = %q, want %q", i, got[i].Name, tt.want[i].Name) + } + if got[i].StartLine != tt.want[i].StartLine { + t.Errorf("block[%d].StartLine = %d, want %d", i, got[i].StartLine, tt.want[i].StartLine) + } + if got[i].EndLine != tt.want[i].EndLine { + t.Errorf("block[%d].EndLine = %d, want %d", i, got[i].EndLine, tt.want[i].EndLine) + } + } + }) + } +} + +func TestFallbackLanguage_CommentSyntax(t *testing.T) { + fb := FallbackLanguage{} + + single, multiOpen, multiClose := fb.CommentSyntax() + + if single != "" { + t.Errorf("CommentSyntax() single = %q, want empty string", single) + } + if multiOpen != "" { + t.Errorf("CommentSyntax() multiOpen = %q, want empty string", multiOpen) + } + if multiClose != "" { + t.Errorf("CommentSyntax() multiClose = %q, want empty string", multiClose) + } +} + +func TestFallbackLanguage_ImplementsInterface(t *testing.T) { + // Compile-time check that FallbackLanguage implements Language + var _ Language = FallbackLanguage{} +} diff --git a/internal/truncate/lang_go.go b/internal/truncate/lang_go.go new file mode 100644 index 00000000..2a41ca89 --- /dev/null +++ b/internal/truncate/lang_go.go @@ -0,0 +1,308 @@ +package truncate + +import ( + "strings" +) + +// GoLanguage implements the Language interface for Go source files. +type GoLanguage struct{} + +func init() { + RegisterLanguage("go", GoLanguage{}) +} + +// CommentSyntax returns Go's comment syntax: // for single-line, /* */ for multi-line. +func (g GoLanguage) CommentSyntax() (single string, multiOpen string, multiClose string) { + return "//", "/*", "*/" +} + +// DetectImportEnd returns the line index where imports end. +// Returns the first non-import, non-comment, non-blank line after the import section. +// Returns 0 if there are no imports. +func (g GoLanguage) DetectImportEnd(lines []string) int { + inImportBlock := false + sawImport := false + lastImportLine := -1 + + for i, line := range lines { + trimmed := strings.TrimSpace(line) + + // Check for import statement + if strings.HasPrefix(trimmed, "import") { + sawImport = true + lastImportLine = i + // Check if it's a grouped import block + if strings.Contains(trimmed, "(") { + inImportBlock = true + } + continue + } + + // Check if we're in an import block (lines between import ( and )) + if inImportBlock { + lastImportLine = i + if strings.Contains(trimmed, ")") { + inImportBlock = false + } + continue + } + + // If we've seen imports, skip blank lines and comments after them + if sawImport { + if trimmed == "" || strings.HasPrefix(trimmed, "//") || strings.HasPrefix(trimmed, "/*") { + continue + } + // Hit first non-blank, non-comment, non-import line after imports + return i + } + } + + // If we only saw imports until EOF, return the line after the last import + if sawImport { + return lastImportLine + 1 + } + + return 0 +} + +// DetectBlocks returns all function, method, and type boundaries in the content. +// Uses bracket counting after detecting func/type keywords. +func (g GoLanguage) DetectBlocks(content string) []Block { + if content == "" { + return []Block{} + } + + // Strip strings and comments for accurate bracket counting + single, multiOpen, multiClose := g.CommentSyntax() + stripper := NewStripper(single, multiOpen, multiClose) + stripped, err := stripper.Strip(content) + if err != nil { + // On error, fall back to treating entire content as one block + lines := strings.Split(content, "\n") + return []Block{ + { + Type: "block", + Name: "", + StartLine: 0, + EndLine: len(lines) - 1, + }, + } + } + + lines := strings.Split(content, "\n") + strippedLines := strings.Split(stripped, "\n") + + var blocks []Block + + for i := 0; i < len(strippedLines); i++ { + trimmed := strings.TrimSpace(strippedLines[i]) + + // Detect function declarations + if strings.HasPrefix(trimmed, "func ") { + block := g.detectFunctionBlock(lines, strippedLines, i) + if block != nil { + blocks = append(blocks, *block) + i = block.EndLine // Skip to end of this block + } + continue + } + + // Detect type declarations + if strings.HasPrefix(trimmed, "type ") { + block := g.detectTypeBlock(lines, strippedLines, i) + if block != nil { + blocks = append(blocks, *block) + i = block.EndLine // Skip to end of this block + } + continue + } + } + + return blocks +} + +// detectFunctionBlock detects a function or method block starting at the given line. +// Handles both functions and methods with receivers: func (r *Type) Method() +func (g GoLanguage) detectFunctionBlock(lines []string, strippedLines []string, startLine int) *Block { + if startLine >= len(strippedLines) { + return nil + } + + // Extract function name + line := strings.TrimSpace(strippedLines[startLine]) + name := g.extractFunctionName(line) + + // Find the opening brace + braceStart := -1 + for i := startLine; i < len(strippedLines); i++ { + if strings.Contains(strippedLines[i], "{") { + braceStart = i + break + } + // If we hit a line with just a semicolon or another func, it's a declaration without body + trimmed := strings.TrimSpace(strippedLines[i]) + if strings.HasSuffix(trimmed, ";") || (i > startLine && strings.HasPrefix(trimmed, "func ")) { + return nil + } + } + + if braceStart == -1 { + // No opening brace found - might be an interface method or forward declaration + return nil + } + + // Track brace depth to find the closing brace + depth := 0 + endLine := braceStart + + for i := braceStart; i < len(strippedLines); i++ { + for _, ch := range strippedLines[i] { + if ch == '{' { + depth++ + } else if ch == '}' { + depth-- + if depth == 0 { + endLine = i + return &Block{ + Type: "function", + Name: name, + StartLine: startLine, + EndLine: endLine, + } + } + } + } + } + + // Unclosed brace - treat as extending to end of file + return &Block{ + Type: "function", + Name: name, + StartLine: startLine, + EndLine: len(strippedLines) - 1, + } +} + +// detectTypeBlock detects a type declaration starting at the given line. +// Handles struct types with braces. +func (g GoLanguage) detectTypeBlock(lines []string, strippedLines []string, startLine int) *Block { + if startLine >= len(strippedLines) { + return nil + } + + // Extract type name + line := strings.TrimSpace(strippedLines[startLine]) + name := g.extractTypeName(line) + + // Check if this is a struct with braces + if !strings.Contains(line, "struct") { + // Simple type alias - single line + return &Block{ + Type: "type", + Name: name, + StartLine: startLine, + EndLine: startLine, + } + } + + // Find the opening brace for struct + braceStart := -1 + for i := startLine; i < len(strippedLines); i++ { + if strings.Contains(strippedLines[i], "{") { + braceStart = i + break + } + // If we hit another type or func, stop + trimmed := strings.TrimSpace(strippedLines[i]) + if i > startLine && (strings.HasPrefix(trimmed, "type ") || strings.HasPrefix(trimmed, "func ")) { + return &Block{ + Type: "type", + Name: name, + StartLine: startLine, + EndLine: i - 1, + } + } + } + + if braceStart == -1 { + // No opening brace found + return &Block{ + Type: "type", + Name: name, + StartLine: startLine, + EndLine: startLine, + } + } + + // Track brace depth to find the closing brace + depth := 0 + endLine := braceStart + + for i := braceStart; i < len(strippedLines); i++ { + for _, ch := range strippedLines[i] { + if ch == '{' { + depth++ + } else if ch == '}' { + depth-- + if depth == 0 { + endLine = i + return &Block{ + Type: "type", + Name: name, + StartLine: startLine, + EndLine: endLine, + } + } + } + } + } + + // Unclosed brace - treat as extending to end of file + return &Block{ + Type: "type", + Name: name, + StartLine: startLine, + EndLine: len(strippedLines) - 1, + } +} + +// extractFunctionName extracts the function or method name from a func declaration line. +// Handles: func Name(), func (r Receiver) Name(), func (r *Receiver) Name() +func (g GoLanguage) extractFunctionName(line string) string { + // Remove "func " prefix + line = strings.TrimPrefix(strings.TrimSpace(line), "func ") + + // Check for method receiver: (r Receiver) or (r *Receiver) + if strings.HasPrefix(line, "(") { + // Find the closing parenthesis of the receiver + closeIdx := strings.Index(line, ")") + if closeIdx > 0 && closeIdx < len(line)-1 { + line = line[closeIdx+1:] + } + } + + // Now extract the function name (everything before the opening parenthesis) + line = strings.TrimSpace(line) + parenIdx := strings.Index(line, "(") + if parenIdx > 0 { + return strings.TrimSpace(line[:parenIdx]) + } + + // No parameters found - might be malformed, return what we have + return strings.Fields(line)[0] +} + +// extractTypeName extracts the type name from a type declaration line. +// Handles: type Name struct, type Name interface, type Name = OtherType +func (g GoLanguage) extractTypeName(line string) string { + // Remove "type " prefix + line = strings.TrimPrefix(strings.TrimSpace(line), "type ") + + // Extract the first word (the type name) + fields := strings.Fields(line) + if len(fields) > 0 { + return fields[0] + } + + return "" +} diff --git a/internal/truncate/lang_go_test.go b/internal/truncate/lang_go_test.go new file mode 100644 index 00000000..c8d9bdfb --- /dev/null +++ b/internal/truncate/lang_go_test.go @@ -0,0 +1,850 @@ +package truncate + +import ( + "testing" +) + +func TestGoLanguage_CommentSyntax(t *testing.T) { + g := GoLanguage{} + + single, multiOpen, multiClose := g.CommentSyntax() + + if single != "//" { + t.Errorf("CommentSyntax() single = %q, want %q", single, "//") + } + if multiOpen != "/*" { + t.Errorf("CommentSyntax() multiOpen = %q, want %q", multiOpen, "/*") + } + if multiClose != "*/" { + t.Errorf("CommentSyntax() multiClose = %q, want %q", multiClose, "*/") + } +} + +func TestGoLanguage_DetectImportEnd(t *testing.T) { + g := GoLanguage{} + + tests := []struct { + name string + lines []string + want int + }{ + { + name: "no imports", + lines: []string{"package main", "", "func main() {}"}, + want: 0, + }, + { + name: "single import", + lines: []string{ + "package main", + "", + "import \"fmt\"", + "", + "func main() {}", + }, + want: 4, + }, + { + name: "multiple single imports", + lines: []string{ + "package main", + "", + "import \"fmt\"", + "import \"os\"", + "import \"strings\"", + "", + "func main() {}", + }, + want: 6, + }, + { + name: "grouped import block", + lines: []string{ + "package main", + "", + "import (", + "\t\"fmt\"", + "\t\"os\"", + "\t\"strings\"", + ")", + "", + "func main() {}", + }, + want: 8, + }, + { + name: "grouped import with comments", + lines: []string{ + "package main", + "", + "import (", + "\t\"fmt\" // for printing", + "\t\"os\"", + "\t// strings package", + "\t\"strings\"", + ")", + "", + "func main() {}", + }, + want: 9, + }, + { + name: "import followed by const", + lines: []string{ + "package main", + "", + "import \"fmt\"", + "", + "const Version = \"1.0\"", + "", + "func main() {}", + }, + want: 4, + }, + { + name: "import followed by var", + lines: []string{ + "package main", + "", + "import \"fmt\"", + "", + "var logger = fmt.Println", + }, + want: 4, + }, + { + name: "import followed by type", + lines: []string{ + "package main", + "", + "import \"fmt\"", + "", + "type MyType struct {}", + }, + want: 4, + }, + { + name: "import at EOF", + lines: []string{ + "package main", + "", + "import \"fmt\"", + }, + want: 3, + }, + { + name: "grouped import at EOF", + lines: []string{ + "package main", + "", + "import (", + "\t\"fmt\"", + ")", + }, + want: 5, + }, + { + name: "imports with blank lines", + lines: []string{ + "package main", + "", + "import (", + "\t\"fmt\"", + "", + "\t\"os\"", + ")", + "", + "func main() {}", + }, + want: 8, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := g.DetectImportEnd(tt.lines) + if got != tt.want { + t.Errorf("DetectImportEnd() = %d, want %d", got, tt.want) + } + }) + } +} + +func TestGoLanguage_DetectBlocks_Functions(t *testing.T) { + g := GoLanguage{} + + tests := []struct { + name string + content string + want []Block + }{ + { + name: "empty content", + content: "", + want: []Block{}, + }, + { + name: "single function", + content: `package main + +func main() { + fmt.Println("hello") +}`, + want: []Block{ + { + Type: "function", + Name: "main", + StartLine: 2, + EndLine: 4, + }, + }, + }, + { + name: "multiple functions", + content: `package main + +func first() { + return +} + +func second() { + return +} + +func third() { + return +}`, + want: []Block{ + { + Type: "function", + Name: "first", + StartLine: 2, + EndLine: 4, + }, + { + Type: "function", + Name: "second", + StartLine: 6, + EndLine: 8, + }, + { + Type: "function", + Name: "third", + StartLine: 10, + EndLine: 12, + }, + }, + }, + { + name: "function with parameters", + content: `package main + +func add(a int, b int) int { + return a + b +}`, + want: []Block{ + { + Type: "function", + Name: "add", + StartLine: 2, + EndLine: 4, + }, + }, + }, + { + name: "function with return type", + content: `package main + +func getName() string { + return "test" +}`, + want: []Block{ + { + Type: "function", + Name: "getName", + StartLine: 2, + EndLine: 4, + }, + }, + }, + { + name: "nested braces in function", + content: `package main + +func process() { + if true { + for i := 0; i < 10; i++ { + fmt.Println(i) + } + } +}`, + want: []Block{ + { + Type: "function", + Name: "process", + StartLine: 2, + EndLine: 8, + }, + }, + }, + { + name: "function with string containing braces", + content: `package main + +func template() { + s := "text { with braces }" + fmt.Println(s) +}`, + want: []Block{ + { + Type: "function", + Name: "template", + StartLine: 2, + EndLine: 5, + }, + }, + }, + { + name: "function with comment containing braces", + content: `package main + +func example() { + // comment { with braces } + x := 42 +}`, + want: []Block{ + { + Type: "function", + Name: "example", + StartLine: 2, + EndLine: 5, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := g.DetectBlocks(tt.content) + + if len(got) != len(tt.want) { + t.Errorf("DetectBlocks() returned %d blocks, want %d", len(got), len(tt.want)) + return + } + + for i := range got { + if got[i].Type != tt.want[i].Type { + t.Errorf("block[%d].Type = %q, want %q", i, got[i].Type, tt.want[i].Type) + } + if got[i].Name != tt.want[i].Name { + t.Errorf("block[%d].Name = %q, want %q", i, got[i].Name, tt.want[i].Name) + } + if got[i].StartLine != tt.want[i].StartLine { + t.Errorf("block[%d].StartLine = %d, want %d", i, got[i].StartLine, tt.want[i].StartLine) + } + if got[i].EndLine != tt.want[i].EndLine { + t.Errorf("block[%d].EndLine = %d, want %d", i, got[i].EndLine, tt.want[i].EndLine) + } + } + }) + } +} + +func TestGoLanguage_DetectBlocks_Methods(t *testing.T) { + g := GoLanguage{} + + tests := []struct { + name string + content string + want []Block + }{ + { + name: "method with value receiver", + content: `package main + +func (s Server) Start() { + fmt.Println("starting") +}`, + want: []Block{ + { + Type: "function", + Name: "Start", + StartLine: 2, + EndLine: 4, + }, + }, + }, + { + name: "method with pointer receiver", + content: `package main + +func (s *Server) Stop() { + fmt.Println("stopping") +}`, + want: []Block{ + { + Type: "function", + Name: "Stop", + StartLine: 2, + EndLine: 4, + }, + }, + }, + { + name: "multiple methods on same type", + content: `package main + +func (s *Server) Start() { + return +} + +func (s *Server) Stop() { + return +} + +func (s *Server) Restart() { + s.Stop() + s.Start() +}`, + want: []Block{ + { + Type: "function", + Name: "Start", + StartLine: 2, + EndLine: 4, + }, + { + Type: "function", + Name: "Stop", + StartLine: 6, + EndLine: 8, + }, + { + Type: "function", + Name: "Restart", + StartLine: 10, + EndLine: 13, + }, + }, + }, + { + name: "methods and functions mixed", + content: `package main + +func NewServer() *Server { + return &Server{} +} + +func (s *Server) Start() { + fmt.Println("starting") +} + +func main() { + s := NewServer() + s.Start() +}`, + want: []Block{ + { + Type: "function", + Name: "NewServer", + StartLine: 2, + EndLine: 4, + }, + { + Type: "function", + Name: "Start", + StartLine: 6, + EndLine: 8, + }, + { + Type: "function", + Name: "main", + StartLine: 10, + EndLine: 13, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := g.DetectBlocks(tt.content) + + if len(got) != len(tt.want) { + t.Errorf("DetectBlocks() returned %d blocks, want %d", len(got), len(tt.want)) + return + } + + for i := range got { + if got[i].Type != tt.want[i].Type { + t.Errorf("block[%d].Type = %q, want %q", i, got[i].Type, tt.want[i].Type) + } + if got[i].Name != tt.want[i].Name { + t.Errorf("block[%d].Name = %q, want %q", i, got[i].Name, tt.want[i].Name) + } + if got[i].StartLine != tt.want[i].StartLine { + t.Errorf("block[%d].StartLine = %d, want %d", i, got[i].StartLine, tt.want[i].StartLine) + } + if got[i].EndLine != tt.want[i].EndLine { + t.Errorf("block[%d].EndLine = %d, want %d", i, got[i].EndLine, tt.want[i].EndLine) + } + } + }) + } +} + +func TestGoLanguage_DetectBlocks_Types(t *testing.T) { + g := GoLanguage{} + + tests := []struct { + name string + content string + want []Block + }{ + { + name: "simple type alias", + content: `package main + +type MyInt int`, + want: []Block{ + { + Type: "type", + Name: "MyInt", + StartLine: 2, + EndLine: 2, + }, + }, + }, + { + name: "struct type", + content: `package main + +type Server struct { + Port int + Host string +}`, + want: []Block{ + { + Type: "type", + Name: "Server", + StartLine: 2, + EndLine: 5, + }, + }, + }, + { + name: "interface type", + content: `package main + +type Handler interface { + Handle() error + Stop() +}`, + want: []Block{ + { + Type: "type", + Name: "Handler", + StartLine: 2, + EndLine: 2, + }, + }, + }, + { + name: "multiple types and functions", + content: `package main + +type Config struct { + Port int +} + +func NewConfig() *Config { + return &Config{Port: 8080} +} + +type Server struct { + config *Config +} + +func (s *Server) Start() { + fmt.Printf("Starting on port %d\n", s.config.Port) +}`, + want: []Block{ + { + Type: "type", + Name: "Config", + StartLine: 2, + EndLine: 4, + }, + { + Type: "function", + Name: "NewConfig", + StartLine: 6, + EndLine: 8, + }, + { + Type: "type", + Name: "Server", + StartLine: 10, + EndLine: 12, + }, + { + Type: "function", + Name: "Start", + StartLine: 14, + EndLine: 16, + }, + }, + }, + { + name: "nested struct", + content: `package main + +type Outer struct { + Inner struct { + Value int + } + Name string +}`, + want: []Block{ + { + Type: "type", + Name: "Outer", + StartLine: 2, + EndLine: 7, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := g.DetectBlocks(tt.content) + + if len(got) != len(tt.want) { + t.Errorf("DetectBlocks() returned %d blocks, want %d", len(got), len(tt.want)) + for i, b := range got { + t.Logf(" got[%d]: %+v", i, b) + } + return + } + + for i := range got { + if got[i].Type != tt.want[i].Type { + t.Errorf("block[%d].Type = %q, want %q", i, got[i].Type, tt.want[i].Type) + } + if got[i].Name != tt.want[i].Name { + t.Errorf("block[%d].Name = %q, want %q", i, got[i].Name, tt.want[i].Name) + } + if got[i].StartLine != tt.want[i].StartLine { + t.Errorf("block[%d].StartLine = %d, want %d", i, got[i].StartLine, tt.want[i].StartLine) + } + if got[i].EndLine != tt.want[i].EndLine { + t.Errorf("block[%d].EndLine = %d, want %d", i, got[i].EndLine, tt.want[i].EndLine) + } + } + }) + } +} + +func TestGoLanguage_DetectBlocks_RealWorldCode(t *testing.T) { + g := GoLanguage{} + + content := `package main + +import ( + "fmt" + "os" +) + +const Version = "1.0.0" + +type Config struct { + Host string + Port int +} + +func NewConfig() *Config { + return &Config{ + Host: "localhost", + Port: 8080, + } +} + +type Server struct { + config *Config +} + +func (s *Server) Start() error { + addr := fmt.Sprintf("%s:%d", s.config.Host, s.config.Port) + fmt.Printf("Server starting on %s\n", addr) + return nil +} + +func (s *Server) Stop() error { + fmt.Println("Server stopping") + return nil +} + +func main() { + config := NewConfig() + server := &Server{config: config} + if err := server.Start(); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } +}` + + want := []Block{ + { + Type: "type", + Name: "Config", + StartLine: 9, + EndLine: 12, + }, + { + Type: "function", + Name: "NewConfig", + StartLine: 14, + EndLine: 19, + }, + { + Type: "type", + Name: "Server", + StartLine: 21, + EndLine: 23, + }, + { + Type: "function", + Name: "Start", + StartLine: 25, + EndLine: 29, + }, + { + Type: "function", + Name: "Stop", + StartLine: 31, + EndLine: 34, + }, + { + Type: "function", + Name: "main", + StartLine: 36, + EndLine: 43, + }, + } + + got := g.DetectBlocks(content) + + if len(got) != len(want) { + t.Errorf("DetectBlocks() returned %d blocks, want %d", len(got), len(want)) + for i, b := range got { + t.Logf(" got[%d]: %+v", i, b) + } + return + } + + for i := range got { + if got[i].Type != want[i].Type { + t.Errorf("block[%d].Type = %q, want %q", i, got[i].Type, want[i].Type) + } + if got[i].Name != want[i].Name { + t.Errorf("block[%d].Name = %q, want %q", i, got[i].Name, want[i].Name) + } + if got[i].StartLine != want[i].StartLine { + t.Errorf("block[%d].StartLine = %d, want %d", i, got[i].StartLine, want[i].StartLine) + } + if got[i].EndLine != want[i].EndLine { + t.Errorf("block[%d].EndLine = %d, want %d", i, got[i].EndLine, want[i].EndLine) + } + } +} + +func TestGoLanguage_ImplementsInterface(t *testing.T) { + // Compile-time check that GoLanguage implements Language + var _ Language = GoLanguage{} +} + +func TestGoLanguage_ExtractFunctionName(t *testing.T) { + g := GoLanguage{} + + tests := []struct { + name string + input string + want string + }{ + { + name: "simple function", + input: "func main() {", + want: "main", + }, + { + name: "function with parameters", + input: "func add(a int, b int) int {", + want: "add", + }, + { + name: "method with value receiver", + input: "func (s Server) Start() {", + want: "Start", + }, + { + name: "method with pointer receiver", + input: "func (s *Server) Stop() error {", + want: "Stop", + }, + { + name: "method with complex receiver", + input: "func (h *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {", + want: "ServeHTTP", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := g.extractFunctionName(tt.input) + if got != tt.want { + t.Errorf("extractFunctionName(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestGoLanguage_ExtractTypeName(t *testing.T) { + g := GoLanguage{} + + tests := []struct { + name string + input string + want string + }{ + { + name: "simple type alias", + input: "type MyInt int", + want: "MyInt", + }, + { + name: "struct type", + input: "type Server struct {", + want: "Server", + }, + { + name: "interface type", + input: "type Handler interface {", + want: "Handler", + }, + { + name: "type alias with equals", + input: "type StringAlias = string", + want: "StringAlias", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := g.extractTypeName(tt.input) + if got != tt.want { + t.Errorf("extractTypeName(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} diff --git a/internal/truncate/lang_javascript.go b/internal/truncate/lang_javascript.go new file mode 100644 index 00000000..f03a7da5 --- /dev/null +++ b/internal/truncate/lang_javascript.go @@ -0,0 +1,64 @@ +package truncate + +import ( + "strings" +) + +// JavaScriptLanguage implements Language interface for JavaScript files. +// JavaScript is essentially TypeScript without type annotations and interfaces. +// This implementation wraps TypeScriptLanguage and filters out TypeScript-specific constructs. +type JavaScriptLanguage struct { + ts TypeScriptLanguage +} + +func init() { + RegisterLanguage("javascript", JavaScriptLanguage{}) +} + +// CommentSyntax returns JavaScript's comment syntax. +// JavaScript uses the same comment syntax as TypeScript. +func (js JavaScriptLanguage) CommentSyntax() (single string, multiOpen string, multiClose string) { + return js.ts.CommentSyntax() +} + +// DetectImportEnd returns the line index where the import section ends. +// JavaScript uses the same import/export syntax as TypeScript (ES6 modules). +func (js JavaScriptLanguage) DetectImportEnd(lines []string) int { + return js.ts.DetectImportEnd(lines) +} + +// DetectBlocks identifies function and class boundaries in JavaScript. +// Filters out TypeScript-specific constructs (interface, type) from the TypeScript parser. +func (js JavaScriptLanguage) DetectBlocks(content string) []Block { + // Use TypeScript parser to detect all blocks + blocks := js.ts.DetectBlocks(content) + + // Filter out TypeScript-specific block types + var jsBlocks []Block + for _, block := range blocks { + // Exclude interface and type blocks (TypeScript-only) + if block.Type != "interface" && block.Type != "type" { + jsBlocks = append(jsBlocks, block) + } + } + + return jsBlocks +} + +// isTypeScriptOnlyLine checks if a line contains TypeScript-specific syntax. +// Used to filter out type annotations and interfaces when parsing as JavaScript. +func (js JavaScriptLanguage) isTypeScriptOnlyLine(line string) bool { + trimmed := strings.TrimSpace(line) + + // Check for interface declarations + if strings.HasPrefix(trimmed, "interface ") || strings.Contains(trimmed, " interface ") { + return true + } + + // Check for type alias declarations (but not typeof operator) + if strings.HasPrefix(trimmed, "type ") && !strings.Contains(trimmed, "typeof") { + return true + } + + return false +} diff --git a/internal/truncate/lang_javascript_test.go b/internal/truncate/lang_javascript_test.go new file mode 100644 index 00000000..25883a0f --- /dev/null +++ b/internal/truncate/lang_javascript_test.go @@ -0,0 +1,403 @@ +package truncate + +import ( + "testing" +) + +func TestJavaScriptLanguage_CommentSyntax(t *testing.T) { + js := JavaScriptLanguage{} + + single, multiOpen, multiClose := js.CommentSyntax() + + if single != "//" { + t.Errorf("CommentSyntax() single = %q, want %q", single, "//") + } + if multiOpen != "/*" { + t.Errorf("CommentSyntax() multiOpen = %q, want %q", multiOpen, "/*") + } + if multiClose != "*/" { + t.Errorf("CommentSyntax() multiClose = %q, want %q", multiClose, "*/") + } +} + +func TestJavaScriptLanguage_DetectImportEnd(t *testing.T) { + js := JavaScriptLanguage{} + + tests := []struct { + name string + lines []string + want int + }{ + { + name: "no imports", + lines: []string{"const x = 1;", "function foo() {}"}, + want: 0, + }, + { + name: "single import", + lines: []string{ + "import { foo } from 'bar';", + "", + "const x = 1;", + }, + want: 1, + }, + { + name: "multiple imports", + lines: []string{ + "import { foo } from 'bar';", + "import { baz } from 'qux';", + "", + "const x = 1;", + }, + want: 2, + }, + { + name: "export statements", + lines: []string{ + "export { foo } from 'bar';", + "export const x = 1;", + "", + "function doStuff() {}", + }, + want: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := js.DetectImportEnd(tt.lines) + if got != tt.want { + t.Errorf("DetectImportEnd() = %d, want %d", got, tt.want) + } + }) + } +} + +func TestJavaScriptLanguage_DetectBlocks_Classes(t *testing.T) { + js := JavaScriptLanguage{} + + tests := []struct { + name string + content string + want []Block + }{ + { + name: "simple class", + content: `class MyClass { + constructor() {} + method() {} +}`, + want: []Block{ + {Type: "class", Name: "MyClass", StartLine: 0, EndLine: 3}, + }, + }, + { + name: "class with export", + content: `export class MyClass { + method() {} +}`, + want: []Block{ + {Type: "class", Name: "MyClass", StartLine: 0, EndLine: 2}, + }, + }, + { + name: "multiple classes", + content: `class First { + method1() {} +} + +class Second { + method2() {} +}`, + want: []Block{ + {Type: "class", Name: "First", StartLine: 0, EndLine: 2}, + {Type: "class", Name: "Second", StartLine: 4, EndLine: 6}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := js.DetectBlocks(tt.content) + if !compareBlocks(got, tt.want) { + t.Errorf("DetectBlocks() =\n%+v\nwant\n%+v", got, tt.want) + } + }) + } +} + +func TestJavaScriptLanguage_DetectBlocks_Functions(t *testing.T) { + js := JavaScriptLanguage{} + + tests := []struct { + name string + content string + want []Block + }{ + { + name: "simple function", + content: `function myFunc() { + return 42; +}`, + want: []Block{ + {Type: "function", Name: "myFunc", StartLine: 0, EndLine: 2}, + }, + }, + { + name: "async function", + content: `async function fetchData() { + return await fetch('/api'); +}`, + want: []Block{ + {Type: "function", Name: "fetchData", StartLine: 0, EndLine: 2}, + }, + }, + { + name: "multiple functions", + content: `function first() { + return 1; +} + +function second() { + return 2; +}`, + want: []Block{ + {Type: "function", Name: "first", StartLine: 0, EndLine: 2}, + {Type: "function", Name: "second", StartLine: 4, EndLine: 6}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := js.DetectBlocks(tt.content) + if !compareBlocks(got, tt.want) { + t.Errorf("DetectBlocks() =\n%+v\nwant\n%+v", got, tt.want) + } + }) + } +} + +func TestJavaScriptLanguage_DetectBlocks_ArrowFunctions(t *testing.T) { + js := JavaScriptLanguage{} + + tests := []struct { + name string + content string + want []Block + }{ + { + name: "simple arrow function expression", + content: `const add = (a, b) => a + b;`, + want: []Block{ + {Type: "function", Name: "add", StartLine: 0, EndLine: 0}, + }, + }, + { + name: "arrow function with block body", + content: `const calculate = (x) => { + const result = x * 2; + return result; +};`, + want: []Block{ + {Type: "function", Name: "calculate", StartLine: 0, EndLine: 3}, + }, + }, + { + name: "exported arrow function", + content: `export const handler = async (req) => { + return { status: 200 }; +};`, + want: []Block{ + {Type: "function", Name: "handler", StartLine: 0, EndLine: 2}, + }, + }, + { + name: "multiple arrow functions", + content: `const first = () => 1; +const second = () => { + return 2; +};`, + want: []Block{ + {Type: "function", Name: "first", StartLine: 0, EndLine: 0}, + {Type: "function", Name: "second", StartLine: 1, EndLine: 3}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := js.DetectBlocks(tt.content) + if !compareBlocks(got, tt.want) { + t.Errorf("DetectBlocks() =\n%+v\nwant\n%+v", got, tt.want) + } + }) + } +} + +func TestJavaScriptLanguage_DetectBlocks_NoTypeScriptConstructs(t *testing.T) { + js := JavaScriptLanguage{} + + tests := []struct { + name string + content string + want []Block + }{ + { + name: "should filter out interfaces", + content: `interface User { + name: string; +} + +class MyClass { + method() {} +}`, + want: []Block{ + {Type: "class", Name: "MyClass", StartLine: 4, EndLine: 6}, + }, + }, + { + name: "should filter out type aliases", + content: `type ID = string; + +function helper() { + return 42; +}`, + want: []Block{ + {Type: "function", Name: "helper", StartLine: 2, EndLine: 4}, + }, + }, + { + name: "should keep only JavaScript constructs", + content: `interface Config { + port: number; +} + +type Status = "active" | "inactive"; + +class Service { + start() {} +} + +function init() { + return true; +} + +const process = () => { + return null; +};`, + want: []Block{ + {Type: "class", Name: "Service", StartLine: 6, EndLine: 8}, + {Type: "function", Name: "init", StartLine: 10, EndLine: 12}, + {Type: "function", Name: "process", StartLine: 14, EndLine: 16}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := js.DetectBlocks(tt.content) + if !compareBlocks(got, tt.want) { + t.Errorf("DetectBlocks() =\n%+v\nwant\n%+v", got, tt.want) + } + }) + } +} + +func TestJavaScriptLanguage_DetectBlocks_Mixed(t *testing.T) { + js := JavaScriptLanguage{} + + content := `import { foo } from 'bar'; + +class UserService { + getUser(id) { + return { name: "test" }; + } +} + +function helper() { + return 42; +} + +const process = (data) => { + return data; +};` + + got := js.DetectBlocks(content) + + // Should detect class, function, and arrow function (no interface or type) + if len(got) != 3 { + t.Errorf("DetectBlocks() found %d blocks, want 3", len(got)) + } + + // Check that we got the right types + expectedTypes := []string{"class", "function", "function"} + for i, block := range got { + if block.Type != expectedTypes[i] { + t.Errorf("Block %d: got type %s, want %s", i, block.Type, expectedTypes[i]) + } + } +} + +func TestJavaScriptLanguage_DetectBlocks_StringsAndComments(t *testing.T) { + js := JavaScriptLanguage{} + + tests := []struct { + name string + content string + want []Block + }{ + { + name: "braces in strings should not affect detection", + content: `function test() { + const str = "this { has } braces"; + return str; +}`, + want: []Block{ + {Type: "function", Name: "test", StartLine: 0, EndLine: 3}, + }, + }, + { + name: "braces in comments should not affect detection", + content: `function test() { + // This comment has { braces } + /* And this one too { } */ + return 42; +}`, + want: []Block{ + {Type: "function", Name: "test", StartLine: 0, EndLine: 4}, + }, + }, + { + name: "template literals with braces", + content: "function test() {\n const str = `value: ${x}`;\n return str;\n}", + want: []Block{ + {Type: "function", Name: "test", StartLine: 0, EndLine: 3}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := js.DetectBlocks(tt.content) + if !compareBlocks(got, tt.want) { + t.Errorf("DetectBlocks() =\n%+v\nwant\n%+v", got, tt.want) + } + }) + } +} + +func TestJavaScriptLanguage_DetectBlocks_EmptyContent(t *testing.T) { + js := JavaScriptLanguage{} + + got := js.DetectBlocks("") + if len(got) != 0 { + t.Errorf("DetectBlocks(\"\") = %v, want empty slice", got) + } +} + +func TestJavaScriptLanguage_ImplementsInterface(t *testing.T) { + // Compile-time check that JavaScriptLanguage implements Language + var _ Language = JavaScriptLanguage{} +} diff --git a/internal/truncate/lang_python.go b/internal/truncate/lang_python.go new file mode 100644 index 00000000..0f377438 --- /dev/null +++ b/internal/truncate/lang_python.go @@ -0,0 +1,390 @@ +package truncate + +import ( + "strings" +) + +// PythonLanguage implements the Language interface for Python source files. +type PythonLanguage struct{} + +func init() { + RegisterLanguage("python", PythonLanguage{}) +} + +// CommentSyntax returns Python's comment syntax: # for single-line, """ and ''' for multi-line. +func (p PythonLanguage) CommentSyntax() (single string, multiOpen string, multiClose string) { + return "#", `"""`, `"""` +} + +// DetectImportEnd returns the line index where imports end. +// Returns the first non-import, non-comment, non-blank line after the import section. +// Returns 0 if there are no imports. +func (p PythonLanguage) DetectImportEnd(lines []string) int { + sawImport := false + lastImportLine := -1 + inMultilineImport := false + usingParens := false + + for i, line := range lines { + trimmed := strings.TrimSpace(line) + + // Skip empty lines and comments before we've seen imports + if !sawImport && (trimmed == "" || strings.HasPrefix(trimmed, "#")) { + continue + } + + // Check for import statements + if strings.HasPrefix(trimmed, "import ") || strings.HasPrefix(trimmed, "from ") { + sawImport = true + lastImportLine = i + + // Check if it's a multiline import (ends with backslash or has opening paren) + if strings.Contains(trimmed, "(") && !strings.Contains(trimmed, ")") { + inMultilineImport = true + usingParens = true + } else if strings.HasSuffix(trimmed, "\\") { + inMultilineImport = true + usingParens = false + } + continue + } + + // Handle continuation of multiline imports + if inMultilineImport { + lastImportLine = i + // Check if the multiline import ends + if usingParens { + if strings.Contains(trimmed, ")") { + inMultilineImport = false + } + } else { + // Backslash continuation - ends when no backslash at end + if !strings.HasSuffix(trimmed, "\\") { + inMultilineImport = false + } + } + continue + } + + // After seeing imports, skip blank lines and comments + if sawImport { + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + continue + } + // Hit first non-blank, non-comment, non-import line + return i + } + } + + // If we only saw imports until EOF, return the line after the last import + if sawImport { + return lastImportLine + 1 + } + + return 0 +} + +// DetectBlocks returns all function and class boundaries in the content. +// Uses indentation tracking for Python's block structure. +func (p PythonLanguage) DetectBlocks(content string) []Block { + if content == "" { + return []Block{} + } + + // Strip strings and comments for accurate detection + single, multiOpen, multiClose := p.CommentSyntax() + stripper := NewStripper(single, multiOpen, multiClose) + stripped, err := stripper.Strip(content) + if err != nil { + // On error, fall back to treating entire content as one block + lines := strings.Split(content, "\n") + return []Block{ + { + Type: "block", + Name: "", + StartLine: 0, + EndLine: len(lines) - 1, + }, + } + } + + lines := strings.Split(content, "\n") + strippedLines := strings.Split(stripped, "\n") + + var blocks []Block + i := 0 + + for i < len(strippedLines) { + trimmed := strings.TrimSpace(strippedLines[i]) + + // Check for decorators (@decorator) - they're part of the function/class block + if strings.HasPrefix(trimmed, "@") { + decoratorStart := i + // Find the def/class that follows the decorator(s) + i++ + for i < len(strippedLines) { + nextTrimmed := strings.TrimSpace(strippedLines[i]) + if nextTrimmed == "" || strings.HasPrefix(nextTrimmed, "@") { + i++ + continue + } + // Check if this is a def or class + if strings.HasPrefix(nextTrimmed, "def ") || strings.HasPrefix(nextTrimmed, "async def ") { + block := p.detectFunctionBlock(lines, strippedLines, decoratorStart) + if block != nil { + blocks = append(blocks, *block) + i = block.EndLine + 1 + } + break + } + if strings.HasPrefix(nextTrimmed, "class ") { + block := p.detectClassBlock(lines, strippedLines, decoratorStart) + if block != nil { + blocks = append(blocks, *block) + i = block.EndLine + 1 + } + break + } + // Not a def or class after decorator, skip the decorator + i = decoratorStart + 1 + break + } + continue + } + + // Detect function definitions (including async) + if strings.HasPrefix(trimmed, "def ") || strings.HasPrefix(trimmed, "async def ") { + block := p.detectFunctionBlock(lines, strippedLines, i) + if block != nil { + blocks = append(blocks, *block) + i = block.EndLine + 1 + continue + } + } + + // Detect class definitions + if strings.HasPrefix(trimmed, "class ") { + block := p.detectClassBlock(lines, strippedLines, i) + if block != nil { + blocks = append(blocks, *block) + i = block.EndLine + 1 + continue + } + } + + i++ + } + + return blocks +} + +// detectFunctionBlock detects a function or method block starting at or before the given line. +// The startLine may point to a decorator; we'll find the actual def line. +func (p PythonLanguage) detectFunctionBlock(lines []string, strippedLines []string, startLine int) *Block { + if startLine >= len(strippedLines) { + return nil + } + + // Find the actual def line (might be after decorators) + defLine := startLine + for defLine < len(strippedLines) { + trimmed := strings.TrimSpace(strippedLines[defLine]) + if strings.HasPrefix(trimmed, "def ") || strings.HasPrefix(trimmed, "async def ") { + break + } + if trimmed != "" && !strings.HasPrefix(trimmed, "@") { + // Not a decorator or def, bail + return nil + } + defLine++ + } + + if defLine >= len(strippedLines) { + return nil + } + + // Extract function name + line := strings.TrimSpace(strippedLines[defLine]) + name := p.extractFunctionName(line) + + // Get base indentation of the def line + baseIndent := p.getIndentation(strippedLines[defLine]) + + // Find the end of the function by tracking indentation + endLine := defLine + foundBody := false + + for i := defLine + 1; i < len(strippedLines); i++ { + trimmed := strings.TrimSpace(strippedLines[i]) + + // Skip empty lines and comments (don't update endLine - might be trailing) + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + continue + } + + indent := p.getIndentation(strippedLines[i]) + + // If indentation is greater than base, we're inside the function + if indent > baseIndent { + foundBody = true + endLine = i + continue + } + + // If indentation is <= base and we've seen the body, the function ends + if foundBody && indent <= baseIndent { + break + } + + // If we haven't seen a body yet and hit same/lower indent, might be empty function + if !foundBody && indent <= baseIndent { + break + } + } + + return &Block{ + Type: "function", + Name: name, + StartLine: startLine, + EndLine: endLine, + } +} + +// detectClassBlock detects a class block starting at or before the given line. +// The startLine may point to a decorator; we'll find the actual class line. +func (p PythonLanguage) detectClassBlock(lines []string, strippedLines []string, startLine int) *Block { + if startLine >= len(strippedLines) { + return nil + } + + // Find the actual class line (might be after decorators) + classLine := startLine + for classLine < len(strippedLines) { + trimmed := strings.TrimSpace(strippedLines[classLine]) + if strings.HasPrefix(trimmed, "class ") { + break + } + if trimmed != "" && !strings.HasPrefix(trimmed, "@") { + // Not a decorator or class, bail + return nil + } + classLine++ + } + + if classLine >= len(strippedLines) { + return nil + } + + // Extract class name + line := strings.TrimSpace(strippedLines[classLine]) + name := p.extractClassName(line) + + // Get base indentation of the class line + baseIndent := p.getIndentation(strippedLines[classLine]) + + // Find the end of the class by tracking indentation + endLine := classLine + foundBody := false + + for i := classLine + 1; i < len(strippedLines); i++ { + trimmed := strings.TrimSpace(strippedLines[i]) + + // Skip empty lines and comments (don't update endLine - might be trailing) + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + continue + } + + indent := p.getIndentation(strippedLines[i]) + + // If indentation is greater than base, we're inside the class + if indent > baseIndent { + foundBody = true + endLine = i + continue + } + + // If indentation is <= base and we've seen the body, the class ends + if foundBody && indent <= baseIndent { + break + } + + // If we haven't seen a body yet and hit same/lower indent, might be empty class + if !foundBody && indent <= baseIndent { + break + } + } + + return &Block{ + Type: "class", + Name: name, + StartLine: startLine, + EndLine: endLine, + } +} + +// getIndentation returns the number of leading spaces/tabs in a line. +// Tabs are counted as 4 spaces for consistency. +func (p PythonLanguage) getIndentation(line string) int { + indent := 0 + for _, ch := range line { + if ch == ' ' { + indent++ + } else if ch == '\t' { + indent += 4 + } else { + break + } + } + return indent +} + +// extractFunctionName extracts the function name from a def line. +// Handles: def name(), async def name() +func (p PythonLanguage) extractFunctionName(line string) string { + // Remove "async def " or "def " prefix + line = strings.TrimPrefix(strings.TrimSpace(line), "async ") + line = strings.TrimPrefix(strings.TrimSpace(line), "def ") + + // Extract name (everything before the opening parenthesis) + parenIdx := strings.Index(line, "(") + if parenIdx > 0 { + return strings.TrimSpace(line[:parenIdx]) + } + + // No parameters found - might be malformed, return what we have + fields := strings.Fields(line) + if len(fields) > 0 { + return fields[0] + } + + return "" +} + +// extractClassName extracts the class name from a class line. +// Handles: class Name:, class Name(Base):, class Name(Base1, Base2): +func (p PythonLanguage) extractClassName(line string) string { + // Remove "class " prefix + line = strings.TrimPrefix(strings.TrimSpace(line), "class ") + + // Extract name (everything before : or ( ) + colonIdx := strings.Index(line, ":") + parenIdx := strings.Index(line, "(") + + if parenIdx > 0 && (colonIdx == -1 || parenIdx < colonIdx) { + // Has base classes + return strings.TrimSpace(line[:parenIdx]) + } + + if colonIdx > 0 { + // No base classes, just name: + return strings.TrimSpace(line[:colonIdx]) + } + + // No : or ( found - might be malformed, return what we have + fields := strings.Fields(line) + if len(fields) > 0 { + return fields[0] + } + + return "" +} diff --git a/internal/truncate/lang_python_test.go b/internal/truncate/lang_python_test.go new file mode 100644 index 00000000..ca51cd4b --- /dev/null +++ b/internal/truncate/lang_python_test.go @@ -0,0 +1,856 @@ +package truncate + +import ( + "testing" +) + +func TestPythonLanguage_CommentSyntax(t *testing.T) { + p := PythonLanguage{} + + single, multiOpen, multiClose := p.CommentSyntax() + + if single != "#" { + t.Errorf("CommentSyntax() single = %q, want %q", single, "#") + } + if multiOpen != `"""` { + t.Errorf("CommentSyntax() multiOpen = %q, want %q", multiOpen, `"""`) + } + if multiClose != `"""` { + t.Errorf("CommentSyntax() multiClose = %q, want %q", multiClose, `"""`) + } +} + +func TestPythonLanguage_DetectImportEnd(t *testing.T) { + p := PythonLanguage{} + + tests := []struct { + name string + lines []string + want int + }{ + { + name: "no imports", + lines: []string{"# comment", "", "def main():", " pass"}, + want: 0, + }, + { + name: "single import", + lines: []string{ + "import os", + "", + "def main():", + " pass", + }, + want: 2, + }, + { + name: "multiple imports", + lines: []string{ + "import os", + "import sys", + "import json", + "", + "def main():", + " pass", + }, + want: 4, + }, + { + name: "from imports", + lines: []string{ + "from os import path", + "from typing import Dict, List", + "", + "def main():", + " pass", + }, + want: 3, + }, + { + name: "mixed import and from", + lines: []string{ + "import os", + "from typing import Dict", + "import sys", + "", + "class Handler:", + " pass", + }, + want: 4, + }, + { + name: "multiline import with backslash", + lines: []string{ + "from package import \\", + " module1, \\", + " module2", + "", + "def main():", + " pass", + }, + want: 4, + }, + { + name: "multiline import with parentheses", + lines: []string{ + "from package import (", + " module1,", + " module2,", + ")", + "", + "def main():", + " pass", + }, + want: 5, + }, + { + name: "import with comment after", + lines: []string{ + "import os", + "import sys # system stuff", + "# This is a comment", + "", + "def main():", + " pass", + }, + want: 4, + }, + { + name: "import at EOF", + lines: []string{ + "import os", + "import sys", + }, + want: 2, + }, + { + name: "docstring before import", + lines: []string{ + `"""Module docstring"""`, + "", + "import os", + "", + "def main():", + " pass", + }, + want: 4, + }, + { + name: "comment before import", + lines: []string{ + "# File header", + "# Copyright notice", + "", + "import os", + "from typing import List", + "", + "def main():", + " pass", + }, + want: 6, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := p.DetectImportEnd(tt.lines) + if got != tt.want { + t.Errorf("DetectImportEnd() = %d, want %d", got, tt.want) + } + }) + } +} + +func TestPythonLanguage_DetectBlocks_Functions(t *testing.T) { + p := PythonLanguage{} + + tests := []struct { + name string + content string + want []Block + }{ + { + name: "empty content", + content: "", + want: []Block{}, + }, + { + name: "single function", + content: `def main(): + print("hello") + return`, + want: []Block{ + { + Type: "function", + Name: "main", + StartLine: 0, + EndLine: 2, + }, + }, + }, + { + name: "multiple functions", + content: `def first(): + return 1 + +def second(): + return 2 + +def third(): + return 3`, + want: []Block{ + { + Type: "function", + Name: "first", + StartLine: 0, + EndLine: 1, + }, + { + Type: "function", + Name: "second", + StartLine: 3, + EndLine: 4, + }, + { + Type: "function", + Name: "third", + StartLine: 6, + EndLine: 7, + }, + }, + }, + { + name: "function with parameters", + content: `def add(a, b): + return a + b`, + want: []Block{ + { + Type: "function", + Name: "add", + StartLine: 0, + EndLine: 1, + }, + }, + }, + { + name: "async function", + content: `async def fetch_data(): + await some_call() + return data`, + want: []Block{ + { + Type: "function", + Name: "fetch_data", + StartLine: 0, + EndLine: 2, + }, + }, + }, + { + name: "decorated function", + content: `@decorator +def process(): + return "processed"`, + want: []Block{ + { + Type: "function", + Name: "process", + StartLine: 0, + EndLine: 2, + }, + }, + }, + { + name: "multiple decorators", + content: `@decorator1 +@decorator2 +@decorator3 +def complex_function(): + return "result"`, + want: []Block{ + { + Type: "function", + Name: "complex_function", + StartLine: 0, + EndLine: 4, + }, + }, + }, + { + name: "decorated async function", + content: `@app.route('/api') +async def handler(): + return {"status": "ok"}`, + want: []Block{ + { + Type: "function", + Name: "handler", + StartLine: 0, + EndLine: 2, + }, + }, + }, + { + name: "nested function", + content: `def outer(): + def inner(): + return 1 + return inner()`, + want: []Block{ + { + Type: "function", + Name: "outer", + StartLine: 0, + EndLine: 3, + }, + }, + }, + { + name: "function with nested structures", + content: `def process(): + if True: + for i in range(10): + print(i) + return`, + want: []Block{ + { + Type: "function", + Name: "process", + StartLine: 0, + EndLine: 4, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := p.DetectBlocks(tt.content) + + if len(got) != len(tt.want) { + t.Errorf("DetectBlocks() returned %d blocks, want %d", len(got), len(tt.want)) + for i, b := range got { + t.Logf(" got[%d]: %+v", i, b) + } + return + } + + for i := range got { + if got[i].Type != tt.want[i].Type { + t.Errorf("block[%d].Type = %q, want %q", i, got[i].Type, tt.want[i].Type) + } + if got[i].Name != tt.want[i].Name { + t.Errorf("block[%d].Name = %q, want %q", i, got[i].Name, tt.want[i].Name) + } + if got[i].StartLine != tt.want[i].StartLine { + t.Errorf("block[%d].StartLine = %d, want %d", i, got[i].StartLine, tt.want[i].StartLine) + } + if got[i].EndLine != tt.want[i].EndLine { + t.Errorf("block[%d].EndLine = %d, want %d", i, got[i].EndLine, tt.want[i].EndLine) + } + } + }) + } +} + +func TestPythonLanguage_DetectBlocks_Classes(t *testing.T) { + p := PythonLanguage{} + + tests := []struct { + name string + content string + want []Block + }{ + { + name: "simple class", + content: `class MyClass: + pass`, + want: []Block{ + { + Type: "class", + Name: "MyClass", + StartLine: 0, + EndLine: 1, + }, + }, + }, + { + name: "class with methods", + content: `class Server: + def __init__(self): + self.port = 8080 + + def start(self): + print("starting")`, + want: []Block{ + { + Type: "class", + Name: "Server", + StartLine: 0, + EndLine: 5, + }, + }, + }, + { + name: "class with inheritance", + content: `class Child(Parent): + def method(self): + return "child"`, + want: []Block{ + { + Type: "class", + Name: "Child", + StartLine: 0, + EndLine: 2, + }, + }, + }, + { + name: "class with multiple inheritance", + content: `class Multi(Base1, Base2, Base3): + pass`, + want: []Block{ + { + Type: "class", + Name: "Multi", + StartLine: 0, + EndLine: 1, + }, + }, + }, + { + name: "decorated class", + content: `@dataclass +class Config: + host: str + port: int`, + want: []Block{ + { + Type: "class", + Name: "Config", + StartLine: 0, + EndLine: 3, + }, + }, + }, + { + name: "nested class", + content: `class Outer: + class Inner: + pass + + def method(self): + return`, + want: []Block{ + { + Type: "class", + Name: "Outer", + StartLine: 0, + EndLine: 5, + }, + }, + }, + { + name: "multiple classes", + content: `class First: + pass + +class Second: + pass + +class Third: + pass`, + want: []Block{ + { + Type: "class", + Name: "First", + StartLine: 0, + EndLine: 1, + }, + { + Type: "class", + Name: "Second", + StartLine: 3, + EndLine: 4, + }, + { + Type: "class", + Name: "Third", + StartLine: 6, + EndLine: 7, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := p.DetectBlocks(tt.content) + + if len(got) != len(tt.want) { + t.Errorf("DetectBlocks() returned %d blocks, want %d", len(got), len(tt.want)) + for i, b := range got { + t.Logf(" got[%d]: %+v", i, b) + } + return + } + + for i := range got { + if got[i].Type != tt.want[i].Type { + t.Errorf("block[%d].Type = %q, want %q", i, got[i].Type, tt.want[i].Type) + } + if got[i].Name != tt.want[i].Name { + t.Errorf("block[%d].Name = %q, want %q", i, got[i].Name, tt.want[i].Name) + } + if got[i].StartLine != tt.want[i].StartLine { + t.Errorf("block[%d].StartLine = %d, want %d", i, got[i].StartLine, tt.want[i].StartLine) + } + if got[i].EndLine != tt.want[i].EndLine { + t.Errorf("block[%d].EndLine = %d, want %d", i, got[i].EndLine, tt.want[i].EndLine) + } + } + }) + } +} + +func TestPythonLanguage_DetectBlocks_Mixed(t *testing.T) { + p := PythonLanguage{} + + tests := []struct { + name string + content string + want []Block + }{ + { + name: "class and functions", + content: `def helper(): + return "help" + +class MyClass: + def method(self): + return "method" + +def another_function(): + return "function"`, + want: []Block{ + { + Type: "function", + Name: "helper", + StartLine: 0, + EndLine: 1, + }, + { + Type: "class", + Name: "MyClass", + StartLine: 3, + EndLine: 5, + }, + { + Type: "function", + Name: "another_function", + StartLine: 7, + EndLine: 8, + }, + }, + }, + { + name: "decorated class with decorated methods", + content: `@dataclass +class Config: + @property + def host(self): + return self._host + + @host.setter + def host(self, value): + self._host = value`, + want: []Block{ + { + Type: "class", + Name: "Config", + StartLine: 0, + EndLine: 8, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := p.DetectBlocks(tt.content) + + if len(got) != len(tt.want) { + t.Errorf("DetectBlocks() returned %d blocks, want %d", len(got), len(tt.want)) + for i, b := range got { + t.Logf(" got[%d]: %+v", i, b) + } + return + } + + for i := range got { + if got[i].Type != tt.want[i].Type { + t.Errorf("block[%d].Type = %q, want %q", i, got[i].Type, tt.want[i].Type) + } + if got[i].Name != tt.want[i].Name { + t.Errorf("block[%d].Name = %q, want %q", i, got[i].Name, tt.want[i].Name) + } + if got[i].StartLine != tt.want[i].StartLine { + t.Errorf("block[%d].StartLine = %d, want %d", i, got[i].StartLine, tt.want[i].StartLine) + } + if got[i].EndLine != tt.want[i].EndLine { + t.Errorf("block[%d].EndLine = %d, want %d", i, got[i].EndLine, tt.want[i].EndLine) + } + } + }) + } +} + +func TestPythonLanguage_DetectBlocks_RealWorld(t *testing.T) { + p := PythonLanguage{} + + content := `""" +Module for handling HTTP requests +""" + +import os +from typing import Dict, List +from dataclasses import dataclass + +@dataclass +class Config: + host: str + port: int + debug: bool = False + +class Handler: + def __init__(self, config: Config): + self.config = config + + async def handle_request(self, request): + """Process incoming request""" + if self.config.debug: + print(f"Handling request: {request}") + return {"status": "ok"} + + def shutdown(self): + print("Shutting down") + +def create_handler(host: str, port: int) -> Handler: + config = Config(host=host, port=port) + return Handler(config) + +@app.route('/health') +async def health_check(): + return {"status": "healthy"} + +if __name__ == "__main__": + handler = create_handler("localhost", 8080) + print("Server started")` + + want := []Block{ + { + Type: "class", + Name: "Config", + StartLine: 8, + EndLine: 12, + }, + { + Type: "class", + Name: "Handler", + StartLine: 14, + EndLine: 25, + }, + { + Type: "function", + Name: "create_handler", + StartLine: 27, + EndLine: 29, + }, + { + Type: "function", + Name: "health_check", + StartLine: 31, + EndLine: 33, + }, + } + + got := p.DetectBlocks(content) + + if len(got) != len(want) { + t.Errorf("DetectBlocks() returned %d blocks, want %d", len(got), len(want)) + for i, b := range got { + t.Logf(" got[%d]: %+v", i, b) + } + return + } + + for i := range got { + if got[i].Type != want[i].Type { + t.Errorf("block[%d].Type = %q, want %q", i, got[i].Type, want[i].Type) + } + if got[i].Name != want[i].Name { + t.Errorf("block[%d].Name = %q, want %q", i, got[i].Name, want[i].Name) + } + if got[i].StartLine != want[i].StartLine { + t.Errorf("block[%d].StartLine = %d, want %d", i, got[i].StartLine, want[i].StartLine) + } + if got[i].EndLine != want[i].EndLine { + t.Errorf("block[%d].EndLine = %d, want %d", i, got[i].EndLine, want[i].EndLine) + } + } +} + +func TestPythonLanguage_ImplementsInterface(t *testing.T) { + // Compile-time check that PythonLanguage implements Language + var _ Language = PythonLanguage{} +} + +func TestPythonLanguage_ExtractFunctionName(t *testing.T) { + p := PythonLanguage{} + + tests := []struct { + name string + input string + want string + }{ + { + name: "simple function", + input: "def main():", + want: "main", + }, + { + name: "function with parameters", + input: "def add(a, b):", + want: "add", + }, + { + name: "async function", + input: "async def fetch_data():", + want: "fetch_data", + }, + { + name: "function with type hints", + input: "def process(data: str) -> bool:", + want: "process", + }, + { + name: "async function with type hints", + input: "async def fetch(url: str) -> Dict[str, Any]:", + want: "fetch", + }, + { + name: "method with self", + input: "def method(self, arg):", + want: "method", + }, + { + name: "class method", + input: "def method(cls, arg):", + want: "method", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := p.extractFunctionName(tt.input) + if got != tt.want { + t.Errorf("extractFunctionName(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestPythonLanguage_ExtractClassName(t *testing.T) { + p := PythonLanguage{} + + tests := []struct { + name string + input string + want string + }{ + { + name: "simple class", + input: "class MyClass:", + want: "MyClass", + }, + { + name: "class with inheritance", + input: "class Child(Parent):", + want: "Child", + }, + { + name: "class with multiple inheritance", + input: "class Multi(Base1, Base2):", + want: "Multi", + }, + { + name: "class with generic types", + input: "class Handler(Generic[T]):", + want: "Handler", + }, + { + name: "class with complex bases", + input: "class Server(BaseServer, metaclass=ABCMeta):", + want: "Server", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := p.extractClassName(tt.input) + if got != tt.want { + t.Errorf("extractClassName(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestPythonLanguage_GetIndentation(t *testing.T) { + p := PythonLanguage{} + + tests := []struct { + name string + input string + want int + }{ + { + name: "no indentation", + input: "def main():", + want: 0, + }, + { + name: "4 spaces", + input: " return", + want: 4, + }, + { + name: "8 spaces", + input: " print('nested')", + want: 8, + }, + { + name: "1 tab", + input: "\treturn", + want: 4, + }, + { + name: "2 tabs", + input: "\t\tprint('nested')", + want: 8, + }, + { + name: "mixed tabs and spaces", + input: "\t return", + want: 6, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := p.getIndentation(tt.input) + if got != tt.want { + t.Errorf("getIndentation(%q) = %d, want %d", tt.input, got, tt.want) + } + }) + } +} diff --git a/internal/truncate/lang_typescript.go b/internal/truncate/lang_typescript.go new file mode 100644 index 00000000..d148c61f --- /dev/null +++ b/internal/truncate/lang_typescript.go @@ -0,0 +1,363 @@ +package truncate + +import ( + "strings" + "unicode" +) + +// TypeScriptLanguage implements Language interface for TypeScript files. +// Recognizes TypeScript-specific constructs including interfaces, type aliases, +// classes, functions (both traditional and arrow), and ES6 module syntax. +type TypeScriptLanguage struct{} + +func init() { + RegisterLanguage("typescript", TypeScriptLanguage{}) +} + +// CommentSyntax returns TypeScript's comment syntax. +// TypeScript uses // for single-line and /* */ for multi-line comments. +func (ts TypeScriptLanguage) CommentSyntax() (single string, multiOpen string, multiClose string) { + return "//", "/*", "*/" +} + +// DetectImportEnd returns the line index where the import section ends. +// Includes import statements and export declarations at the top of the file. +// Returns the first non-import, non-export, non-comment, non-blank line. +func (ts TypeScriptLanguage) DetectImportEnd(lines []string) int { + lastImportLine := -1 + + for i, line := range lines { + trimmed := strings.TrimSpace(line) + + // Skip blank lines and comments + if trimmed == "" || strings.HasPrefix(trimmed, "//") || strings.HasPrefix(trimmed, "/*") || strings.HasPrefix(trimmed, "*") { + continue + } + + // Check for import or export statements + if strings.HasPrefix(trimmed, "import ") || strings.HasPrefix(trimmed, "import{") || + strings.HasPrefix(trimmed, "export ") || strings.HasPrefix(trimmed, "export{") { + lastImportLine = i + continue + } + + // If we hit a non-import/export line, stop + break + } + + // Return the line after the last import/export + if lastImportLine >= 0 { + return lastImportLine + 1 + } + + return 0 +} + +// DetectBlocks identifies function, class, interface, and type boundaries in TypeScript. +// Uses brace depth tracking after stripping strings and comments for accuracy. +// Arrow functions with simple expressions (no braces) are treated as single-line blocks. +func (ts TypeScriptLanguage) DetectBlocks(content string) []Block { + if content == "" { + return []Block{} + } + + // Strip strings and comments for accurate brace counting + single, multiOpen, multiClose := ts.CommentSyntax() + stripper := NewStripper(single, multiOpen, multiClose) + stripped, err := stripper.Strip(content) + if err != nil { + // If stripping fails, return the entire content as one block + lines := strings.Split(content, "\n") + return []Block{ + { + Type: "block", + Name: "", + StartLine: 0, + EndLine: len(lines) - 1, + }, + } + } + + // Keep both original and stripped lines for name extraction + originalLines := strings.Split(content, "\n") + strippedLines := strings.Split(stripped, "\n") + var blocks []Block + + for i := 0; i < len(strippedLines); i++ { + strippedLine := strings.TrimSpace(strippedLines[i]) + originalLine := strings.TrimSpace(originalLines[i]) + + // Check for class declarations (use stripped line for detection) + if strings.HasPrefix(strippedLine, "class ") || strings.Contains(strippedLine, " class ") { + if block := ts.detectBraceBlock(originalLines, strippedLines, i, "class", stripped); block != nil { + blocks = append(blocks, *block) + } + continue + } + + // Check for interface declarations (use stripped line for detection) + if strings.HasPrefix(strippedLine, "interface ") || strings.Contains(strippedLine, " interface ") { + if block := ts.detectBraceBlock(originalLines, strippedLines, i, "interface", stripped); block != nil { + blocks = append(blocks, *block) + } + continue + } + + // Check for type declarations (use stripped line for detection) + if strings.HasPrefix(strippedLine, "type ") || strings.Contains(strippedLine, " type ") { + // Type aliases may or may not have braces + if block := ts.detectTypeBlock(originalLines, strippedLines, i, stripped); block != nil { + blocks = append(blocks, *block) + } + continue + } + + // Check for function declarations (use stripped line for detection) + if strings.HasPrefix(strippedLine, "function ") || strings.Contains(strippedLine, " function ") || + strings.HasPrefix(strippedLine, "async function ") { + if block := ts.detectBraceBlock(originalLines, strippedLines, i, "function", stripped); block != nil { + blocks = append(blocks, *block) + } + continue + } + + // Check for arrow functions - use ORIGINAL line for detection since names are stripped + if (strings.HasPrefix(originalLine, "const ") || strings.HasPrefix(originalLine, "let ") || + strings.HasPrefix(originalLine, "var ") || strings.HasPrefix(originalLine, "export const ") || + strings.HasPrefix(originalLine, "export let ")) && strings.Contains(strippedLine, "=>") { + if block := ts.detectArrowFunction(originalLines, strippedLines, i, stripped); block != nil { + blocks = append(blocks, *block) + } + continue + } + } + + return blocks +} + +// detectBraceBlock finds a block that starts with a keyword and is delimited by braces. +// Used for classes, interfaces, and traditional functions. +func (ts TypeScriptLanguage) detectBraceBlock(originalLines []string, strippedLines []string, startLine int, blockType string, stripped string) *Block { + name := ts.extractName(originalLines[startLine], blockType) + + // Find the opening brace in stripped lines + openBraceLine := -1 + for i := startLine; i < len(strippedLines); i++ { + if strings.Contains(strippedLines[i], "{") { + openBraceLine = i + break + } + // If we hit a semicolon or another statement before finding a brace, this isn't a block + if strings.Contains(strippedLines[i], ";") { + return nil + } + } + + if openBraceLine == -1 { + return nil + } + + // Track brace depth from the opening brace + depth := 0 + allStrippedLines := strings.Split(stripped, "\n") + + for i := openBraceLine; i < len(allStrippedLines); i++ { + for _, ch := range allStrippedLines[i] { + if ch == '{' { + depth++ + } else if ch == '}' { + depth-- + if depth == 0 { + return &Block{ + Type: blockType, + Name: name, + StartLine: startLine, + EndLine: i, + } + } + } + } + } + + // If we never found the closing brace, treat up to the end as the block + return &Block{ + Type: blockType, + Name: name, + StartLine: startLine, + EndLine: len(originalLines) - 1, + } +} + +// detectTypeBlock handles type alias declarations which may span multiple lines. +// Type aliases can be simple (type X = Y;) or complex with braces (type X = { ... }). +func (ts TypeScriptLanguage) detectTypeBlock(originalLines []string, strippedLines []string, startLine int, stripped string) *Block { + name := ts.extractName(originalLines[startLine], "type") + + // Check if the type has braces (object type) in stripped line + if strings.Contains(strippedLines[startLine], "{") { + return ts.detectBraceBlock(originalLines, strippedLines, startLine, "type", stripped) + } + + // Simple type alias - find the semicolon in stripped lines + for i := startLine; i < len(strippedLines); i++ { + if strings.Contains(strippedLines[i], ";") { + return &Block{ + Type: "type", + Name: name, + StartLine: startLine, + EndLine: i, + } + } + // If we see another statement start, the previous line was the end + trimmed := strings.TrimSpace(strippedLines[i]) + if i > startLine && (strings.HasPrefix(trimmed, "import ") || + strings.HasPrefix(trimmed, "export ") || + strings.HasPrefix(trimmed, "const ") || + strings.HasPrefix(trimmed, "let ") || + strings.HasPrefix(trimmed, "var ") || + strings.HasPrefix(trimmed, "function ") || + strings.HasPrefix(trimmed, "class ") || + strings.HasPrefix(trimmed, "interface ") || + strings.HasPrefix(trimmed, "type ")) { + return &Block{ + Type: "type", + Name: name, + StartLine: startLine, + EndLine: i - 1, + } + } + } + + // If no terminator found, treat as single line + return &Block{ + Type: "type", + Name: name, + StartLine: startLine, + EndLine: startLine, + } +} + +// detectArrowFunction handles arrow function expressions. +// Arrow functions can be simple expressions (x => x + 1) or block bodies (x => { ... }). +// Only block-body arrows are tracked as separate blocks. +func (ts TypeScriptLanguage) detectArrowFunction(originalLines []string, strippedLines []string, startLine int, stripped string) *Block { + originalLine := originalLines[startLine] + strippedLine := strippedLines[startLine] + name := ts.extractArrowFunctionName(originalLine) + + // Check if this is a block-body arrow function (has braces after =>) using stripped line + arrowIdx := strings.Index(strippedLine, "=>") + openBraceLine := startLine + + if arrowIdx == -1 { + // Multi-line arrow definition; look ahead in stripped lines + for i := startLine; i < len(strippedLines) && i < startLine+5; i++ { + if strings.Contains(strippedLines[i], "=>") { + arrowIdx = strings.Index(strippedLines[i], "=>") + strippedLine = strippedLines[i] + openBraceLine = i + break + } + } + if arrowIdx == -1 { + return nil + } + } + + afterArrow := strings.TrimSpace(strippedLine[arrowIdx+2:]) + + // If there's an opening brace after the arrow, it's a block body + if strings.HasPrefix(afterArrow, "{") || strings.Contains(strippedLine, "=> {") { + // Track brace depth to find the end of the function + depth := 0 + allStrippedLines := strings.Split(stripped, "\n") + + for i := openBraceLine; i < len(allStrippedLines); i++ { + for _, ch := range allStrippedLines[i] { + if ch == '{' { + depth++ + } else if ch == '}' { + depth-- + if depth == 0 { + return &Block{ + Type: "function", + Name: name, + StartLine: startLine, + EndLine: i, + } + } + } + } + } + + // If we never found the closing brace, treat up to the end as the block + return &Block{ + Type: "function", + Name: name, + StartLine: startLine, + EndLine: len(originalLines) - 1, + } + } + + // Simple expression arrow function - treat as single line + return &Block{ + Type: "function", + Name: name, + StartLine: startLine, + EndLine: startLine, + } +} + +// extractName extracts the identifier name from a declaration line. +// Handles class, interface, type, and function declarations. +func (ts TypeScriptLanguage) extractName(line string, blockType string) string { + // Remove leading keywords and modifiers + line = strings.TrimSpace(line) + tokens := strings.Fields(line) + + // Find the keyword and extract the next token as the name + for i, token := range tokens { + if token == blockType { + if i+1 < len(tokens) { + name := tokens[i+1] + // Clean up any trailing characters + name = strings.TrimFunc(name, func(r rune) bool { + return !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' + }) + return name + } + } + } + + return "" +} + +// extractArrowFunctionName extracts the variable name from an arrow function declaration. +// Handles const/let/var declarations with export modifiers. +func (ts TypeScriptLanguage) extractArrowFunctionName(line string) string { + line = strings.TrimSpace(line) + + // Remove export modifier if present + line = strings.TrimPrefix(line, "export ") + line = strings.TrimSpace(line) + + // Find the variable name after const/let/var + for _, keyword := range []string{"const ", "let ", "var "} { + if strings.HasPrefix(line, keyword) { + rest := strings.TrimSpace(line[len(keyword):]) + // Split by whitespace and get the first token + tokens := strings.Fields(rest) + if len(tokens) > 0 { + // The first token should be the variable name + // Remove any trailing punctuation like : or = + name := tokens[0] + name = strings.TrimFunc(name, func(r rune) bool { + return !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' && r != '$' + }) + return name + } + } + } + + return "" +} diff --git a/internal/truncate/lang_typescript_test.go b/internal/truncate/lang_typescript_test.go new file mode 100644 index 00000000..16e5a5d4 --- /dev/null +++ b/internal/truncate/lang_typescript_test.go @@ -0,0 +1,558 @@ +package truncate + +import ( + "testing" +) + +func TestTypeScriptLanguage_CommentSyntax(t *testing.T) { + ts := TypeScriptLanguage{} + + single, multiOpen, multiClose := ts.CommentSyntax() + + if single != "//" { + t.Errorf("CommentSyntax() single = %q, want %q", single, "//") + } + if multiOpen != "/*" { + t.Errorf("CommentSyntax() multiOpen = %q, want %q", multiOpen, "/*") + } + if multiClose != "*/" { + t.Errorf("CommentSyntax() multiClose = %q, want %q", multiClose, "*/") + } +} + +func TestTypeScriptLanguage_DetectImportEnd(t *testing.T) { + ts := TypeScriptLanguage{} + + tests := []struct { + name string + lines []string + want int + }{ + { + name: "no imports", + lines: []string{"const x = 1;", "function foo() {}"}, + want: 0, + }, + { + name: "single import", + lines: []string{ + "import { foo } from 'bar';", + "", + "const x = 1;", + }, + want: 1, + }, + { + name: "multiple imports", + lines: []string{ + "import { foo } from 'bar';", + "import { baz } from 'qux';", + "", + "const x = 1;", + }, + want: 2, + }, + { + name: "imports with comments", + lines: []string{ + "// This is a comment", + "import { foo } from 'bar';", + "/* Multi-line", + " * comment */", + "import { baz } from 'qux';", + "", + "const x = 1;", + }, + want: 5, + }, + { + name: "export statements", + lines: []string{ + "export { foo } from 'bar';", + "export const x = 1;", + "", + "function doStuff() {}", + }, + want: 2, + }, + { + name: "mixed import and export", + lines: []string{ + "import { foo } from 'bar';", + "export { baz } from 'qux';", + "", + "const x = 1;", + }, + want: 2, + }, + { + name: "import without space", + lines: []string{ + "import{foo}from'bar';", + "", + "const x = 1;", + }, + want: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ts.DetectImportEnd(tt.lines) + if got != tt.want { + t.Errorf("DetectImportEnd() = %d, want %d", got, tt.want) + } + }) + } +} + +func TestTypeScriptLanguage_DetectBlocks_Classes(t *testing.T) { + ts := TypeScriptLanguage{} + + tests := []struct { + name string + content string + want []Block + }{ + { + name: "simple class", + content: `class MyClass { + constructor() {} + method() {} +}`, + want: []Block{ + {Type: "class", Name: "MyClass", StartLine: 0, EndLine: 3}, + }, + }, + { + name: "class with export", + content: `export class MyClass { + method() {} +}`, + want: []Block{ + {Type: "class", Name: "MyClass", StartLine: 0, EndLine: 2}, + }, + }, + { + name: "multiple classes", + content: `class First { + method1() {} +} + +class Second { + method2() {} +}`, + want: []Block{ + {Type: "class", Name: "First", StartLine: 0, EndLine: 2}, + {Type: "class", Name: "Second", StartLine: 4, EndLine: 6}, + }, + }, + { + name: "class with nested braces", + content: `class MyClass { + method() { + if (true) { + return { key: "value" }; + } + } +}`, + want: []Block{ + {Type: "class", Name: "MyClass", StartLine: 0, EndLine: 6}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ts.DetectBlocks(tt.content) + if !compareBlocks(got, tt.want) { + t.Errorf("DetectBlocks() =\n%+v\nwant\n%+v", got, tt.want) + } + }) + } +} + +func TestTypeScriptLanguage_DetectBlocks_Functions(t *testing.T) { + ts := TypeScriptLanguage{} + + tests := []struct { + name string + content string + want []Block + }{ + { + name: "simple function", + content: `function myFunc() { + return 42; +}`, + want: []Block{ + {Type: "function", Name: "myFunc", StartLine: 0, EndLine: 2}, + }, + }, + { + name: "async function", + content: `async function fetchData() { + return await fetch('/api'); +}`, + want: []Block{ + {Type: "function", Name: "fetchData", StartLine: 0, EndLine: 2}, + }, + }, + { + name: "multiple functions", + content: `function first() { + return 1; +} + +function second() { + return 2; +}`, + want: []Block{ + {Type: "function", Name: "first", StartLine: 0, EndLine: 2}, + {Type: "function", Name: "second", StartLine: 4, EndLine: 6}, + }, + }, + { + name: "function with nested braces", + content: `function complex() { + if (condition) { + const obj = { key: "value" }; + return obj; + } + return null; +}`, + want: []Block{ + {Type: "function", Name: "complex", StartLine: 0, EndLine: 6}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ts.DetectBlocks(tt.content) + if !compareBlocks(got, tt.want) { + t.Errorf("DetectBlocks() =\n%+v\nwant\n%+v", got, tt.want) + } + }) + } +} + +func TestTypeScriptLanguage_DetectBlocks_ArrowFunctions(t *testing.T) { + ts := TypeScriptLanguage{} + + tests := []struct { + name string + content string + want []Block + }{ + { + name: "simple arrow function expression", + content: `const add = (a, b) => a + b;`, + want: []Block{ + {Type: "function", Name: "add", StartLine: 0, EndLine: 0}, + }, + }, + { + name: "arrow function with block body", + content: `const calculate = (x) => { + const result = x * 2; + return result; +};`, + want: []Block{ + {Type: "function", Name: "calculate", StartLine: 0, EndLine: 3}, + }, + }, + { + name: "exported arrow function", + content: `export const handler = async (req) => { + return { status: 200 }; +};`, + want: []Block{ + {Type: "function", Name: "handler", StartLine: 0, EndLine: 2}, + }, + }, + { + name: "multiple arrow functions", + content: `const first = () => 1; +const second = () => { + return 2; +};`, + want: []Block{ + {Type: "function", Name: "first", StartLine: 0, EndLine: 0}, + {Type: "function", Name: "second", StartLine: 1, EndLine: 3}, + }, + }, + { + name: "arrow function with nested braces", + content: `const process = (data) => { + const obj = { key: "value" }; + if (data) { + return obj; + } + return null; +};`, + want: []Block{ + {Type: "function", Name: "process", StartLine: 0, EndLine: 6}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ts.DetectBlocks(tt.content) + if !compareBlocks(got, tt.want) { + t.Errorf("DetectBlocks() =\n%+v\nwant\n%+v", got, tt.want) + } + }) + } +} + +func TestTypeScriptLanguage_DetectBlocks_Interfaces(t *testing.T) { + ts := TypeScriptLanguage{} + + tests := []struct { + name string + content string + want []Block + }{ + { + name: "simple interface", + content: `interface User { + name: string; + age: number; +}`, + want: []Block{ + {Type: "interface", Name: "User", StartLine: 0, EndLine: 3}, + }, + }, + { + name: "exported interface", + content: `export interface Config { + port: number; + host: string; +}`, + want: []Block{ + {Type: "interface", Name: "Config", StartLine: 0, EndLine: 3}, + }, + }, + { + name: "multiple interfaces", + content: `interface First { + id: number; +} + +interface Second { + name: string; +}`, + want: []Block{ + {Type: "interface", Name: "First", StartLine: 0, EndLine: 2}, + {Type: "interface", Name: "Second", StartLine: 4, EndLine: 6}, + }, + }, + { + name: "interface with nested object", + content: `interface Complex { + data: { + nested: { + value: string; + }; + }; +}`, + want: []Block{ + {Type: "interface", Name: "Complex", StartLine: 0, EndLine: 6}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ts.DetectBlocks(tt.content) + if !compareBlocks(got, tt.want) { + t.Errorf("DetectBlocks() =\n%+v\nwant\n%+v", got, tt.want) + } + }) + } +} + +func TestTypeScriptLanguage_DetectBlocks_Types(t *testing.T) { + ts := TypeScriptLanguage{} + + tests := []struct { + name string + content string + want []Block + }{ + { + name: "simple type alias", + content: `type ID = string | number;`, + want: []Block{ + {Type: "type", Name: "ID", StartLine: 0, EndLine: 0}, + }, + }, + { + name: "object type", + content: `type User = { + name: string; + age: number; +};`, + want: []Block{ + {Type: "type", Name: "User", StartLine: 0, EndLine: 3}, + }, + }, + { + name: "exported type", + content: `export type Status = "active" | "inactive";`, + want: []Block{ + {Type: "type", Name: "Status", StartLine: 0, EndLine: 0}, + }, + }, + { + name: "multiple types", + content: `type First = string; + +type Second = { + value: number; +};`, + want: []Block{ + {Type: "type", Name: "First", StartLine: 0, EndLine: 0}, + {Type: "type", Name: "Second", StartLine: 2, EndLine: 4}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ts.DetectBlocks(tt.content) + if !compareBlocks(got, tt.want) { + t.Errorf("DetectBlocks() =\n%+v\nwant\n%+v", got, tt.want) + } + }) + } +} + +func TestTypeScriptLanguage_DetectBlocks_Mixed(t *testing.T) { + ts := TypeScriptLanguage{} + + content := `import { foo } from 'bar'; + +interface User { + name: string; +} + +type ID = string; + +class UserService { + getUser(id: ID): User { + return { name: "test" }; + } +} + +function helper() { + return 42; +} + +const process = (data) => { + return data; +};` + + got := ts.DetectBlocks(content) + + // Should detect interface, type, class, function, and arrow function + if len(got) != 5 { + t.Errorf("DetectBlocks() found %d blocks, want 5", len(got)) + } + + // Check that we got the right types + expectedTypes := map[string]bool{ + "interface": true, + "type": true, + "class": true, + "function": true, // both regular and arrow + } + + for _, block := range got { + if !expectedTypes[block.Type] { + t.Errorf("Unexpected block type: %s", block.Type) + } + } +} + +func TestTypeScriptLanguage_DetectBlocks_StringsAndComments(t *testing.T) { + ts := TypeScriptLanguage{} + + tests := []struct { + name string + content string + want []Block + }{ + { + name: "braces in strings should not affect detection", + content: `function test() { + const str = "this { has } braces"; + return str; +}`, + want: []Block{ + {Type: "function", Name: "test", StartLine: 0, EndLine: 3}, + }, + }, + { + name: "braces in comments should not affect detection", + content: `function test() { + // This comment has { braces } + /* And this one too { } */ + return 42; +}`, + want: []Block{ + {Type: "function", Name: "test", StartLine: 0, EndLine: 4}, + }, + }, + { + name: "template literals with braces", + content: "function test() {\n const str = `value: ${x}`;\n return str;\n}", + want: []Block{ + {Type: "function", Name: "test", StartLine: 0, EndLine: 3}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ts.DetectBlocks(tt.content) + if !compareBlocks(got, tt.want) { + t.Errorf("DetectBlocks() =\n%+v\nwant\n%+v", got, tt.want) + } + }) + } +} + +func TestTypeScriptLanguage_DetectBlocks_EmptyContent(t *testing.T) { + ts := TypeScriptLanguage{} + + got := ts.DetectBlocks("") + if len(got) != 0 { + t.Errorf("DetectBlocks(\"\") = %v, want empty slice", got) + } +} + +func TestTypeScriptLanguage_ImplementsInterface(t *testing.T) { + // Compile-time check that TypeScriptLanguage implements Language + var _ Language = TypeScriptLanguage{} +} + +// compareBlocks compares two slices of blocks for equality. +func compareBlocks(got, want []Block) bool { + if len(got) != len(want) { + return false + } + + for i := range got { + if got[i].Type != want[i].Type || + got[i].Name != want[i].Name || + got[i].StartLine != want[i].StartLine || + got[i].EndLine != want[i].EndLine { + return false + } + } + + return true +} diff --git a/internal/truncate/language.go b/internal/truncate/language.go new file mode 100644 index 00000000..6274f9c2 --- /dev/null +++ b/internal/truncate/language.go @@ -0,0 +1,32 @@ +package truncate + +// Block represents a code block (function, class, method, etc.) with its boundaries. +type Block struct { + Type string // "function", "class", "method", "block" + Name string // Identifier name if available + StartLine int // Line number where block starts (0-indexed) + EndLine int // Line number where block ends (0-indexed, inclusive) +} + +// Language defines the interface for language-specific parsing. +// Implementations provide language-aware detection of imports and code blocks +// to enable intelligent truncation at natural boundaries. +type Language interface { + // DetectImportEnd returns the line index (0-indexed) where the import section ends. + // Returns the first non-import, non-comment, non-blank line after imports. + // Returns 0 if there are no imports. + DetectImportEnd(lines []string) int + + // DetectBlocks returns all function/class/method boundaries in the content. + // The content string contains the full file content. + // Returns blocks in order of appearance (by StartLine). + DetectBlocks(content string) []Block + + // CommentSyntax returns the comment syntax for this language. + // Returns: + // - single: single-line comment prefix (e.g., "//", "#") + // - multiOpen: multi-line comment opening (e.g., "/*", `"""`) + // - multiClose: multi-line comment closing (e.g., "*/", `"""`) + // Empty strings indicate no support for that comment type. + CommentSyntax() (single string, multiOpen string, multiClose string) +} diff --git a/internal/truncate/registry.go b/internal/truncate/registry.go new file mode 100644 index 00000000..ba89f629 --- /dev/null +++ b/internal/truncate/registry.go @@ -0,0 +1,43 @@ +package truncate + +import ( + "strings" + "sync" +) + +var ( + // registry stores language parsers by normalized language name + registry = make(map[string]Language) + // registryMu protects concurrent access to the registry + registryMu sync.RWMutex +) + +// RegisterLanguage registers a language parser for the given language identifier. +// The language name is case-insensitive and will be normalized to lowercase. +// If a parser already exists for this language, it will be replaced. +// Thread-safe for concurrent registration. +func RegisterLanguage(language string, parser Language) { + registryMu.Lock() + defer registryMu.Unlock() + + normalized := normalizeLanguage(language) + registry[normalized] = parser +} + +// GetLanguage retrieves the language parser for the given language identifier. +// Returns nil if no parser is registered for this language. +// The language name is case-insensitive. +// Thread-safe for concurrent access. +func GetLanguage(language string) Language { + registryMu.RLock() + defer registryMu.RUnlock() + + normalized := normalizeLanguage(language) + return registry[normalized] +} + +// normalizeLanguage converts a language identifier to its canonical form. +// Converts to lowercase and trims whitespace for consistent lookups. +func normalizeLanguage(language string) string { + return strings.ToLower(strings.TrimSpace(language)) +} diff --git a/internal/truncate/registry_test.go b/internal/truncate/registry_test.go new file mode 100644 index 00000000..3c227554 --- /dev/null +++ b/internal/truncate/registry_test.go @@ -0,0 +1,283 @@ +package truncate + +import ( + "sync" + "testing" +) + +// mockLanguage is a test implementation of the Language interface +type mockLanguage struct { + name string +} + +func (m mockLanguage) DetectImportEnd(lines []string) int { + return 0 +} + +func (m mockLanguage) DetectBlocks(content string) []Block { + return []Block{} +} + +func (m mockLanguage) CommentSyntax() (string, string, string) { + return "//", "/*", "*/" +} + +func TestRegisterLanguage(t *testing.T) { + // Save and restore registry for test isolation + registryMu.Lock() + saved := registry + registry = make(map[string]Language) + registryMu.Unlock() + + defer func() { + registryMu.Lock() + registry = saved + registryMu.Unlock() + }() + + tests := []struct { + name string + language string + parser Language + }{ + { + name: "register go", + language: "go", + parser: mockLanguage{name: "go"}, + }, + { + name: "register python", + language: "python", + parser: mockLanguage{name: "python"}, + }, + { + name: "register with uppercase", + language: "TYPESCRIPT", + parser: mockLanguage{name: "typescript"}, + }, + { + name: "register with spaces", + language: " javascript ", + parser: mockLanguage{name: "javascript"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + RegisterLanguage(tt.language, tt.parser) + + // Verify it was registered + got := GetLanguage(tt.language) + if got == nil { + t.Errorf("GetLanguage(%q) = nil, want non-nil", tt.language) + } + }) + } +} + +func TestGetLanguage(t *testing.T) { + // Save and restore registry for test isolation + registryMu.Lock() + saved := registry + registry = make(map[string]Language) + registryMu.Unlock() + + defer func() { + registryMu.Lock() + registry = saved + registryMu.Unlock() + }() + + goParser := mockLanguage{name: "go"} + pythonParser := mockLanguage{name: "python"} + + RegisterLanguage("go", goParser) + RegisterLanguage("python", pythonParser) + + tests := []struct { + name string + language string + wantNil bool + }{ + { + name: "get registered language", + language: "go", + wantNil: false, + }, + { + name: "get with uppercase", + language: "GO", + wantNil: false, + }, + { + name: "get with mixed case", + language: "Go", + wantNil: false, + }, + { + name: "get with spaces", + language: " go ", + wantNil: false, + }, + { + name: "get unregistered language", + language: "rust", + wantNil: true, + }, + { + name: "get empty string", + language: "", + wantNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := GetLanguage(tt.language) + if tt.wantNil && got != nil { + t.Errorf("GetLanguage(%q) = %v, want nil", tt.language, got) + } + if !tt.wantNil && got == nil { + t.Errorf("GetLanguage(%q) = nil, want non-nil", tt.language) + } + }) + } +} + +func TestGetLanguage_CaseInsensitive(t *testing.T) { + // Save and restore registry for test isolation + registryMu.Lock() + saved := registry + registry = make(map[string]Language) + registryMu.Unlock() + + defer func() { + registryMu.Lock() + registry = saved + registryMu.Unlock() + }() + + parser := mockLanguage{name: "test"} + RegisterLanguage("TypeScript", parser) + + // All these should retrieve the same parser + variations := []string{"typescript", "TYPESCRIPT", "TypeScript", "tYpEsCrIpT"} + + for _, variation := range variations { + t.Run(variation, func(t *testing.T) { + got := GetLanguage(variation) + if got == nil { + t.Errorf("GetLanguage(%q) = nil, want non-nil", variation) + } + }) + } +} + +func TestRegisterLanguage_Replacement(t *testing.T) { + // Save and restore registry for test isolation + registryMu.Lock() + saved := registry + registry = make(map[string]Language) + registryMu.Unlock() + + defer func() { + registryMu.Lock() + registry = saved + registryMu.Unlock() + }() + + parser1 := mockLanguage{name: "first"} + parser2 := mockLanguage{name: "second"} + + RegisterLanguage("go", parser1) + RegisterLanguage("go", parser2) + + // Should get the second parser (replacement) + got := GetLanguage("go") + if got == nil { + t.Fatal("GetLanguage(\"go\") = nil, want non-nil") + } + + // Verify it's the second parser by checking the name field + if mock, ok := got.(mockLanguage); ok { + if mock.name != "second" { + t.Errorf("GetLanguage(\"go\") returned parser with name %q, want %q", mock.name, "second") + } + } +} + +func TestRegistry_ThreadSafety(t *testing.T) { + // Save and restore registry for test isolation + registryMu.Lock() + saved := registry + registry = make(map[string]Language) + registryMu.Unlock() + + defer func() { + registryMu.Lock() + registry = saved + registryMu.Unlock() + }() + + // Run concurrent registrations and lookups + const goroutines = 100 + const iterations = 100 + + var wg sync.WaitGroup + wg.Add(goroutines * 2) + + // Concurrent registrations + for i := 0; i < goroutines; i++ { + go func(id int) { + defer wg.Done() + for j := 0; j < iterations; j++ { + parser := mockLanguage{name: "concurrent"} + RegisterLanguage("go", parser) + } + }(i) + } + + // Concurrent lookups + for i := 0; i < goroutines; i++ { + go func(id int) { + defer wg.Done() + for j := 0; j < iterations; j++ { + _ = GetLanguage("go") + } + }(i) + } + + wg.Wait() + + // Verify registry is still functional + got := GetLanguage("go") + if got == nil { + t.Error("GetLanguage(\"go\") = nil after concurrent operations, want non-nil") + } +} + +func TestNormalizeLanguage(t *testing.T) { + tests := []struct { + input string + want string + }{ + {input: "go", want: "go"}, + {input: "Go", want: "go"}, + {input: "GO", want: "go"}, + {input: "TypeScript", want: "typescript"}, + {input: "PYTHON", want: "python"}, + {input: " javascript ", want: "javascript"}, + {input: "\tgo\t", want: "go"}, + {input: "", want: ""}, + {input: " ", want: ""}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := normalizeLanguage(tt.input) + if got != tt.want { + t.Errorf("normalizeLanguage(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} diff --git a/internal/truncate/stripper.go b/internal/truncate/stripper.go new file mode 100644 index 00000000..a1f5b5f8 --- /dev/null +++ b/internal/truncate/stripper.go @@ -0,0 +1,283 @@ +package truncate + +import ( + "errors" + "strings" +) + +const ( + defaultMaxNestingDepth = 1000 +) + +var ( + ErrMaxNestingDepthExceeded = errors.New("maximum nesting depth exceeded") +) + +// stripState represents the current parsing state in the stripper state machine. +type stripState int + +const ( + stateNormal stripState = iota + stateSingleLineComment + stateMultiLineComment + stateDoubleQuoteString + stateSingleQuoteString + stateBacktickString + stateTripleDoubleQuoteString + stateTripleSingleQuoteString +) + +// Stripper removes string literals and comments from code while preserving +// structure (line breaks, character positions) for bracket counting. +type Stripper struct { + singleLineComment string + multiOpen string + multiClose string + maxDepth int +} + +// NewStripper creates a stripper for the given comment syntax. +// Pass empty strings for unsupported comment types. +func NewStripper(singleLineComment, multiOpen, multiClose string) *Stripper { + return &Stripper{ + singleLineComment: singleLineComment, + multiOpen: multiOpen, + multiClose: multiClose, + maxDepth: defaultMaxNestingDepth, + } +} + +// Strip removes strings and comments from content, replacing them with spaces. +// Line structure and character positions are preserved for accurate bracket counting. +// Returns error if nesting depth exceeds safety limit. +func (s *Stripper) Strip(content string) (string, error) { + if content == "" { + return "", nil + } + + var result strings.Builder + result.Grow(len(content)) + + state := stateNormal + escaped := false + depth := 0 + i := 0 + + for i < len(content) { + ch := content[i] + + // Check nesting depth to prevent stack overflow attacks + if depth > s.maxDepth { + return "", ErrMaxNestingDepthExceeded + } + + switch state { + case stateNormal: + // Check for triple-quoted strings first (Python) + if i+2 < len(content) && content[i:i+3] == `"""` { + result.WriteString(" ") + i += 3 + state = stateTripleDoubleQuoteString + depth++ + continue + } + if i+2 < len(content) && content[i:i+3] == "'''" { + result.WriteString(" ") + i += 3 + state = stateTripleSingleQuoteString + depth++ + continue + } + + // Check for multi-line comment + if s.multiOpen != "" && strings.HasPrefix(content[i:], s.multiOpen) { + result.WriteString(strings.Repeat(" ", len(s.multiOpen))) + i += len(s.multiOpen) + state = stateMultiLineComment + depth++ + continue + } + + // Check for single-line comment + if s.singleLineComment != "" && strings.HasPrefix(content[i:], s.singleLineComment) { + result.WriteString(strings.Repeat(" ", len(s.singleLineComment))) + i += len(s.singleLineComment) + state = stateSingleLineComment + depth++ + continue + } + + // Check for string literals + if ch == '"' { + result.WriteByte(' ') + i++ + state = stateDoubleQuoteString + depth++ + continue + } + if ch == '\'' { + result.WriteByte(' ') + i++ + state = stateSingleQuoteString + depth++ + continue + } + if ch == '`' { + result.WriteByte(' ') + i++ + state = stateBacktickString + depth++ + continue + } + + // Normal code character + result.WriteByte(ch) + i++ + + case stateSingleLineComment: + if ch == '\n' { + result.WriteByte('\n') + state = stateNormal + depth-- + } else { + result.WriteByte(' ') + } + i++ + + case stateMultiLineComment: + if s.multiClose != "" && strings.HasPrefix(content[i:], s.multiClose) { + result.WriteString(strings.Repeat(" ", len(s.multiClose))) + i += len(s.multiClose) + state = stateNormal + depth-- + continue + } + + // Preserve newlines for line structure + if ch == '\n' { + result.WriteByte('\n') + } else { + result.WriteByte(' ') + } + i++ + + case stateDoubleQuoteString: + if escaped { + result.WriteByte(' ') + escaped = false + i++ + continue + } + + if ch == '\\' { + result.WriteByte(' ') + escaped = true + i++ + continue + } + + if ch == '"' { + result.WriteByte(' ') + state = stateNormal + depth-- + i++ + continue + } + + // Preserve newlines (though uncommon in non-raw strings) + if ch == '\n' { + result.WriteByte('\n') + } else { + result.WriteByte(' ') + } + i++ + + case stateSingleQuoteString: + if escaped { + result.WriteByte(' ') + escaped = false + i++ + continue + } + + if ch == '\\' { + result.WriteByte(' ') + escaped = true + i++ + continue + } + + if ch == '\'' { + result.WriteByte(' ') + state = stateNormal + depth-- + i++ + continue + } + + // Preserve newlines + if ch == '\n' { + result.WriteByte('\n') + } else { + result.WriteByte(' ') + } + i++ + + case stateBacktickString: + // Backquote strings (Go raw strings, JS template literals) + // No escape sequences in Go raw strings + // JS template literals can have ${} but we're just stripping + if ch == '`' { + result.WriteByte(' ') + state = stateNormal + depth-- + i++ + continue + } + + // Preserve newlines + if ch == '\n' { + result.WriteByte('\n') + } else { + result.WriteByte(' ') + } + i++ + + case stateTripleDoubleQuoteString: + if strings.HasPrefix(content[i:], `"""`) { + result.WriteString(" ") + i += 3 + state = stateNormal + depth-- + continue + } + + // Preserve newlines + if ch == '\n' { + result.WriteByte('\n') + } else { + result.WriteByte(' ') + } + i++ + + case stateTripleSingleQuoteString: + if strings.HasPrefix(content[i:], "'''") { + result.WriteString(" ") + i += 3 + state = stateNormal + depth-- + continue + } + + // Preserve newlines + if ch == '\n' { + result.WriteByte('\n') + } else { + result.WriteByte(' ') + } + i++ + } + } + + return result.String(), nil +} diff --git a/internal/truncate/stripper_test.go b/internal/truncate/stripper_test.go new file mode 100644 index 00000000..772b6dd5 --- /dev/null +++ b/internal/truncate/stripper_test.go @@ -0,0 +1,444 @@ +package truncate + +import ( + "strings" + "testing" +) + +func TestStripper_Strip_Go(t *testing.T) { + // Go uses //, /* */, double-quoted strings, and backtick raw strings + s := NewStripper("//", "/*", "*/") + + tests := []struct { + name string + input string + want string + wantErr bool + }{ + { + name: "empty string", + input: "", + want: "", + }, + { + name: "code without strings or comments", + input: "func main() {\n\tx := 42\n}", + want: "func main() {\n\tx := 42\n}", + }, + { + name: "single line comment", + input: "x := 42 // comment here\ny := 10", + want: "x := 42 \ny := 10", + }, + { + name: "multi-line comment", + input: "x := 42 /* comment\nacross lines */ y := 10", + want: "x := 42 \n y := 10", + }, + { + name: "double-quoted string", + input: `s := "hello world"`, + want: `s := `, + }, + { + name: "escaped quote in string", + input: `s := "hello \"quoted\" world"`, + want: `s := `, + }, + { + name: "backtick raw string", + input: "s := `raw\nstring`", + want: "s := \n ", + }, + { + name: "comment marker inside string - preserved", + input: `s := "// not a comment"`, + want: `s := `, + }, + { + name: "string inside comment - stripped", + input: `// comment with "string"`, + want: ` `, + }, + { + name: "bracket in string - preserved", + input: `s := "text { bracket }"`, + want: `s := `, + }, + { + name: "bracket in comment - stripped", + input: `// comment { with bracket }`, + want: ` `, + }, + { + name: "multiple strings and comments", + input: "s1 := \"hello\"\n// comment\ns2 := \"world\"", + want: "s1 := \n \ns2 := ", + }, + { + name: "nested quotes with escapes", + input: `s := "outer \"inner \\\" nested\" outer"`, + want: `s := `, + }, + { + name: "comment at EOF", + input: "x := 42 // comment", + want: "x := 42 ", + }, + { + name: "string at EOF", + input: `s := "text"`, + want: `s := `, + }, + { + name: "unclosed string", + input: `s := "unclosed`, + want: `s := `, + }, + { + name: "unclosed comment", + input: "x := 42 /* unclosed", + want: "x := 42 ", + }, + { + name: "empty string literal", + input: `s := ""`, + want: `s := `, + }, + { + name: "empty comment", + input: "x := 42 //\ny := 10", + want: "x := 42 \ny := 10", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := s.Strip(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("Strip() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("Strip() mismatch:\ninput: %q\ngot: %q\nwant: %q", tt.input, got, tt.want) + } + + // Verify length preservation + if len(got) != len(tt.input) { + t.Errorf("Strip() length mismatch: got %d, want %d", len(got), len(tt.input)) + } + + // Verify newline preservation + if strings.Count(got, "\n") != strings.Count(tt.input, "\n") { + t.Errorf("Strip() newline count mismatch: got %d, want %d", + strings.Count(got, "\n"), strings.Count(tt.input, "\n")) + } + }) + } +} + +func TestStripper_Strip_Python(t *testing.T) { + // Python uses #, triple-quoted strings + s := NewStripper("#", "", "") + + tests := []struct { + name string + input string + want string + }{ + { + name: "hash comment", + input: "x = 42 # comment here\ny = 10", + want: "x = 42 \ny = 10", + }, + { + name: "triple-double-quote docstring", + input: "def f():\n \"\"\"docstring\n here\"\"\"\n pass", + want: "def f():\n \n \n pass", + }, + { + name: "triple-single-quote docstring", + input: "def f():\n '''docstring\n here'''\n pass", + want: "def f():\n \n \n pass", + }, + { + name: "single-quoted string", + input: "s = 'hello world'", + want: "s = ", + }, + { + name: "double-quoted string", + input: `s = "hello world"`, + want: `s = `, + }, + { + name: "escaped quote in python string", + input: `s = "test \" quote"`, + want: `s = `, + }, + { + name: "comment marker in string", + input: `s = "# not a comment"`, + want: `s = `, + }, + { + name: "string in comment", + input: `# comment with "string"`, + want: ` `, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := s.Strip(tt.input) + if err != nil { + t.Errorf("Strip() error = %v", err) + return + } + if got != tt.want { + t.Errorf("Strip() mismatch:\ninput: %q\ngot: %q\nwant: %q", tt.input, got, tt.want) + } + }) + } +} + +func TestStripper_Strip_JavaScript(t *testing.T) { + // JavaScript uses //, /* */, template literals with backticks + s := NewStripper("//", "/*", "*/") + + tests := []struct { + name string + input string + want string + }{ + { + name: "template literal", + input: "const s = `template ${x} string`", + want: "const s = ", + }, + { + name: "template literal multiline", + input: "const s = `line1\nline2`", + want: "const s = \n ", + }, + { + name: "single and double quotes", + input: `const s1 = "double"; const s2 = 'single'`, + want: `const s1 = ; const s2 = `, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := s.Strip(tt.input) + if err != nil { + t.Errorf("Strip() error = %v", err) + return + } + if got != tt.want { + t.Errorf("Strip() mismatch:\ninput: %q\ngot: %q\nwant: %q", tt.input, got, tt.want) + } + }) + } +} + +func TestStripper_Strip_ComplexNesting(t *testing.T) { + s := NewStripper("//", "/*", "*/") + + tests := []struct { + name string + input string + want string + }{ + { + name: "multiple nested structures", + input: `s1 := "string1" // comment` + "\n" + `s2 := "string2" /* comment */`, + want: `s1 := ` + "\n" + `s2 := `, + }, + { + name: "bracket counting scenario", + input: "func f() {\n\ts := \"text { here }\"\n\t// comment { here }\n\tx := 42\n}", + want: "func f() {\n\ts := \n\t \n\tx := 42\n}", + }, + { + name: "real world Go code", + input: `package main + +import "fmt" + +func main() { + // This is a comment + s := "hello {world}" /* block comment */ + fmt.Println(s) +}`, + // Each stripped string/comment is replaced with spaces + // Line 3: import "fmt" -> import (space + 5 spaces for "fmt") + // Line 6: \t// This is a comment -> \t (21 spaces for comment) + // Line 7: \ts := "hello {world}" /* block comment */ + // -> \ts := (15 + 1 + 20 spaces) + want: "package main\n\nimport \n\nfunc main() {\n\t \n\ts := \n\tfmt.Println(s)\n}", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := s.Strip(tt.input) + if err != nil { + t.Errorf("Strip() error = %v", err) + return + } + if got != tt.want { + t.Errorf("Strip() mismatch:\ninput:\n%s\n\ngot:\n%s\n\nwant:\n%s", tt.input, got, tt.want) + } + }) + } +} + +func TestStripper_Strip_MaxDepth(t *testing.T) { + // The depth check is incremented per string/comment entry + // not per character, so this test needs adjustment + // The depth tracks nesting level (string inside comment), not sequential items + + // Better approach: create input with legitimate depth tracking + // We enter comment (depth=1), and process many chars. + // Each char checks if depth > maxDepth + // So even at depth=1, if maxDepth=0, it would fail + + s := NewStripper("//", "", "") + s.maxDepth = 0 + + _, err := s.Strip("// comment") + if err != ErrMaxNestingDepthExceeded { + t.Errorf("Strip() with maxDepth=0: got error %v, want %v", err, ErrMaxNestingDepthExceeded) + } +} + +func TestStripper_Strip_NoCommentSyntax(t *testing.T) { + // Stripper with no comment syntax (e.g., for plaintext) + s := NewStripper("", "", "") + + input := `some text with "quotes" and // things` + // Only strings should be stripped, not comment-like patterns (no comment syntax configured) + want := `some text with and // things` + + got, err := s.Strip(input) + if err != nil { + t.Errorf("Strip() error = %v", err) + return + } + if got != want { + t.Errorf("Strip() mismatch:\ninput: %q\ngot: %q\nwant: %q", input, got, want) + } +} + +func TestStripper_Strip_EdgeCases(t *testing.T) { + s := NewStripper("//", "/*", "*/") + + tests := []struct { + name string + input string + want string + }{ + { + name: "backslash at end of string", + input: `s := "text\\"`, + want: `s := `, + }, + { + name: "multiple backslashes", + input: `s := "text\\\\"`, + want: `s := `, + }, + { + name: "escaped backslash before quote", + input: `s := "text\\\""`, + want: `s := `, + }, + { + name: "only comments", + input: "// line1\n// line2\n// line3", + want: " \n \n ", + }, + { + name: "only strings", + input: `"string1" "string2" "string3"`, + want: ` `, + }, + { + name: "consecutive quotes", + input: `s := """`, + want: `s := `, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := s.Strip(tt.input) + if err != nil { + t.Errorf("Strip() error = %v", err) + return + } + if got != tt.want { + t.Errorf("Strip() mismatch:\ninput: %q\ngot: %q\nwant: %q", tt.input, got, tt.want) + } + }) + } +} + +func TestStripper_Strip_PreservesStructure(t *testing.T) { + s := NewStripper("//", "/*", "*/") + + input := `func calculate() { + // Initialize variables + x := 42 + y := "test { bracket }" + /* + * Multi-line comment + * with { brackets } + */ + return x +}` + + got, err := s.Strip(input) + if err != nil { + t.Errorf("Strip() error = %v", err) + return + } + + // Count brackets in stripped version - only code brackets should remain + openBrackets := strings.Count(got, "{") + closeBrackets := strings.Count(got, "}") + + // Should have exactly 2 { and 2 } from the actual code structure + // (function body has 1 pair, the brackets in string/comment are stripped) + if openBrackets != 1 { + t.Errorf("Strip() open brackets in result = %d, want 1 (brackets in string/comment should be stripped)", openBrackets) + } + if closeBrackets != 1 { + t.Errorf("Strip() close brackets in result = %d, want 1 (brackets in string/comment should be stripped)", closeBrackets) + } + + // Verify line count is preserved + inputLines := strings.Count(input, "\n") + gotLines := strings.Count(got, "\n") + if gotLines != inputLines { + t.Errorf("Strip() line count = %d, want %d", gotLines, inputLines) + } +} + +func TestNewStripper(t *testing.T) { + s := NewStripper("//", "/*", "*/") + + if s.singleLineComment != "//" { + t.Errorf("NewStripper() singleLineComment = %q, want //", s.singleLineComment) + } + if s.multiOpen != "/*" { + t.Errorf("NewStripper() multiOpen = %q, want /*", s.multiOpen) + } + if s.multiClose != "*/" { + t.Errorf("NewStripper() multiClose = %q, want */", s.multiClose) + } + if s.maxDepth != defaultMaxNestingDepth { + t.Errorf("NewStripper() maxDepth = %d, want %d", s.maxDepth, defaultMaxNestingDepth) + } +} diff --git a/internal/truncate/truncate_bench_test.go b/internal/truncate/truncate_bench_test.go new file mode 100644 index 00000000..3b0d9700 --- /dev/null +++ b/internal/truncate/truncate_bench_test.go @@ -0,0 +1,342 @@ +package truncate + +import ( + "fmt" + "strings" + "testing" +) + +// BenchmarkStripper benchmarks the string/comment stripper performance. +func BenchmarkStripper(b *testing.B) { + tests := []struct { + name string + content string + }{ + { + name: "SmallFile", + content: generateGoCode(100), + }, + { + name: "MediumFile", + content: generateGoCode(1000), + }, + { + name: "LargeFile", + content: generateGoCode(10000), + }, + } + + for _, tt := range tests { + b.Run(tt.name, func(b *testing.B) { + stripper := NewStripper("//", "/*", "*/") + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := stripper.Strip(tt.content) + if err != nil { + b.Fatalf("Strip failed: %v", err) + } + } + }) + } +} + +// BenchmarkGoParser benchmarks the Go language parser. +func BenchmarkGoParser(b *testing.B) { + tests := []struct { + name string + content string + }{ + { + name: "SmallFile", + content: generateGoCode(100), + }, + { + name: "MediumFile", + content: generateGoCode(1000), + }, + { + name: "LargeFile", + content: generateGoCode(10000), + }, + } + + for _, tt := range tests { + b.Run(tt.name, func(b *testing.B) { + parser := GoLanguage{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = parser.DetectBlocks(tt.content) + } + }) + } +} + +// BenchmarkTypeScriptParser benchmarks the TypeScript language parser. +func BenchmarkTypeScriptParser(b *testing.B) { + tests := []struct { + name string + content string + }{ + { + name: "SmallFile", + content: generateTypeScriptCode(100), + }, + { + name: "MediumFile", + content: generateTypeScriptCode(1000), + }, + { + name: "LargeFile", + content: generateTypeScriptCode(10000), + }, + } + + for _, tt := range tests { + b.Run(tt.name, func(b *testing.B) { + parser := TypeScriptLanguage{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = parser.DetectBlocks(tt.content) + } + }) + } +} + +// BenchmarkPythonParser benchmarks the Python language parser. +func BenchmarkPythonParser(b *testing.B) { + tests := []struct { + name string + content string + }{ + { + name: "SmallFile", + content: generatePythonCode(100), + }, + { + name: "MediumFile", + content: generatePythonCode(1000), + }, + { + name: "LargeFile", + content: generatePythonCode(10000), + }, + } + + for _, tt := range tests { + b.Run(tt.name, func(b *testing.B) { + parser := PythonLanguage{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = parser.DetectBlocks(tt.content) + } + }) + } +} + +// BenchmarkJavaScriptParser benchmarks the JavaScript language parser. +func BenchmarkJavaScriptParser(b *testing.B) { + tests := []struct { + name string + content string + }{ + { + name: "SmallFile", + content: generateJavaScriptCode(100), + }, + { + name: "MediumFile", + content: generateJavaScriptCode(1000), + }, + { + name: "LargeFile", + content: generateJavaScriptCode(10000), + }, + } + + for _, tt := range tests { + b.Run(tt.name, func(b *testing.B) { + parser := JavaScriptLanguage{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = parser.DetectBlocks(tt.content) + } + }) + } +} + +// BenchmarkFallbackParser benchmarks the fallback line-based parser. +func BenchmarkFallbackParser(b *testing.B) { + tests := []struct { + name string + content string + }{ + { + name: "SmallFile", + content: generatePlainText(100), + }, + { + name: "MediumFile", + content: generatePlainText(1000), + }, + { + name: "LargeFile", + content: generatePlainText(10000), + }, + } + + for _, tt := range tests { + b.Run(tt.name, func(b *testing.B) { + parser := FallbackLanguage{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = parser.DetectBlocks(tt.content) + } + }) + } +} + +// BenchmarkMemoryUsage measures memory allocation for large file processing. +func BenchmarkMemoryUsage(b *testing.B) { + content := generateGoCode(10000) + parser := GoLanguage{} + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _ = parser.DetectBlocks(content) + } +} + +// Helper functions to generate test content + +func generateGoCode(lines int) string { + var builder strings.Builder + + // Package and imports + builder.WriteString("package main\n\n") + builder.WriteString("import (\n") + builder.WriteString("\t\"fmt\"\n") + builder.WriteString("\t\"strings\"\n") + builder.WriteString(")\n\n") + + // Generate functions to fill the target line count + funcCount := (lines - 10) / 15 // Approximate 15 lines per function + if funcCount < 1 { + funcCount = 1 + } + + for i := 0; i < funcCount; i++ { + builder.WriteString(fmt.Sprintf("// Function%d performs operation %d.\n", i, i)) + builder.WriteString(fmt.Sprintf("func Function%d(arg string) string {\n", i)) + builder.WriteString("\t// Comment inside function\n") + builder.WriteString("\tif arg == \"\" {\n") + builder.WriteString("\t\treturn \"default\"\n") + builder.WriteString("\t}\n") + builder.WriteString("\tresult := strings.ToUpper(arg)\n") + builder.WriteString("\tfor i := 0; i < 10; i++ {\n") + builder.WriteString("\t\tresult += fmt.Sprintf(\"_%d\", i)\n") + builder.WriteString("\t}\n") + builder.WriteString("\treturn result\n") + builder.WriteString("}\n\n") + } + + return builder.String() +} + +func generateTypeScriptCode(lines int) string { + var builder strings.Builder + + // Imports + builder.WriteString("import { Component } from '@angular/core';\n") + builder.WriteString("import { HttpClient } from '@angular/common/http';\n\n") + + // Generate classes/functions + funcCount := (lines - 10) / 15 + if funcCount < 1 { + funcCount = 1 + } + + for i := 0; i < funcCount; i++ { + builder.WriteString(fmt.Sprintf("// Function%d documentation\n", i)) + builder.WriteString(fmt.Sprintf("function function%d(arg: string): string {\n", i)) + builder.WriteString("\t// Implementation\n") + builder.WriteString("\tif (arg === '') {\n") + builder.WriteString("\t\treturn 'default';\n") + builder.WriteString("\t}\n") + builder.WriteString("\tconst result = arg.toUpperCase();\n") + builder.WriteString("\tfor (let i = 0; i < 10; i++) {\n") + builder.WriteString("\t\tresult += `_${i}`;\n") + builder.WriteString("\t}\n") + builder.WriteString("\treturn result;\n") + builder.WriteString("}\n\n") + } + + return builder.String() +} + +func generatePythonCode(lines int) string { + var builder strings.Builder + + // Imports + builder.WriteString("import os\n") + builder.WriteString("import sys\n") + builder.WriteString("from typing import Optional\n\n") + + // Generate functions + funcCount := (lines - 10) / 12 + if funcCount < 1 { + funcCount = 1 + } + + for i := 0; i < funcCount; i++ { + builder.WriteString(fmt.Sprintf("def function_%d(arg: str) -> str:\n", i)) + builder.WriteString(fmt.Sprintf(" \"\"\"Function %d documentation.\"\"\"\n", i)) + builder.WriteString(" if not arg:\n") + builder.WriteString(" return 'default'\n") + builder.WriteString(" result = arg.upper()\n") + builder.WriteString(" for i in range(10):\n") + builder.WriteString(" result += f'_{i}'\n") + builder.WriteString(" return result\n\n") + } + + return builder.String() +} + +func generateJavaScriptCode(lines int) string { + var builder strings.Builder + + // Imports + builder.WriteString("const fs = require('fs');\n") + builder.WriteString("const path = require('path');\n\n") + + // Generate functions + funcCount := (lines - 10) / 12 + if funcCount < 1 { + funcCount = 1 + } + + for i := 0; i < funcCount; i++ { + builder.WriteString(fmt.Sprintf("// Function%d implementation\n", i)) + builder.WriteString(fmt.Sprintf("function function%d(arg) {\n", i)) + builder.WriteString("\tif (arg === '') {\n") + builder.WriteString("\t\treturn 'default';\n") + builder.WriteString("\t}\n") + builder.WriteString("\tlet result = arg.toUpperCase();\n") + builder.WriteString("\tfor (let i = 0; i < 10; i++) {\n") + builder.WriteString("\t\tresult += `_${i}`;\n") + builder.WriteString("\t}\n") + builder.WriteString("\treturn result;\n") + builder.WriteString("}\n\n") + } + + return builder.String() +} + +func generatePlainText(lines int) string { + var builder strings.Builder + for i := 0; i < lines; i++ { + builder.WriteString(fmt.Sprintf("Line %d of plain text content.\n", i+1)) + } + return builder.String() +} diff --git a/internal/truncate/truncate_fuzz_test.go b/internal/truncate/truncate_fuzz_test.go new file mode 100644 index 00000000..4fbece8e --- /dev/null +++ b/internal/truncate/truncate_fuzz_test.go @@ -0,0 +1,302 @@ +package truncate + +import ( + "strings" + "testing" + "unicode/utf8" +) + +// FuzzStripper tests the stripper against random inputs. +// Ensures no panics occur regardless of input. +func FuzzStripper(f *testing.F) { + // Seed corpus with interesting test cases + f.Add("func main() { /* comment */ }") + f.Add("\"string with { bracket\"") + f.Add("// comment with } bracket\n") + f.Add("/* multi\nline\ncomment */") + f.Add("`raw string with \" quote`") + f.Add("\\\"escaped quote\\\"") + f.Add(strings.Repeat("{", 100)) + f.Add(strings.Repeat("}", 100)) + f.Add("") + f.Add("\n\n\n") + + f.Fuzz(func(t *testing.T, input string) { + // Ensure input is valid UTF-8 to avoid noise + if !utf8.ValidString(input) { + t.Skip("invalid UTF-8") + } + + // Test Go comment syntax + stripper := NewStripper("//", "/*", "*/") + result, err := stripper.Strip(input) + + // Should either succeed or return known error + if err != nil && err != ErrMaxNestingDepthExceeded { + t.Errorf("unexpected error: %v", err) + } + + // Result should be same length as input (spaces replace stripped content) + if err == nil && len(result) != len(input) { + t.Errorf("length mismatch: got %d, want %d", len(result), len(input)) + } + + // Test Python comment syntax + stripperPy := NewStripper("#", `"""`, `"""`) + resultPy, errPy := stripperPy.Strip(input) + + if errPy != nil && errPy != ErrMaxNestingDepthExceeded { + t.Errorf("unexpected error (Python): %v", errPy) + } + + if errPy == nil && len(resultPy) != len(input) { + t.Errorf("length mismatch (Python): got %d, want %d", len(resultPy), len(input)) + } + }) +} + +// FuzzGoParser tests the Go parser against random inputs. +// Ensures no panics occur regardless of input structure. +func FuzzGoParser(f *testing.F) { + // Seed corpus + f.Add("package main\nfunc main() {}") + f.Add("type MyStruct struct { Field string }") + f.Add("func (r *Receiver) Method() {}") + f.Add("import \"fmt\"") + f.Add("import (\n\"os\"\n\"io\"\n)") + f.Add("// comment\n") + f.Add("") + f.Add("func unclosed() {") + f.Add("func malformed(") + f.Add(strings.Repeat("func f() {}\n", 100)) + + f.Fuzz(func(t *testing.T, input string) { + if !utf8.ValidString(input) { + t.Skip("invalid UTF-8") + } + + parser := GoLanguage{} + + // DetectBlocks should never panic + blocks := parser.DetectBlocks(input) + + // Verify blocks are ordered and valid + for i, block := range blocks { + if block.StartLine < 0 { + t.Errorf("block %d has negative StartLine: %d", i, block.StartLine) + } + if block.EndLine < block.StartLine { + t.Errorf("block %d has EndLine < StartLine: %d < %d", i, block.EndLine, block.StartLine) + } + } + + // DetectImportEnd should never panic + lines := strings.Split(input, "\n") + importEnd := parser.DetectImportEnd(lines) + + if importEnd < 0 { + t.Errorf("DetectImportEnd returned negative value: %d", importEnd) + } + if importEnd > len(lines) { + t.Errorf("DetectImportEnd returned value > line count: %d > %d", importEnd, len(lines)) + } + }) +} + +// FuzzTypeScriptParser tests the TypeScript parser against random inputs. +func FuzzTypeScriptParser(f *testing.F) { + // Seed corpus + f.Add("import { Component } from '@angular/core';") + f.Add("export class MyClass {}") + f.Add("function myFunc() {}") + f.Add("const x = () => {}") + f.Add("interface MyInterface { field: string }") + f.Add("") + f.Add("class unclosed {") + f.Add(strings.Repeat("function f() {}\n", 100)) + + f.Fuzz(func(t *testing.T, input string) { + if !utf8.ValidString(input) { + t.Skip("invalid UTF-8") + } + + parser := TypeScriptLanguage{} + + // Should never panic + blocks := parser.DetectBlocks(input) + + for i, block := range blocks { + if block.StartLine < 0 { + t.Errorf("block %d has negative StartLine: %d", i, block.StartLine) + } + if block.EndLine < block.StartLine { + t.Errorf("block %d has EndLine < StartLine: %d < %d", i, block.EndLine, block.StartLine) + } + } + + lines := strings.Split(input, "\n") + importEnd := parser.DetectImportEnd(lines) + + if importEnd < 0 || importEnd > len(lines) { + t.Errorf("DetectImportEnd out of bounds: %d (line count: %d)", importEnd, len(lines)) + } + }) +} + +// FuzzPythonParser tests the Python parser against random inputs. +func FuzzPythonParser(f *testing.F) { + // Seed corpus + f.Add("import os") + f.Add("from typing import Optional") + f.Add("def my_func():\n pass") + f.Add("class MyClass:\n def method(self):\n pass") + f.Add("@decorator\ndef func():\n pass") + f.Add("") + f.Add("def unclosed():") + f.Add(strings.Repeat("def f():\n pass\n", 100)) + + f.Fuzz(func(t *testing.T, input string) { + if !utf8.ValidString(input) { + t.Skip("invalid UTF-8") + } + + parser := PythonLanguage{} + + // Should never panic + blocks := parser.DetectBlocks(input) + + for i, block := range blocks { + if block.StartLine < 0 { + t.Errorf("block %d has negative StartLine: %d", i, block.StartLine) + } + if block.EndLine < block.StartLine { + t.Errorf("block %d has EndLine < StartLine: %d < %d", i, block.EndLine, block.StartLine) + } + } + + lines := strings.Split(input, "\n") + importEnd := parser.DetectImportEnd(lines) + + if importEnd < 0 || importEnd > len(lines) { + t.Errorf("DetectImportEnd out of bounds: %d (line count: %d)", importEnd, len(lines)) + } + }) +} + +// FuzzJavaScriptParser tests the JavaScript parser against random inputs. +func FuzzJavaScriptParser(f *testing.F) { + // Seed corpus + f.Add("const fs = require('fs');") + f.Add("export function myFunc() {}") + f.Add("class MyClass {}") + f.Add("const arrow = () => {}") + f.Add("") + f.Add("function unclosed() {") + f.Add(strings.Repeat("function f() {}\n", 100)) + + f.Fuzz(func(t *testing.T, input string) { + if !utf8.ValidString(input) { + t.Skip("invalid UTF-8") + } + + parser := JavaScriptLanguage{} + + // Should never panic + blocks := parser.DetectBlocks(input) + + for i, block := range blocks { + if block.StartLine < 0 { + t.Errorf("block %d has negative StartLine: %d", i, block.StartLine) + } + if block.EndLine < block.StartLine { + t.Errorf("block %d has EndLine < StartLine: %d < %d", i, block.EndLine, block.StartLine) + } + } + + lines := strings.Split(input, "\n") + importEnd := parser.DetectImportEnd(lines) + + if importEnd < 0 || importEnd > len(lines) { + t.Errorf("DetectImportEnd out of bounds: %d (line count: %d)", importEnd, len(lines)) + } + }) +} + +// FuzzFallbackParser tests the fallback parser against random inputs. +func FuzzFallbackParser(f *testing.F) { + // Seed corpus + f.Add("plain text") + f.Add("") + f.Add("\n\n\n") + f.Add(strings.Repeat("line\n", 1000)) + f.Add("random { } [ ] ( ) content") + + f.Fuzz(func(t *testing.T, input string) { + if !utf8.ValidString(input) { + t.Skip("invalid UTF-8") + } + + parser := FallbackLanguage{} + + // Should never panic + blocks := parser.DetectBlocks(input) + + // Fallback returns entire content as single block (or empty for empty input) + if input == "" { + if len(blocks) != 0 { + t.Errorf("FallbackLanguage should return empty blocks for empty input, got %d", len(blocks)) + } + } else { + if len(blocks) != 1 { + t.Errorf("FallbackLanguage should return 1 block for non-empty input, got %d", len(blocks)) + } + if len(blocks) == 1 { + if blocks[0].Type != "block" { + t.Errorf("expected block type 'block', got %q", blocks[0].Type) + } + if blocks[0].StartLine != 0 { + t.Errorf("expected StartLine 0, got %d", blocks[0].StartLine) + } + } + } + + lines := strings.Split(input, "\n") + importEnd := parser.DetectImportEnd(lines) + + // Fallback should always return 0 + if importEnd != 0 { + t.Errorf("FallbackLanguage DetectImportEnd should return 0, got %d", importEnd) + } + }) +} + +// FuzzBracketCounting tests bracket counting with malformed inputs. +func FuzzBracketCounting(f *testing.F) { + // Seed corpus with bracket-heavy inputs + f.Add("{{{}}}") + f.Add("}{") + f.Add(strings.Repeat("{", 500)) + f.Add(strings.Repeat("}", 500)) + f.Add("func f() { if true { for { }}}") + f.Add("{ /* { */ }") + f.Add("{ \"string with {\" }") + + f.Fuzz(func(t *testing.T, input string) { + if !utf8.ValidString(input) { + t.Skip("invalid UTF-8") + } + + // Wrap in a function to test bracket counting + content := "package test\nfunc test() {\n" + input + "\n}\n" + + parser := GoLanguage{} + + // Should not panic even with unbalanced brackets + blocks := parser.DetectBlocks(content) + + // Should detect at least the wrapping function + if len(blocks) == 0 { + t.Error("expected at least one block for wrapped function") + } + }) +} diff --git a/sdk/doc.go b/sdk/doc.go index baf4c8b0..54f21427 100644 --- a/sdk/doc.go +++ b/sdk/doc.go @@ -59,9 +59,39 @@ // - LLM Abstraction: Multi-provider support (Anthropic, OpenAI, Ollama) with cost tracking // - Action System: Built-in actions (file, shell, http) plus custom tool registration // - Agent Loops: ReAct-style agent execution with tool use +// - Code Truncation: Language-aware code truncation for context window optimization // - Event Streaming: Real-time events for UI integration // - Security: Credential handling, MCP server trust model // +// # Code Truncation +// +// TruncateCode intelligently shortens code files while preserving structural integrity. +// Instead of naive character or line truncation, it understands code structure and +// truncates at natural boundaries (between functions, after imports, at class boundaries). +// +// Example: Truncate a large file for LLM context: +// +// result, err := sdk.TruncateCode(sourceCode, sdk.TruncateOptions{ +// MaxLines: 500, // Limit to 500 lines +// Language: "go", // Structure-aware Go truncation +// PreserveTop: true, // Keep imports and package declaration +// PreserveFunc: true, // Don't cut mid-function +// }) +// if err != nil { +// log.Fatal(err) +// } +// +// fmt.Printf("Truncated from %d to %d lines\n", +// result.OriginalLines, result.FinalLines) +// fmt.Printf("Omitted %d functions\n", len(result.OmittedItems)) +// +// Supported languages: Go, TypeScript, Python, JavaScript. Unknown languages fall back +// to line-based truncation. Token-based truncation is also supported using the +// MaxTokens option with a chars/4 estimation heuristic. +// +// The function is thread-safe, deterministic, and includes panic recovery for graceful +// error handling. +// // # Architecture // // The SDK wraps existing pkg/* packages with a stable public API: diff --git a/sdk/truncate.go b/sdk/truncate.go new file mode 100644 index 00000000..65690876 --- /dev/null +++ b/sdk/truncate.go @@ -0,0 +1,403 @@ +package sdk + +import ( + "fmt" + "strings" + + "github.com/tombee/conductor/internal/truncate" +) + +// TruncateCode truncates code content while preserving structural integrity. +// It intelligently shortens code files by understanding language structure +// (imports, complete functions, class boundaries) to maximize useful context +// within size constraints. +// +// The function supports Go, TypeScript, Python, and JavaScript with +// language-aware truncation. For unsupported or unspecified languages, +// it falls back to line-based truncation. +// +// TruncateCode is thread-safe and can be called concurrently. It is +// deterministic - the same inputs always produce the same output. +// +// Security Considerations: +// +// This function is designed to safely process untrusted code input with +// the following protections: +// +// - Input Size Protection: Enforces MaxBytes limit (default 10MB) to prevent +// memory exhaustion attacks. Inputs exceeding this limit are rejected +// before processing with ErrInputTooLarge. +// +// - No External I/O: The function operates purely on in-memory strings with +// no file system access, network calls, or external command execution. +// This isolation prevents path traversal, arbitrary file access, or +// command injection vulnerabilities. +// +// - Deterministic Output: All operations are deterministic with no randomness, +// time-based logic, or external state dependencies. The same input always +// produces the same output, preventing timing attacks or non-deterministic +// behavior that could leak information. +// +// - Panic Recovery: All panics are caught and returned as errors to prevent +// crashes. While panics should not occur in normal operation, this defense- +// in-depth approach ensures graceful degradation even with malformed input. +// +// - Nesting Depth Limits: The internal string/comment stripper enforces a +// maximum nesting depth (1000 levels) to prevent stack overflow attacks +// from deeply nested structures or excessive bracket depth. +// +// - No Information Leakage: Error messages never include code content, line +// numbers, or structural details to prevent information disclosure about +// the input or internal processing state. +// +// - Bounded Complexity: All parsing algorithms have linear or near-linear +// time complexity relative to input size. No unbounded loops, recursion, +// or backtracking that could enable algorithmic complexity attacks. +// +// Example usage: +// +// // Truncate a Go file to 500 lines, preserving imports and complete functions +// result, err := sdk.TruncateCode(sourceCode, sdk.TruncateOptions{ +// MaxLines: 500, +// Language: "go", +// PreserveTop: true, +// PreserveFunc: true, +// }) +// if err != nil { +// log.Fatal(err) +// } +// fmt.Printf("Truncated from %d to %d lines\n", +// result.OriginalLines, result.FinalLines) +// +// Returns an error if: +// - Input exceeds MaxBytes limit (defaults to 10MB) +// - Options contain negative values for MaxLines, MaxTokens, or MaxBytes +// +// The function returns metadata about what was removed including: +// - WasTruncated: Whether any content was removed +// - OmittedItems: Details of code blocks that were removed +// - EstimatedTokens: Token estimate using chars/4 heuristic +// - Indicator: Human-readable truncation comment +func TruncateCode(content string, opts TruncateOptions) (result TruncateResult, err error) { + // Recover from any panics and return as errors for graceful handling + defer func() { + if r := recover(); r != nil { + // This shouldn't happen in normal operation, but we want to + // handle it gracefully rather than crashing the caller + err = fmt.Errorf("truncation panic: %v", r) + result = TruncateResult{} + } + }() + + // Step 1: Input validation + if err := validateOptions(opts); err != nil { + return TruncateResult{}, err + } + + maxBytes := opts.MaxBytes + if maxBytes == 0 { + maxBytes = DefaultMaxBytes + } + + if len(content) > maxBytes { + return TruncateResult{}, NewInputTooLargeError() + } + + // Handle empty content + if content == "" { + return TruncateResult{ + Content: "", + WasTruncated: false, + OriginalLines: 0, + FinalLines: 0, + EstimatedTokens: 0, + OmittedItems: []OmittedItem{}, + Indicator: "", + }, nil + } + + // Step 2: Get language parser from internal registry + lang := truncate.GetLanguage(opts.Language) + if lang == nil { + lang = truncate.FallbackLanguage{} + } + + // Split content into lines for processing + lines := strings.Split(content, "\n") + originalLineCount := len(lines) + + // If no limits specified, return original content + if opts.MaxLines <= 0 && opts.MaxTokens <= 0 { + return TruncateResult{ + Content: content, + WasTruncated: false, + OriginalLines: originalLineCount, + FinalLines: originalLineCount, + EstimatedTokens: estimateTokens(content), + OmittedItems: []OmittedItem{}, + Indicator: "", + }, nil + } + + // Step 3: Detect block boundaries using language parser + blocks := lang.DetectBlocks(content) + + // Step 4: Apply PreserveTop logic + importEndLine := 0 + if opts.PreserveTop { + importEndLine = lang.DetectImportEnd(lines) + } + + // Step 5: Apply PreserveFunc logic and truncation + result = applyTruncation(content, lines, blocks, importEndLine, opts, lang) + return result, nil +} + +// validateOptions checks that all options are valid. +func validateOptions(opts TruncateOptions) error { + if opts.MaxLines < 0 { + return NewInvalidOptionsError("MaxLines cannot be negative") + } + if opts.MaxTokens < 0 { + return NewInvalidOptionsError("MaxTokens cannot be negative") + } + if opts.MaxBytes < 0 { + return NewInvalidOptionsError("MaxBytes cannot be negative") + } + return nil +} + +// applyTruncation performs the actual truncation logic. +func applyTruncation(content string, lines []string, blocks []truncate.Block, importEndLine int, opts TruncateOptions, lang truncate.Language) TruncateResult { + var selectedLines []string + var omittedItems []OmittedItem + wasTruncated := false + + // If PreserveFunc is disabled, do simple line-based truncation + if !opts.PreserveFunc { + truncLine := calculateTruncationPoint(lines, opts) + if truncLine < len(lines) { + wasTruncated = true + selectedLines = lines[:truncLine] + + // Calculate omitted content + omittedLineCount := len(lines) - truncLine + if len(blocks) > 0 { + // Count omitted blocks + for _, block := range blocks { + if block.StartLine >= truncLine { + omittedItems = append(omittedItems, OmittedItem{ + Type: block.Type, + Name: block.Name, + StartLine: block.StartLine + 1, // Convert to 1-indexed + EndLine: block.EndLine + 1, // Convert to 1-indexed + }) + } + } + } + + // Add truncation indicator + indicator := generateIndicator(omittedItems, omittedLineCount, lang) + selectedLines = append(selectedLines, indicator) + } else { + selectedLines = lines + } + } else { + // PreserveFunc is enabled - truncate at function boundaries + selectedLines, omittedItems, wasTruncated = preserveFuncTruncation(lines, blocks, importEndLine, opts, lang) + } + + // Build final content + finalContent := strings.Join(selectedLines, "\n") + + return TruncateResult{ + Content: finalContent, + WasTruncated: wasTruncated, + OriginalLines: len(lines), + FinalLines: len(selectedLines), + EstimatedTokens: estimateTokens(finalContent), + OmittedItems: omittedItems, + Indicator: "", + } +} + +// preserveFuncTruncation applies function-preserving truncation logic. +func preserveFuncTruncation(lines []string, blocks []truncate.Block, importEndLine int, opts TruncateOptions, lang truncate.Language) ([]string, []OmittedItem, bool) { + var selectedLines []string + var omittedItems []OmittedItem + currentLine := 0 + + // Step 1: Add imports/header if PreserveTop is enabled + if importEndLine > 0 { + for i := 0; i < importEndLine && i < len(lines); i++ { + selectedLines = append(selectedLines, lines[i]) + currentLine = i + 1 + } + } + + // Check if we've already exceeded limits with just imports + if exceedsLimits(selectedLines, opts) { + // Truncate even the imports + truncLine := calculateTruncationPoint(selectedLines, opts) + selectedLines = selectedLines[:truncLine] + + // Calculate what was omitted + omittedLineCount := len(lines) - len(selectedLines) + for _, block := range blocks { + omittedItems = append(omittedItems, OmittedItem{ + Type: block.Type, + Name: block.Name, + StartLine: block.StartLine + 1, + EndLine: block.EndLine + 1, + }) + } + + indicator := generateIndicator(omittedItems, omittedLineCount, lang) + selectedLines = append(selectedLines, indicator) + return selectedLines, omittedItems, true + } + + // Step 2: Add complete functions from the start until we hit the limit + for _, block := range blocks { + // Skip blocks that are before our current position (already included in imports) + if block.EndLine < currentLine { + continue + } + + // Try to include this block + blockLines := []string{} + + // Add any gap lines between current position and block start + for i := currentLine; i < block.StartLine && i < len(lines); i++ { + blockLines = append(blockLines, lines[i]) + } + + // Add the block itself + for i := block.StartLine; i <= block.EndLine && i < len(lines); i++ { + blockLines = append(blockLines, lines[i]) + } + + // Check if adding this block would exceed limits + testLines := append(selectedLines, blockLines...) + if exceedsLimits(testLines, opts) { + // This block doesn't fit - omit it and all remaining blocks + for _, b := range blocks { + if b.StartLine >= block.StartLine { + omittedItems = append(omittedItems, OmittedItem{ + Type: b.Type, + Name: b.Name, + StartLine: b.StartLine + 1, + EndLine: b.EndLine + 1, + }) + } + } + + // Add truncation indicator + omittedLineCount := len(lines) - len(selectedLines) + indicator := generateIndicator(omittedItems, omittedLineCount, lang) + selectedLines = append(selectedLines, indicator) + return selectedLines, omittedItems, true + } + + // Block fits - include it + selectedLines = testLines + currentLine = block.EndLine + 1 + } + + // Check if we included everything + wasTruncated := currentLine < len(lines) + if wasTruncated { + // There's content after the last block that we need to account for + omittedLineCount := len(lines) - currentLine + indicator := generateIndicator(omittedItems, omittedLineCount, lang) + selectedLines = append(selectedLines, indicator) + } + + return selectedLines, omittedItems, wasTruncated +} + +// exceedsLimits checks if the given lines exceed the specified limits. +func exceedsLimits(lines []string, opts TruncateOptions) bool { + if opts.MaxLines > 0 && len(lines) > opts.MaxLines { + return true + } + + if opts.MaxTokens > 0 { + content := strings.Join(lines, "\n") + tokens := estimateTokens(content) + if tokens > opts.MaxTokens { + return true + } + } + + return false +} + +// calculateTruncationPoint determines where to truncate for line-based truncation. +// Returns the line index where truncation should occur. +func calculateTruncationPoint(lines []string, opts TruncateOptions) int { + // Start with MaxLines if specified + truncPoint := len(lines) + + if opts.MaxLines > 0 && opts.MaxLines < truncPoint { + truncPoint = opts.MaxLines + } + + // Apply MaxTokens constraint if more restrictive + if opts.MaxTokens > 0 { + // Binary search for the line that fits within token limit + for i := 0; i < truncPoint; i++ { + testContent := strings.Join(lines[:i+1], "\n") + if estimateTokens(testContent) > opts.MaxTokens { + truncPoint = i + break + } + } + } + + // Ensure we don't exceed the content length + if truncPoint > len(lines) { + truncPoint = len(lines) + } + + return truncPoint +} + +// generateIndicator creates a truncation indicator comment. +func generateIndicator(omittedItems []OmittedItem, omittedLines int, lang truncate.Language) string { + single, _, _ := lang.CommentSyntax() + + // Count items by type + itemCounts := make(map[string]int) + for _, item := range omittedItems { + itemCounts[item.Type]++ + } + + // Build description + var parts []string + for itemType, count := range itemCounts { + if count == 1 { + parts = append(parts, fmt.Sprintf("1 %s", itemType)) + } else { + parts = append(parts, fmt.Sprintf("%d %ss", count, itemType)) + } + } + + description := strings.Join(parts, ", ") + if description == "" { + description = fmt.Sprintf("%d items", len(omittedItems)) + } + + // Format indicator based on language + if single != "" { + return fmt.Sprintf("%s ... %s omitted (%d lines)", single, description, omittedLines) + } + + // Fallback for unknown languages + return fmt.Sprintf("... %s omitted (%d lines)", description, omittedLines) +} + +// estimateTokens estimates the token count using the chars/4 heuristic. +func estimateTokens(content string) int { + return (len(content) + 3) / 4 // Round up +} diff --git a/sdk/truncate_errors.go b/sdk/truncate_errors.go new file mode 100644 index 00000000..dc59d4f0 --- /dev/null +++ b/sdk/truncate_errors.go @@ -0,0 +1,52 @@ +package sdk + +import "fmt" + +// TruncateError represents an error during code truncation. +type TruncateError struct { + Code string + Message string +} + +func (e *TruncateError) Error() string { + return fmt.Sprintf("truncate error [%s]: %s", e.Code, e.Message) +} + +// Error codes for truncation failures. +const ( + ErrCodeInputTooLarge = "INPUT_TOO_LARGE" + ErrCodeInvalidOptions = "INVALID_OPTIONS" +) + +// Common truncation errors. +var ( + // ErrInputTooLarge indicates the input exceeds MaxBytes limit. + ErrInputTooLarge = &TruncateError{ + Code: ErrCodeInputTooLarge, + Message: "input exceeds maximum size limit", + } + + // ErrInvalidOptions indicates invalid truncation options were provided. + ErrInvalidOptions = &TruncateError{ + Code: ErrCodeInvalidOptions, + Message: "invalid truncation options", + } +) + +// NewInputTooLargeError creates an error for inputs exceeding size limits. +// Does not include actual size to prevent information leakage. +func NewInputTooLargeError() *TruncateError { + return &TruncateError{ + Code: ErrCodeInputTooLarge, + Message: "input exceeds maximum size limit", + } +} + +// NewInvalidOptionsError creates an error for invalid option values. +// Does not include field names or values to prevent information leakage. +func NewInvalidOptionsError(reason string) *TruncateError { + return &TruncateError{ + Code: ErrCodeInvalidOptions, + Message: fmt.Sprintf("invalid options: %s", reason), + } +} diff --git a/sdk/truncate_errors_test.go b/sdk/truncate_errors_test.go new file mode 100644 index 00000000..39c0d38f --- /dev/null +++ b/sdk/truncate_errors_test.go @@ -0,0 +1,184 @@ +package sdk + +import ( + "strings" + "testing" +) + +func TestTruncateError_Error(t *testing.T) { + tests := []struct { + name string + err *TruncateError + wantCode string + wantMsg string + }{ + { + name: "input too large error", + err: &TruncateError{ + Code: ErrCodeInputTooLarge, + Message: "input exceeds maximum size limit", + }, + wantCode: "INPUT_TOO_LARGE", + wantMsg: "input exceeds maximum size limit", + }, + { + name: "invalid options error", + err: &TruncateError{ + Code: ErrCodeInvalidOptions, + Message: "invalid truncation options", + }, + wantCode: "INVALID_OPTIONS", + wantMsg: "invalid truncation options", + }, + { + name: "custom error", + err: &TruncateError{ + Code: "CUSTOM_ERROR", + Message: "custom message", + }, + wantCode: "CUSTOM_ERROR", + wantMsg: "custom message", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.err.Error() + if !strings.Contains(got, tt.wantCode) { + t.Errorf("TruncateError.Error() = %v, want to contain code %v", got, tt.wantCode) + } + if !strings.Contains(got, tt.wantMsg) { + t.Errorf("TruncateError.Error() = %v, want to contain message %v", got, tt.wantMsg) + } + }) + } +} + +func TestNewInputTooLargeError(t *testing.T) { + err := NewInputTooLargeError() + + if err.Code != ErrCodeInputTooLarge { + t.Errorf("NewInputTooLargeError() code = %v, want %v", err.Code, ErrCodeInputTooLarge) + } + + errMsg := err.Error() + if !strings.Contains(errMsg, "input exceeds maximum size limit") { + t.Errorf("NewInputTooLargeError() message = %v, want to contain 'input exceeds maximum size limit'", errMsg) + } + + // Verify no sensitive information in error message + if strings.Contains(errMsg, "bytes") || strings.Contains(errMsg, "MB") { + t.Errorf("NewInputTooLargeError() message contains size information: %v", errMsg) + } +} + +func TestNewInvalidOptionsError(t *testing.T) { + tests := []struct { + name string + reason string + wantInMsg string + wantNotIn []string + }{ + { + name: "negative MaxLines", + reason: "MaxLines must be non-negative", + wantInMsg: "MaxLines must be non-negative", + }, + { + name: "negative MaxTokens", + reason: "MaxTokens must be non-negative", + wantInMsg: "MaxTokens must be non-negative", + }, + { + name: "negative MaxBytes", + reason: "MaxBytes must be non-negative", + wantInMsg: "MaxBytes must be non-negative", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := NewInvalidOptionsError(tt.reason) + + if err.Code != ErrCodeInvalidOptions { + t.Errorf("NewInvalidOptionsError() code = %v, want %v", err.Code, ErrCodeInvalidOptions) + } + + errMsg := err.Error() + if !strings.Contains(errMsg, tt.wantInMsg) { + t.Errorf("NewInvalidOptionsError() message = %v, want to contain %v", errMsg, tt.wantInMsg) + } + + for _, notWant := range tt.wantNotIn { + if strings.Contains(errMsg, notWant) { + t.Errorf("NewInvalidOptionsError() message = %v, should not contain %v", errMsg, notWant) + } + } + }) + } +} + +func TestErrorConstants(t *testing.T) { + // Verify ErrInputTooLarge constant + if ErrInputTooLarge.Code != ErrCodeInputTooLarge { + t.Errorf("ErrInputTooLarge.Code = %v, want %v", ErrInputTooLarge.Code, ErrCodeInputTooLarge) + } + if ErrInputTooLarge.Message == "" { + t.Error("ErrInputTooLarge.Message should not be empty") + } + + // Verify ErrInvalidOptions constant + if ErrInvalidOptions.Code != ErrCodeInvalidOptions { + t.Errorf("ErrInvalidOptions.Code = %v, want %v", ErrInvalidOptions.Code, ErrCodeInvalidOptions) + } + if ErrInvalidOptions.Message == "" { + t.Error("ErrInvalidOptions.Message should not be empty") + } +} + +func TestErrorCodesAreUnique(t *testing.T) { + codes := map[string]bool{ + ErrCodeInputTooLarge: true, + ErrCodeInvalidOptions: true, + } + + if len(codes) != 2 { + t.Errorf("Error codes are not unique, got %d unique codes, want 2", len(codes)) + } +} + +func TestErrorMessageSafety(t *testing.T) { + // Verify that error messages don't leak sensitive information + tests := []struct { + name string + err *TruncateError + forbiddenStrs []string + }{ + { + name: "input too large - no size leak", + err: NewInputTooLargeError(), + forbiddenStrs: []string{ + "10MB", "10485760", "bytes", "size:", "limit:", + }, + }, + { + name: "invalid options with reason", + err: NewInvalidOptionsError("MaxLines must be non-negative"), + // Reason is allowed to be in the message, but not actual values + forbiddenStrs: []string{ + "-1", "-100", "value:", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errMsg := tt.err.Error() + for _, forbidden := range tt.forbiddenStrs { + if strings.Contains(errMsg, forbidden) { + t.Errorf("Error message contains forbidden string %q: %v", forbidden, errMsg) + } + } + }) + } +} diff --git a/sdk/truncate_test.go b/sdk/truncate_test.go new file mode 100644 index 00000000..a8ed7714 --- /dev/null +++ b/sdk/truncate_test.go @@ -0,0 +1,630 @@ +package sdk_test + +import ( + "fmt" + "strings" + "sync" + "testing" + + "github.com/tombee/conductor/sdk" +) + +// TestTruncateCode_US1_GoFilePreservation tests US1: Code Review Agent Truncation +// A 2000-line Go file should truncate to 500 lines without cutting mid-function, +// with imports preserved and a truncation indicator showing what was removed. +func TestTruncateCode_US1_GoFilePreservation(t *testing.T) { + // Generate a large Go file with imports and many functions + var sb strings.Builder + sb.WriteString("package main\n\n") + sb.WriteString("import (\n") + sb.WriteString("\t\"fmt\"\n") + sb.WriteString("\t\"log\"\n") + sb.WriteString(")\n\n") + + // Add 400 functions (5 lines each = 2000 lines) + for i := 1; i <= 400; i++ { + sb.WriteString(fmt.Sprintf("func function%d() {\n", i)) + sb.WriteString(fmt.Sprintf("\tfmt.Println(\"function %d\")\n", i)) + sb.WriteString("}\n\n") + } + + content := sb.String() + originalLines := strings.Count(content, "\n") + 1 + + // Truncate to 500 lines with preservation options + result, err := sdk.TruncateCode(content, sdk.TruncateOptions{ + MaxLines: 500, + Language: "go", + PreserveTop: true, + PreserveFunc: true, + }) + + if err != nil { + t.Fatalf("TruncateCode failed: %v", err) + } + + // AC: Can truncate a 2000-line Go file to 500 lines without cutting mid-function + if !result.WasTruncated { + t.Error("Expected content to be truncated") + } + + if result.FinalLines > 500 { + t.Errorf("Expected at most 500 lines, got %d", result.FinalLines) + } + + if result.OriginalLines != originalLines { + t.Errorf("Expected original line count %d, got %d", originalLines, result.OriginalLines) + } + + // AC: Import section is preserved when PreserveTop is enabled + if !strings.Contains(result.Content, "import (") { + t.Error("Expected imports to be preserved") + } + + // AC: Truncation indicator shows what was removed + if result.Indicator != "" { + t.Errorf("Indicator should be empty in result struct, but got: %s", result.Indicator) + } + + if !strings.Contains(result.Content, "//") && !strings.Contains(result.Content, "omitted") { + t.Error("Expected truncation indicator in content") + } + + // Verify no mid-function cuts by checking that all function declarations have closing braces + lines := strings.Split(result.Content, "\n") + openBraces := 0 + sawFunc := false + for _, line := range lines { + if strings.Contains(line, "func ") { + sawFunc = true + } + openBraces += strings.Count(line, "{") + openBraces -= strings.Count(line, "}") + } + + if sawFunc && openBraces != 0 { + t.Errorf("Expected balanced braces (no mid-function cut), got imbalance: %d", openBraces) + } + + // AC: Output is parseable - basic check for valid Go structure + if !strings.HasPrefix(result.Content, "package main") { + t.Error("Expected output to start with package declaration") + } +} + +// TestTruncateCode_US2_MultiLanguageSupport tests US2: Multi-Language Support +// The function should support TypeScript, Go, Python, and JavaScript with +// language-aware truncation, and fall back gracefully for unknown languages. +func TestTruncateCode_US2_MultiLanguageSupport(t *testing.T) { + tests := []struct { + name string + language string + content string + wantType string // Expected type in omitted items + }{ + { + name: "TypeScript with interfaces and classes", + language: "typescript", + content: `interface User { + name: string; + age: number; +} + +class UserService { + getUser(): User { + return { name: "test", age: 30 }; + } +} + +function processUser() { + console.log("processing"); +}`, + wantType: "function", + }, + { + name: "Go with functions and methods", + language: "go", + content: `package main + +func helper() { + fmt.Println("helper") +} + +type Service struct{} + +func (s Service) Method() { + fmt.Println("method") +}`, + wantType: "function", + }, + { + name: "Python with classes and decorators", + language: "python", + content: `import sys + +@decorator +def decorated_func(): + print("decorated") + +class MyClass: + def method(self): + print("method")`, + wantType: "function", + }, + { + name: "JavaScript with functions and classes", + language: "javascript", + content: `const helper = () => { + console.log("helper"); +}; + +class Service { + method() { + console.log("method"); + } +} + +function process() { + console.log("process"); +}`, + wantType: "function", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := sdk.TruncateCode(tt.content, sdk.TruncateOptions{ + MaxLines: 5, + Language: tt.language, + PreserveTop: true, + PreserveFunc: true, + }) + + if err != nil { + t.Fatalf("TruncateCode failed: %v", err) + } + + // Should truncate and preserve structure + if !result.WasTruncated { + t.Error("Expected content to be truncated") + } + + // Allow 6 lines (5 content + 1 indicator) + if result.FinalLines > 6 { + t.Errorf("Expected at most 6 lines (5 + indicator), got %d", result.FinalLines) + } + }) + } + + // AC: Falls back to line-based truncation for unsupported languages + t.Run("Unsupported language fallback", func(t *testing.T) { + content := strings.Repeat("line\n", 100) + result, err := sdk.TruncateCode(content, sdk.TruncateOptions{ + MaxLines: 50, + Language: "ruby", // Unsupported + }) + + if err != nil { + t.Fatalf("TruncateCode should not error for unsupported language: %v", err) + } + + if !result.WasTruncated { + t.Error("Expected content to be truncated") + } + + if result.FinalLines > 51 { // 50 lines + indicator + t.Errorf("Expected at most 51 lines (50 + indicator), got %d", result.FinalLines) + } + }) + + // AC: Language matching is case-insensitive + t.Run("Case-insensitive language", func(t *testing.T) { + content := `package main +func test() { + fmt.Println("test") +}` + + cases := []string{"GO", "Go", "go", "gO"} + for _, lang := range cases { + result, err := sdk.TruncateCode(content, sdk.TruncateOptions{ + MaxLines: 2, + Language: lang, + }) + + if err != nil { + t.Errorf("TruncateCode failed for language %q: %v", lang, err) + } + + if !result.WasTruncated { + t.Errorf("Expected truncation for language %q", lang) + } + } + }) +} + +// TestTruncateCode_US3_TokenBasedTruncation tests US3: Token-Based Truncation +// The function should support MaxTokens with reasonable accuracy (within 15%). +func TestTruncateCode_US3_TokenBasedTruncation(t *testing.T) { + // Generate content with known character count + var sb strings.Builder + sb.WriteString("package main\n\n") + for i := 0; i < 100; i++ { + // Each function is approximately 60 characters + sb.WriteString(fmt.Sprintf("func f%d() {\n\tfmt.Println(%d)\n}\n\n", i, i)) + } + content := sb.String() + + // Request specific token limit (chars/4 heuristic) + maxTokens := 500 + result, err := sdk.TruncateCode(content, sdk.TruncateOptions{ + MaxTokens: maxTokens, + Language: "go", + PreserveFunc: true, + }) + + if err != nil { + t.Fatalf("TruncateCode failed: %v", err) + } + + // AC: Can specify MaxTokens instead of MaxLines + if !result.WasTruncated { + t.Error("Expected content to be truncated") + } + + // AC: Token estimation is reasonably accurate (within 15%) + if result.EstimatedTokens > maxTokens { + // Allow small overage for the indicator line (15% tolerance) + maxAllowedWithOverage := maxTokens + (maxTokens * 15 / 100) + if result.EstimatedTokens > maxAllowedWithOverage { + t.Errorf("Token estimate %d exceeds MaxTokens %d by more than 15%%", + result.EstimatedTokens, maxTokens) + } + } + + // Test: When both limits specified, more restrictive one applies + t.Run("Both limits specified", func(t *testing.T) { + result, err := sdk.TruncateCode(content, sdk.TruncateOptions{ + MaxLines: 100, // More permissive + MaxTokens: 200, // More restrictive + Language: "go", + PreserveFunc: true, + }) + + if err != nil { + t.Fatalf("TruncateCode failed: %v", err) + } + + // Should respect the more restrictive limit (MaxTokens) + // Allow 15% overage for indicator line + maxAllowedTokens := 200 + (200 * 15 / 100) + if result.EstimatedTokens > maxAllowedTokens { + t.Errorf("Expected tokens to respect MaxTokens=200 with 15%% overage, got %d", result.EstimatedTokens) + } + }) +} + +// TestTruncateCode_US4_FallbackForUnknownLanguages tests US4: Fallback behavior +// Unknown languages should fall back to line-based truncation without errors. +func TestTruncateCode_US4_FallbackForUnknownLanguages(t *testing.T) { + content := strings.Repeat("This is a line of text\n", 100) + + // AC: Unknown languages fall back to line-based truncation + result, err := sdk.TruncateCode(content, sdk.TruncateOptions{ + MaxLines: 50, + Language: "unknown-language", + }) + + if err != nil { + t.Errorf("TruncateCode should not error for unknown language: %v", err) + } + + if !result.WasTruncated { + t.Error("Expected content to be truncated") + } + + // AC: Plaintext files truncate at line boundaries (not mid-line) + lines := strings.Split(result.Content, "\n") + for i, line := range lines[:len(lines)-1] { // Skip last line (indicator) + if !strings.HasPrefix(line, "This is a line") && line != "" { + t.Errorf("Line %d was cut mid-line: %q", i, line) + } + } + + // AC: No errors thrown for unsupported language identifiers + unsupportedLangs := []string{"", "rust", "java", "c++", "ruby", "php"} + for _, lang := range unsupportedLangs { + _, err := sdk.TruncateCode(content, sdk.TruncateOptions{ + MaxLines: 10, + Language: lang, + }) + + if err != nil { + t.Errorf("TruncateCode should not error for language %q: %v", lang, err) + } + } +} + +// TestTruncateCode_ConcurrentCalls tests NFR5: Thread Safety +// The function should be safe for concurrent use. +func TestTruncateCode_ConcurrentCalls(t *testing.T) { + content := `package main + +import "fmt" + +func main() { + fmt.Println("hello") +} + +func helper() { + fmt.Println("helper") +} +` + + const numGoroutines = 50 + var wg sync.WaitGroup + errors := make(chan error, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + result, err := sdk.TruncateCode(content, sdk.TruncateOptions{ + MaxLines: 5, + Language: "go", + PreserveTop: true, + PreserveFunc: true, + }) + + if err != nil { + errors <- fmt.Errorf("goroutine %d: %w", id, err) + return + } + + // Verify deterministic output + if !result.WasTruncated { + errors <- fmt.Errorf("goroutine %d: expected truncation", id) + } + }(i) + } + + wg.Wait() + close(errors) + + // Check for any errors from goroutines + for err := range errors { + t.Error(err) + } +} + +// TestTruncateCode_MalformedInput tests error handling for invalid inputs. +func TestTruncateCode_MalformedInput(t *testing.T) { + tests := []struct { + name string + content string + opts sdk.TruncateOptions + wantErr bool + errContains string + }{ + { + name: "Negative MaxLines", + content: "test", + opts: sdk.TruncateOptions{MaxLines: -1}, + wantErr: true, + errContains: "MaxLines", + }, + { + name: "Negative MaxTokens", + content: "test", + opts: sdk.TruncateOptions{MaxTokens: -1}, + wantErr: true, + errContains: "MaxTokens", + }, + { + name: "Negative MaxBytes", + content: "test", + opts: sdk.TruncateOptions{MaxBytes: -1}, + wantErr: true, + errContains: "MaxBytes", + }, + { + name: "Input exceeds MaxBytes", + content: strings.Repeat("x", 1000), + opts: sdk.TruncateOptions{MaxBytes: 100}, + wantErr: true, + errContains: "INPUT_TOO_LARGE", + }, + { + name: "Empty content", + content: "", + opts: sdk.TruncateOptions{MaxLines: 10}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := sdk.TruncateCode(tt.content, tt.opts) + + if tt.wantErr { + if err == nil { + t.Error("Expected error but got none") + } else if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("Expected error containing %q, got: %v", tt.errContains, err) + } + } else { + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + // For empty content, verify empty result + if tt.content == "" { + if result.WasTruncated { + t.Error("Empty content should not be marked as truncated") + } + if result.Content != "" { + t.Errorf("Expected empty content, got: %q", result.Content) + } + } + } + }) + } +} + +// TestTruncateCode_Deterministic verifies that the function is deterministic. +func TestTruncateCode_Deterministic(t *testing.T) { + content := `package main + +func a() { fmt.Println("a") } +func b() { fmt.Println("b") } +func c() { fmt.Println("c") } +` + + opts := sdk.TruncateOptions{ + MaxLines: 3, + Language: "go", + PreserveFunc: true, + } + + // Run multiple times + var results []sdk.TruncateResult + for i := 0; i < 5; i++ { + result, err := sdk.TruncateCode(content, opts) + if err != nil { + t.Fatalf("Run %d failed: %v", i, err) + } + results = append(results, result) + } + + // All results should be identical + for i := 1; i < len(results); i++ { + if results[i].Content != results[0].Content { + t.Errorf("Run %d produced different content than run 0", i) + } + if results[i].FinalLines != results[0].FinalLines { + t.Errorf("Run %d: FinalLines=%d, expected %d", i, results[i].FinalLines, results[0].FinalLines) + } + } +} + +// TestTruncateCode_PreserveOptions tests the PreserveTop and PreserveFunc options. +func TestTruncateCode_PreserveOptions(t *testing.T) { + content := `package main + +import ( + "fmt" + "log" +) + +func first() { + fmt.Println("first") +} + +func second() { + log.Println("second") +} +` + + t.Run("PreserveTop only", func(t *testing.T) { + result, err := sdk.TruncateCode(content, sdk.TruncateOptions{ + MaxLines: 8, + Language: "go", + PreserveTop: true, + }) + + if err != nil { + t.Fatalf("TruncateCode failed: %v", err) + } + + // Should include imports + if !strings.Contains(result.Content, "import") { + t.Error("Expected imports to be preserved") + } + }) + + t.Run("PreserveFunc only", func(t *testing.T) { + result, err := sdk.TruncateCode(content, sdk.TruncateOptions{ + MaxLines: 8, + Language: "go", + PreserveFunc: true, + }) + + if err != nil { + t.Fatalf("TruncateCode failed: %v", err) + } + + // Should include at least one complete function + if !strings.Contains(result.Content, "func") { + t.Error("Expected at least one function") + } + }) + + t.Run("Both preserve options", func(t *testing.T) { + result, err := sdk.TruncateCode(content, sdk.TruncateOptions{ + MaxLines: 15, + Language: "go", + PreserveTop: true, + PreserveFunc: true, + }) + + if err != nil { + t.Fatalf("TruncateCode failed: %v", err) + } + + // Should include imports and complete functions + if !strings.Contains(result.Content, "import") { + t.Error("Expected imports to be preserved") + } + + if !strings.Contains(result.Content, "func first") { + t.Error("Expected first function to be preserved") + } + }) + + t.Run("Neither preserve option", func(t *testing.T) { + result, err := sdk.TruncateCode(content, sdk.TruncateOptions{ + MaxLines: 5, + Language: "go", + }) + + if err != nil { + t.Fatalf("TruncateCode failed: %v", err) + } + + // Just simple line-based truncation + if result.FinalLines > 6 { // 5 + indicator + t.Errorf("Expected at most 6 lines, got %d", result.FinalLines) + } + }) +} + +// TestTruncateCode_NoTruncationNeeded tests that content within limits is returned unchanged. +func TestTruncateCode_NoTruncationNeeded(t *testing.T) { + content := `package main + +func small() { + fmt.Println("small") +} +` + + result, err := sdk.TruncateCode(content, sdk.TruncateOptions{ + MaxLines: 100, + Language: "go", + }) + + if err != nil { + t.Fatalf("TruncateCode failed: %v", err) + } + + if result.WasTruncated { + t.Error("Content should not be truncated when within limits") + } + + if result.Content != content { + t.Error("Content should be unchanged when within limits") + } + + if len(result.OmittedItems) > 0 { + t.Error("No items should be omitted when within limits") + } +} diff --git a/sdk/truncate_types.go b/sdk/truncate_types.go new file mode 100644 index 00000000..66c11768 --- /dev/null +++ b/sdk/truncate_types.go @@ -0,0 +1,73 @@ +package sdk + +// TruncateOptions configures code truncation behavior. +type TruncateOptions struct { + // MaxLines is the maximum number of lines in the output. + // If 0, no line limit is applied. + MaxLines int + + // MaxTokens is the maximum estimated token count. + // Uses chars/4 heuristic. If 0, no token limit is applied. + MaxTokens int + + // MaxBytes is the maximum input size in bytes. + // If 0, defaults to 10MB. Inputs exceeding this are rejected. + MaxBytes int + + // Language specifies the programming language for structure-aware truncation. + // Supported: "go", "typescript", "python", "javascript". + // Required for structure-aware truncation; empty string uses line-based fallback. + Language string + + // PreserveTop keeps import statements and file headers when true. + PreserveTop bool + + // PreserveFunc avoids cutting in the middle of functions when true. + // Functions are kept from the beginning; omitted from the end. + PreserveFunc bool +} + +// TruncateResult contains the truncation output and metadata. +type TruncateResult struct { + // Content is the truncated code. + Content string + + // WasTruncated indicates whether any content was removed. + WasTruncated bool + + // OriginalLines is the line count of the input. + OriginalLines int + + // FinalLines is the line count of the output. + FinalLines int + + // EstimatedTokens is the estimated token count of the output (chars/4). + EstimatedTokens int + + // OmittedItems lists the code blocks that were removed. + OmittedItems []OmittedItem + + // Indicator is the truncation comment added to the output. + Indicator string +} + +// OmittedItem describes a code block that was removed during truncation. +type OmittedItem struct { + // Type is the kind of block: "function", "method", "class", "interface", "type", "const", "var", "block". + Type string + + // Name is the identifier of the omitted block. + Name string + + // StartLine is the original line number where the block started. + StartLine int + + // EndLine is the original line number where the block ended. + EndLine int +} + +// DefaultMaxBytes is the default maximum input size (10MB). +const DefaultMaxBytes = 10 * 1024 * 1024 + +// DefaultMaxNestingDepth is the maximum bracket nesting depth for parsing. +const DefaultMaxNestingDepth = 1000