Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions internal/implementations/extractor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package impls

import (
"fmt"
"go/types"
"log/slog"

"github.com/scip-code/scip-go/internal/loader"
"github.com/scip-code/scip-go/internal/lookup"
"github.com/scip-code/scip/bindings/go/scip"
"golang.org/x/tools/go/packages"
"golang.org/x/tools/go/types/typeutil"
)

type Extractor struct {
global *lookup.Global
methodSetCache typeutil.MethodSetCache
}

func NewExtractor(global *lookup.Global) *Extractor {
return &Extractor{
global: global,
}
}

func (e *Extractor) Extract(pkgLookup loader.PackageLookup) (map[string]ImplDef, map[string]ImplDef) {
interfaces := map[string]ImplDef{}
concretes := map[string]ImplDef{}

for _, pkg := range pkgLookup {
if pkg.Name == "builtin" {
continue
}

if pkg.TypesInfo != nil {
e.extractLocal(pkg, interfaces, concretes)
} else if pkg.Types != nil {
e.extractRemote(pkg, interfaces, concretes)
} else {
slog.Warn("No types for package", "path", pkg.PkgPath)
}
}

return interfaces, concretes
}

func (e *Extractor) extractLocal(pkg *packages.Package, interfaces, concretes map[string]ImplDef) {
pkgSymbols := e.global.GetPackage(pkg)
if pkgSymbols == nil {
slog.Warn("No symbols for package", "path", pkg.PkgPath)
return
}

for ident, obj := range pkg.TypesInfo.Defs {
if obj == nil {
continue
}

typeName, ok := obj.(*types.TypeName)
if !ok {
continue
}

if pkg.Types != nil && typeName.Parent() != pkg.Types.Scope() {
continue
}

named, ok := obj.Type().(*types.Named)
if !ok {
continue
}

sym, ok := pkgSymbols.Get(typeName.Pos())
if !ok {
slog.Debug(
"No symbol for package-level named type",
"identifier", ident.Name,
"package", pkg.PkgPath,
"id", obj.Id(),
)
continue
}

e.classify(named, sym, pkg.PkgPath, interfaces, concretes)
}
}

func (e *Extractor) extractRemote(pkg *packages.Package, interfaces, concretes map[string]ImplDef) {
scope := pkg.Types.Scope()

for _, name := range scope.Names() {
typeName, ok := scope.Lookup(name).(*types.TypeName)
if !ok || !typeName.Exported() {
continue
}

named, ok := typeName.Type().(*types.Named)
if !ok {
continue
}

sym := e.global.Composer().Compose(pkg, typeName)
if sym == "" {
continue
}

e.classify(named, &scip.SymbolInformation{Symbol: sym}, pkg.PkgPath, interfaces, concretes)
}
}

func (e *Extractor) classify(
named *types.Named,
sym *scip.SymbolInformation,
pkgPath string,
interfaces, concretes map[string]ImplDef,
) {
methods := typeutil.IntuitiveMethodSet(named, &e.methodSetCache)
if len(methods) == 0 {
return
}

methodSymbols := map[methodID]*scip.SymbolInformation{}
for _, method := range methods {
sym, ok, err := e.global.GetSymbolOfObject(method.Obj())
if err != nil {
slog.Debug(fmt.Sprintf("Error while looking for symbol %s | %s", err, method.Obj()))
continue
}
if !ok {
continue
}

methodSymbols[methodID(method.Obj().Id())] = sym
}

impl := ImplDef{
Symbol: sym,
Named: named,
Methods: methodSymbols,
Mask: methodMask(methods),
MethodCount: len(methods),
HasUnexported: hasUnexportedMethods(methods),
PkgPath: pkgPath,
}
if types.IsInterface(named) {
interfaces[impl.Symbol.Symbol] = impl
} else {
concretes[impl.Symbol.Symbol] = impl
}
}
163 changes: 16 additions & 147 deletions internal/implementations/implementations.go
Original file line number Diff line number Diff line change
@@ -1,33 +1,23 @@
package impls

import (
"fmt"
"go/ast"
"go/types"
"hash/crc32"
"log/slog"
"sync"
"sync/atomic"

"github.com/scip-code/scip-go/internal/implementations/fingerprint"
"github.com/scip-code/scip-go/internal/loader"
"github.com/scip-code/scip-go/internal/lookup"
"github.com/scip-code/scip-go/internal/output"
"github.com/scip-code/scip/bindings/go/scip"
"golang.org/x/tools/go/packages"
"golang.org/x/tools/go/types/typeutil"
)

// methodID is a unique identifier for a method, using types.Id semantics
// (package-path-qualified for unexported methods, just the name for exported).
type methodID string

type ImplDef struct {
// The corresponding scip symbol, generated via previous iteration over the AST
Symbol *scip.SymbolInformation

Pkg *packages.Package
Ident *ast.Ident
Symbol *scip.SymbolInformation
Named *types.Named
Methods map[methodID]*scip.SymbolInformation

Expand Down Expand Up @@ -69,19 +59,9 @@ func hasUnexportedMethods(methods []*types.Selection) bool {
return false
}

func findImplementations(concreteTypes map[string]ImplDef, interfaces map[string]ImplDef, symbols *lookup.Global, count *uint64) {
func findImplementations(concreteTypes map[string]ImplDef, interfaces map[string]ImplDef, count *uint64) {
for _, ty := range concreteTypes {
pos := ty.Ident.Pos()
sym, ok := symbols.GetSymbolInformation(ty.Pkg, pos)
if !ok {
panic(fmt.Sprintf("Could not find symbol for %s", ty.Symbol))
}

for _, iface := range interfaces {
if iface.Ident == nil {
continue
}

ifaceType, ok := iface.Named.Underlying().(*types.Interface)
if !ok {
continue
Expand Down Expand Up @@ -114,7 +94,7 @@ func findImplementations(concreteTypes map[string]ImplDef, interfaces map[string
}

// Add implementation details for the struct & interface relationship
sym.Relationships = append(sym.Relationships, &scip.Relationship{
ty.Symbol.Relationships = append(ty.Symbol.Relationships, &scip.Relationship{
Symbol: iface.Symbol.Symbol,
IsImplementation: true,
})
Expand All @@ -133,23 +113,23 @@ func findImplementations(concreteTypes map[string]ImplDef, interfaces map[string
}
}

ty.Symbol.Relationships = scip.CanonicalizeRelationships(ty.Symbol.Relationships)
for _, method := range ty.Methods {
method.Relationships = scip.CanonicalizeRelationships(method.Relationships)
}

atomic.AddUint64(count, 1)
}
}

func AddImplementationRelationships(
pkgs loader.PackageLookup,
allPackages loader.PackageLookup,
symbols *lookup.Global,
extractor *Extractor,
) ([]*scip.SymbolInformation, error) {
var externalSymbols []*scip.SymbolInformation

var msCache typeutil.MethodSetCache
localInterfaces, localTypes, err := extractInterfacesAndConcreteTypes(
pkgs, symbols, &msCache)
if err != nil {
return nil, err
}
localInterfaces, localTypes := extractor.Extract(pkgs)

remotePackages := make(loader.PackageLookup)
for pkgID, pkg := range allPackages {
Expand All @@ -159,11 +139,7 @@ func AddImplementationRelationships(

remotePackages[pkgID] = pkg
}
remoteInterfaces, remoteTypes, err := extractInterfacesAndConcreteTypes(
remotePackages, symbols, &msCache)
if err != nil {
return nil, err
}
remoteInterfaces, remoteTypes := extractor.Extract(remotePackages)

// Total concrete types to check across the three passes.
total := uint64(len(localTypes)*2 + len(remoteTypes))
Expand All @@ -175,131 +151,24 @@ func AddImplementationRelationships(
defer wg.Done()

// local type -> local interface
findImplementations(localTypes, localInterfaces, symbols, &count)
findImplementations(localTypes, localInterfaces, &count)

// local type -> remote interface
findImplementations(localTypes, remoteInterfaces, symbols, &count)
findImplementations(localTypes, remoteInterfaces, &count)

// remote type -> local interface
// We emit these as external symbols so index consumer can merge them.
findImplementations(remoteTypes, localInterfaces, symbols, &count)
findImplementations(remoteTypes, localInterfaces, &count)
}()

output.WithProgressParallel(&wg, "Indexing Implementations", &count, total)

// Collect remote type symbols that gained relationships
for _, typ := range remoteTypes {
if sym, ok := symbols.GetSymbolInformation(typ.Pkg, typ.Ident.Pos()); ok {
if len(sym.Relationships) > 0 {
externalSymbols = append(externalSymbols, sym)
}
if len(typ.Symbol.Relationships) > 0 {
externalSymbols = append(externalSymbols, typ.Symbol)
}
}

return externalSymbols, nil
}

func extractInterfacesAndConcreteTypes(
pkgs loader.PackageLookup,
symbols *lookup.Global,
msCache *typeutil.MethodSetCache,
) (interfaces map[string]ImplDef, concreteTypes map[string]ImplDef, err error) {
interfaces = map[string]ImplDef{}
concreteTypes = map[string]ImplDef{}

for _, pkg := range pkgs {
// Builtin isn't the same as standard library, that is for builtin types
// We don't need to check those for implemenations.
if pkg.Name == "builtin" {
continue
}

if pkg.TypesInfo == nil {
slog.Warn("No types for package", "path", pkg.PkgPath)
continue
}

pkgSymbols := symbols.GetPackage(pkg)
if pkgSymbols == nil {
slog.Warn("No symbols for package", "path", pkg.PkgPath)
continue
}

for ident, obj := range pkg.TypesInfo.Defs {
if obj == nil {
continue
}

// We ignore aliases 'type M = N' to avoid duplicate reporting
// of the Named type N.
obj, ok := obj.(*types.TypeName)
if !ok {
continue
}

// Skip types declared inside function bodies — the type visitor
// only indexes package-level declarations, so local types will
// never have a symbol entry.
if pkg.Types != nil && obj.Parent() != pkg.Types.Scope() {
continue
}

objType, ok := obj.Type().(*types.Named)
if !ok {
continue
}

symbol, ok := pkgSymbols.Get(obj.Pos())
if !ok {
slog.Debug(
"No symbol for package-level named type",
"identifier", ident.Name, "package", pkg.PkgPath, "id", obj.Id())
continue
}

methods := typeutil.IntuitiveMethodSet(objType, msCache)

// ignore interfaces that are empty. they are too
// plentiful and don't provide useful intelligence.
if len(methods) == 0 {
continue
}

methodIds := map[methodID]*scip.SymbolInformation{}
for _, m := range methods {
sym, ok, err := symbols.GetSymbolOfObject(m.Obj())
if err != nil {
slog.Debug(fmt.Sprintf("Error while looking for symbol %s | %s", err, m.Obj()))
continue
}

if !ok {
continue
}

methodIds[methodID(m.Obj().Id())] = sym
}

d := ImplDef{
Symbol: symbol,
Pkg: pkg,
Ident: ident,
Named: objType,
Methods: methodIds,
Mask: methodMask(methods),
MethodCount: len(methods),
HasUnexported: hasUnexportedMethods(methods),
PkgPath: pkg.PkgPath,
}

if types.IsInterface(objType) {
interfaces[d.Symbol.Symbol] = d
} else {
concreteTypes[d.Symbol.Symbol] = d
}

}
}

return
}
Loading
Loading