Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,14 @@ jobs:
# Steps represent a sequence of tasks that will be executed as part of the job
steps:
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: Mount bazel caches
uses: actions/cache@v2
uses: actions/cache@v4
with:
path: |
"~/.cache/bazel"
"~/.cache/bazel-repo"
key: bazel-cache-${{ hashFiles('**/BUILD.bazel', '**/*.bzl', 'WORKSPACE', '**/*.js') }}
restore-keys: bazel-cache-
- name: bazel test //...
env:
# Bazelisk will download bazel to here
XDG_CACHE_HOME: ~/.cache/bazel-repo
run: bazel test //...
170 changes: 121 additions & 49 deletions python/file_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,57 @@ import (
"os"
"path/filepath"
"strings"
"sync"

sitter "github.com/smacker/go-tree-sitter"
"github.com/smacker/go-tree-sitter/python"
)

// --- Pools & Pre-compiled Query ---

var parserPool = sync.Pool{
New: func() any {
parser := sitter.NewParser()
parser.SetLanguage(python.GetLanguage())
return parser
},
}

var cursorPool = sync.Pool{
New: func() any { return sitter.NewQueryCursor() },
}

var importQuery *sitter.Query

func init() {
var err error
queryString := `
(import_statement) @import
(import_from_statement) @import_from
(comment) @comment
(if_statement) @if_stmt
`
importQuery, err = sitter.NewQuery([]byte(queryString), python.GetLanguage())
if err != nil {
panic(fmt.Sprintf("failed to compile import query: %v", err))
}
}

// --- Constants ---

const (
sitterNodeTypeString = "string"
sitterNodeTypeComment = "comment"
sitterNodeTypeIdentifier = "identifier"
sitterNodeTypeDottedName = "dotted_name"
sitterNodeTypeIfStatement = "if_statement"
sitterNodeTypeAliasedImport = "aliased_import"
sitterNodeTypeWildcardImport = "wildcard_import"
sitterNodeTypeImportStatement = "import_statement"
sitterNodeTypeComparisonOperator = "comparison_operator"
sitterNodeTypeImportFromStatement = "import_from_statement"
)

// --- Types ---

type ParserOutput struct {
FileName string
Modules []module
Expand All @@ -55,9 +88,11 @@ func NewFileParser() *FileParser {
return &FileParser{}
}

func ParseCode(code []byte) (*sitter.Node, error) {
parser := sitter.NewParser()
parser.SetLanguage(python.GetLanguage())
// --- Parsing Logic ---

func parseCode(code []byte) (*sitter.Node, error) {
parser := parserPool.Get().(*sitter.Parser)
defer parserPool.Put(parser)

tree, err := parser.ParseCtx(context.Background(), nil, code)
if err != nil {
Expand All @@ -67,25 +102,21 @@ func ParseCode(code []byte) (*sitter.Node, error) {
return tree.RootNode(), nil
}

func (p *FileParser) parseMain(ctx context.Context, node *sitter.Node) bool {
func (p *FileParser) parseMain(node *sitter.Node) bool {
for i := 0; i < int(node.ChildCount()); i++ {
if err := ctx.Err(); err != nil {
return false
}
child := node.Child(i)
if child.Type() == sitterNodeTypeIfStatement &&
child.Child(1).Type() == sitterNodeTypeComparisonOperator && child.Child(1).Child(1).Type() == "==" {
if child.Type() == "if_statement" &&
child.ChildCount() > 1 &&
child.Child(1).Type() == sitterNodeTypeComparisonOperator &&
child.Child(1).ChildCount() > 2 &&
child.Child(1).Child(1).Type() == "==" {
statement := child.Child(1)
a, b := statement.Child(0), statement.Child(2)
// convert "'__main__' == __name__" to "__name__ == '__main__'"
if b.Type() == sitterNodeTypeIdentifier {
a, b = b, a
}
if a.Type() == sitterNodeTypeIdentifier && a.Content(p.code) == "__name__" &&
// at github.com/smacker/go-tree-sitter@latest (after v0.0.0-20240422154435-0628b34cbf9c we used)
// "__main__" is the second child of b. But now, it isn't.
// we cannot use the latest go-tree-sitter because of the top level reference in scanner.c.
// https://github.com/smacker/go-tree-sitter/blob/04d6b33fe138a98075210f5b770482ded024dc0f/python/scanner.c#L1
b.Type() == sitterNodeTypeString && string(p.code[b.StartByte()+1:b.EndByte()-1]) == "__main__" {
return true
}
Expand All @@ -94,6 +125,24 @@ func (p *FileParser) parseMain(ctx context.Context, node *sitter.Node) bool {
return false
}

func (p *FileParser) isTypeCheckingBlock(node *sitter.Node) bool {
if node.Type() != "if_statement" || node.ChildCount() < 2 {
return false
}
condition := node.Child(1)
if condition.Type() == sitterNodeTypeIdentifier && condition.Content(p.code) == "TYPE_CHECKING" {
return true
}
if condition.Type() == "attribute" && condition.ChildCount() >= 3 {
obj, attr := condition.Child(0), condition.Child(2)
if obj.Type() == sitterNodeTypeIdentifier && obj.Content(p.code) == "typing" &&
attr.Type() == sitterNodeTypeIdentifier && attr.Content(p.code) == "TYPE_CHECKING" {
return true
}
}
return false
}

func parseImportStatement(node *sitter.Node, code []byte) (module, bool) {
switch node.Type() {
case sitterNodeTypeDottedName:
Expand All @@ -112,7 +161,7 @@ func parseImportStatement(node *sitter.Node, code []byte) (module, bool) {
return module{}, false
}

func (p *FileParser) parseImportStatements(node *sitter.Node) bool {
func (p *FileParser) parseImportStatements(node *sitter.Node) {
if node.Type() == sitterNodeTypeImportStatement {
for j := 1; j < int(node.ChildCount()); j++ {
m, ok := parseImportStatement(node.Child(j), p.code)
Expand All @@ -128,7 +177,7 @@ func (p *FileParser) parseImportStatements(node *sitter.Node) bool {
} else if node.Type() == sitterNodeTypeImportFromStatement {
from := node.Child(1).Content(p.code)
if strings.HasPrefix(from, ".") {
return true
return
}
for j := 3; j < int(node.ChildCount()); j++ {
m, ok := parseImportStatement(node.Child(j), p.code)
Expand All @@ -140,18 +189,7 @@ func (p *FileParser) parseImportStatements(node *sitter.Node) bool {
m.Name = fmt.Sprintf("%s.%s", from, m.Name)
p.output.Modules = append(p.output.Modules, m)
}
} else {
return false
}
return true
}

func (p *FileParser) parseComments(node *sitter.Node) bool {
if node.Type() == sitterNodeTypeComment {
p.output.Comments = append(p.output.Comments, comment(node.Content(p.code)))
return true
}
return false
}

func (p *FileParser) SetCodeAndFile(code []byte, relPackagePath, filename string) {
Expand All @@ -160,34 +198,68 @@ func (p *FileParser) SetCodeAndFile(code []byte, relPackagePath, filename string
p.output.FileName = filename
}

func (p *FileParser) parse(ctx context.Context, node *sitter.Node) {
if node == nil {
return
// Parse uses pre-compiled tree-sitter queries to extract imports and comments
// from the parsed AST, replacing the previous recursive traversal approach.
func (p *FileParser) Parse(ctx context.Context) (*ParserOutput, error) {
rootNode, err := parseCode(p.code)
if err != nil {
return nil, err
}
for i := 0; i < int(node.ChildCount()); i++ {

p.output.HasMain = p.parseMain(rootNode)

cursor := cursorPool.Get().(*sitter.QueryCursor)
defer cursorPool.Put(cursor)

cursor.Exec(importQuery, rootNode)

seenImports := make(map[uint32]bool)

for {
if err := ctx.Err(); err != nil {
return
}
child := node.Child(i)
if p.parseImportStatements(child) {
continue
return nil, err
}
if p.parseComments(child) {
continue
match, ok := cursor.NextMatch()
if !ok {
break
}
p.parse(ctx, child)
}
}

func (p *FileParser) Parse(ctx context.Context) (*ParserOutput, error) {
rootNode, err := ParseCode(p.code)
if err != nil {
return nil, err
}
for _, capture := range match.Captures {
captureName := importQuery.CaptureNameForId(capture.Index)
node := capture.Node

switch captureName {
case "import", "import_from":
if seenImports[node.StartByte()] {
continue
}
seenImports[node.StartByte()] = true
p.parseImportStatements(node)

case "comment":
p.output.Comments = append(p.output.Comments, comment(node.Content(p.code)))

p.output.HasMain = p.parseMain(ctx, rootNode)
case "if_stmt":
if p.isTypeCheckingBlock(node) {
for j := 0; j < int(node.ChildCount()); j++ {
subChild := node.Child(j)
if subChild.Type() == "block" {
for k := 0; k < int(subChild.ChildCount()); k++ {
stmt := subChild.Child(k)
if stmt.Type() == sitterNodeTypeImportStatement || stmt.Type() == sitterNodeTypeImportFromStatement {
if !seenImports[stmt.StartByte()] {
seenImports[stmt.StartByte()] = true
p.parseImportStatements(stmt)
}
}
}
}
}
}
}
}
}

p.parse(ctx, rootNode)
return &p.output, nil
}

Expand Down
34 changes: 26 additions & 8 deletions python/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"log"
"os"
"path/filepath"
"regexp"
"sort"
"strings"

Expand Down Expand Up @@ -207,13 +206,26 @@ func (py *Python) GenerateRules(args language.GenerateArgs) language.GenerateRes

parser := newPython3Parser(args.Config.RepoRoot, args.Rel, cfg.IgnoresDependency)

// Parse all Python files once upfront into a lookup table (LUT).
// This avoids re-parsing the same file multiple times when it appears
// in multiple targets (e.g. __init__.py in per-file mode).
allPyFiles := treeset.NewWith(godsutils.StringComparator)
pyLibraryFilenames.Each(func(_ int, v interface{}) { allPyFiles.Add(v) })
pyTestFilenames.Each(func(_ int, v interface{}) { allPyFiles.Add(v) })
djangoTestFilesNames.Each(func(_ int, v interface{}) { allPyFiles.Add(v) })

parseLUT, err := parser.parseAllToLUT(allPyFiles)
if err != nil {
log.Fatalf("ERROR: %v\n", err)
}

var result language.GenerateResult
result.Gen = make([]*rule.Rule, 0)

collisionErrors := singlylinkedlist.New()

appendPyLibrary := func(srcs *treeset.Set, pyLibraryTargetName string) {
allDeps, mainModules, annotations, err := parser.parse(srcs)
allDeps, mainModules, annotations, err := parser.parseFromLUT(srcs, parseLUT)
if err != nil {
log.Fatalf("ERROR: %v\n", err)
}
Expand Down Expand Up @@ -280,7 +292,7 @@ func (py *Python) GenerateRules(args language.GenerateArgs) language.GenerateRes

var pyTestTargets []*targetBuilder
newPyTestTargetBuilder := func(srcs *treeset.Set, pyTestTargetName string) *targetBuilder {
deps, _, annotations, err := parser.parse(srcs)
deps, _, annotations, err := parser.parseFromLUT(srcs, parseLUT)
if err != nil {
log.Fatalf("ERROR: %v\n", err)
}
Expand All @@ -302,7 +314,7 @@ func (py *Python) GenerateRules(args language.GenerateArgs) language.GenerateRes
generateImportsAttribute()
}
newDjangoTestBuilder := func(srcs *treeset.Set, djangoTestTargetName string) *targetBuilder {
deps, _, annotations, err := parser.parse(srcs)
deps, _, annotations, err := parser.parseFromLUT(srcs, parseLUT)
if err != nil {
log.Fatalf("ERROR: %v\n", err)
}
Expand Down Expand Up @@ -442,10 +454,9 @@ func ensureNoCollision(file *rule.File, targetName, kind string) error {
return nil
}

// isDjangoTestFile returns whether the given path contains the following
// regex regexp.MustCompile(`from django\.test import.*TestCase|pytest\.mark\.django_db|gazelle: django_test`)
// isDjangoTestFile returns whether the given path contains django test markers:
// "from django.test import...TestCase", "pytest.mark.django_db", or "gazelle: django_test".
func isDjangoTestFile(path string) bool {
re := regexp.MustCompile(`from django\.test import.*TestCase|pytest\.mark\.django_db|gazelle: django_test`)
file, err := os.Open(path)
if err != nil {
log.Fatalf("ERROR: %v\n", err)
Expand All @@ -454,7 +465,14 @@ func isDjangoTestFile(path string) bool {
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
if re.MatchString(scanner.Text()) {
line := scanner.Text()
if strings.Contains(line, "pytest.mark.django_db") {
return true
}
if strings.Contains(line, "gazelle: django_test") {
return true
}
if strings.Contains(line, "django.test") && strings.Contains(line, "TestCase") {
return true
}
}
Expand Down
Loading