Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ tests/
Go 1.23+: Follow standard conventions

## Recent Changes
- 013-message-include-filter: Added Go 1.25+ + `github.com/emicklei/proto` v1.14.3, `gopkg.in/yaml.v3`
- 012-fix-include-unannotated: Added Go 1.25+ + `github.com/emicklei/proto` v1.14.3, `gopkg.in/yaml.v3`
- 011-combined-include-exclude: Added Go 1.25+ (existing project) + `github.com/emicklei/proto` v1.14.3, `github.com/emicklei/proto-contrib` v0.18.3, `gopkg.in/yaml.v3`
- 010-substitution-placeholders: Added Go 1.25+ + `github.com/emicklei/proto` v1.14.3, `gopkg.in/yaml.v3`


<!-- MANUAL ADDITIONS START -->
Expand Down
169 changes: 168 additions & 1 deletion internal/filter/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,165 @@ func IncludeMethodsByAnnotation(def *proto.Proto, annotations []string) int {
return removed
}

// CollectIncludeMessageRoots returns the FQNs of messages and enums that
// have a matching include annotation. These can be passed to
// RemoveOrphanedDefinitions as pinned roots to prevent them from being
// removed as orphans.
func CollectIncludeMessageRoots(def *proto.Proto, annotations []string) map[string]bool {
if len(annotations) == 0 {
return nil
}
annotSet := make(map[string]bool, len(annotations))
for _, a := range annotations {
annotSet[a] = true
}

pkg := ""
for _, elem := range def.Elements {
if p, ok := elem.(*proto.Package); ok {
pkg = p.Name
break
}
}

roots := make(map[string]bool)
for _, elem := range def.Elements {
switch v := elem.(type) {
case *proto.Message:
for _, a := range ExtractAnnotations(v.Comment) {
if annotSet[a] {
roots[qualifiedName(pkg, v.Name)] = true
break
}
}
case *proto.Enum:
for _, a := range ExtractAnnotations(v.Comment) {
if annotSet[a] {
roots[qualifiedName(pkg, v.Name)] = true
break
}
}
}
}
return roots
}

// IncludeMessagesByAnnotation removes top-level messages and enums from the
// proto AST whose comments do NOT contain any of the specified include
// annotations and are not referenced (directly or transitively) by an
// annotated message. Non-message/non-enum elements pass through unchanged.
// Returns the number of removed messages/enums.
func IncludeMessagesByAnnotation(def *proto.Proto, annotations []string) int {
if len(annotations) == 0 {
return 0
}
annotSet := make(map[string]bool, len(annotations))
for _, a := range annotations {
annotSet[a] = true
}

// Extract package name for qualified name resolution
pkg := ""
for _, elem := range def.Elements {
if p, ok := elem.(*proto.Package); ok {
pkg = p.Name
break
}
}

// Phase 1: Identify annotated messages/enums as roots, and collect
// all types referenced by remaining services (services are not filtered here)
roots := make(map[string]bool)
for _, elem := range def.Elements {
switch v := elem.(type) {
case *proto.Message:
annots := ExtractAnnotations(v.Comment)
for _, a := range annots {
if annotSet[a] {
roots[qualifiedName(pkg, v.Name)] = true
break
}
}
case *proto.Enum:
annots := ExtractAnnotations(v.Comment)
for _, a := range annots {
if annotSet[a] {
roots[qualifiedName(pkg, v.Name)] = true
break
}
}
case *proto.Service:
// Services are kept/removed by IncludeServicesByAnnotation.
// Any service still in the AST is kept — collect its references.
for _, svcElem := range v.Elements {
if rpc, ok := svcElem.(*proto.RPC); ok {
if rpc.RequestType != "" {
roots[qualifiedName(pkg, rpc.RequestType)] = true
}
if rpc.ReturnsType != "" {
roots[qualifiedName(pkg, rpc.ReturnsType)] = true
}
}
}
}
}

// Phase 2: Collect transitive dependencies of roots
keep := make(map[string]bool)
for fqn := range roots {
keep[fqn] = true
}

// Iteratively collect references from kept messages
for {
refs := make(map[string]bool)
for _, elem := range def.Elements {
if msg, ok := elem.(*proto.Message); ok {
fqn := qualifiedName(pkg, msg.Name)
if keep[fqn] {
collectMessageRefs(refs, pkg, msg)
}
}
}
added := false
for ref := range refs {
if !keep[ref] {
keep[ref] = true
added = true
}
}
if !added {
break
}
}

// Phase 3: Remove messages/enums not in the keep set
filtered := make([]proto.Visitee, 0, len(def.Elements))
removed := 0
for _, elem := range def.Elements {
switch v := elem.(type) {
case *proto.Message:
fqn := qualifiedName(pkg, v.Name)
if keep[fqn] {
filtered = append(filtered, elem)
} else {
removed++
}
case *proto.Enum:
fqn := qualifiedName(pkg, v.Name)
if keep[fqn] {
filtered = append(filtered, elem)
} else {
removed++
}
default:
filtered = append(filtered, elem)
}
}
def.Elements = filtered
return removed
}

// FilterFieldsByAnnotation removes individual message fields from the proto
// AST whose comments contain any of the specified annotations. Handles
// NormalField, MapField, and OneOfField (within Oneof containers). Also
Expand Down Expand Up @@ -519,11 +678,19 @@ func isUserType(typeName string) bool {

// RemoveOrphanedDefinitions iteratively removes messages and enums
// that are no longer referenced by any remaining RPC method or message.
// Optional pinned FQNs are always kept (treated as roots).
// Returns the total count of removed definitions.
func RemoveOrphanedDefinitions(def *proto.Proto, pkg string) int {
func RemoveOrphanedDefinitions(def *proto.Proto, pkg string, pinned ...map[string]bool) int {
var pinnedSet map[string]bool
if len(pinned) > 0 {
pinnedSet = pinned[0]
}
totalRemoved := 0
for {
refs := CollectReferencedTypes(def, pkg)
for fqn := range pinnedSet {
refs[fqn] = true
}

filtered := make([]proto.Visitee, 0, len(def.Elements))
removed := 0
Expand Down
96 changes: 96 additions & 0 deletions internal/filter/filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2989,6 +2989,102 @@ func TestCombinedSameAnnotationInBothLists(t *testing.T) {
}
}

// --- Include Message Filtering Tests (Feature 013) ---

// T003: Basic include message filtering
func TestIncludeMessagesByAnnotation(t *testing.T) {
def := &proto.Proto{
Elements: []proto.Visitee{
&proto.Message{
Name: "AnnotatedMessage",
Comment: &proto.Comment{
Lines: []string{" [PublishedApi]"},
},
},
&proto.Message{
Name: "UnannotatedMessage",
},
&proto.Enum{
Name: "AnnotatedEnum",
Comment: &proto.Comment{
Lines: []string{" [PublishedApi]"},
},
},
},
}

removed := IncludeMessagesByAnnotation(def, []string{"PublishedApi"})
if removed != 1 {
t.Errorf("expected 1 removed (UnannotatedMessage), got %d", removed)
}

// Check remaining elements
names := make(map[string]bool)
for _, elem := range def.Elements {
switch v := elem.(type) {
case *proto.Message:
names[v.Name] = true
case *proto.Enum:
names[v.Name] = true
}
}
if !names["AnnotatedMessage"] {
t.Error("AnnotatedMessage should remain (has [PublishedApi])")
}
if names["UnannotatedMessage"] {
t.Error("UnannotatedMessage should be removed (no annotations)")
}
if !names["AnnotatedEnum"] {
t.Error("AnnotatedEnum should remain (has [PublishedApi])")
}
}

// T004: All messages removed when none match
func TestIncludeMessagesByAnnotation_NoAnnotations(t *testing.T) {
def := &proto.Proto{
Elements: []proto.Visitee{
&proto.Message{Name: "Msg1"},
&proto.Message{Name: "Msg2"},
},
}

removed := IncludeMessagesByAnnotation(def, []string{"PublishedApi"})
if removed != 2 {
t.Errorf("expected 2 removed, got %d", removed)
}

for _, elem := range def.Elements {
if _, ok := elem.(*proto.Message); ok {
t.Error("no messages should remain")
}
}
}

// T005: Empty annotation list returns 0
func TestIncludeMessagesByAnnotation_EmptyList(t *testing.T) {
def := &proto.Proto{
Elements: []proto.Visitee{
&proto.Message{Name: "Msg1"},
&proto.Message{Name: "Msg2"},
},
}

removed := IncludeMessagesByAnnotation(def, []string{})
if removed != 0 {
t.Errorf("expected 0 removed for empty annotation list, got %d", removed)
}

count := 0
for _, elem := range def.Elements {
if _, ok := elem.(*proto.Message); ok {
count++
}
}
if count != 2 {
t.Errorf("expected 2 messages to remain, got %d", count)
}
}

func testdataDir(t *testing.T, sub string) string {
t.Helper()
dir, err := filepath.Abs(filepath.Join("..", "..", "testdata", sub))
Expand Down
13 changes: 9 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,12 @@ func run() int {

// Pass 1: Prune, filter, convert block comments, collect annotations
type processedFile struct {
pf parsedFile
pf parsedFile
skip bool // true if file has no remaining definitions after filtering
}
processed := make([]processedFile, 0, len(parsed))
servicesRemoved := 0
messagesRemoved := 0
methodsRemoved := 0
fieldsRemoved := 0
orphansRemoved := 0
Expand All @@ -190,9 +191,12 @@ func run() int {
skip := false
// Annotation-based filtering
if cfg != nil && cfg.HasAnnotations() {
var sr, mr, fr int
var sr, msgr, mr, fr int
var includeRoots map[string]bool
if cfg.HasAnnotationInclude() {
sr += filter.IncludeServicesByAnnotation(pf.def, cfg.Annotations.Include)
includeRoots = filter.CollectIncludeMessageRoots(pf.def, cfg.Annotations.Include)
msgr += filter.IncludeMessagesByAnnotation(pf.def, cfg.Annotations.Include)
if !cfg.HasAnnotationExclude() {
// Include-only mode: also filter methods by include annotations
mr += filter.IncludeMethodsByAnnotation(pf.def, cfg.Annotations.Include)
Expand All @@ -204,11 +208,12 @@ func run() int {
fr += filter.FilterFieldsByAnnotation(pf.def, cfg.Annotations.Exclude)
}
servicesRemoved += sr
messagesRemoved += msgr
methodsRemoved += mr
fieldsRemoved += fr
filter.RemoveEmptyServices(pf.def)
if sr > 0 || mr > 0 || fr > 0 {
orphansRemoved += filter.RemoveOrphanedDefinitions(pf.def, pf.pkg)
orphansRemoved += filter.RemoveOrphanedDefinitions(pf.def, pf.pkg, includeRoots)
}

if !filter.HasRemainingDefinitions(pf.def) {
Expand Down Expand Up @@ -290,7 +295,7 @@ func run() int {
fmt.Fprintf(os.Stderr, "proto-filter: processed %d files, %d definitions\n", len(files), totalDefs)
fmt.Fprintf(os.Stderr, "proto-filter: included %d definitions, excluded %d\n", includedCount, excludedCount)
if cfg != nil && cfg.HasAnnotations() {
fmt.Fprintf(os.Stderr, "proto-filter: removed %d services by annotation, %d methods by annotation, %d fields by annotation, %d orphaned definitions\n", servicesRemoved, methodsRemoved, fieldsRemoved, orphansRemoved)
fmt.Fprintf(os.Stderr, "proto-filter: removed %d services by annotation, %d messages by annotation, %d methods by annotation, %d fields by annotation, %d orphaned definitions\n", servicesRemoved, messagesRemoved, methodsRemoved, fieldsRemoved, orphansRemoved)
}
if cfg != nil && cfg.HasSubstitutions() {
fmt.Fprintf(os.Stderr, "proto-filter: substituted %d annotations\n", substitutionCount)
Expand Down
Loading