diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 998a912..f6bc568 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -19,9 +19,9 @@ jobs: # Steps represent a sequence of tasks that will be executed as part of the job steps: # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Mount bazel caches - uses: actions/cache@v2 + uses: actions/cache@v4 with: path: | "~/.cache/bazel" @@ -29,7 +29,4 @@ jobs: key: bazel-cache-${{ hashFiles('**/BUILD.bazel', '**/*.bzl', 'WORKSPACE', '**/*.js') }} restore-keys: bazel-cache- - name: bazel test //... - env: - # Bazelisk will download bazel to here - XDG_CACHE_HOME: ~/.cache/bazel-repo run: bazel test //... diff --git a/python/file_parser.go b/python/file_parser.go index a2b22c2..04bfed4 100644 --- a/python/file_parser.go +++ b/python/file_parser.go @@ -20,17 +20,48 @@ import ( "os" "path/filepath" "strings" + "sync" sitter "github.com/smacker/go-tree-sitter" "github.com/smacker/go-tree-sitter/python" ) +// --- Pools & Pre-compiled Query --- + +var parserPool = sync.Pool{ + New: func() any { + parser := sitter.NewParser() + parser.SetLanguage(python.GetLanguage()) + return parser + }, +} + +var cursorPool = sync.Pool{ + New: func() any { return sitter.NewQueryCursor() }, +} + +var importQuery *sitter.Query + +func init() { + var err error + queryString := ` + (import_statement) @import + (import_from_statement) @import_from + (comment) @comment + (if_statement) @if_stmt + ` + importQuery, err = sitter.NewQuery([]byte(queryString), python.GetLanguage()) + if err != nil { + panic(fmt.Sprintf("failed to compile import query: %v", err)) + } +} + +// --- Constants --- + const ( sitterNodeTypeString = "string" - sitterNodeTypeComment = "comment" sitterNodeTypeIdentifier = "identifier" sitterNodeTypeDottedName = "dotted_name" - sitterNodeTypeIfStatement = "if_statement" sitterNodeTypeAliasedImport = "aliased_import" sitterNodeTypeWildcardImport = "wildcard_import" sitterNodeTypeImportStatement = "import_statement" @@ -38,6 +69,8 @@ const ( sitterNodeTypeImportFromStatement = "import_from_statement" ) +// --- Types --- + type ParserOutput struct { FileName string Modules []module @@ -55,9 +88,11 @@ func NewFileParser() *FileParser { return &FileParser{} } -func ParseCode(code []byte) (*sitter.Node, error) { - parser := sitter.NewParser() - parser.SetLanguage(python.GetLanguage()) +// --- Parsing Logic --- + +func parseCode(code []byte) (*sitter.Node, error) { + parser := parserPool.Get().(*sitter.Parser) + defer parserPool.Put(parser) tree, err := parser.ParseCtx(context.Background(), nil, code) if err != nil { @@ -67,14 +102,14 @@ func ParseCode(code []byte) (*sitter.Node, error) { return tree.RootNode(), nil } -func (p *FileParser) parseMain(ctx context.Context, node *sitter.Node) bool { +func (p *FileParser) parseMain(node *sitter.Node) bool { for i := 0; i < int(node.ChildCount()); i++ { - if err := ctx.Err(); err != nil { - return false - } child := node.Child(i) - if child.Type() == sitterNodeTypeIfStatement && - child.Child(1).Type() == sitterNodeTypeComparisonOperator && child.Child(1).Child(1).Type() == "==" { + if child.Type() == "if_statement" && + child.ChildCount() > 1 && + child.Child(1).Type() == sitterNodeTypeComparisonOperator && + child.Child(1).ChildCount() > 2 && + child.Child(1).Child(1).Type() == "==" { statement := child.Child(1) a, b := statement.Child(0), statement.Child(2) // convert "'__main__' == __name__" to "__name__ == '__main__'" @@ -82,10 +117,6 @@ func (p *FileParser) parseMain(ctx context.Context, node *sitter.Node) bool { a, b = b, a } if a.Type() == sitterNodeTypeIdentifier && a.Content(p.code) == "__name__" && - // at github.com/smacker/go-tree-sitter@latest (after v0.0.0-20240422154435-0628b34cbf9c we used) - // "__main__" is the second child of b. But now, it isn't. - // we cannot use the latest go-tree-sitter because of the top level reference in scanner.c. - // https://github.com/smacker/go-tree-sitter/blob/04d6b33fe138a98075210f5b770482ded024dc0f/python/scanner.c#L1 b.Type() == sitterNodeTypeString && string(p.code[b.StartByte()+1:b.EndByte()-1]) == "__main__" { return true } @@ -94,6 +125,24 @@ func (p *FileParser) parseMain(ctx context.Context, node *sitter.Node) bool { return false } +func (p *FileParser) isTypeCheckingBlock(node *sitter.Node) bool { + if node.Type() != "if_statement" || node.ChildCount() < 2 { + return false + } + condition := node.Child(1) + if condition.Type() == sitterNodeTypeIdentifier && condition.Content(p.code) == "TYPE_CHECKING" { + return true + } + if condition.Type() == "attribute" && condition.ChildCount() >= 3 { + obj, attr := condition.Child(0), condition.Child(2) + if obj.Type() == sitterNodeTypeIdentifier && obj.Content(p.code) == "typing" && + attr.Type() == sitterNodeTypeIdentifier && attr.Content(p.code) == "TYPE_CHECKING" { + return true + } + } + return false +} + func parseImportStatement(node *sitter.Node, code []byte) (module, bool) { switch node.Type() { case sitterNodeTypeDottedName: @@ -112,7 +161,7 @@ func parseImportStatement(node *sitter.Node, code []byte) (module, bool) { return module{}, false } -func (p *FileParser) parseImportStatements(node *sitter.Node) bool { +func (p *FileParser) parseImportStatements(node *sitter.Node) { if node.Type() == sitterNodeTypeImportStatement { for j := 1; j < int(node.ChildCount()); j++ { m, ok := parseImportStatement(node.Child(j), p.code) @@ -128,7 +177,7 @@ func (p *FileParser) parseImportStatements(node *sitter.Node) bool { } else if node.Type() == sitterNodeTypeImportFromStatement { from := node.Child(1).Content(p.code) if strings.HasPrefix(from, ".") { - return true + return } for j := 3; j < int(node.ChildCount()); j++ { m, ok := parseImportStatement(node.Child(j), p.code) @@ -140,18 +189,7 @@ func (p *FileParser) parseImportStatements(node *sitter.Node) bool { m.Name = fmt.Sprintf("%s.%s", from, m.Name) p.output.Modules = append(p.output.Modules, m) } - } else { - return false } - return true -} - -func (p *FileParser) parseComments(node *sitter.Node) bool { - if node.Type() == sitterNodeTypeComment { - p.output.Comments = append(p.output.Comments, comment(node.Content(p.code))) - return true - } - return false } func (p *FileParser) SetCodeAndFile(code []byte, relPackagePath, filename string) { @@ -160,34 +198,68 @@ func (p *FileParser) SetCodeAndFile(code []byte, relPackagePath, filename string p.output.FileName = filename } -func (p *FileParser) parse(ctx context.Context, node *sitter.Node) { - if node == nil { - return +// Parse uses pre-compiled tree-sitter queries to extract imports and comments +// from the parsed AST, replacing the previous recursive traversal approach. +func (p *FileParser) Parse(ctx context.Context) (*ParserOutput, error) { + rootNode, err := parseCode(p.code) + if err != nil { + return nil, err } - for i := 0; i < int(node.ChildCount()); i++ { + + p.output.HasMain = p.parseMain(rootNode) + + cursor := cursorPool.Get().(*sitter.QueryCursor) + defer cursorPool.Put(cursor) + + cursor.Exec(importQuery, rootNode) + + seenImports := make(map[uint32]bool) + + for { if err := ctx.Err(); err != nil { - return - } - child := node.Child(i) - if p.parseImportStatements(child) { - continue + return nil, err } - if p.parseComments(child) { - continue + match, ok := cursor.NextMatch() + if !ok { + break } - p.parse(ctx, child) - } -} -func (p *FileParser) Parse(ctx context.Context) (*ParserOutput, error) { - rootNode, err := ParseCode(p.code) - if err != nil { - return nil, err - } + for _, capture := range match.Captures { + captureName := importQuery.CaptureNameForId(capture.Index) + node := capture.Node + + switch captureName { + case "import", "import_from": + if seenImports[node.StartByte()] { + continue + } + seenImports[node.StartByte()] = true + p.parseImportStatements(node) + + case "comment": + p.output.Comments = append(p.output.Comments, comment(node.Content(p.code))) - p.output.HasMain = p.parseMain(ctx, rootNode) + case "if_stmt": + if p.isTypeCheckingBlock(node) { + for j := 0; j < int(node.ChildCount()); j++ { + subChild := node.Child(j) + if subChild.Type() == "block" { + for k := 0; k < int(subChild.ChildCount()); k++ { + stmt := subChild.Child(k) + if stmt.Type() == sitterNodeTypeImportStatement || stmt.Type() == sitterNodeTypeImportFromStatement { + if !seenImports[stmt.StartByte()] { + seenImports[stmt.StartByte()] = true + p.parseImportStatements(stmt) + } + } + } + } + } + } + } + } + } - p.parse(ctx, rootNode) return &p.output, nil } diff --git a/python/generate.go b/python/generate.go index cfcd738..2a67dfa 100644 --- a/python/generate.go +++ b/python/generate.go @@ -21,7 +21,6 @@ import ( "log" "os" "path/filepath" - "regexp" "sort" "strings" @@ -207,13 +206,26 @@ func (py *Python) GenerateRules(args language.GenerateArgs) language.GenerateRes parser := newPython3Parser(args.Config.RepoRoot, args.Rel, cfg.IgnoresDependency) + // Parse all Python files once upfront into a lookup table (LUT). + // This avoids re-parsing the same file multiple times when it appears + // in multiple targets (e.g. __init__.py in per-file mode). + allPyFiles := treeset.NewWith(godsutils.StringComparator) + pyLibraryFilenames.Each(func(_ int, v interface{}) { allPyFiles.Add(v) }) + pyTestFilenames.Each(func(_ int, v interface{}) { allPyFiles.Add(v) }) + djangoTestFilesNames.Each(func(_ int, v interface{}) { allPyFiles.Add(v) }) + + parseLUT, err := parser.parseAllToLUT(allPyFiles) + if err != nil { + log.Fatalf("ERROR: %v\n", err) + } + var result language.GenerateResult result.Gen = make([]*rule.Rule, 0) collisionErrors := singlylinkedlist.New() appendPyLibrary := func(srcs *treeset.Set, pyLibraryTargetName string) { - allDeps, mainModules, annotations, err := parser.parse(srcs) + allDeps, mainModules, annotations, err := parser.parseFromLUT(srcs, parseLUT) if err != nil { log.Fatalf("ERROR: %v\n", err) } @@ -280,7 +292,7 @@ func (py *Python) GenerateRules(args language.GenerateArgs) language.GenerateRes var pyTestTargets []*targetBuilder newPyTestTargetBuilder := func(srcs *treeset.Set, pyTestTargetName string) *targetBuilder { - deps, _, annotations, err := parser.parse(srcs) + deps, _, annotations, err := parser.parseFromLUT(srcs, parseLUT) if err != nil { log.Fatalf("ERROR: %v\n", err) } @@ -302,7 +314,7 @@ func (py *Python) GenerateRules(args language.GenerateArgs) language.GenerateRes generateImportsAttribute() } newDjangoTestBuilder := func(srcs *treeset.Set, djangoTestTargetName string) *targetBuilder { - deps, _, annotations, err := parser.parse(srcs) + deps, _, annotations, err := parser.parseFromLUT(srcs, parseLUT) if err != nil { log.Fatalf("ERROR: %v\n", err) } @@ -442,10 +454,9 @@ func ensureNoCollision(file *rule.File, targetName, kind string) error { return nil } -// isDjangoTestFile returns whether the given path contains the following -// regex regexp.MustCompile(`from django\.test import.*TestCase|pytest\.mark\.django_db|gazelle: django_test`) +// isDjangoTestFile returns whether the given path contains django test markers: +// "from django.test import...TestCase", "pytest.mark.django_db", or "gazelle: django_test". func isDjangoTestFile(path string) bool { - re := regexp.MustCompile(`from django\.test import.*TestCase|pytest\.mark\.django_db|gazelle: django_test`) file, err := os.Open(path) if err != nil { log.Fatalf("ERROR: %v\n", err) @@ -454,7 +465,14 @@ func isDjangoTestFile(path string) bool { defer file.Close() scanner := bufio.NewScanner(file) for scanner.Scan() { - if re.MatchString(scanner.Text()) { + line := scanner.Text() + if strings.Contains(line, "pytest.mark.django_db") { + return true + } + if strings.Contains(line, "gazelle: django_test") { + return true + } + if strings.Contains(line, "django.test") && strings.Contains(line, "TestCase") { return true } } diff --git a/python/parser.go b/python/parser.go index 1b2a90d..9a942cb 100644 --- a/python/parser.go +++ b/python/parser.go @@ -18,6 +18,7 @@ import ( "context" _ "embed" "fmt" + "runtime" "strings" "github.com/emirpasic/gods/sets/treeset" @@ -50,47 +51,61 @@ func newPython3Parser( } } -// parseSingle parses a single Python file and returns the extracted modules -// from the import statements as well as the parsed comments. -func (p *python3Parser) parseSingle(pyFilename string) (*treeset.Set, map[string]*treeset.Set, *annotations, error) { - pyFilenames := treeset.NewWith(godsutils.StringComparator) - pyFilenames.Add(pyFilename) - return p.parse(pyFilenames) -} - -// parse parses multiple Python files and returns the extracted modules from -// the import statements as well as the parsed comments. -func (p *python3Parser) parse(pyFilenames *treeset.Set) (*treeset.Set, map[string]*treeset.Set, *annotations, error) { - modules := treeset.NewWith(moduleComparator) +// parseAllToLUT parses all Python files concurrently and returns a lookup table +// mapping filename to its parsed output. This allows parsing each file exactly +// once, even when results are consumed by multiple targets. +func (p *python3Parser) parseAllToLUT(pyFilenames *treeset.Set) (map[string]*ParserOutput, error) { + values := pyFilenames.Values() + lut := make(map[string]*ParserOutput, len(values)) g, ctx := errgroup.WithContext(context.Background()) - ch := make(chan struct{}, 6) // Limit the number of concurrent parses. - chRes := make(chan *ParserOutput, len(pyFilenames.Values())) - for _, v := range pyFilenames.Values() { - ch <- struct{}{} - g.Go(func(filename string) func() error { - return func() error { - defer func() { - <-ch - }() - res, err := NewFileParser().ParseFile(ctx, p.repoRoot, p.relPackagePath, filename) - if err != nil { - return err - } - chRes <- res - return nil + g.SetLimit(runtime.NumCPU()) + + results := make([]*ParserOutput, len(values)) + filenames := make([]string, len(values)) + + for i, v := range values { + filenames[i] = v.(string) + } + + for i, filename := range filenames { + i, filename := i, filename + g.Go(func() error { + res, err := NewFileParser().ParseFile(ctx, p.repoRoot, p.relPackagePath, filename) + if err != nil { + return err } - }(v.(string))) + results[i] = res + return nil + }) } + if err := g.Wait(); err != nil { - return nil, nil, nil, err + return nil, err + } + + for i, filename := range filenames { + lut[filename] = results[i] } - close(ch) - close(chRes) - mainModules := make(map[string]*treeset.Set, len(chRes)) + + return lut, nil +} + +// parseFromLUT processes pre-parsed results from a LUT for the given filenames, +// applying annotation and dependency filtering. +func (p *python3Parser) parseFromLUT(pyFilenames *treeset.Set, lut map[string]*ParserOutput) (*treeset.Set, map[string]*treeset.Set, *annotations, error) { + modules := treeset.NewWith(moduleComparator) + mainModules := make(map[string]*treeset.Set) allAnnotations := new(annotations) allAnnotations.ignore = make(map[string]struct{}) - for res := range chRes { + + for _, v := range pyFilenames.Values() { + filename := v.(string) + res, ok := lut[filename] + if !ok { + return nil, nil, nil, fmt.Errorf("file %q not found in parse LUT", filename) + } + if res.HasMain { mainModules[res.FileName] = treeset.NewWith(moduleComparator) } @@ -100,25 +115,18 @@ func (p *python3Parser) parse(pyFilenames *treeset.Set) (*treeset.Set, map[strin } for _, m := range res.Modules { - // Check for ignored dependencies set via an annotation to the Python - // module. if annotations.ignores(m.Name) || annotations.ignores(m.From) { continue } - - // Check for ignored dependencies set via a Gazelle directive in a BUILD - // file. if p.ignoresDependency(m.Name) || p.ignoresDependency(m.From) { continue } - modules.Add(m) if res.HasMain { mainModules[res.FileName].Add(m) } } - // Collect all annotations from each file into a single annotations struct. for k, v := range annotations.ignore { allAnnotations.ignore[k] = v } @@ -130,6 +138,24 @@ func (p *python3Parser) parse(pyFilenames *treeset.Set) (*treeset.Set, map[strin return modules, mainModules, allAnnotations, nil } +// parseSingle parses a single Python file and returns the extracted modules +// from the import statements as well as the parsed comments. +func (p *python3Parser) parseSingle(pyFilename string) (*treeset.Set, map[string]*treeset.Set, *annotations, error) { + pyFilenames := treeset.NewWith(godsutils.StringComparator) + pyFilenames.Add(pyFilename) + return p.parse(pyFilenames) +} + +// parse parses multiple Python files and returns the extracted modules from +// the import statements as well as the parsed comments. +func (p *python3Parser) parse(pyFilenames *treeset.Set) (*treeset.Set, map[string]*treeset.Set, *annotations, error) { + lut, err := p.parseAllToLUT(pyFilenames) + if err != nil { + return nil, nil, nil, err + } + return p.parseFromLUT(pyFilenames, lut) +} + // removeDupesFromStringTreeSetSlice takes a []string, makes a set out of the // elements, and then returns a new []string with all duplicates removed. Order // is preserved.