From f304faf7207783f11729e34b441ee67d92127c84 Mon Sep 17 00:00:00 2001 From: Patrick Hobusch Date: Mon, 18 Aug 2025 21:57:17 +0800 Subject: [PATCH] feat: allow env function for config.toml to have a default value --- pkg/config/config_test.go | 45 ++++++ pkg/config/decode_hooks.go | 12 +- pkg/config/decode_hooks_test.go | 245 ++++++++++++++++++++++++++++++++ 3 files changed, 299 insertions(+), 3 deletions(-) create mode 100644 pkg/config/decode_hooks_test.go diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index fca80ad2d..f750d82f3 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -72,6 +72,51 @@ func TestConfigParsing(t *testing.T) { // Run test assert.Error(t, config.Load("", fsys)) }) + + t.Run("config file with env defaults uses defaults when vars not set", func(t *testing.T) { + // Setup in-memory fs + fsys := fs.MapFS{ + "supabase/config.toml": &fs.MapFile{Data: []byte(` +[auth] +site_url = "env(SITE_URL, http://localhost:3000)" + +[auth.external.github] +client_id = "env(GITHUB_CLIENT_ID, default_client_id)" +`)}, + } + config := NewConfig() + // Run test + assert.NoError(t, config.Load("", fsys)) + // Check defaults are used + assert.Equal(t, "http://localhost:3000", config.Auth.SiteUrl) + github := config.Auth.External["github"] + assert.Equal(t, "default_client_id", github.ClientId) + }) + + t.Run("config file with env defaults uses env vars when set", func(t *testing.T) { + // Clear environment variables + os.Unsetenv("SITE_URL") + os.Unsetenv("GITHUB_CLIENT_ID") + // Setup in-memory fs + fsys := fs.MapFS{ + "supabase/config.toml": &fs.MapFile{Data: []byte(` +[auth] +site_url = "env(SITE_URL, http://localhost:3000)" + +[auth.external.github] +client_id = "env(GITHUB_CLIENT_ID, default_client_id)" +`)}, + } + config := NewConfig() + // Run test + t.Setenv("SITE_URL", "https://example.com") + t.Setenv("GITHUB_CLIENT_ID", "real_client_id") + assert.NoError(t, config.Load("", fsys)) + // Check env vars are used + assert.Equal(t, "https://example.com", config.Auth.SiteUrl) + github := config.Auth.External["github"] + assert.Equal(t, "real_client_id", github.ClientId) + }) } func TestRemoteOverride(t *testing.T) { diff --git a/pkg/config/decode_hooks.go b/pkg/config/decode_hooks.go index b97d9ba8d..5e0e37c86 100644 --- a/pkg/config/decode_hooks.go +++ b/pkg/config/decode_hooks.go @@ -4,23 +4,29 @@ import ( "os" "reflect" "regexp" + "strings" "github.com/go-errors/errors" ) -var envPattern = regexp.MustCompile(`^env\((.*)\)$`) +var envPattern = regexp.MustCompile(`^env\(\s*([^,\s]+)\s*(?:,\s*(.+?)\s*)?\)$`) // LoadEnvHook is a mapstructure decode hook that loads environment variables -// from strings formatted as env(VAR_NAME). +// from strings formatted as env(VAR_NAME) or env(VAR_NAME, default_value). func LoadEnvHook(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) { if f != reflect.String { return data, nil } value := data.(string) if matches := envPattern.FindStringSubmatch(value); len(matches) > 1 { - if env := os.Getenv(matches[1]); len(env) > 0 { + varName := strings.TrimSpace(matches[1]) + if env := os.Getenv(varName); len(env) > 0 { value = env + } else if len(matches) > 2 && matches[2] != "" { + // Use default value if environment variable is not set or empty + value = strings.TrimSpace(matches[2]) } + // If no env var and no default, keep original value (current behavior) } return value, nil } diff --git a/pkg/config/decode_hooks_test.go b/pkg/config/decode_hooks_test.go new file mode 100644 index 000000000..d5aae463f --- /dev/null +++ b/pkg/config/decode_hooks_test.go @@ -0,0 +1,245 @@ +package config + +import ( + "os" + "reflect" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLoadEnvHook(t *testing.T) { + tests := []struct { + name string + input string + envVar string + envValue string + expected string + description string + }{ + { + name: "basic env var substitution", + input: "env(TEST_VAR)", + envVar: "TEST_VAR", + envValue: "test_value", + expected: "test_value", + description: "should replace env(VAR) with environment variable value", + }, + { + name: "env var with default - env var set", + input: "env(TEST_VAR, default_value)", + envVar: "TEST_VAR", + envValue: "env_value", + expected: "env_value", + description: "should use environment variable value when available, ignoring default", + }, + { + name: "env var with default - env var not set", + input: "env(MISSING_VAR, default_value)", + envVar: "", + envValue: "", + expected: "default_value", + description: "should use default value when environment variable is not set", + }, + { + name: "env var with default - env var empty", + input: "env(EMPTY_VAR, default_value)", + envVar: "EMPTY_VAR", + envValue: "", + expected: "default_value", + description: "should use default value when environment variable is empty", + }, + { + name: "env var with spaces in default", + input: "env(MISSING_VAR, my default value)", + envVar: "", + envValue: "", + expected: "my default value", + description: "should handle default values with spaces", + }, + { + name: "env var with extra spaces", + input: "env( TEST_VAR , default_value )", + envVar: "TEST_VAR", + envValue: "trimmed_value", + expected: "trimmed_value", + description: "should handle extra spaces around variable name and default", + }, + { + name: "env var with default containing commas", + input: "env(MISSING_VAR, value,with,commas)", + envVar: "", + envValue: "", + expected: "value,with,commas", + description: "should handle default values containing commas", + }, + { + name: "non-env string unchanged", + input: "regular_string", + envVar: "", + envValue: "", + expected: "regular_string", + description: "should leave non-env strings unchanged", + }, + { + name: "malformed env syntax unchanged", + input: "env(MISSING_VAR", + envVar: "", + envValue: "", + expected: "env(MISSING_VAR", + description: "should leave malformed env syntax unchanged", + }, + { + name: "env var without default - missing var", + input: "env(MISSING_VAR)", + envVar: "", + envValue: "", + expected: "env(MISSING_VAR)", + description: "should leave original string when env var missing and no default", + }, + { + name: "env var without default - empty var", + input: "env(EMPTY_VAR)", + envVar: "EMPTY_VAR", + envValue: "", + expected: "env(EMPTY_VAR)", + description: "should leave original string when env var empty and no default", + }, + { + name: "quoted default value", + input: `env(MISSING_VAR, "quoted default")`, + envVar: "", + envValue: "", + expected: `"quoted default"`, + description: "should preserve quotes in default values", + }, + { + name: "numeric default value", + input: "env(MISSING_VAR, 12345)", + envVar: "", + envValue: "", + expected: "12345", + description: "should handle numeric default values as strings", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup environment variable if specified + if tt.envVar != "" { + if tt.envValue != "" { + t.Setenv(tt.envVar, tt.envValue) + } else { + // Ensure the env var is not set + os.Unsetenv(tt.envVar) + } + } + + // Call the hook function + result, err := LoadEnvHook(reflect.String, reflect.String, tt.input) + + // Assertions + require.NoError(t, err, "LoadEnvHook should not return an error") + assert.Equal(t, tt.expected, result, tt.description) + }) + } +} + +func TestLoadEnvHook_NonStringInput(t *testing.T) { + tests := []struct { + name string + fromKind reflect.Kind + toKind reflect.Kind + input interface{} + expected interface{} + }{ + { + name: "integer input", + fromKind: reflect.Int, + toKind: reflect.String, + input: 42, + expected: 42, + }, + { + name: "boolean input", + fromKind: reflect.Bool, + toKind: reflect.String, + input: true, + expected: true, + }, + { + name: "slice input", + fromKind: reflect.Slice, + toKind: reflect.String, + input: []string{"test"}, + expected: []string{"test"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := LoadEnvHook(tt.fromKind, tt.toKind, tt.input) + require.NoError(t, err) + assert.Equal(t, tt.expected, result, "non-string inputs should be returned unchanged") + }) + } +} + +func TestLoadEnvHook_RegressionTest(t *testing.T) { + // Test that existing functionality still works as expected + t.Run("existing env() patterns continue to work", func(t *testing.T) { + t.Setenv("EXISTING_VAR", "existing_value") + + result, err := LoadEnvHook(reflect.String, reflect.String, "env(EXISTING_VAR)") + require.NoError(t, err) + assert.Equal(t, "existing_value", result) + }) + + t.Run("missing env vars without defaults preserve original behavior", func(t *testing.T) { + os.Unsetenv("NONEXISTENT_VAR") + + result, err := LoadEnvHook(reflect.String, reflect.String, "env(NONEXISTENT_VAR)") + require.NoError(t, err) + assert.Equal(t, "env(NONEXISTENT_VAR)", result) + }) +} + +func TestEnvPattern_Regex(t *testing.T) { + tests := []struct { + input string + shouldMatch bool + varName string + defaultVal string + description string + }{ + {"env(VAR)", true, "VAR", "", "basic env var"}, + {"env(VAR, default)", true, "VAR", "default", "env var with default"}, + {"env( VAR , default )", true, "VAR", "default", "env var with spaces"}, + {"env(VAR,default)", true, "VAR", "default", "env var without spaces around comma"}, + {"env(VAR, default with spaces)", true, "VAR", "default with spaces", "default with spaces"}, + {"env(VAR, val,ue)", true, "VAR", "val,ue", "default with comma"}, + {"env()", false, "", "", "empty env"}, + {"env(VAR", false, "", "", "missing closing paren"}, + {"env VAR)", false, "", "", "missing opening paren"}, + {"notenv(VAR)", false, "", "", "wrong function name"}, + {"env(VAR, )", true, "VAR", "", "empty default"}, + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + matches := envPattern.FindStringSubmatch(tt.input) + + if tt.shouldMatch { + require.True(t, len(matches) > 1, "should match pattern: %s", tt.input) + assert.Equal(t, tt.varName, strings.TrimSpace(matches[1]), "variable name should match") + if len(matches) > 2 { + assert.Equal(t, tt.defaultVal, strings.TrimSpace(matches[2]), "default value should match") + } + } else { + assert.True(t, len(matches) <= 1, "should not match pattern: %s", tt.input) + } + }) + } +}