diff --git a/README.md b/README.md index ca5e5ac..9bbbb57 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ An enum generator for Go that creates type-safe enumerations with useful methods ```go package main - + // ENUM(red, green, blue) type Color int ``` @@ -67,7 +67,7 @@ An enum generator for Go that creates type-safe enumerations with useful methods ```shell # Go 1.24+ (recommended) go tool go-enum -f your_file.go - + # Or for older Go versions go-enum -f your_file.go ``` @@ -77,7 +77,7 @@ An enum generator for Go that creates type-safe enumerations with useful methods ```go color := ColorRed fmt.Println(color.String()) // prints "red" - + parsed, err := ParseColor("green") if err == nil { fmt.Println(parsed) // prints "green" @@ -147,7 +147,7 @@ const ( ``` If you would like to get integer values in sql, but strings elsewhere, you can assign an int value in the declaration -like always, and specify the `--sqlint` flag. Those values will be then used to convey the int value to sql, while allowing you to use only strings elsewhere. +like always, and specify the `--sqlint` flag. Those values will be then used to convey the int value to sql, while allowing you to use only strings elsewhere. This might be helpful for things like swagger docs where you want the same type being used on the api layer, as you do in the sql layer, and not have swagger assume that your enumerations are integers, but are in fact strings! @@ -183,6 +183,62 @@ Change the default `_enum.go` suffix to something else: go-enum --output-suffix="_generated" -f your_file.go # Creates your_file_generated.go ``` +### Inline Annotations (v0.10.0+) + +You can now specify configuration options directly in the enum declaration using inline annotations. This allows you to override global command-line options on a per-enum basis. + +Annotations are specified as comments starting with `@` before the `ENUM` declaration: + +```go +// @marshal:true @sql:false @prefix:"My" +// ENUM(pending, running, completed, failed) +type AnnotationStatus string + +// @noprefix @nocase +// ENUM(annotation_red, annotation_green, annotation_blue) +type AnnotationColor string + +// @marshal @sql +// ENUM(one, two, three) +type AnnotationNumber int +``` + +**Available annotations:** + +| Annotation | Values | Description | +| ------------- | -------------- | -------------------------------------------------- | +| `@prefix` | `"string"` | Custom prefix for constants (e.g., `@prefix:"My"`) | +| `@marshal` | `true`/`false` | Enables/disables JSON/text marshaling methods | +| `@sql` | `true`/`false` | Enables/disables SQL Scan/Value methods | +| `@sqlint` | `true`/`false` | Stores string enums as integers in SQL | +| `@noprefix` | `true`/`false` | Disables prefixing constants with enum name | +| `@nocase` | `true`/`false` | Enables case-insensitive parsing | +| `@noparse` | `true`/`false` | Disables Parse method generation | +| `@mustparse` | `true`/`false` | Adds MustParse method that panics on failure | +| `@flag` | `true`/`false` | Adds flag.Value interface methods | +| `@ptr` | `true`/`false` | Adds Ptr() method | +| `@names` | `true`/`false` | Adds Names() []string method | +| `@values` | `true`/`false` | Adds Values() []Enum method | +| `@nocomments` | `true`/`false` | Disables auto-generated comments | +| `@noiota` | `true`/`false` | Disables iota usage | +| `@forcelower` | `true`/`false` | Forces lowercase constant names | +| `@forceupper` | `true`/`false` | Forces uppercase constant names | + +**Syntax notes:** + +- Boolean annotations can be specified as `@annotation` (defaults to `true`) or `@annotation:true`/`@annotation:false` +- String annotations use quotes: `@prefix:"My"` +- Multiple annotations can be specified on the same line or across multiple lines +- Inline annotations override global command-line options + +**Example with mixed annotations:** + +```go +// @marshal @sql:false @nocase @prefix:"App" +// ENUM(draft, review, published, archived) +type DocumentStatus string +``` + ## Goal The goal of go-enum is to create an easy to use enum generator that will take a decorated type declaration like `type EnumName int` and create the associated constant values and funcs that will make life a little easier for adding new values. @@ -302,7 +358,7 @@ For older Go versions: ## Command options -``` shell +```shell go-enum --help NAME: @@ -343,12 +399,13 @@ GLOBAL OPTIONS: --help, -h show help --version, -v print the version ``` +**Note:** Many command-line options can also be specified as inline annotations directly in your enum declarations. See the [Inline Annotations](#inline-annotations-v0100) section for details. ### Syntax The parser looks for comments on your type defs and parse the enum declarations from it. -The parser will look for `ENUM(` and continue to look for comma separated values until it finds a `)`. You can put values on the same line, or on multiple lines.\ -If you need to have a specific value jump in the enum, you can now specify that by adding `=numericValue` to the enum declaration. Keep in mind, this resets the data for all following values. So if you specify `50` in the middle of an enum, each value after that will be `51, 52, 53...` +The parser will look for `ENUM(` and continue to look for comma separated values until it finds a `)`. You can put values on the same line, or on multiple lines.\ +If you need to have a specific value jump in the enum, you can now specify that by adding `=numericValue` to the enum declaration. Keep in mind, this resets the data for all following values. So if you specify `50` in the middle of an enum, each value after that will be `51, 52, 53...` [Examples can be found in the example folder](./example/) @@ -391,7 +448,7 @@ const ( There are a few examples in the `example` [directory](./example/). I've included one here for easy access, but can't guarantee it's up to date. -``` go +```go // Color is an enumeration of colors that are allowed. /* ENUM( Black, White, Red @@ -410,7 +467,7 @@ type Color int32 The generated code will look something like: -``` go +```go // Code generated by go-enum DO NOT EDIT. // Version: example // Revision: example diff --git a/example/annotation.go b/example/annotation.go new file mode 100644 index 0000000..75b9e77 --- /dev/null +++ b/example/annotation.go @@ -0,0 +1,15 @@ +//go:generate ../bin/go-enum -b example + +package example + +// @marshal:true @sql:false @prefix:"My" +// ENUM(pending, running, completed, failed) +type AnnotationStatus string + +// @noprefix @nocase +// ENUM(annotation_red, annotation_green, annotation_blue) +type AnnotationColor string + +// @marshal @sql @marshal +// ENUM(one, two, three) +type AnnotationNumber int diff --git a/example/annotation_enum.go b/example/annotation_enum.go new file mode 100644 index 0000000..3291dfd --- /dev/null +++ b/example/annotation_enum.go @@ -0,0 +1,266 @@ +// Code generated by go-enum DO NOT EDIT. +// Version: example +// Revision: example +// Build Date: example +// Built By: example + +//go:build example +// +build example + +package example + +import ( + "database/sql/driver" + "errors" + "fmt" + "strings" +) + +const ( + // AnnotationRed is a AnnotationColor of type annotation_red. + AnnotationRed AnnotationColor = "annotation_red" + // AnnotationGreen is a AnnotationColor of type annotation_green. + AnnotationGreen AnnotationColor = "annotation_green" + // AnnotationBlue is a AnnotationColor of type annotation_blue. + AnnotationBlue AnnotationColor = "annotation_blue" +) + +var ErrInvalidAnnotationColor = errors.New("not a valid AnnotationColor") + +// String implements the Stringer interface. +func (x AnnotationColor) String() string { + return string(x) +} + +// IsValid provides a quick way to determine if the typed value is +// part of the allowed enumerated values +func (x AnnotationColor) IsValid() bool { + _, err := ParseAnnotationColor(string(x)) + return err == nil +} + +var _AnnotationColorValue = map[string]AnnotationColor{ + "annotation_red": AnnotationRed, + "annotation_green": AnnotationGreen, + "annotation_blue": AnnotationBlue, +} + +// ParseAnnotationColor attempts to convert a string to a AnnotationColor. +func ParseAnnotationColor(name string) (AnnotationColor, error) { + if x, ok := _AnnotationColorValue[name]; ok { + return x, nil + } + // Case insensitive parse, do a separate lookup to prevent unnecessary cost of lowercasing a string if we don't need to. + if x, ok := _AnnotationColorValue[strings.ToLower(name)]; ok { + return x, nil + } + return AnnotationColor(""), fmt.Errorf("%s is %w", name, ErrInvalidAnnotationColor) +} + +const ( + // AnnotationNumberOne is a AnnotationNumber of type One. + AnnotationNumberOne AnnotationNumber = iota + // AnnotationNumberTwo is a AnnotationNumber of type Two. + AnnotationNumberTwo + // AnnotationNumberThree is a AnnotationNumber of type Three. + AnnotationNumberThree +) + +var ErrInvalidAnnotationNumber = errors.New("not a valid AnnotationNumber") + +const _AnnotationNumberName = "onetwothree" + +var _AnnotationNumberMap = map[AnnotationNumber]string{ + AnnotationNumberOne: _AnnotationNumberName[0:3], + AnnotationNumberTwo: _AnnotationNumberName[3:6], + AnnotationNumberThree: _AnnotationNumberName[6:11], +} + +// String implements the Stringer interface. +func (x AnnotationNumber) String() string { + if str, ok := _AnnotationNumberMap[x]; ok { + return str + } + return fmt.Sprintf("AnnotationNumber(%d)", x) +} + +// IsValid provides a quick way to determine if the typed value is +// part of the allowed enumerated values +func (x AnnotationNumber) IsValid() bool { + _, ok := _AnnotationNumberMap[x] + return ok +} + +var _AnnotationNumberValue = map[string]AnnotationNumber{ + _AnnotationNumberName[0:3]: AnnotationNumberOne, + _AnnotationNumberName[3:6]: AnnotationNumberTwo, + _AnnotationNumberName[6:11]: AnnotationNumberThree, +} + +// ParseAnnotationNumber attempts to convert a string to a AnnotationNumber. +func ParseAnnotationNumber(name string) (AnnotationNumber, error) { + if x, ok := _AnnotationNumberValue[name]; ok { + return x, nil + } + return AnnotationNumber(0), fmt.Errorf("%s is %w", name, ErrInvalidAnnotationNumber) +} + +// MarshalText implements the text marshaller method. +func (x AnnotationNumber) MarshalText() ([]byte, error) { + return []byte(x.String()), nil +} + +// UnmarshalText implements the text unmarshaller method. +func (x *AnnotationNumber) UnmarshalText(text []byte) error { + name := string(text) + tmp, err := ParseAnnotationNumber(name) + if err != nil { + return err + } + *x = tmp + return nil +} + +// AppendText appends the textual representation of itself to the end of b +// (allocating a larger slice if necessary) and returns the updated slice. +// +// Implementations must not retain b, nor mutate any bytes within b[:len(b)]. +func (x *AnnotationNumber) AppendText(b []byte) ([]byte, error) { + return append(b, x.String()...), nil +} + +var errAnnotationNumberNilPtr = errors.New("value pointer is nil") // one per type for package clashes + +// Scan implements the Scanner interface. +func (x *AnnotationNumber) Scan(value interface{}) (err error) { + if value == nil { + *x = AnnotationNumber(0) + return + } + + // A wider range of scannable types. + // driver.Value values at the top of the list for expediency + switch v := value.(type) { + case int64: + *x = AnnotationNumber(v) + case string: + *x, err = ParseAnnotationNumber(v) + case []byte: + *x, err = ParseAnnotationNumber(string(v)) + case AnnotationNumber: + *x = v + case int: + *x = AnnotationNumber(v) + case *AnnotationNumber: + if v == nil { + return errAnnotationNumberNilPtr + } + *x = *v + case uint: + *x = AnnotationNumber(v) + case uint64: + *x = AnnotationNumber(v) + case *int: + if v == nil { + return errAnnotationNumberNilPtr + } + *x = AnnotationNumber(*v) + case *int64: + if v == nil { + return errAnnotationNumberNilPtr + } + *x = AnnotationNumber(*v) + case float64: // json marshals everything as a float64 if it's a number + *x = AnnotationNumber(v) + case *float64: // json marshals everything as a float64 if it's a number + if v == nil { + return errAnnotationNumberNilPtr + } + *x = AnnotationNumber(*v) + case *uint: + if v == nil { + return errAnnotationNumberNilPtr + } + *x = AnnotationNumber(*v) + case *uint64: + if v == nil { + return errAnnotationNumberNilPtr + } + *x = AnnotationNumber(*v) + case *string: + if v == nil { + return errAnnotationNumberNilPtr + } + *x, err = ParseAnnotationNumber(*v) + } + + return +} + +// Value implements the driver Valuer interface. +func (x AnnotationNumber) Value() (driver.Value, error) { + return x.String(), nil +} + +const ( + // MyAnnotationStatusPending is a AnnotationStatus of type pending. + MyAnnotationStatusPending AnnotationStatus = "pending" + // MyAnnotationStatusRunning is a AnnotationStatus of type running. + MyAnnotationStatusRunning AnnotationStatus = "running" + // MyAnnotationStatusCompleted is a AnnotationStatus of type completed. + MyAnnotationStatusCompleted AnnotationStatus = "completed" + // MyAnnotationStatusFailed is a AnnotationStatus of type failed. + MyAnnotationStatusFailed AnnotationStatus = "failed" +) + +var ErrInvalidAnnotationStatus = errors.New("not a valid AnnotationStatus") + +// String implements the Stringer interface. +func (x AnnotationStatus) String() string { + return string(x) +} + +// IsValid provides a quick way to determine if the typed value is +// part of the allowed enumerated values +func (x AnnotationStatus) IsValid() bool { + _, err := ParseAnnotationStatus(string(x)) + return err == nil +} + +var _AnnotationStatusValue = map[string]AnnotationStatus{ + "pending": MyAnnotationStatusPending, + "running": MyAnnotationStatusRunning, + "completed": MyAnnotationStatusCompleted, + "failed": MyAnnotationStatusFailed, +} + +// ParseAnnotationStatus attempts to convert a string to a AnnotationStatus. +func ParseAnnotationStatus(name string) (AnnotationStatus, error) { + if x, ok := _AnnotationStatusValue[name]; ok { + return x, nil + } + return AnnotationStatus(""), fmt.Errorf("%s is %w", name, ErrInvalidAnnotationStatus) +} + +// MarshalText implements the text marshaller method. +func (x AnnotationStatus) MarshalText() ([]byte, error) { + return []byte(string(x)), nil +} + +// UnmarshalText implements the text unmarshaller method. +func (x *AnnotationStatus) UnmarshalText(text []byte) error { + tmp, err := ParseAnnotationStatus(string(text)) + if err != nil { + return err + } + *x = tmp + return nil +} + +// AppendText appends the textual representation of itself to the end of b +// (allocating a larger slice if necessary) and returns the updated slice. +// +// Implementations must not retain b, nor mutate any bytes within b[:len(b)]. +func (x *AnnotationStatus) AppendText(b []byte) ([]byte, error) { + return append(b, x.String()...), nil +} diff --git a/example/annotation_test.go b/example/annotation_test.go new file mode 100644 index 0000000..2a48716 --- /dev/null +++ b/example/annotation_test.go @@ -0,0 +1,218 @@ +//go:build example +// +build example + +package example + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +type annotationTestData struct { + Status AnnotationStatus `json:"status"` + Color AnnotationColor `json:"color"` + Number AnnotationNumber `json:"number"` +} + +func TestAnnotationStatus(t *testing.T) { + // Test prefix "My" was applied + assert.Equal(t, MyAnnotationStatusPending, AnnotationStatus("pending")) + assert.Equal(t, MyAnnotationStatusRunning, AnnotationStatus("running")) + assert.Equal(t, MyAnnotationStatusCompleted, AnnotationStatus("completed")) + assert.Equal(t, MyAnnotationStatusFailed, AnnotationStatus("failed")) + + // Test String() + assert.Equal(t, "pending", MyAnnotationStatusPending.String()) + assert.Equal(t, "running", MyAnnotationStatusRunning.String()) + + // Test IsValid() + assert.True(t, MyAnnotationStatusPending.IsValid()) + assert.True(t, MyAnnotationStatusRunning.IsValid()) + assert.False(t, AnnotationStatus("invalid").IsValid()) + + // Test Parse + parsed, err := ParseAnnotationStatus("pending") + assert.NoError(t, err) + assert.Equal(t, MyAnnotationStatusPending, parsed) + + _, err = ParseAnnotationStatus("invalid") + assert.Error(t, err) + assert.Equal(t, "invalid is not a valid AnnotationStatus", err.Error()) + + // Test Marshal/Unmarshal + jsonData := `{"status":"pending"}` + var data struct { + Status AnnotationStatus `json:"status"` + } + err = json.Unmarshal([]byte(jsonData), &data) + assert.NoError(t, err) + assert.Equal(t, MyAnnotationStatusPending, data.Status) + + marshaled, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, jsonData, string(marshaled)) + + // Test AppendText (method has pointer receiver) + status := MyAnnotationStatusPending + text, err := status.AppendText(nil) + assert.NoError(t, err) + assert.Equal(t, "pending", string(text)) +} + +func TestAnnotationColor(t *testing.T) { + // Test noprefix - no "AnnotationColor" prefix + assert.Equal(t, AnnotationRed, AnnotationColor("annotation_red")) + assert.Equal(t, AnnotationGreen, AnnotationColor("annotation_green")) + assert.Equal(t, AnnotationBlue, AnnotationColor("annotation_blue")) + + // Test nocase - case insensitive parsing + parsed, err := ParseAnnotationColor("ANNOTATION_RED") + assert.NoError(t, err) + assert.Equal(t, AnnotationRed, parsed) + + parsed, err = ParseAnnotationColor("annotation_red") + assert.NoError(t, err) + assert.Equal(t, AnnotationRed, parsed) + + parsed, err = ParseAnnotationColor("AnNoTaTiOn_ReD") + assert.NoError(t, err) + assert.Equal(t, AnnotationRed, parsed) + + // Test invalid + _, err = ParseAnnotationColor("invalid") + assert.Error(t, err) + assert.Equal(t, "invalid is not a valid AnnotationColor", err.Error()) + + // Test String() + assert.Equal(t, "annotation_red", AnnotationRed.String()) + + // Test IsValid() + assert.True(t, AnnotationRed.IsValid()) + assert.False(t, AnnotationColor("invalid").IsValid()) + + // Note: No marshal methods for AnnotationColor (not specified) +} + +func TestAnnotationNumber(t *testing.T) { + // Test constants + assert.Equal(t, AnnotationNumberOne, AnnotationNumber(0)) + assert.Equal(t, AnnotationNumberTwo, AnnotationNumber(1)) + assert.Equal(t, AnnotationNumberThree, AnnotationNumber(2)) + + // Test String() + assert.Equal(t, "one", AnnotationNumberOne.String()) + assert.Equal(t, "two", AnnotationNumberTwo.String()) + assert.Equal(t, "three", AnnotationNumberThree.String()) + + // Test IsValid() + assert.True(t, AnnotationNumberOne.IsValid()) + assert.False(t, AnnotationNumber(999).IsValid()) + + // Test Parse + parsed, err := ParseAnnotationNumber("one") + assert.NoError(t, err) + assert.Equal(t, AnnotationNumberOne, parsed) + + _, err = ParseAnnotationNumber("invalid") + assert.Error(t, err) + assert.Equal(t, "invalid is not a valid AnnotationNumber", err.Error()) + + // Test Marshal/Unmarshal + jsonData := `{"number":"one"}` + var data struct { + Number AnnotationNumber `json:"number"` + } + err = json.Unmarshal([]byte(jsonData), &data) + assert.NoError(t, err) + assert.Equal(t, AnnotationNumberOne, data.Number) + + marshaled, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, jsonData, string(marshaled)) + + // Test SQL Scan/Value (basic test) + var numScan AnnotationNumber + err = numScan.Scan("one") + assert.NoError(t, err) + assert.Equal(t, AnnotationNumberOne, numScan) + + val, err := AnnotationNumberOne.Value() + assert.NoError(t, err) + assert.Equal(t, "one", val) + + // Test AppendText (method has pointer receiver) + numAppend := AnnotationNumberOne + text, err := numAppend.AppendText(nil) + assert.NoError(t, err) + assert.Equal(t, "one", string(text)) +} + +func TestAnnotationSQL(t *testing.T) { + // Test AnnotationNumber SQL (enabled) + var num AnnotationNumber + + // Scan from string + err := num.Scan("two") + assert.NoError(t, err) + assert.Equal(t, AnnotationNumberTwo, num) + + // Scan from int + err = num.Scan(1) + assert.NoError(t, err) + assert.Equal(t, AnnotationNumberTwo, num) + + // Value returns string + val, err := num.Value() + assert.NoError(t, err) + assert.Equal(t, "two", val) + + // Test AnnotationStatus SQL (disabled - should not have Scan/Value methods) + // We can't test absence directly, but we can verify that the type doesn't implement + // driver.Valuer and sql.Scanner for AnnotationStatus (they're not generated) +} + +func TestAnnotationMarshalCombined(t *testing.T) { + // Test all three together + jsonData := `{"status":"completed","color":"annotation_green","number":"three"}` + var data annotationTestData + err := json.Unmarshal([]byte(jsonData), &data) + assert.NoError(t, err) + assert.Equal(t, MyAnnotationStatusCompleted, data.Status) + assert.Equal(t, AnnotationGreen, data.Color) + assert.Equal(t, AnnotationNumberThree, data.Number) + + marshaled, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, jsonData, string(marshaled)) +} + +func BenchmarkAnnotationParse(b *testing.B) { + knownItems := []string{ + "pending", + "annotation_red", + "one", + } + + var err error + for _, item := range knownItems { + b.Run(item, func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Try to parse with appropriate parser + switch { + case strings.Contains(item, "annotation_"): + _, err = ParseAnnotationColor(item) + case item == "one" || item == "two" || item == "three": + _, err = ParseAnnotationNumber(item) + default: + _, err = ParseAnnotationStatus(item) + } + assert.NoError(b, err) + } + }) + } +} diff --git a/generator/enum_config.go b/generator/enum_config.go new file mode 100644 index 0000000..dc2d288 --- /dev/null +++ b/generator/enum_config.go @@ -0,0 +1,188 @@ +package generator + +import ( + "fmt" + "strconv" + "strings" +) + +// EnumConfigValue holds a configuration value with its validity flag. +type EnumConfigValue struct { + Value interface{} + Valid bool +} + +// GetBool returns the boolean value if valid, otherwise returns the default value. +func (v *EnumConfigValue) GetBool(defaultValue bool) bool { + if v.Valid { + if b, ok := v.Value.(bool); ok { + return b + } + } + return defaultValue +} + +// GetString returns the string value if valid, otherwise returns the default value. +func (v *EnumConfigValue) GetString(defaultValue string) string { + if v.Valid { + if s, ok := v.Value.(string); ok { + return s + } + } + return defaultValue +} + + +// EnumConfig holds configuration options specific to a single enum. +// These options can be specified inline via annotations and override global GeneratorConfig. +type EnumConfig struct { + // Bool options + NoPrefix EnumConfigValue `json:"no_prefix"` + NoIota EnumConfigValue `json:"no_iota"` + LowercaseLookup EnumConfigValue `json:"lowercase_lookup"` + CaseInsensitive EnumConfigValue `json:"case_insensitive"` + Marshal EnumConfigValue `json:"marshal"` + SQL EnumConfigValue `json:"sql"` + SQLInt EnumConfigValue `json:"sql_int"` + Flag EnumConfigValue `json:"flag"` + Names EnumConfigValue `json:"names"` + Values EnumConfigValue `json:"values"` + LeaveSnakeCase EnumConfigValue `json:"leave_snake_case"` + Ptr EnumConfigValue `json:"ptr"` + SQLNullInt EnumConfigValue `json:"sql_null_int"` + SQLNullStr EnumConfigValue `json:"sql_null_str"` + MustParse EnumConfigValue `json:"must_parse"` + ForceLower EnumConfigValue `json:"force_lower"` + ForceUpper EnumConfigValue `json:"force_upper"` + NoComments EnumConfigValue `json:"no_comments"` + NoParse EnumConfigValue `json:"no_parse"` + + // String options + Prefix EnumConfigValue `json:"prefix"` + + // Slice/map options (not supported inline for simplicity) + // BuildTags []string + // ReplacementNames map[string]string + // TemplateFileNames []string +} + +// NewEnumConfig creates a new EnumConfig with default values. +func NewEnumConfig() *EnumConfig { + return &EnumConfig{} +} + +// ParseAnnotation parses a single annotation string (e.g., "@marshal", "@marshal:true", "@prefix=\"My\"") +// and updates the EnumConfig accordingly. +func (ec *EnumConfig) ParseAnnotation(annotation string) error { + annotation = strings.TrimSpace(annotation) + if annotation == "" { + return nil + } + + // Remove @ prefix + if !strings.HasPrefix(annotation, "@") { + return fmt.Errorf("annotation must start with @: %s", annotation) + } + annotation = annotation[1:] + + // Check for key:value format (e.g., @marshal:true, @marshal:false) + if strings.Contains(annotation, ":") { + parts := strings.SplitN(annotation, ":", 2) + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + + // Parse boolean value + if value == "true" || value == "false" { + boolValue, _ := strconv.ParseBool(value) + return ec.setBoolOption(key, boolValue) + } + + // String value (could be quoted) + if len(value) >= 2 && ((value[0] == '"' && value[len(value)-1] == '"') || + (value[0] == '\'' && value[len(value)-1] == '\'')) { + value = value[1 : len(value)-1] + } + + return ec.setStringOption(key, value) + } + + // Check for key=value format (legacy style, e.g., @prefix="My") + if strings.Contains(annotation, "=") { + parts := strings.SplitN(annotation, "=", 2) + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + + // Remove quotes if present + if len(value) >= 2 && ((value[0] == '"' && value[len(value)-1] == '"') || + (value[0] == '\'' && value[len(value)-1] == '\'')) { + value = value[1 : len(value)-1] + } + + return ec.setStringOption(key, value) + } + + // Boolean flag without explicit value (defaults to true) + return ec.setBoolOption(annotation, true) +} + +// setBoolOption sets a boolean option in the EnumConfig. +func (ec *EnumConfig) setBoolOption(key string, value bool) error { + switch key { + case "noprefix": + ec.NoPrefix = EnumConfigValue{Value: value, Valid: true} + case "noiota": + ec.NoIota = EnumConfigValue{Value: value, Valid: true} + case "lower": + ec.LowercaseLookup = EnumConfigValue{Value: value, Valid: true} + case "nocase": + ec.CaseInsensitive = EnumConfigValue{Value: value, Valid: true} + if value { + ec.LowercaseLookup = EnumConfigValue{Value: true, Valid: true} // nocase forces lower + } + case "marshal": + ec.Marshal = EnumConfigValue{Value: value, Valid: true} + case "sql": + ec.SQL = EnumConfigValue{Value: value, Valid: true} + case "sqlint": + ec.SQLInt = EnumConfigValue{Value: value, Valid: true} + case "flag": + ec.Flag = EnumConfigValue{Value: value, Valid: true} + case "names": + ec.Names = EnumConfigValue{Value: value, Valid: true} + case "values": + ec.Values = EnumConfigValue{Value: value, Valid: true} + case "nocamel": + ec.LeaveSnakeCase = EnumConfigValue{Value: value, Valid: true} + case "ptr": + ec.Ptr = EnumConfigValue{Value: value, Valid: true} + case "sqlnullint": + ec.SQLNullInt = EnumConfigValue{Value: value, Valid: true} + case "sqlnullstr": + ec.SQLNullStr = EnumConfigValue{Value: value, Valid: true} + case "mustparse": + ec.MustParse = EnumConfigValue{Value: value, Valid: true} + case "forcelower": + ec.ForceLower = EnumConfigValue{Value: value, Valid: true} + case "forceupper": + ec.ForceUpper = EnumConfigValue{Value: value, Valid: true} + case "nocomments": + ec.NoComments = EnumConfigValue{Value: value, Valid: true} + case "noparse": + ec.NoParse = EnumConfigValue{Value: value, Valid: true} + default: + return fmt.Errorf("unknown annotation: @%s", key) + } + + return nil +} + +// setStringOption sets a string option in the EnumConfig. +func (ec *EnumConfig) setStringOption(key, value string) error { + switch key { + case "prefix": + ec.Prefix = EnumConfigValue{Value: value, Valid: true} + default: + return fmt.Errorf("unknown annotation with value: @%s=%s", key, value) + } + return nil +} diff --git a/generator/generator.go b/generator/generator.go index cac315a..7999591 100644 --- a/generator/generator.go +++ b/generator/generator.go @@ -45,6 +45,7 @@ type Enum struct { Type string Values []EnumValue Comment string + Config *EnumConfig } // EnumValue holds the individual data for each enum value within the found enum. @@ -194,39 +195,45 @@ func (g *Generator) Generate(f *ast.File) ([]byte, error) { created++ + // Use enum-specific config if available, otherwise fall back to global config + config := enum.Config + // Determine parse method generation logic - parseNeeded := g.MustParse || g.Marshal || g.anySQLEnabled() || g.Flag - generateParse := !g.NoParse || parseNeeded - parseIsPublic := !g.NoParse + parseNeeded := config.MustParse.GetBool(g.MustParse) || config.Marshal.GetBool(g.Marshal) || + (config.SQL.GetBool(g.SQL) || config.SQLInt.GetBool(g.SQLInt) || + config.SQLNullStr.GetBool(g.SQLNullStr) || config.SQLNullInt.GetBool(g.SQLNullInt)) || + config.Flag.GetBool(g.Flag) + generateParse := !config.NoParse.GetBool(g.NoParse) || parseNeeded + parseIsPublic := !config.NoParse.GetBool(g.NoParse) parseName := "Parse" if !parseIsPublic && generateParse { parseName = "parse" } // Determine if error variable is needed - generateError := generateParse || (enum.Type == "string" && g.SQLInt) + generateError := generateParse || (enum.Type == "string" && config.SQLInt.GetBool(g.SQLInt)) data := map[string]any{ "enum": enum, "name": name, - "lowercase": g.LowercaseLookup, - "nocase": g.CaseInsensitive, - "nocomments": g.NoComments, - "noIota": g.NoIota, - "marshal": g.Marshal, - "sql": g.SQL, - "sqlint": g.SQLInt, - "flag": g.Flag, - "names": g.Names, - "ptr": g.Ptr, - "values": g.Values, - "anySQLEnabled": g.anySQLEnabled(), - "sqlnullint": g.SQLNullInt, - "sqlnullstr": g.SQLNullStr, - "mustparse": g.MustParse, - "forcelower": g.ForceLower, - "forceupper": g.ForceUpper, - "noparse": g.NoParse, + "lowercase": config.LowercaseLookup.GetBool(g.LowercaseLookup), + "nocase": config.CaseInsensitive.GetBool(g.CaseInsensitive), + "nocomments": config.NoComments.GetBool(g.NoComments), + "noIota": config.NoIota.GetBool(g.NoIota), + "marshal": config.Marshal.GetBool(g.Marshal), + "sql": config.SQL.GetBool(g.SQL), + "sqlint": config.SQLInt.GetBool(g.SQLInt), + "flag": config.Flag.GetBool(g.Flag), + "names": config.Names.GetBool(g.Names), + "ptr": config.Ptr.GetBool(g.Ptr), + "values": config.Values.GetBool(g.Values), + "anySQLEnabled": config.SQL.GetBool(g.SQL) || config.SQLInt.GetBool(g.SQLInt) || config.SQLNullStr.GetBool(g.SQLNullStr) || config.SQLNullInt.GetBool(g.SQLNullInt), + "sqlnullint": config.SQLNullInt.GetBool(g.SQLNullInt), + "sqlnullstr": config.SQLNullStr.GetBool(g.SQLNullStr), + "mustparse": config.MustParse.GetBool(g.MustParse), + "forcelower": config.ForceLower.GetBool(g.ForceLower), + "forceupper": config.ForceUpper.GetBool(g.ForceUpper), + "noparse": config.NoParse.GetBool(g.NoParse), // Computed values for cleaner templates "generateParse": generateParse, "parseIsPublic": parseIsPublic, @@ -284,21 +291,42 @@ func (g *Generator) parseEnum(ts *ast.TypeSpec) (*Enum, error) { return nil, errors.New("no doc on enum") } - enum := &Enum{} + enum := &Enum{ + Config: NewEnumConfig(), + } enum.Name = ts.Name.Name enum.Type = fmt.Sprintf("%s", ts.Type) - if !g.NoPrefix { + + // Extract annotations and enum declaration + annotations, enumDecl := extractAnnotationsAndEnumDecl(ts.Doc.List) + + // Parse annotations + for _, annotation := range annotations { + if err := enum.Config.ParseAnnotation(annotation); err != nil { + fmt.Printf("Warning: failed to parse annotation %q: %v\n", annotation, err) + } + } + + // Determine prefix based on config (local overrides global) + noPrefix := enum.Config.NoPrefix.GetBool(g.NoPrefix) + if !noPrefix { enum.Prefix = ts.Name.Name } + + // Apply global prefix if set if g.Prefix != "" { enum.Prefix = g.Prefix + enum.Prefix } + + // Apply annotation prefix if set (overrides everything) + if prefix := enum.Config.Prefix.GetString(""); prefix != "" { + enum.Prefix = prefix + ts.Name.Name + } commentPreEnumDecl, _, _ := strings.Cut(ts.Doc.Text(), `ENUM(`) enum.Comment = strings.TrimSpace(commentPreEnumDecl) - enumDecl := getEnumDeclFromComments(ts.Doc.List) if enumDecl == "" { return nil, errors.New("failed parsing enum") } @@ -657,3 +685,46 @@ func isTypeSpecEnum(ts *ast.TypeSpec) bool { return isEnum } + +// extractAnnotationsAndEnumDecl extracts annotations (lines starting with @) and the ENUM declaration +// from the comment list. Returns the annotations and the enum declaration string. +func extractAnnotationsAndEnumDecl(comments []*ast.Comment) ([]string, string) { + var annotations []string + var enumDecl string + + for _, comment := range comments { + lines := breakCommentIntoLines(comment) + for _, line := range lines { + trimmedLine := strings.TrimSpace(line) + + // Skip empty lines + if trimmedLine == "" { + continue + } + + // Check if this line contains ENUM( + if strings.Contains(trimmedLine, "ENUM(") { + // Use the existing getEnumDeclFromComments function to get the full declaration + enumDecl = getEnumDeclFromComments(comments) + break + } + + // Check if this line contains annotations + if strings.Contains(trimmedLine, "@") { + // Split by whitespace to get individual annotations + // This handles cases like "@para1 @param2 @para3" + parts := strings.Fields(trimmedLine) + for _, part := range parts { + if strings.HasPrefix(part, "@") { + annotations = append(annotations, part) + } + } + } + } + if enumDecl != "" { + break + } + } + + return annotations, enumDecl +}