diff --git a/v2/parser/parse.go b/v2/parser/parse.go index 4c1efa00..91220129 100644 --- a/v2/parser/parse.go +++ b/v2/parser/parse.go @@ -637,6 +637,10 @@ func (p *Parser) convertSignature(u types.Universe, t *gotypes.Signature) *types // walkType adds the type, and any necessary child types. func (p *Parser) walkType(u types.Universe, useName *types.Name, in gotypes.Type) *types.Type { + var out *types.Type + defer func() { + out.InitMultiverse(u) + }() // Most of the cases are underlying types of the named type. name := goNameToName(in.String()) if useName != nil { @@ -645,13 +649,13 @@ func (p *Parser) walkType(u types.Universe, useName *types.Name, in gotypes.Type // Handle alias types conditionally on go1.22+. // Inline this once the minimum supported version is go1.22 - if out := p.walkAliasType(u, in); out != nil { + if out = p.walkAliasType(u, in); out != nil { return out } switch t := in.(type) { case *gotypes.Struct: - out := u.Type(name) + out = u.Type(name) out.GoType = in if out.Kind != types.Unknown { return out @@ -670,7 +674,7 @@ func (p *Parser) walkType(u types.Universe, useName *types.Name, in gotypes.Type } return out case *gotypes.Map: - out := u.Type(name) + out = u.Type(name) out.GoType = in if out.Kind != types.Unknown { return out @@ -680,7 +684,7 @@ func (p *Parser) walkType(u types.Universe, useName *types.Name, in gotypes.Type out.Key = p.walkType(u, nil, t.Key()) return out case *gotypes.Pointer: - out := u.Type(name) + out = u.Type(name) out.GoType = in if out.Kind != types.Unknown { return out @@ -689,7 +693,7 @@ func (p *Parser) walkType(u types.Universe, useName *types.Name, in gotypes.Type out.Elem = p.walkType(u, nil, t.Elem()) return out case *gotypes.Slice: - out := u.Type(name) + out = u.Type(name) out.GoType = in if out.Kind != types.Unknown { return out @@ -698,7 +702,7 @@ func (p *Parser) walkType(u types.Universe, useName *types.Name, in gotypes.Type out.Elem = p.walkType(u, nil, t.Elem()) return out case *gotypes.Array: - out := u.Type(name) + out = u.Type(name) out.GoType = in if out.Kind != types.Unknown { return out @@ -708,7 +712,7 @@ func (p *Parser) walkType(u types.Universe, useName *types.Name, in gotypes.Type out.Len = in.(*gotypes.Array).Len() return out case *gotypes.Chan: - out := u.Type(name) + out = u.Type(name) out.GoType = in if out.Kind != types.Unknown { return out @@ -719,7 +723,7 @@ func (p *Parser) walkType(u types.Universe, useName *types.Name, in gotypes.Type // cannot be properly written. return out case *gotypes.Basic: - out := u.Type(types.Name{ + out = u.Type(types.Name{ Package: "", // This is a magic package name in the Universe. Name: t.Name(), }) @@ -730,7 +734,7 @@ func (p *Parser) walkType(u types.Universe, useName *types.Name, in gotypes.Type out.Kind = types.Unsupported return out case *gotypes.Signature: - out := u.Type(name) + out = u.Type(name) out.GoType = in if out.Kind != types.Unknown { return out @@ -739,7 +743,7 @@ func (p *Parser) walkType(u types.Universe, useName *types.Name, in gotypes.Type out.Signature = p.convertSignature(u, t) return out case *gotypes.Interface: - out := u.Type(name) + out = u.Type(name) out.GoType = in if out.Kind != types.Unknown { return out @@ -758,7 +762,6 @@ func (p *Parser) walkType(u types.Universe, useName *types.Name, in gotypes.Type } return out case *gotypes.Named: - var out *types.Type switch t.Underlying().(type) { case *gotypes.Named, *gotypes.Basic, *gotypes.Map, *gotypes.Slice: name := goNameToName(t.String()) @@ -785,7 +788,7 @@ func (p *Parser) walkType(u types.Universe, useName *types.Name, in gotypes.Type name.Name = fmt.Sprintf("%s[%s]", strings.SplitN(name.Name, "[", 2)[0], strings.Join(tpNames, ",")) } - if out := u.Type(name); out.Kind != types.Unknown { + if out = u.Type(name); out.Kind != types.Unknown { out.GoType = in return out // short circuit if we've already made this. } @@ -797,7 +800,7 @@ func (p *Parser) walkType(u types.Universe, useName *types.Name, in gotypes.Type // "feature" for users. This flattens those types // together. name := goNameToName(t.String()) - if out := u.Type(name); out.Kind != types.Unknown { + if out = u.Type(name); out.Kind != types.Unknown { return out // short circuit if we've already made this. } out = p.walkType(u, &name, t.Underlying()) @@ -827,7 +830,7 @@ func (p *Parser) walkType(u types.Universe, useName *types.Name, in gotypes.Type Kind: types.TypeParam, } default: - out := u.Type(name) + out = u.Type(name) out.GoType = in if out.Kind != types.Unknown { return out diff --git a/v2/parser/parse_test.go b/v2/parser/parse_test.go index 99654252..ff762a7c 100644 --- a/v2/parser/parse_test.go +++ b/v2/parser/parse_test.go @@ -1124,6 +1124,7 @@ func TestStructParse(t *testing.T) { } opts := []cmp.Option{ cmpopts.IgnoreFields(types.Type{}, "GoType"), + cmpopts.IgnoreFields(types.Type{}, "multiverse"), } if e, a := expected, st; !cmp.Equal(e, a, opts...) { t.Errorf("wanted, got:\n%#v\n%#v\n%s", e, a, cmp.Diff(e, a, opts...)) diff --git a/v2/types/types.go b/v2/types/types.go index dab11d96..35b020eb 100644 --- a/v2/types/types.go +++ b/v2/types/types.go @@ -19,6 +19,7 @@ package types import ( gotypes "go/types" "strings" + "sync" ) // Ref makes a reference to the given type. It can only be used for e.g. @@ -364,6 +365,50 @@ type Type struct { // The underlying Go type. GoType gotypes.Type + + // The reference to Multiverse + multiverse *multiverse +} + +// multiverse holds Type definitions that were not found in the imported code but are needed during +// generation. For example the imported code my have T but not *T, while the generated code needs *T. This +// can't be part of the main Universe because that is a simple map, with no locking, and we all know what happens +// when you modify a map while it is being iterated. Storing the multiverse alongside a Type ensures that there's +// at most one *Type for every type, and maintains the invariant that PointerTo(String) == PointerTo(String). +type multiverse struct { + real Universe + mu sync.Mutex + synthetic map[string]*Type +} + +// InitMultiverse inits a multiverse for a Type. +// It panics if called twice. +func (t *Type) InitMultiverse(u Universe) { + if t.multiverse != nil { + panic("Can't initialize a non-empty multiverse on Type") + } + t.multiverse = &multiverse{ + real: u, + mu: sync.Mutex{}, + synthetic: map[string]*Type{}, + } +} + +// GetOrAddType searches a Type in the Universe and synthetic map. +// If there is a matching name, return the Type, otherwise, create the Type. +func (m *multiverse) GetOrAddType(t *Type) *Type { + if p, ok := m.real[t.Name.Package]; ok { + if t, ok := p.Types[t.Name.Name]; ok { + return t + } + } + m.mu.Lock() + defer m.mu.Unlock() + if t, ok := m.synthetic[t.Name.Name]; ok { + return t + } + m.synthetic[t.Name.Name] = t + return t } // String returns the name of the type. @@ -556,13 +601,14 @@ var ( ) func PointerTo(t *Type) *Type { - return &Type{ + pt := &Type{ Name: Name{ Name: "*" + t.Name.String(), }, Kind: Pointer, Elem: t, } + return t.multiverse.GetOrAddType(pt) } func IsInteger(t *Type) bool { diff --git a/v2/types/types_test.go b/v2/types/types_test.go index 35cc4d61..8e6861bf 100644 --- a/v2/types/types_test.go +++ b/v2/types/types_test.go @@ -17,6 +17,8 @@ limitations under the License. package types import ( + "reflect" + "sync" "testing" ) @@ -37,6 +39,95 @@ func TestGetBuiltin(t *testing.T) { } } +func TestPointerTo(t *testing.T) { + type1 := &Type{ + Name: Name{Package: "pkgname", Name: "structname"}, + Kind: Struct, + } + type2 := &Type{ + Name: Name{Package: "pkgname", Name: "secondstructname"}, + Kind: Struct, + } + + u := Universe{ + "pkgname": &Package{ + Types: map[string]*Type{ + "structname": type1, + "secondstructname": type2, + }, + }, + "": &Package{ + Types: map[string]*Type{ + "*pkgname.structname": &Type{ + Name: Name{Name: "*pkgname.structname"}, + Kind: Pointer, + }, + }, + }, + } + + type3 := &Type{ + Name: Name{Package: "pkgname", Name: "thridstructname"}, + Kind: Struct, + multiverse: &multiverse{ + real: u, + synthetic: map[string]*Type{ + "*pkgname.thridstructname": &Type{ + Name: Name{Name: "*pkgname.thridstructname"}, + Kind: Pointer, + }, + }, + mu: sync.Mutex{}, + }, + } + + testCases := []struct { + name string + tp *Type + expected *Type + expectCreation bool + }{ + { + name: "universe has the pointer type", + tp: type1, + expected: u[""].Types["*pkgname.structname"], + expectCreation: false, + }, + { + name: "neither universe or cache has the pointer type", + tp: type2, + expected: &Type{ + Name: Name{Name: "*pkgname.secondstructname"}, + Kind: Pointer, + Elem: type2, + }, + expectCreation: true, + }, + { + name: "cache has the pointer type", + tp: type3, + expected: type3.multiverse.synthetic["*pkgname.thridstructname"], + expectCreation: false, + }, + } + for _, tc := range testCases { + if tc.tp.multiverse == nil { + tc.tp.multiverse = &multiverse{ + real: u, + synthetic: map[string]*Type{}, + mu: sync.Mutex{}, + } + } + tp := PointerTo(tc.tp) + if tc.expectCreation && !reflect.DeepEqual(tp, tc.expected) { + t.Errorf("PointerTo failed, expected %v, got : %v", tc.expected, tp) + } + if !tc.expectCreation && tp != tc.expected { + t.Errorf("PointerTo should not create a new pointer type, expected %v, got : %v", tc.expected, tp) + } + } +} + func TestGetMarker(t *testing.T) { u := Universe{} n := Name{Package: "path/to/package", Name: "Foo"}