diff --git a/README.md b/README.md index 420a71dc..b6581154 100644 --- a/README.md +++ b/README.md @@ -117,7 +117,8 @@ openfeature.SetProviderAndWait(MyProvider{}) ``` In some situations, it may be beneficial to register multiple providers in the same application. -This is possible using [domains](#domains), which is covered in more details below. +This is possible using [domains](#domains), which is covered in more details below, or the included [multiprovider](#multi-provider-implementation) +implementation. ### Targeting @@ -331,6 +332,11 @@ tCtx := openfeature.MergeTransactionContext(ctx, openfeature.EvaluationContext{} client.BooleanValue(tCtx, ....) ``` +### Multi-Provider Implementation + +Included with this SDK is an _experimental_ multi-provider that can be used to query multiple feature flag providers simultaneously. +More information can be found in the [multi package's README](openfeature/multi/README.md). + ## Extending ### Develop a provider diff --git a/codecov.yml b/codecov.yml index 98d1eed5..ca2113eb 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,2 +1,2 @@ ignore: - - "**/*_mock.go" \ No newline at end of file + - "**/*_mock.go" diff --git a/go.mod b/go.mod index 948e65e6..0ebd999e 100644 --- a/go.mod +++ b/go.mod @@ -5,16 +5,21 @@ go 1.24.0 require ( github.com/cucumber/godog v0.15.1 github.com/go-logr/logr v1.4.3 + github.com/stretchr/testify v1.11.1 go.uber.org/mock v0.6.0 + golang.org/x/sync v0.17.0 golang.org/x/text v0.30.0 ) require ( github.com/cucumber/gherkin/go/v26 v26.2.0 // indirect github.com/cucumber/messages/go/v21 v21.0.1 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/gofrs/uuid v4.4.0+incompatible // indirect github.com/hashicorp/go-immutable-radix v1.3.1 // indirect github.com/hashicorp/go-memdb v1.3.4 // indirect github.com/hashicorp/golang-lru v1.0.2 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.7 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index e0c1762e..15a079af 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,6 @@ github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/cucumber/gherkin/go/v26 v26.2.0 h1:EgIjePLWiPeslwIWmNQ3XHcypPsWAHoMCz/YEBKP4GI= github.com/cucumber/gherkin/go/v26 v26.2.0/go.mod h1:t2GAPnB8maCT4lkHL99BDCVNzCh1d7dBhCLt150Nr/0= -github.com/cucumber/godog v0.15.0 h1:51AL8lBXF3f0cyA5CV4TnJFCTHpgiy+1x1Hb3TtZUmo= -github.com/cucumber/godog v0.15.0/go.mod h1:FX3rzIDybWABU4kuIXLZ/qtqEe1Ac5RdXmqvACJOces= github.com/cucumber/godog v0.15.1 h1:rb/6oHDdvVZKS66hrhpjFQFHjthFSrQBCOI1LwshNTI= github.com/cucumber/godog v0.15.1/go.mod h1:qju+SQDewOljHuq9NSM66s0xEhogx0q30flfxL4WUk8= github.com/cucumber/messages/go/v21 v21.0.1 h1:wzA0LxwjlWQYZd32VTlAVDTkW6inOFmSM+RuOwHZiMI= @@ -30,14 +28,15 @@ github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uG github.com/hashicorp/golang-lru v1.0.2 h1:dV3g9Z/unq5DpblPpw+Oqcv4dU/1omnb4Ok8iPY6p1c= github.com/hashicorp/golang-lru v1.0.2/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= -github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.7 h1:vN6T9TfwStFPFM5XzjsvmzZkLuaLX+HS+0SeFLRgU6M= github.com/spf13/pflag v1.0.7/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= @@ -48,21 +47,18 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko= -go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= -golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= -golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= -golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= -golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= +golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/openfeature/multi/README.md b/openfeature/multi/README.md new file mode 100644 index 00000000..09deba12 --- /dev/null +++ b/openfeature/multi/README.md @@ -0,0 +1,165 @@ +OpenFeature Multi-Provider +------------ + +> [!WARNING] +> The multi package for the go-sdk is experimental. + + +The multi-provider allows you to use multiple underlying providers as sources of flag data for the OpenFeature server SDK. +The multi-provider acts as a wrapper providing a unified interface to interact with all of those providers at once. +When a flag is being evaluated, the Multi-Provider will consult each underlying provider it is managing in order to +determine the final result. Different evaluation strategies can be defined to control which providers get evaluated and +which result is used. + +The multi-provider is defined within [Appendix A: Included Utilities](https://openfeature.dev/specification/appendix-a#multi-provider) +of the openfeature spec. + +The multi-provider is a powerful tool for performing migrations between flag providers, or combining multiple providers +into a single feature flagging interface. For example: + +- **Migration**: When migrating between two providers, you can run both in parallel under a unified flagging interface. + As flags are added to the new provider, the multi-provider will automatically find and return them, falling back to the old provider + if the new provider does not have +- **Multiple Data Sources**: The multi-provider allows you to seamlessly combine many sources of flagging data, such as + environment variables, local files, database values and SaaS hosted feature management systems. + +# Usage + +```go +import ( + "github.com/open-feature/go-sdk/openfeature" + "github.com/open-feature/go-sdk/openfeature/multi" + "github.com/open-feature/go-sdk/openfeature/memprovider" +) + +providers := make(multi.ProviderMap) +providers["providerA"] = memprovider.NewInMemoryProvider(map[string]memprovider.InMemoryFlag{}) +providers["providerB"] = myCustomProvider +mprovider, err := multi.NewProvider(providers, multi.StrategyFirstMatch) +if err != nil { + return err +} + +openfeature.SetNamedProviderAndWait("multiprovider", mprovider) +``` + +# Strategies + +There are three strategies that are defined by the spec and are available within this multi-provider implementation. In +addition to the three provided strategies a custom strategy can be defined as well. + +The three provided strategies are: + +- _First Match_ +- _First Success_ +- _Comparison_ + +## First Match Strategy + +The first match strategy works by **sequentially** calling each provider until a valid result is returned. +The first provider that returns a result will be used. It will try calling the next provider whenever it encounters a `FLAG_NOT_FOUND` +error. However, if a provider returns an error other than `FLAG_NOT_FOUND` the provider will stop and return the default +value along with setting the error details if a detailed request is issued. + +## First Success Strategy + +The first success strategy also works by calling each provider **sequentially**. The first provider that returns a response +with no errors is used. This differs from the first match strategy in that any provider raising an error will not halt +calling the next provider if a successful result has not yet been encountered. If no provider provides a successful result +the default value will be returned to the caller. + +## Comparison Strategy + +The comparison strategy works by calling each provider in **parallel**. All results are collected from each provider and +then the resolved results are compared to each other. If they all agree then that value is returned. If not a fallback +provider can be specified to be executed instead or the default value will be returned. If a provider returns +`FLAG_NOT_FOUND` that result will not be included in the comparison. If all providers return not found then the default +value is returned. Finally, if any provider returns an error other than `FLAG_NOT_FOUND` the evaluation immediately stops +and that error result is returned with the default value. + +The fallback provider can be set using the `WithFallbackProvider` [`Option`](#options). + +Special care must be taken when this strategy is used with `ObjectEvaluation`. If the resulting value is not a +[`comparable`](https://go.dev/blog/comparable) type then the default result or fallback provider will always be used. In +order to evaluate non `comparable` types a `Comparator` function must be provided as an `Option` to the constructor. + +## Custom Strategies + +A custom strategy can be defined using the `WithCustomStrategy` `Option` along with the `StrategyCustom` constant. +A custom strategy is defined by the following generic function signature: + +```go +StrategyFn[T FlagTypes] func(ctx context.Context, flag string, defaultValue T, flatCtx openfeature.FlattenedContext) openfeature.GenericResolutionDetail[T] +``` + +However, this doesn't provide any way to retrieve the providers! Therefore, there's the type `StrategyConstructor` that +is called for you to close over the providers inside your `StratetegyFn` implementation. + +```go +type StrategyConstructor func(providers []*NamedProvider) StrategyFn[FlagTypes] +``` + +Build your strategy to wrap around the slice of providers +```go +option := multi.WithCustomStrategy(func(providers []*NamedProvider) StrategyFn[FlagTypes] { + return func[T FlagTypes](ctx context.Context, flag string, defaultValue T, flatCtx openfeature.FlattenedContext) openfeature.GenericResolutionDetail[T] { + // implementation + // ... + } +}) +``` + +It is highly recommended to use the provided exposed functions to build your custom strategy. Specifically, the functions +`BuildDefaultResult` & `Evaluate` are exposed for those implementing their own custom strategies. + +The `Evaluate` method should be used for evaluating the result of a single `NamedProvider`. It determines the evaluation +type via the type of the generic `defaultVal` parameter. + +The `BuildDefaultResult` method should be called when an error is encountered or the strategy "fails" and needs to return +the default result passed to one of the Evaluation methods of `openfeature.FeatureProvider`. + +# Options + +The `multi.NewProvider` constructor implements the optional pattern for setting additional configuration. + +## General Options + +### `WithLogger` + +Allows for providing a `slog.Logger` instance for internal logging of the multi-provider's evaluation processing for debugging +purposes. By default, are logs are discarded unless this option is used. + +### `WithCustomStrategy` + +Allows for setting a custom strategy function for the evaluation of providers. This must be used in conjunction with the +`StrategyCustom` `EvaluationStrategy` parameter. The option itself takes a `StrategyConstructor` function, which is +essentially a factory that allows the `StrategyFn` to wrap around a slice of `NamedProvider` instances. + +### `WithGlobalHooks` + +Allows for setting global hooks for the multi-provider. These are `openfeature.Hook` implementations that affect +**all** internal `FeatureProvider` instances. + +### `WithProviderHooks` + +Allows for setting `openfeature.Hook` implementations on a specific named `FeatureProvider` within the multi-provider. +This should only be used when hooks need to be attached to a `FeatureProvider` instance that does not implement that functionality. +Using a provider name that is not known will cause an error to be returned during the creation time. This option can be +used multiple times using unique provider names. + +## `StrategyComparision` specific options + +There are two options specifically for usage with the `StrategyComparision` `EvaluationStrategy`. If these options are +used with a different `EvaluationStrategy` they are ignored. + +### `WithFallbackProvider` + +When the results are not in agreement with each other the fallback provider will be called. The result of this provider +is what will be returned to the caller. If no fallback provider is set then the default value will be returned instead. + +### `WithCustomComparator` + +When using `ObjectEvaluation` there are cases where the results are not able to be compared to each other by default. +This happens if the returned type is not `comparable`. In that situation all the results are passed to the custom `Comparator` +to evaluate if they are in agreement or not. If not provided and the return type is not `comparable` then either the fallback +provider is used or the default value. diff --git a/openfeature/multi/comparison_strategy.go b/openfeature/multi/comparison_strategy.go new file mode 100644 index 00000000..4df93290 --- /dev/null +++ b/openfeature/multi/comparison_strategy.go @@ -0,0 +1,249 @@ +package multi + +import ( + "context" + "errors" + "reflect" + "slices" + "strings" + + of "github.com/open-feature/go-sdk/openfeature" + "golang.org/x/sync/errgroup" +) + +// ErrAggregationNotAllowed is an error returned if [of.FeatureProvider.ObjectEvaluation] is called using the [StrategyComparison] +// strategy without a custom [Comparator] function configured when response objects are not comparable. +var ErrAggregationNotAllowed = errors.New(errAggregationNotAllowedText) + +// Comparator is used to compare the results of [of.FeatureProvider.ObjectEvaluation]. +// This is required if returned results are not comparable. +type Comparator func(values []any) bool + +// newComparisonStrategy returns a [StrategyComparison] strategy function. The fallback provider specified is called when +// there is a comparison failure -- prior to returning a default result. The [Comparator] parameter is optional and nil +// can be passed as long as ObjectEvaluation is never called with objects that are not comparable. The custom [Comparator] +// will only be used for [of.FeatureProvider.ObjectEvaluation] if set. If [of.FeatureProvider.ObjectEvaluation] is +// called without setting a [Comparator], and the returned object(s) are not comparable, then a panic will occur. +func newComparisonStrategy(providers []*NamedProvider, fallbackProvider of.FeatureProvider, comparator Comparator) StrategyFn[FlagTypes] { + return evaluateComparison[FlagTypes](providers, fallbackProvider, comparator) +} + +func defaultComparator(values []any) bool { + if len(values) == 0 { + return false + } + current := values[0] + + switch current.(type) { + case int8, int16, int32, int64, int, uint8, uint16, uint32, uint64, uint, uintptr, float32, float64, string, bool: + for i, v := range values { + if i == 0 { + continue + } + if v != current { + return false + } + } + return true + default: + if current == nil { + return false // nilable values are not comparable + } + t := reflect.TypeOf(current) + if t.Comparable() { + set := map[any]struct{}{} + for _, v := range values { + if v == nil { + return false // nil is not comparable + } + set[v] = struct{}{} + } + + return len(set) == 1 + } + return false + } +} + +func comparisonResolutionError(metadata of.FlagMetadata) of.ResolutionError { + if isDefault, err := metadata.GetBool(MetadataIsDefaultValue); err != nil || !isDefault { + return of.ResolutionError{} + } + + if notFound, err := metadata.GetString(MetadataSuccessfulProviderName); err == nil && notFound == "none" { + return of.NewFlagNotFoundResolutionError("not found in any providers") + } + + if evalErr, err := metadata.GetString(MetadataEvaluationError); err == nil && evalErr != "" { + return of.NewGeneralResolutionError(evalErr) + } + + return of.NewGeneralResolutionError("comparison failure") +} + +func evaluateComparison[T FlagTypes](providers []*NamedProvider, fallbackProvider of.FeatureProvider, comparator Comparator) StrategyFn[T] { + return func(ctx context.Context, flag string, defaultValue T, evalCtx of.FlattenedContext) of.GenericResolutionDetail[T] { + if comparator == nil { + comparator = defaultComparator + switch any(defaultValue).(type) { + case int8, int16, int32, int64, int, uint8, uint16, uint32, uint64, uint, uintptr, float32, float64, string, bool: + break + default: + t := reflect.TypeOf(defaultValue) + if !t.Comparable() { + // Impossible to evaluate strategy with expected result type + defaultResult := BuildDefaultResult(StrategyComparison, defaultValue, ErrAggregationNotAllowed) + defaultResult.FlagMetadata[MetadataFallbackUsed] = false + defaultResult.FlagMetadata[MetadataIsDefaultValue] = true + return defaultResult + } + } + } + + // Short circuit if there's only one provider as no comparison nor workers are needed + if len(providers) == 1 { + result := Evaluate(ctx, providers[0], flag, defaultValue, evalCtx) + metadata := setFlagMetadata(StrategyComparison, providers[0].Name, make(of.FlagMetadata)) + metadata[MetadataFallbackUsed] = false + result.FlagMetadata = mergeFlagMeta(result.FlagMetadata, metadata) + return result + } + + type namedResult struct { + name string + res *of.GenericResolutionDetail[T] + } + + resultChan := make(chan *namedResult, len(providers)) + notFoundChan := make(chan any) + errGrp, grpCtx := errgroup.WithContext(ctx) + for _, provider := range providers { + closedProvider := provider + errGrp.Go(func() error { + result := Evaluate(grpCtx, closedProvider, flag, defaultValue, evalCtx) + notFound := result.ResolutionDetail().ErrorCode == of.FlagNotFoundCode + if !notFound && result.Error() != nil { + return &ProviderError{ + ProviderName: closedProvider.Name, + Err: result.Error(), + } + } + if !notFound { + resultChan <- &namedResult{ + name: closedProvider.Name, + res: &result, + } + } else { + notFoundChan <- struct{}{} + } + return nil + }) + } + + results := make([]namedResult, 0, len(providers)) + resultValues := make([]T, 0, len(providers)) + notFoundCount := 0 + + ListenerLoop: + for { + select { + case <-grpCtx.Done(): + // Error occurred + result := BuildDefaultResult(StrategyComparison, defaultValue, grpCtx.Err()) + result.FlagMetadata[MetadataFallbackUsed] = false + result.FlagMetadata[MetadataIsDefaultValue] = true + result.FlagMetadata[MetadataEvaluationError] = grpCtx.Err().Error() + result.ResolutionError = comparisonResolutionError(result.FlagMetadata) + return result + case r := <-resultChan: + results = append(results, *r) + resultValues = append(resultValues, r.res.Value) + if (len(results) + notFoundCount) == len(providers) { + // All results accounted for + break ListenerLoop + } + case <-notFoundChan: + notFoundCount += 1 + if notFoundCount == len(providers) { + result := BuildDefaultResult(StrategyComparison, defaultValue, nil) + result.FlagMetadata[MetadataFallbackUsed] = false + result.FlagMetadata[MetadataIsDefaultValue] = true + result.ResolutionError = comparisonResolutionError(result.FlagMetadata) + return result + } + if (len(results) + notFoundCount) == len(providers) { + // All results accounted for + break ListenerLoop + } + } + } + // Evaluate Results Are Equal + metadata := make(of.FlagMetadata) + metadata[MetadataStrategyUsed] = StrategyComparison + // Build Aggregate metadata key'd by their names of all Providers + for _, r := range results { + metadata[r.name] = r.res.FlagMetadata + } + resultsForComparison := make([]any, 0, len(resultValues)) + for _, r := range resultValues { + resultsForComparison = append(resultsForComparison, r) + } + if comparator(resultsForComparison) { + metadata[MetadataFallbackUsed] = false + metadata[MetadataIsDefaultValue] = false + metadata[MetadataComparisonDisagreeingProviders] = []string{} + success := make([]string, 0, len(providers)) + variants := make([]string, 0, len(providers)) + // Gather metadata from provider results + for _, r := range results { + metadata[r.name] = r.res.FlagMetadata + success = append(success, r.name) + variants = append(variants, r.res.Variant) + } + // maintain stable order of metadata results + slices.Sort(success) + metadata[MetadataSuccessfulProviderNames] = strings.Join(success, ", ") + // Unique values only + slices.Sort(variants) + variants = slices.Compact(variants) + var variantResults string + if len(variants) == 1 { + variantResults = variants[0] + } else { + variantResults = strings.Join(variants, ", ") + } + return of.GenericResolutionDetail[T]{ + Value: resultValues[0], // All values should be equal + ProviderResolutionDetail: of.ProviderResolutionDetail{ + Reason: ReasonAggregated, + Variant: variantResults, + FlagMetadata: metadata, + }, + } + } + + if fallbackProvider != nil { + fallbackResult := Evaluate( + ctx, + &NamedProvider{Name: "fallback", FeatureProvider: fallbackProvider}, + flag, + defaultValue, + evalCtx, + ) + fallbackResult.FlagMetadata = mergeFlagMeta(fallbackResult.FlagMetadata, metadata) + fallbackResult.FlagMetadata[MetadataFallbackUsed] = true + fallbackResult.FlagMetadata[MetadataIsDefaultValue] = false + fallbackResult.FlagMetadata[MetadataSuccessfulProviderName] = "fallback" + fallbackResult.FlagMetadata[MetadataStrategyUsed] = StrategyComparison + fallbackResult.Reason = ReasonAggregatedFallback + return fallbackResult + } + + defaultResult := BuildDefaultResult(StrategyComparison, defaultValue, errors.New("no fallback provider configured")) + mergeFlagMeta(defaultResult.FlagMetadata, metadata) + defaultResult.FlagMetadata[MetadataFallbackUsed] = false + defaultResult.FlagMetadata[MetadataIsDefaultValue] = true + + return defaultResult + } +} diff --git a/openfeature/multi/comparison_strategy_test.go b/openfeature/multi/comparison_strategy_test.go new file mode 100644 index 00000000..f6a76c8a --- /dev/null +++ b/openfeature/multi/comparison_strategy_test.go @@ -0,0 +1,742 @@ +package multi + +import ( + "context" + "fmt" + "testing" + + of "github.com/open-feature/go-sdk/openfeature" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" +) + +func configureComparisonProvider[R any](provider *of.MockFeatureProvider, resultVal R, state bool, error int, forceObj bool) { + var rErr of.ResolutionError + var variant string + var reason of.Reason + switch error { + case TestErrorError: + rErr = of.NewGeneralResolutionError("test error") + reason = of.ErrorReason + case TestErrorNotFound: + rErr = of.NewFlagNotFoundResolutionError("not found") + reason = of.DefaultReason + } + if state { + variant = "on" + } else { + variant = "off" + } + details := of.ProviderResolutionDetail{ + ResolutionError: rErr, + Reason: reason, + Variant: variant, + FlagMetadata: make(of.FlagMetadata), + } + provider.EXPECT().Metadata().Return(of.Metadata{Name: "mock provider"}).MaxTimes(1) + objFunc := func(p *of.MockFeatureProvider) { + p.EXPECT().ObjectEvaluation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(c context.Context, flag string, defaultVal any, evalCtx of.FlattenedContext) of.InterfaceResolutionDetail { + return of.InterfaceResolutionDetail{ + Value: resultVal, + ProviderResolutionDetail: details, + } + }).MaxTimes(1) + } + + if forceObj { + objFunc(provider) + return + } + + switch any(resultVal).(type) { + case bool: + provider.EXPECT().BooleanEvaluation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(c context.Context, flag string, defaultVal bool, evalCtx of.FlattenedContext) of.BoolResolutionDetail { + return of.BoolResolutionDetail{ + Value: any(resultVal).(bool), + ProviderResolutionDetail: details, + } + }).MaxTimes(1) + case string: + provider.EXPECT().StringEvaluation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(c context.Context, flag string, defaultVal string, evalCtx of.FlattenedContext) of.StringResolutionDetail { + return of.StringResolutionDetail{ + Value: any(resultVal).(string), + ProviderResolutionDetail: details, + } + }).MaxTimes(1) + case int64: + provider.EXPECT().IntEvaluation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(c context.Context, flag string, defaultVal int64, evalCtx of.FlattenedContext) of.IntResolutionDetail { + return of.IntResolutionDetail{ + Value: any(resultVal).(int64), + ProviderResolutionDetail: details, + } + }).MaxTimes(1) + case float64: + provider.EXPECT().FloatEvaluation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(c context.Context, flag string, defaultVal float64, evalCtx of.FlattenedContext) of.FloatResolutionDetail { + return of.FloatResolutionDetail{ + Value: any(resultVal).(float64), + ProviderResolutionDetail: details, + } + }).MaxTimes(1) + default: + objFunc(provider) + } +} + +func Test_ComparisonStrategy_Evaluation(t *testing.T) { + tests := []struct { + kind of.Type + successVal FlagTypes + defaultVal FlagTypes + }{ + {of.Boolean, true, false}, + {of.String, "success", "default"}, + {of.Int, int64(1234), int64(0)}, + {of.Float, float64(12.34), float64(0)}, + {of.Object, struct{ Field string }{Field: "test"}, struct{}{}}, + } + for _, tt := range tests { + t.Run(tt.kind.String(), func(t *testing.T) { + successVal := tt.successVal + defaultVal := tt.defaultVal + t.Run("single success", func(t *testing.T) { + ctrl := gomock.NewController(t) + provider := of.NewMockFeatureProvider(ctrl) + fallback := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider, successVal, true, TestErrorNone, false) + + strategy := newComparisonStrategy([]*NamedProvider{ + { + Name: "test-provider", + FeatureProvider: provider, + }, + }, fallback, nil) + + result := strategy(context.Background(), testFlag, defaultVal, of.FlattenedContext{}) + assert.Equal(t, successVal, result.Value) + assert.Contains(t, result.FlagMetadata, MetadataStrategyUsed) + assert.Equal(t, StrategyComparison, result.FlagMetadata[MetadataStrategyUsed]) + assert.Contains(t, result.FlagMetadata, MetadataSuccessfulProviderName) + assert.Equal(t, "test-provider", result.FlagMetadata[MetadataSuccessfulProviderName]) + assert.False(t, result.FlagMetadata[MetadataFallbackUsed].(bool)) + }) + + t.Run("two success", func(t *testing.T) { + ctrl := gomock.NewController(t) + fallback := of.NewMockFeatureProvider(ctrl) + provider1 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider1, successVal, true, TestErrorNone, false) + provider2 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider2, successVal, true, TestErrorNone, false) + + strategy := newComparisonStrategy([]*NamedProvider{ + { + Name: "test-provider1", + FeatureProvider: provider1, + }, + { + Name: "test-provider2", + FeatureProvider: provider2, + }, + }, fallback, nil) + + result := strategy(context.Background(), testFlag, defaultVal, of.FlattenedContext{}) + assert.Equal(t, successVal, result.Value) + assert.Contains(t, result.FlagMetadata, MetadataStrategyUsed) + assert.Equal(t, StrategyComparison, result.FlagMetadata[MetadataStrategyUsed]) + assert.Contains(t, result.FlagMetadata, MetadataSuccessfulProviderNames) + assert.Equal(t, "test-provider1, test-provider2", result.FlagMetadata[MetadataSuccessfulProviderNames]) + assert.False(t, result.FlagMetadata[MetadataFallbackUsed].(bool)) + }) + + t.Run("multiple success", func(t *testing.T) { + ctrl := gomock.NewController(t) + fallback := of.NewMockFeatureProvider(ctrl) + fallback.EXPECT().IntEvaluation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0) + provider1 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider1, successVal, true, TestErrorNone, false) + provider2 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider2, successVal, true, TestErrorNone, false) + provider3 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider3, successVal, true, TestErrorNone, false) + + strategy := newComparisonStrategy([]*NamedProvider{ + { + Name: "test-provider1", + FeatureProvider: provider1, + }, + { + Name: "test-provider2", + FeatureProvider: provider2, + }, + { + Name: "test-provider3", + FeatureProvider: provider3, + }, + }, fallback, nil) + + result := strategy(context.Background(), testFlag, defaultVal, of.FlattenedContext{}) + assert.Equal(t, successVal, result.Value) + assert.Contains(t, result.FlagMetadata, MetadataStrategyUsed) + assert.Equal(t, StrategyComparison, result.FlagMetadata[MetadataStrategyUsed]) + assert.Contains(t, result.FlagMetadata, MetadataSuccessfulProviderNames) + assert.Equal(t, "test-provider1, test-provider2, test-provider3", result.FlagMetadata[MetadataSuccessfulProviderNames]) + assert.False(t, result.FlagMetadata[MetadataFallbackUsed].(bool)) + }) + + t.Run("multiple not found with single success", func(t *testing.T) { + ctrl := gomock.NewController(t) + fallback := of.NewMockFeatureProvider(ctrl) + provider1 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider1, defaultVal, true, TestErrorNotFound, false) + provider2 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider2, defaultVal, true, TestErrorNotFound, false) + provider3 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider3, successVal, true, TestErrorNone, false) + + strategy := newComparisonStrategy([]*NamedProvider{ + { + Name: "test-provider1", + FeatureProvider: provider1, + }, + { + Name: "test-provider2", + FeatureProvider: provider2, + }, + { + Name: "test-provider3", + FeatureProvider: provider3, + }, + }, fallback, nil) + + result := strategy(context.Background(), testFlag, defaultVal, of.FlattenedContext{}) + assert.Equal(t, successVal, result.Value) + assert.Contains(t, result.FlagMetadata, MetadataStrategyUsed) + assert.Equal(t, StrategyComparison, result.FlagMetadata[MetadataStrategyUsed]) + assert.Contains(t, result.FlagMetadata, MetadataSuccessfulProviderNames) + assert.Equal(t, "test-provider3", result.FlagMetadata[MetadataSuccessfulProviderNames]) + assert.Contains(t, result.FlagMetadata, MetadataFallbackUsed) + assert.False(t, result.FlagMetadata[MetadataFallbackUsed].(bool)) + }) + + t.Run("multiple not found with multiple success", func(t *testing.T) { + ctrl := gomock.NewController(t) + fallback := of.NewMockFeatureProvider(ctrl) + fallback.EXPECT().IntEvaluation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0) + provider1 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider1, defaultVal, true, TestErrorNotFound, false) + provider2 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider2, defaultVal, true, TestErrorNotFound, false) + provider3 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider3, successVal, true, TestErrorNone, false) + provider4 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider4, successVal, true, TestErrorNone, false) + + strategy := newComparisonStrategy([]*NamedProvider{ + { + Name: "test-provider1", + FeatureProvider: provider1, + }, + { + Name: "test-provider2", + FeatureProvider: provider2, + }, + { + Name: "test-provider3", + FeatureProvider: provider3, + }, + { + Name: "test-provider4", + FeatureProvider: provider4, + }, + }, fallback, nil) + + result := strategy(context.Background(), testFlag, defaultVal, of.FlattenedContext{}) + assert.Equal(t, successVal, result.Value) + assert.Contains(t, result.FlagMetadata, MetadataStrategyUsed) + assert.Equal(t, StrategyComparison, result.FlagMetadata[MetadataStrategyUsed]) + assert.Contains(t, result.FlagMetadata, MetadataSuccessfulProviderNames) + assert.Equal(t, "test-provider3, test-provider4", result.FlagMetadata[MetadataSuccessfulProviderNames]) + assert.Contains(t, result.FlagMetadata, MetadataFallbackUsed) + assert.False(t, result.FlagMetadata[MetadataFallbackUsed].(bool)) + }) + + t.Run("comparison failure uses fallback", func(t *testing.T) { + ctrl := gomock.NewController(t) + fallback := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(fallback, successVal, true, TestErrorNone, false) + provider1 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider1, defaultVal, true, TestErrorNone, false) + provider2 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider2, defaultVal, true, TestErrorNone, false) + provider3 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider3, successVal, true, TestErrorNone, false) + + strategy := newComparisonStrategy([]*NamedProvider{ + { + Name: "test-provider1", + FeatureProvider: provider1, + }, + { + Name: "test-provider2", + FeatureProvider: provider2, + }, + { + Name: "test-provider3", + FeatureProvider: provider3, + }, + }, fallback, nil) + + result := strategy(context.Background(), testFlag, defaultVal, of.FlattenedContext{}) + assert.Equal(t, successVal, result.Value) + assert.Contains(t, result.FlagMetadata, MetadataStrategyUsed) + assert.Equal(t, StrategyComparison, result.FlagMetadata[MetadataStrategyUsed]) + assert.NotContains(t, result.FlagMetadata, MetadataSuccessfulProviderNames) + assert.Contains(t, result.FlagMetadata, MetadataSuccessfulProviderName) + assert.Equal(t, "fallback", result.FlagMetadata[MetadataSuccessfulProviderName]) + assert.True(t, result.FlagMetadata[MetadataFallbackUsed].(bool)) + }) + + t.Run("not found all providers", func(t *testing.T) { + ctrl := gomock.NewController(t) + fallback := of.NewMockFeatureProvider(ctrl) + fallback.EXPECT().FloatEvaluation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0) + provider1 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider1, defaultVal, true, TestErrorNotFound, false) + provider2 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider2, defaultVal, true, TestErrorNotFound, false) + + strategy := newComparisonStrategy([]*NamedProvider{ + { + Name: "test-provider1", + FeatureProvider: provider1, + }, + { + Name: "test-provider2", + FeatureProvider: provider2, + }, + }, fallback, nil) + + result := strategy(context.Background(), testFlag, defaultVal, of.FlattenedContext{}) + assert.Equal(t, defaultVal, result.Value) + assert.Contains(t, result.FlagMetadata, MetadataStrategyUsed) + assert.Equal(t, StrategyComparison, result.FlagMetadata[MetadataStrategyUsed]) + assert.NotContains(t, result.FlagMetadata, MetadataSuccessfulProviderNames) + assert.Contains(t, result.FlagMetadata, MetadataSuccessfulProviderName) + assert.Equal(t, "none", result.FlagMetadata[MetadataSuccessfulProviderName]) + assert.Contains(t, result.FlagMetadata, MetadataFallbackUsed) + assert.False(t, result.FlagMetadata[MetadataFallbackUsed].(bool)) + }) + + t.Run("comparison failure with not found", func(t *testing.T) { + ctrl := gomock.NewController(t) + fallback := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(fallback, successVal, true, TestErrorNone, false) + provider1 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider1, defaultVal, true, TestErrorNotFound, false) + provider2 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider2, defaultVal, true, TestErrorNotFound, false) + provider3 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider3, successVal, true, TestErrorNone, false) + provider4 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider4, defaultVal, true, TestErrorNone, false) + + strategy := newComparisonStrategy([]*NamedProvider{ + { + Name: "test-provider1", + FeatureProvider: provider1, + }, + { + Name: "test-provider2", + FeatureProvider: provider2, + }, + { + Name: "test-provider3", + FeatureProvider: provider3, + }, + { + Name: "test-provider4", + FeatureProvider: provider4, + }, + }, fallback, nil) + + result := strategy(context.Background(), testFlag, defaultVal, of.FlattenedContext{}) + assert.Equal(t, successVal, result.Value) + assert.Contains(t, result.FlagMetadata, MetadataStrategyUsed) + assert.Equal(t, StrategyComparison, result.FlagMetadata[MetadataStrategyUsed]) + assert.NotContains(t, result.FlagMetadata, MetadataSuccessfulProviderNames) + assert.Contains(t, result.FlagMetadata, MetadataSuccessfulProviderName) + assert.Equal(t, "fallback", result.FlagMetadata[MetadataSuccessfulProviderName]) + assert.Contains(t, result.FlagMetadata, MetadataFallbackUsed) + assert.True(t, result.FlagMetadata[MetadataFallbackUsed].(bool)) + }) + + t.Run("non FLAG_NOT_FOUND error causes default", func(t *testing.T) { + ctrl := gomock.NewController(t) + fallback := of.NewMockFeatureProvider(ctrl) + provider1 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider1, successVal, true, TestErrorError, false) + provider2 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider2, defaultVal, true, TestErrorError, false) + + strategy := newComparisonStrategy([]*NamedProvider{ + { + Name: "test-provider1", + FeatureProvider: provider1, + }, + { + Name: "test-provider2", + FeatureProvider: provider2, + }, + }, fallback, nil) + + result := strategy(context.Background(), testFlag, defaultVal, of.FlattenedContext{}) + assert.Equal(t, defaultVal, result.Value) + assert.Equal(t, of.ErrorReason, result.Reason) + assert.Contains(t, result.FlagMetadata, MetadataStrategyUsed) + assert.Equal(t, StrategyComparison, result.FlagMetadata[MetadataStrategyUsed]) + assert.NotContains(t, result.FlagMetadata, MetadataSuccessfulProviderNames) + assert.Contains(t, result.FlagMetadata, MetadataEvaluationError) + assert.Contains(t, result.FlagMetadata, MetadataSuccessfulProviderName) + assert.Equal(t, "none", result.FlagMetadata[MetadataSuccessfulProviderName]) + assert.False(t, result.FlagMetadata[MetadataFallbackUsed].(bool)) + }) + }) + } +} + +func Test_ComparisonStrategy_ObjectEvaluation(t *testing.T) { + successVal := struct{ Name string }{Name: "test"} + defaultVal := struct{}{} + + type orderableTestCase struct { + typeName string + successValue any + defaultValue any + } + + orderableTests := []orderableTestCase{ + { + typeName: "int8", + successValue: int8(5), + defaultValue: int8(1), + }, + { + typeName: "int16", + successValue: int16(5), + defaultValue: int16(1), + }, + { + typeName: "int32", + successValue: int32(5), + defaultValue: int32(1), + }, + // { + // typeName: "int64", + // successValue: int64(5), + // defaultValue: int64(1), + // }, + { + typeName: "int", + successValue: 5, + defaultValue: 1, + }, + { + typeName: "uint8", + successValue: uint8(5), + defaultValue: uint8(1), + }, + { + typeName: "uint16", + successValue: uint16(5), + defaultValue: uint16(1), + }, + { + typeName: "uint32", + successValue: uint32(5), + defaultValue: uint32(1), + }, + { + typeName: "uint64", + successValue: uint64(5), + defaultValue: uint64(1), + }, + { + typeName: "uint", + successValue: uint(5), + defaultValue: uint(1), + }, + { + typeName: "uintptr", + successValue: uintptr(5), + defaultValue: uintptr(1), + }, + { + typeName: "float32", + successValue: float32(5.5), + defaultValue: float32(1.1), + }, + // { + // typeName: "float64", + // successValue: 5.5, + // defaultValue: 1.1, + // }, + // { + // typeName: "string", + // successValue: "success", + // defaultValue: "default", + // }, + } + + for _, testCase := range orderableTests { + tc := testCase + t.Run(fmt.Sprintf("with orderable type %s success", tc.typeName), func(t *testing.T) { + ctrl := gomock.NewController(t) + fallback := of.NewMockFeatureProvider(ctrl) + provider1 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider1, testCase.successValue, true, TestErrorNone, true) + provider2 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider2, testCase.successValue, true, TestErrorNone, true) + + strategy := newComparisonStrategy([]*NamedProvider{ + { + Name: "test-provider1", + FeatureProvider: provider1, + }, + { + Name: "test-provider2", + FeatureProvider: provider2, + }, + }, fallback, nil) + + result := strategy(context.Background(), testFlag, tc.defaultValue, of.FlattenedContext{}) + assert.Equal(t, tc.successValue, result.Value) + assert.NoError(t, result.Error()) + assert.Equal(t, ReasonAggregated, result.Reason) + assert.Contains(t, result.FlagMetadata, MetadataStrategyUsed) + assert.Equal(t, StrategyComparison, result.FlagMetadata[MetadataStrategyUsed]) + assert.Contains(t, result.FlagMetadata, MetadataSuccessfulProviderNames) + assert.Equal(t, "test-provider1, test-provider2", result.FlagMetadata[MetadataSuccessfulProviderNames]) + assert.False(t, result.FlagMetadata[MetadataFallbackUsed].(bool)) + }) + + t.Run(fmt.Sprintf("with orderable type %s no match fallback", tc.typeName), func(t *testing.T) { + ctrl := gomock.NewController(t) + fallback := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(fallback, tc.successValue, true, TestErrorNone, true) + provider1 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider1, tc.successValue, true, TestErrorNone, true) + provider2 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider2, tc.defaultValue, true, TestErrorNone, true) + + strategy := newComparisonStrategy([]*NamedProvider{ + { + Name: "test-provider1", + FeatureProvider: provider1, + }, + { + Name: "test-provider2", + FeatureProvider: provider2, + }, + }, fallback, nil) + result := strategy(context.Background(), testFlag, tc.defaultValue, of.FlattenedContext{}) + assert.Equal(t, tc.successValue, result.Value) + assert.NoError(t, result.Error()) + assert.Equal(t, ReasonAggregatedFallback, result.Reason) + assert.Contains(t, result.FlagMetadata, MetadataStrategyUsed) + assert.Equal(t, StrategyComparison, result.FlagMetadata[MetadataStrategyUsed]) + assert.Contains(t, result.FlagMetadata, MetadataSuccessfulProviderName) + assert.Equal(t, "fallback", result.FlagMetadata[MetadataSuccessfulProviderName]) + assert.True(t, result.FlagMetadata[MetadataFallbackUsed].(bool)) + }) + } + + t.Run("with comparable custom type success", func(t *testing.T) { + ctrl := gomock.NewController(t) + fallback := of.NewMockFeatureProvider(ctrl) + provider1 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider1, successVal, true, TestErrorNone, true) + provider2 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider2, successVal, true, TestErrorNone, true) + + strategy := newComparisonStrategy([]*NamedProvider{ + { + Name: "test-provider1", + FeatureProvider: provider1, + }, + { + Name: "test-provider2", + FeatureProvider: provider2, + }, + }, fallback, nil) + + result := strategy(context.Background(), testFlag, defaultVal, of.FlattenedContext{}) + assert.Equal(t, successVal, result.Value) + assert.NoError(t, result.Error()) + assert.Equal(t, ReasonAggregated, result.Reason) + assert.Contains(t, result.FlagMetadata, MetadataStrategyUsed) + assert.Equal(t, StrategyComparison, result.FlagMetadata[MetadataStrategyUsed]) + assert.Contains(t, result.FlagMetadata, MetadataSuccessfulProviderNames) + assert.Equal(t, "test-provider1, test-provider2", result.FlagMetadata[MetadataSuccessfulProviderNames]) + assert.False(t, result.FlagMetadata[MetadataFallbackUsed].(bool)) + }) + + t.Run("with comparable custom type no match fallback", func(t *testing.T) { + ctrl := gomock.NewController(t) + fallback := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(fallback, successVal, true, TestErrorNone, true) + provider1 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider1, successVal, true, TestErrorNone, true) + provider2 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider2, defaultVal, true, TestErrorNone, true) + + strategy := newComparisonStrategy([]*NamedProvider{ + { + Name: "test-provider1", + FeatureProvider: provider1, + }, + { + Name: "test-provider2", + FeatureProvider: provider2, + }, + }, fallback, nil) + result := strategy(context.Background(), testFlag, defaultVal, of.FlattenedContext{}) + assert.Equal(t, successVal, result.Value) + assert.NoError(t, result.Error()) + assert.Equal(t, ReasonAggregatedFallback, result.Reason) + assert.Contains(t, result.FlagMetadata, MetadataStrategyUsed) + assert.Equal(t, StrategyComparison, result.FlagMetadata[MetadataStrategyUsed]) + assert.Contains(t, result.FlagMetadata, MetadataSuccessfulProviderName) + assert.Equal(t, "fallback", result.FlagMetadata[MetadataSuccessfulProviderName]) + assert.True(t, result.FlagMetadata[MetadataFallbackUsed].(bool)) + }) + + t.Run("with comparable custom type force custom comparator", func(t *testing.T) { + ctrl := gomock.NewController(t) + fallback := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(fallback, defaultVal, true, TestErrorNone, true) + provider1 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider1, successVal, true, TestErrorNone, true) + provider2 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider2, successVal, true, TestErrorNone, true) + + strategy := newComparisonStrategy([]*NamedProvider{ + { + Name: "test-provider1", + FeatureProvider: provider1, + }, + { + Name: "test-provider2", + FeatureProvider: provider2, + }, + }, fallback, func(val []any) bool { + return true + }) + result := strategy(context.Background(), testFlag, defaultVal, of.FlattenedContext{}) + assert.Equal(t, successVal, result.Value) + assert.NoError(t, result.Error()) + assert.Equal(t, ReasonAggregated, result.Reason) + assert.Contains(t, result.FlagMetadata, MetadataStrategyUsed) + assert.Equal(t, StrategyComparison, result.FlagMetadata[MetadataStrategyUsed]) + assert.Contains(t, result.FlagMetadata, MetadataSuccessfulProviderNames) + assert.False(t, result.FlagMetadata[MetadataFallbackUsed].(bool)) + }) + + t.Run("with non comparable types using custom comparator success", func(t *testing.T) { + successVal := []string{"test1", "test2"} + defaultVal := []string{"test3"} + ctrl := gomock.NewController(t) + fallback := of.NewMockFeatureProvider(ctrl) + provider1 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider1, successVal, true, TestErrorNone, false) + provider2 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider2, successVal, true, TestErrorNone, false) + + strategy := newComparisonStrategy([]*NamedProvider{ + { + Name: "test-provider1", + FeatureProvider: provider1, + }, + { + Name: "test-provider2", + FeatureProvider: provider2, + }, + }, fallback, func(val []any) bool { + return true + }) + + result := strategy(context.Background(), testFlag, defaultVal, of.FlattenedContext{}) + assert.Equal(t, successVal, result.Value) + assert.NoError(t, result.Error()) + assert.Equal(t, ReasonAggregated, result.Reason) + assert.Contains(t, result.FlagMetadata, MetadataStrategyUsed) + assert.Equal(t, StrategyComparison, result.FlagMetadata[MetadataStrategyUsed]) + assert.Contains(t, result.FlagMetadata, MetadataSuccessfulProviderNames) + assert.Equal(t, "test-provider1, test-provider2", result.FlagMetadata[MetadataSuccessfulProviderNames]) + assert.False(t, result.FlagMetadata[MetadataFallbackUsed].(bool)) + }) + + t.Run("with non comparable types using custom comparator no match fallback", func(t *testing.T) { + successVal := []string{"test1", "test2"} + defaultVal := []string{"test3"} + ctrl := gomock.NewController(t) + fallback := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(fallback, successVal, true, TestErrorNone, false) + provider1 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider1, defaultVal, true, TestErrorNone, false) + provider2 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider2, defaultVal, true, TestErrorNone, false) + + strategy := newComparisonStrategy([]*NamedProvider{ + { + Name: "test-provider1", + FeatureProvider: provider1, + }, + { + Name: "test-provider2", + FeatureProvider: provider2, + }, + }, fallback, func(val []any) bool { + return false + }) + result := strategy(context.Background(), testFlag, defaultVal, of.FlattenedContext{}) + assert.Equal(t, successVal, result.Value) + assert.NoError(t, result.Error()) + assert.Equal(t, ReasonAggregatedFallback, result.Reason) + assert.Contains(t, result.FlagMetadata, MetadataStrategyUsed) + assert.Equal(t, StrategyComparison, result.FlagMetadata[MetadataStrategyUsed]) + assert.Contains(t, result.FlagMetadata, MetadataSuccessfulProviderName) + assert.Equal(t, "fallback", result.FlagMetadata[MetadataSuccessfulProviderName]) + assert.True(t, result.FlagMetadata[MetadataFallbackUsed].(bool)) + }) + + t.Run("any provider error bypasses comparison", func(t *testing.T) { + successVal := []string{"test1", "test2"} + defaultVal := []string{"test3"} + ctrl := gomock.NewController(t) + fallback := of.NewMockFeatureProvider(ctrl) + provider1 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider1, successVal, true, TestErrorNone, false) + provider2 := of.NewMockFeatureProvider(ctrl) + configureComparisonProvider(provider2, successVal, true, TestErrorError, false) + strategy := newComparisonStrategy([]*NamedProvider{ + { + Name: "test-provider1", + FeatureProvider: provider1, + }, + { + Name: "test-provider2", + FeatureProvider: provider2, + }, + }, fallback, nil) + result := strategy(context.Background(), testFlag, defaultVal, of.FlattenedContext{}) + assert.Equal(t, defaultVal, result.Value) + assert.Equal(t, of.ErrorReason, result.Reason) + assert.Equal(t, of.NewGeneralResolutionError(ErrAggregationNotAllowed.Error()), result.ResolutionError) + assert.Contains(t, result.FlagMetadata, MetadataStrategyUsed) + assert.Equal(t, StrategyComparison, result.FlagMetadata[MetadataStrategyUsed]) + assert.NotContains(t, result.FlagMetadata, MetadataSuccessfulProviderNames) + assert.Contains(t, result.FlagMetadata, MetadataSuccessfulProviderName) + assert.Equal(t, "none", result.FlagMetadata[MetadataSuccessfulProviderName]) + assert.False(t, result.FlagMetadata[MetadataFallbackUsed].(bool)) + }) +} diff --git a/openfeature/multi/errors.go b/openfeature/multi/errors.go new file mode 100644 index 00000000..b3150c86 --- /dev/null +++ b/openfeature/multi/errors.go @@ -0,0 +1,46 @@ +package multi + +import ( + "errors" + "fmt" +) + +type ( + // ProviderError is an error wrapper that includes the provider name. + ProviderError struct { + // Err is the original error that was returned from a provider + Err error + // ProviderName is the name of the provider that returned the included error + ProviderName string + } + + // AggregateError is a map that contains up to one error per provider within the multiprovider. + AggregateError []ProviderError +) + +// Compile-time interface compliance checks +var ( + _ error = (*ProviderError)(nil) + _ error = (AggregateError)(nil) +) + +func (e *ProviderError) Error() string { + return fmt.Sprintf("Provider %s: %s", e.ProviderName, e.Err.Error()) +} + +// NewAggregateError creates a new AggregateError from a slice of [ProviderError] instances +func NewAggregateError(providerErrors []ProviderError) AggregateError { + return providerErrors +} + +func (ae AggregateError) Error() string { + if len(ae) == 0 { + return "" + } + + errs := make([]error, 0, len(ae)) + for i := range ae { + errs = append(errs, &ae[i]) + } + return errors.Join(errs...).Error() +} diff --git a/openfeature/multi/errors_test.go b/openfeature/multi/errors_test.go new file mode 100644 index 00000000..2bc75933 --- /dev/null +++ b/openfeature/multi/errors_test.go @@ -0,0 +1,41 @@ +package multi + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_AggregateError_Error(t *testing.T) { + t.Run("empty error", func(t *testing.T) { + err := NewAggregateError([]ProviderError{}) + assert.Empty(t, err.Error()) + }) + + t.Run("single error", func(t *testing.T) { + err := NewAggregateError([]ProviderError{ + { + Err: fmt.Errorf("test error"), + ProviderName: "test-provider", + }, + }) + + assert.Equal(t, "Provider test-provider: test error", err.Error()) + }) + + t.Run("multiple errors", func(t *testing.T) { + err := NewAggregateError([]ProviderError{ + { + Err: fmt.Errorf("test error"), + ProviderName: "test-provider1", + }, + { + Err: fmt.Errorf("test error"), + ProviderName: "test-provider2", + }, + }) + + assert.Equal(t, "Provider test-provider1: test error\nProvider test-provider2: test error", err.Error()) + }) +} diff --git a/openfeature/multi/first_match_strategy.go b/openfeature/multi/first_match_strategy.go new file mode 100644 index 00000000..88546a5e --- /dev/null +++ b/openfeature/multi/first_match_strategy.go @@ -0,0 +1,39 @@ +package multi + +import ( + "context" + + of "github.com/open-feature/go-sdk/openfeature" +) + +// newFirstMatchStrategy returns a [StrategyFn] that returns the result of the first [of.FeatureProvider] whose response is +// not [of.FlagNotFoundCode]. This is executed sequentially, and not in parallel. +func newFirstMatchStrategy(providers []*NamedProvider) StrategyFn[FlagTypes] { + return firstMatchStrategyFn[FlagTypes](providers) +} + +func firstMatchStrategyFn[T FlagTypes](providers []*NamedProvider) StrategyFn[T] { + return func(ctx context.Context, flag string, defaultValue T, flatCtx of.FlattenedContext) of.GenericResolutionDetail[T] { + for _, provider := range providers { + resolution := Evaluate(ctx, provider, flag, defaultValue, flatCtx) + if resolution.Error() != nil && resolution.ResolutionDetail().ErrorCode == of.FlagNotFoundCode { + continue + } + + if resolution.Error() != nil { + resolution.FlagMetadata = mergeFlagMeta(resolution.FlagMetadata, of.FlagMetadata{ + MetadataSuccessfulProviderName: "none", + MetadataStrategyUsed: StrategyFirstMatch, + }) + // Stop evaluation if an error occurs + return resolution + } + + // success! + resolution.FlagMetadata = setFlagMetadata(StrategyFirstMatch, provider.Name, resolution.FlagMetadata) + return resolution + } + + return BuildDefaultResult(StrategyFirstMatch, defaultValue, nil) + } +} diff --git a/openfeature/multi/first_match_strategy_test.go b/openfeature/multi/first_match_strategy_test.go new file mode 100644 index 00000000..69c86828 --- /dev/null +++ b/openfeature/multi/first_match_strategy_test.go @@ -0,0 +1,153 @@ +package multi + +import ( + "context" + "strconv" + "testing" + + of "github.com/open-feature/go-sdk/openfeature" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" +) + +func Test_FirstMatchStrategy_Evaluation(t *testing.T) { + tests := []struct { + kind of.Type + successVal FlagTypes + defaultVal FlagTypes + }{ + {kind: of.Boolean, successVal: true, defaultVal: false}, + {kind: of.Int, successVal: int64(123), defaultVal: int64(0)}, + {kind: of.String, successVal: "stringValue", defaultVal: ""}, + {kind: of.Float, successVal: float64(123.45), defaultVal: float64(0.0)}, + {kind: of.Object, successVal: struct{ Field int }{Field: 123}, defaultVal: struct{}{}}, + } + for _, tt := range tests { + t.Run(tt.kind.String(), func(t *testing.T) { + ctrl := gomock.NewController(t) + + t.Run("Single Provider Match", func(t *testing.T) { + mocks := createMockProviders(ctrl, 1) + configureFirstMatchProviderMock(mocks[0], tt.successVal, TestErrorNone, "mock provider") + providers := make([]*NamedProvider, 0, 5) + for i, m := range mocks { + providers = append(providers, &NamedProvider{ + Name: strconv.Itoa(i), + FeatureProvider: m, + }) + } + strategy := newFirstMatchStrategy(providers) + result := strategy(context.Background(), "test-string", tt.defaultVal, of.FlattenedContext{}) + assert.Equal(t, tt.successVal, result.Value) + assert.Contains(t, result.FlagMetadata, MetadataSuccessfulProviderName) + assert.Equal(t, providers[0].Name, result.FlagMetadata[MetadataSuccessfulProviderName]) + }) + + t.Run("Default Resolution", func(t *testing.T) { + mocks := createMockProviders(ctrl, 1) + configureFirstMatchProviderMock(mocks[0], tt.defaultVal, TestErrorNotFound, "mock provider") + providers := make([]*NamedProvider, 0, 5) + for i, m := range mocks { + providers = append(providers, &NamedProvider{ + Name: strconv.Itoa(i), + FeatureProvider: m, + }) + } + strategy := newFirstMatchStrategy(providers) + result := strategy(context.Background(), "test-string", tt.defaultVal, of.FlattenedContext{}) + assert.Equal(t, tt.defaultVal, result.Value) + assert.Equal(t, of.DefaultReason, result.Reason) + assert.Equal(t, of.NewFlagNotFoundResolutionError("not found in any provider").Error(), result.ResolutionError.Error()) + assert.Equal(t, "none", result.FlagMetadata[MetadataSuccessfulProviderName]) + assert.Equal(t, StrategyFirstMatch, result.FlagMetadata[MetadataStrategyUsed]) + }) + + t.Run("Evaluation stops after match", func(t *testing.T) { + mocks := createMockProviders(ctrl, 5) + configureFirstMatchProviderMock(mocks[0], tt.defaultVal, TestErrorNotFound, "mock provider 1") + configureFirstMatchProviderMock(mocks[1], tt.successVal, TestErrorNone, "mock provider 2") + providers := make([]*NamedProvider, 0, 5) + for i, m := range mocks { + providers = append(providers, &NamedProvider{ + Name: strconv.Itoa(i), + FeatureProvider: m, + }) + } + + strategy := newFirstMatchStrategy(providers) + result := strategy(context.Background(), "test-flag", tt.defaultVal, of.FlattenedContext{}) + assert.Equal(t, tt.successVal, result.Value) + assert.Contains(t, result.FlagMetadata, MetadataSuccessfulProviderName) + assert.Equal(t, providers[1].Name, result.FlagMetadata[MetadataSuccessfulProviderName]) + }) + + t.Run("Evaluation stops after first error that is not a FLAG_NOT_FOUND error", func(t *testing.T) { + mocks := createMockProviders(ctrl, 5) + expectedErr := of.NewGeneralResolutionError("test error") + providers := make([]*NamedProvider, 0, 5) + for i, m := range mocks { + providers = append(providers, &NamedProvider{ + Name: strconv.Itoa(i), + FeatureProvider: m, + }) + switch { + case i < 3: + configureFirstMatchProviderMock(mocks[i], tt.successVal, TestErrorNotFound, "mock provider fail") + case i == 3: + configureFirstMatchProviderMock(mocks[i], tt.successVal, TestErrorError, "mock provider") + } + + } + strategy := newFirstMatchStrategy(providers) + result := strategy(context.Background(), "test-string", tt.successVal, of.FlattenedContext{}) + assert.Equal(t, tt.successVal, result.Value) + assert.Equal(t, of.ErrorReason, result.Reason) + assert.Equal(t, expectedErr.Error(), result.ResolutionError.Error()) + assert.Equal(t, "none", result.FlagMetadata[MetadataSuccessfulProviderName]) + assert.Equal(t, StrategyFirstMatch, result.FlagMetadata[MetadataStrategyUsed]) + }) + }) + } +} + +func configureFirstMatchProviderMock[R FlagTypes](mock *of.MockFeatureProvider, value R, error int, providerName string) { + var rErr of.ResolutionError + var reason of.Reason + switch error { + case TestErrorError: + rErr = of.NewGeneralResolutionError("test error") + reason = of.ErrorReason + case TestErrorNotFound: + rErr = of.NewFlagNotFoundResolutionError("test not found") + reason = of.DefaultReason + } + + details := of.ProviderResolutionDetail{ + ResolutionError: rErr, + Reason: reason, + FlagMetadata: make(of.FlagMetadata), + } + switch v := any(value).(type) { + case bool: + mock.EXPECT(). + BooleanEvaluation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(of.BoolResolutionDetail{Value: v, ProviderResolutionDetail: details}) + case string: + mock.EXPECT(). + StringEvaluation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(of.StringResolutionDetail{Value: v, ProviderResolutionDetail: details}) + case int64: + mock.EXPECT(). + IntEvaluation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(of.IntResolutionDetail{Value: v, ProviderResolutionDetail: details}) + case float64: + mock.EXPECT(). + FloatEvaluation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(of.FloatResolutionDetail{Value: v, ProviderResolutionDetail: details}) + default: + mock.EXPECT(). + ObjectEvaluation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(of.InterfaceResolutionDetail{Value: v, ProviderResolutionDetail: details}) + } + mock.EXPECT().Metadata().Return(of.Metadata{Name: providerName}) +} diff --git a/openfeature/multi/first_success_strategy.go b/openfeature/multi/first_success_strategy.go new file mode 100644 index 00000000..d30d73bb --- /dev/null +++ b/openfeature/multi/first_success_strategy.go @@ -0,0 +1,30 @@ +package multi + +import ( + "context" + "errors" + + of "github.com/open-feature/go-sdk/openfeature" +) + +// newFirstSuccessStrategy returns a [StrategyFn] that returns the result of the First [of.FeatureProvider] whose response +// is not an error. This executed sequentially. +func newFirstSuccessStrategy(providers []*NamedProvider) StrategyFn[FlagTypes] { + return firstSuccessStrategyFn[FlagTypes](providers) +} + +func firstSuccessStrategyFn[T FlagTypes](providers []*NamedProvider) StrategyFn[T] { + return func(ctx context.Context, flag string, defaultValue T, flatCtx of.FlattenedContext) of.GenericResolutionDetail[T] { + resolutionErrors := make([]error, 0, len(providers)) + for _, provider := range providers { + resolution := Evaluate(ctx, provider, flag, defaultValue, flatCtx) + if resolution.Error() != nil { + resolutionErrors = append(resolutionErrors, resolution.Error()) + continue + } + resolution.FlagMetadata = setFlagMetadata(StrategyFirstSuccess, provider.Name, resolution.FlagMetadata) + return resolution + } + return BuildDefaultResult(StrategyFirstSuccess, defaultValue, errors.Join(resolutionErrors...)) + } +} diff --git a/openfeature/multi/first_success_strategy_test.go b/openfeature/multi/first_success_strategy_test.go new file mode 100644 index 00000000..0c178895 --- /dev/null +++ b/openfeature/multi/first_success_strategy_test.go @@ -0,0 +1,194 @@ +package multi + +import ( + "context" + "testing" + + of "github.com/open-feature/go-sdk/openfeature" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" +) + +func configureFirstSuccessProvider[R any](provider *of.MockFeatureProvider, resultVal R, state bool, error int) { + var rErr of.ResolutionError + var variant string + var reason of.Reason + switch error { + case TestErrorError: + rErr = of.NewGeneralResolutionError("test error") + reason = of.ErrorReason + case TestErrorNotFound: + rErr = of.NewFlagNotFoundResolutionError("test not found") + reason = of.DefaultReason + } + + if state { + variant = "on" + } else { + variant = "off" + } + details := of.ProviderResolutionDetail{ + ResolutionError: rErr, + Reason: reason, + Variant: variant, + FlagMetadata: make(of.FlagMetadata), + } + + provider.EXPECT().Metadata().Return(of.Metadata{Name: "mock provider"}).MaxTimes(1) + switch any(resultVal).(type) { + case bool: + provider.EXPECT().BooleanEvaluation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(c context.Context, flag string, defaultVal bool, evalCtx of.FlattenedContext) of.BoolResolutionDetail { + return of.BoolResolutionDetail{ + Value: any(resultVal).(bool), + ProviderResolutionDetail: details, + } + }).MaxTimes(1) + case string: + provider.EXPECT().StringEvaluation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(c context.Context, flag string, defaultVal string, evalCtx of.FlattenedContext) of.StringResolutionDetail { + return of.StringResolutionDetail{ + Value: any(resultVal).(string), + ProviderResolutionDetail: details, + } + }).MaxTimes(1) + case int64: + provider.EXPECT().IntEvaluation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(c context.Context, flag string, defaultVal int64, evalCtx of.FlattenedContext) of.IntResolutionDetail { + return of.IntResolutionDetail{ + Value: any(resultVal).(int64), + ProviderResolutionDetail: details, + } + }).MaxTimes(1) + case float64: + provider.EXPECT().FloatEvaluation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(c context.Context, flag string, defaultVal float64, evalCtx of.FlattenedContext) of.FloatResolutionDetail { + return of.FloatResolutionDetail{ + Value: any(resultVal).(float64), + ProviderResolutionDetail: details, + } + }).MaxTimes(1) + default: + provider.EXPECT().ObjectEvaluation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(c context.Context, flag string, defaultVal any, evalCtx of.FlattenedContext) of.InterfaceResolutionDetail { + return of.InterfaceResolutionDetail{ + Value: resultVal, + ProviderResolutionDetail: details, + } + }).MaxTimes(1) + } +} + +func Test_FirstSuccessStrategyEvaluation(t *testing.T) { + tests := []struct { + kind of.Type + successVal FlagTypes + defaultVal FlagTypes + }{ + {kind: of.Boolean, successVal: true, defaultVal: false}, + {kind: of.String, successVal: "success", defaultVal: "default"}, + {kind: of.Int, successVal: int64(150), defaultVal: int64(0)}, + {kind: of.Float, successVal: float64(15.5), defaultVal: float64(0)}, + {kind: of.Object, successVal: struct{ Field string }{Field: "test"}, defaultVal: struct{}{}}, + } + for _, tt := range tests { + t.Run(tt.kind.String(), func(t *testing.T) { + t.Run("single success", func(t *testing.T) { + ctrl := gomock.NewController(t) + provider := of.NewMockFeatureProvider(ctrl) + configureFirstSuccessProvider(provider, tt.successVal, true, TestErrorNone) + + strategy := newFirstSuccessStrategy([]*NamedProvider{ + { + Name: "test-provider", + FeatureProvider: provider, + }, + }) + result := strategy(context.Background(), testFlag, tt.defaultVal, of.FlattenedContext{}) + assert.Equal(t, tt.successVal, result.Value) + assert.Contains(t, result.FlagMetadata, MetadataStrategyUsed) + assert.Equal(t, StrategyFirstSuccess, result.FlagMetadata[MetadataStrategyUsed]) + assert.Contains(t, result.FlagMetadata, MetadataSuccessfulProviderName) + assert.Equal(t, "test-provider", result.FlagMetadata[MetadataSuccessfulProviderName]) + }) + + t.Run("first success", func(t *testing.T) { + ctrl := gomock.NewController(t) + provider1 := of.NewMockFeatureProvider(ctrl) + configureFirstSuccessProvider(provider1, tt.successVal, true, TestErrorNone) + provider2 := of.NewMockFeatureProvider(ctrl) + configureFirstSuccessProvider(provider2, tt.defaultVal, false, TestErrorError) + + strategy := newFirstSuccessStrategy([]*NamedProvider{ + { + Name: "success-provider", + FeatureProvider: provider1, + }, + { + Name: "failure-provider", + FeatureProvider: provider2, + }, + }) + + result := strategy(context.Background(), testFlag, tt.defaultVal, of.FlattenedContext{}) + assert.Equal(t, tt.successVal, result.Value) + assert.Equal(t, StrategyFirstSuccess, result.FlagMetadata[MetadataStrategyUsed]) + assert.Contains(t, result.FlagMetadata, MetadataSuccessfulProviderName) + assert.Equal(t, "success-provider", result.FlagMetadata[MetadataSuccessfulProviderName]) + }) + + t.Run("second success", func(t *testing.T) { + ctrl := gomock.NewController(t) + provider1 := of.NewMockFeatureProvider(ctrl) + configureFirstSuccessProvider(provider1, tt.successVal, true, TestErrorNone) + provider2 := of.NewMockFeatureProvider(ctrl) + configureFirstSuccessProvider(provider2, tt.defaultVal, false, TestErrorError) + + strategy := newFirstSuccessStrategy([]*NamedProvider{ + { + Name: "success-provider", + FeatureProvider: provider1, + }, + { + Name: "failure-provider", + FeatureProvider: provider2, + }, + }) + + result := strategy(context.Background(), testFlag, tt.defaultVal, of.FlattenedContext{}) + assert.Equal(t, tt.successVal, result.Value) + assert.Equal(t, StrategyFirstSuccess, result.FlagMetadata[MetadataStrategyUsed]) + assert.Contains(t, result.FlagMetadata, MetadataSuccessfulProviderName) + assert.Equal(t, "success-provider", result.FlagMetadata[MetadataSuccessfulProviderName]) + }) + + t.Run("all errors", func(t *testing.T) { + ctrl := gomock.NewController(t) + provider1 := of.NewMockFeatureProvider(ctrl) + configureFirstSuccessProvider(provider1, tt.defaultVal, false, TestErrorError) + provider2 := of.NewMockFeatureProvider(ctrl) + configureFirstSuccessProvider(provider2, tt.defaultVal, false, TestErrorNotFound) + provider3 := of.NewMockFeatureProvider(ctrl) + configureFirstSuccessProvider(provider3, tt.defaultVal, false, TestErrorError) + + strategy := newFirstSuccessStrategy([]*NamedProvider{ + { + Name: "provider1", + FeatureProvider: provider1, + }, + { + Name: "provider2", + FeatureProvider: provider2, + }, + { + Name: "provider3", + FeatureProvider: provider3, + }, + }) + + result := strategy(context.Background(), testFlag, tt.defaultVal, of.FlattenedContext{}) + assert.Equal(t, tt.defaultVal, result.Value) + assert.Equal(t, StrategyFirstSuccess, result.FlagMetadata[MetadataStrategyUsed]) + assert.Contains(t, result.FlagMetadata, MetadataSuccessfulProviderName) + assert.Equal(t, "none", result.FlagMetadata[MetadataSuccessfulProviderName]) + assert.Equal(t, of.ErrorReason, result.Reason) + assert.NotNil(t, result.Error()) + }) + }) + } +} diff --git a/openfeature/multi/isolation.go b/openfeature/multi/isolation.go new file mode 100644 index 00000000..268af842 --- /dev/null +++ b/openfeature/multi/isolation.go @@ -0,0 +1,334 @@ +package multi + +import ( + "context" + "fmt" + "sync" + + of "github.com/open-feature/go-sdk/openfeature" +) + +type ( + // hookIsolator is used as a wrapper around a provider that prevents context changes from leaking between providers + // during evaluation + hookIsolator struct { + of.UnimplementedHook + mu sync.Mutex + of.FeatureProvider + hooks []of.Hook + capturedContext of.HookContext + capturedHints of.HookHints + } + + // eventHandlingHookIsolator is equivalent to hookIsolator, but also implements [of.EventHandler] + eventHandlingHookIsolator struct { + hookIsolator + } +) + +// Compile-time interface compliance checks +var ( + _ of.FeatureProvider = (*hookIsolator)(nil) + _ of.Hook = (*hookIsolator)(nil) + _ of.EventHandler = (*eventHandlingHookIsolator)(nil) +) + +// isolateProvider wraps a [of.FeatureProvider] to execute its hooks along with any additional ones. +func isolateProvider(provider of.FeatureProvider, extraHooks []of.Hook) *hookIsolator { + return &hookIsolator{ + FeatureProvider: provider, + hooks: append(provider.Hooks(), extraHooks...), + } +} + +// isolateProviderWithEvents wraps a [of.FeatureProvider] to execute its hooks along with any additional ones. This is +// identical to [isolateProvider], but also this will also implement [of.EventHandler]. +func isolateProviderWithEvents(provider of.FeatureProvider, extraHooks []of.Hook) *eventHandlingHookIsolator { + return &eventHandlingHookIsolator{*isolateProvider(provider, extraHooks)} +} + +func (h *eventHandlingHookIsolator) EventChannel() <-chan of.Event { + return h.FeatureProvider.(of.EventHandler).EventChannel() +} + +func (h *hookIsolator) Before(_ context.Context, hookContext of.HookContext, hookHints of.HookHints) (*of.EvaluationContext, error) { + // Used for capturing the context and hints + h.mu.Lock() + defer h.mu.Unlock() + h.capturedContext = hookContext + h.capturedHints = hookHints + // Return copy of original evaluation context so any changes are isolated to each provider's hooks + evalCtx := h.capturedContext.EvaluationContext() + return &evalCtx, nil +} + +func (h *hookIsolator) Metadata() of.Metadata { + return h.FeatureProvider.Metadata() +} + +func (h *hookIsolator) BooleanEvaluation(ctx context.Context, flag string, defaultValue bool, flatCtx of.FlattenedContext) of.BoolResolutionDetail { + completeEval := h.evaluate(ctx, flag, of.Boolean, defaultValue, flatCtx) + + return of.BoolResolutionDetail{ + Value: completeEval.Value.(bool), + ProviderResolutionDetail: toProviderResolutionDetail(completeEval), + } +} + +func (h *hookIsolator) StringEvaluation(ctx context.Context, flag string, defaultValue string, flatCtx of.FlattenedContext) of.StringResolutionDetail { + completeEval := h.evaluate(ctx, flag, of.String, defaultValue, flatCtx) + + return of.StringResolutionDetail{ + Value: completeEval.Value.(string), + ProviderResolutionDetail: toProviderResolutionDetail(completeEval), + } +} + +func (h *hookIsolator) FloatEvaluation(ctx context.Context, flag string, defaultValue float64, flatCtx of.FlattenedContext) of.FloatResolutionDetail { + completeEval := h.evaluate(ctx, flag, of.Float, defaultValue, flatCtx) + + return of.FloatResolutionDetail{ + Value: completeEval.Value.(float64), + ProviderResolutionDetail: toProviderResolutionDetail(completeEval), + } +} + +func (h *hookIsolator) IntEvaluation(ctx context.Context, flag string, defaultValue int64, flatCtx of.FlattenedContext) of.IntResolutionDetail { + completeEval := h.evaluate(ctx, flag, of.Int, defaultValue, flatCtx) + + return of.IntResolutionDetail{ + Value: completeEval.Value.(int64), + ProviderResolutionDetail: toProviderResolutionDetail(completeEval), + } +} + +func (h *hookIsolator) ObjectEvaluation(ctx context.Context, flag string, defaultValue any, flatCtx of.FlattenedContext) of.InterfaceResolutionDetail { + completeEval := h.evaluate(ctx, flag, of.Object, defaultValue, flatCtx) + + return of.InterfaceResolutionDetail{ + Value: completeEval.Value, + ProviderResolutionDetail: toProviderResolutionDetail(completeEval), + } +} + +// toProviderResolutionDetail Converts a [of.InterfaceEvaluationDetails] to a [of.ProviderResolutionDetail]. +func toProviderResolutionDetail(evalDetails of.InterfaceEvaluationDetails) of.ProviderResolutionDetail { + var resolutionErr of.ResolutionError + var reason of.Reason + switch evalDetails.ErrorCode { + case of.GeneralCode: + resolutionErr = of.NewGeneralResolutionError(evalDetails.ErrorMessage) + reason = of.ErrorReason + case of.FlagNotFoundCode: + resolutionErr = of.NewFlagNotFoundResolutionError(evalDetails.ErrorMessage) + reason = of.DefaultReason + case of.TargetingKeyMissingCode: + resolutionErr = of.NewTargetingKeyMissingResolutionError(evalDetails.ErrorMessage) + reason = of.TargetingMatchReason + case of.TypeMismatchCode: + resolutionErr = of.NewTypeMismatchResolutionError(evalDetails.ErrorMessage) + reason = of.ErrorReason + case of.ParseErrorCode: + resolutionErr = of.NewParseErrorResolutionError(evalDetails.ErrorMessage) + reason = of.ErrorReason + case of.InvalidContextCode: + resolutionErr = of.NewInvalidContextResolutionError(evalDetails.ErrorMessage) + reason = of.ErrorReason + } + return of.ProviderResolutionDetail{ + ResolutionError: resolutionErr, + Reason: reason, + Variant: evalDetails.Variant, + FlagMetadata: evalDetails.FlagMetadata, + } +} + +func (h *hookIsolator) Hooks() []of.Hook { + // return self as hook to capture contexts + return []of.Hook{h} +} + +// evaluate Executes evaluation of the flag wrapped by executing hooks. +func (h *hookIsolator) evaluate(ctx context.Context, flag string, flagType of.Type, defaultValue any, flatCtx of.FlattenedContext) of.InterfaceEvaluationDetails { + evalDetails := of.InterfaceEvaluationDetails{ + Value: defaultValue, + EvaluationDetails: of.EvaluationDetails{ + FlagKey: flag, + FlagType: flagType, + }, + } + + defer func() { + h.finallyHooks(ctx, evalDetails) + }() + + evalCtx, err := h.beforeHooks(ctx) + // Update hook context unconditionally + h.updateEvalContext(evalCtx) + if err != nil { + err = fmt.Errorf("before hook: %w", err) + h.errorHooks(ctx, err) + evalDetails.ResolutionDetail = of.ResolutionDetail{ + Reason: of.ErrorReason, + ErrorCode: of.GeneralCode, + ErrorMessage: err.Error(), + FlagMetadata: nil, + } + return evalDetails + } + + // Merge together the passed in flat context and the captured evaluation context and transform back into a flattened + // context for evaluation + flatCtx = flattenContext(mergeContexts(h.capturedContext.EvaluationContext(), deepenContext(flatCtx))) + + var resolution of.InterfaceResolutionDetail + switch flagType { + case of.Object: + resolution = h.FeatureProvider.ObjectEvaluation(ctx, flag, defaultValue, flatCtx) + case of.Boolean: + defValue := defaultValue.(bool) + res := h.FeatureProvider.BooleanEvaluation(ctx, flag, defValue, flatCtx) + resolution.ProviderResolutionDetail = res.ProviderResolutionDetail + resolution.Value = res.Value + case of.String: + defValue := defaultValue.(string) + res := h.FeatureProvider.StringEvaluation(ctx, flag, defValue, flatCtx) + resolution.ProviderResolutionDetail = res.ProviderResolutionDetail + resolution.Value = res.Value + case of.Float: + defValue := defaultValue.(float64) + res := h.FeatureProvider.FloatEvaluation(ctx, flag, defValue, flatCtx) + resolution.ProviderResolutionDetail = res.ProviderResolutionDetail + resolution.Value = res.Value + case of.Int: + defValue := defaultValue.(int64) + res := h.FeatureProvider.IntEvaluation(ctx, flag, defValue, flatCtx) + resolution.ProviderResolutionDetail = res.ProviderResolutionDetail + resolution.Value = res.Value + } + + err = resolution.Error() + if err != nil { + err = fmt.Errorf("error code: %w", err) + h.errorHooks(ctx, err) + evalDetails.ResolutionDetail = resolution.ResolutionDetail() + evalDetails.Reason = of.ErrorReason + return evalDetails + } + evalDetails.Value = resolution.Value + evalDetails.ResolutionDetail = resolution.ResolutionDetail() + + if err := h.afterHooks(ctx, evalDetails); err != nil { + err = fmt.Errorf("after hook: %w", err) + h.errorHooks(ctx, err) + return evalDetails + } + + return evalDetails +} + +// beforeHooks Executes all before hooks together, merging the changes to the [of.EvaluationContext] as it goes. The +// return of this function is a merged version of the evaluation context +func (h *hookIsolator) beforeHooks(ctx context.Context) (of.EvaluationContext, error) { + contexts := []of.EvaluationContext{h.capturedContext.EvaluationContext()} + for _, hook := range h.hooks { + resultEvalCtx, err := hook.Before(ctx, h.capturedContext, h.capturedHints) + if resultEvalCtx != nil { + contexts = append(contexts, *resultEvalCtx) + } + if err != nil { + return mergeContexts(contexts...), err + } + } + + return mergeContexts(contexts...), nil +} + +// afterHooks executes all after [of.Hook] instances together. +func (h *hookIsolator) afterHooks(ctx context.Context, evalDetails of.InterfaceEvaluationDetails) error { + for _, hook := range h.hooks { + if err := hook.After(ctx, h.capturedContext, evalDetails, h.capturedHints); err != nil { + return err + } + } + + return nil +} + +// errorHooks executes all error [of.Hook] instances together. +func (h *hookIsolator) errorHooks(ctx context.Context, err error) { + for _, hook := range h.hooks { + hook.Error(ctx, h.capturedContext, err, h.capturedHints) + } +} + +// finallyHooks execute all finally [of.Hook] instances together. +func (h *hookIsolator) finallyHooks(ctx context.Context, details of.InterfaceEvaluationDetails) { + for _, hook := range h.hooks { + hook.Finally(ctx, h.capturedContext, details, h.capturedHints) + } +} + +// updateEvalContext returns a new [of.HookContext] with an updated [of.EvaluationContext] value. [of.HookContext] is +// immutable, and this returns a new [of.HookContext] with all other values copied +func (h *hookIsolator) updateEvalContext(evalCtx of.EvaluationContext) { + hookCtx := of.NewHookContext( + h.capturedContext.FlagKey(), + h.capturedContext.FlagType(), + h.capturedContext.DefaultValue(), + h.capturedContext.ClientMetadata(), + h.capturedContext.ProviderMetadata(), + evalCtx, + ) + h.mu.Lock() + defer h.mu.Unlock() + h.capturedContext = hookCtx +} + +// deepenContext converts a [of.FlattenedContext] to a [of.EvaluationContext]. +func deepenContext(flatCtx of.FlattenedContext) of.EvaluationContext { + noTargetingKey := make(map[string]any) + for k, v := range flatCtx { + if k != of.TargetingKey { + noTargetingKey[k] = v + } + } + var targetingKey string + if tk, ok := flatCtx[of.TargetingKey]; ok { + targetingKey, _ = tk.(string) + } + return of.NewEvaluationContext(targetingKey, noTargetingKey) +} + +// flattenContext converts a [of.EvaluationContext] to a [of.FlattenedContext] +func flattenContext(evalCtx of.EvaluationContext) of.FlattenedContext { + flatCtx := evalCtx.Attributes() + flatCtx[of.TargetingKey] = evalCtx.TargetingKey() + return flatCtx +} + +// mergeContexts merges attributes from the given EvaluationContexts with the nth [of.EvaluationContext] taking precedence +// in case of any conflicts with the (n+1)th [of.EvaluationContext]. +func mergeContexts(evaluationContexts ...of.EvaluationContext) of.EvaluationContext { + if len(evaluationContexts) == 0 { + return of.EvaluationContext{} + } + // create initial values + attributes := evaluationContexts[0].Attributes() + targetingKey := evaluationContexts[0].TargetingKey() + + for i := 1; i < len(evaluationContexts); i++ { + if targetingKey == "" && evaluationContexts[i].TargetingKey() != "" { + targetingKey = evaluationContexts[i].TargetingKey() + } + + for k, v := range evaluationContexts[i].Attributes() { + _, ok := attributes[k] + if !ok { + attributes[k] = v + } + } + } + + return of.NewEvaluationContext(targetingKey, attributes) +} diff --git a/openfeature/multi/isolation_test.go b/openfeature/multi/isolation_test.go new file mode 100644 index 00000000..b940e972 --- /dev/null +++ b/openfeature/multi/isolation_test.go @@ -0,0 +1,101 @@ +package multi + +import ( + "context" + "errors" + "testing" + + of "github.com/open-feature/go-sdk/openfeature" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +func Test_HookIsolator_BeforeCapturesData(t *testing.T) { + hookCtx := of.NewHookContext( + "test-key", + of.Boolean, + false, + of.ClientMetadata{}, + of.Metadata{}, + of.NewEvaluationContext("target", map[string]any{}), + ) + hookHints := of.NewHookHints(map[string]any{"foo": "bar"}) + ctrl := gomock.NewController(t) + provider := of.NewMockFeatureProvider(ctrl) + provider.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) + isolator := isolateProvider(provider, []of.Hook{}) + assert.Zero(t, isolator.capturedContext) + assert.Zero(t, isolator.capturedHints) + evalCtx, err := isolator.Before(context.Background(), hookCtx, hookHints) + require.NoError(t, err) + assert.NotNil(t, evalCtx) + assert.Equal(t, hookCtx, isolator.capturedContext) + assert.Equal(t, hookHints, isolator.capturedHints) +} + +func Test_HookIsolator_Hooks_ReturnsSelf(t *testing.T) { + ctrl := gomock.NewController(t) + provider := of.NewMockFeatureProvider(ctrl) + provider.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) + isolator := isolateProvider(provider, []of.Hook{}) + hooks := isolator.Hooks() + assert.NotEmpty(t, hooks) + assert.Same(t, isolator, hooks[0]) +} + +func Test_HookIsolator_ExecutesHooksDuringEvaluation_NoError(t *testing.T) { + ctrl := gomock.NewController(t) + testHook := of.NewMockHook(ctrl) + testHook.EXPECT().Before(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) + testHook.EXPECT().After(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + testHook.EXPECT().Finally(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + testHook.EXPECT().Error(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0) + + provider := of.NewMockFeatureProvider(ctrl) + provider.EXPECT().Hooks().Return([]of.Hook{testHook}) + provider.EXPECT().BooleanEvaluation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(of.BoolResolutionDetail{ + Value: true, + ProviderResolutionDetail: of.ProviderResolutionDetail{}, + }) + + isolator := isolateProvider(provider, nil) + result := isolator.BooleanEvaluation(context.Background(), "test-flag", false, of.FlattenedContext{"targetingKey": "anon"}) + assert.True(t, result.Value) +} + +func Test_HookIsolator_ExecutesHooksDuringEvaluation_BeforeErrorAbortsExecution(t *testing.T) { + ctrl := gomock.NewController(t) + testHook := of.NewMockHook(ctrl) + testHook.EXPECT().Before(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("test error")) + testHook.EXPECT().After(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0) + testHook.EXPECT().Finally(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + testHook.EXPECT().Error(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + + provider := of.NewMockFeatureProvider(ctrl) + provider.EXPECT().Hooks().Return([]of.Hook{testHook}) + + isolator := isolateProvider(provider, nil) + result := isolator.BooleanEvaluation(context.Background(), "test-flag", false, of.FlattenedContext{"targetingKey": "anon"}) + assert.False(t, result.Value) +} + +func Test_HookIsolator_ExecutesHooksDuringEvaluation_WithAfterError(t *testing.T) { + ctrl := gomock.NewController(t) + testHook := of.NewMockHook(ctrl) + testHook.EXPECT().Before(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) + testHook.EXPECT().After(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("test error")) + testHook.EXPECT().Finally(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + testHook.EXPECT().Error(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + + provider := of.NewMockFeatureProvider(ctrl) + provider.EXPECT().Hooks().Return([]of.Hook{testHook}) + provider.EXPECT().BooleanEvaluation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(of.BoolResolutionDetail{ + Value: false, + ProviderResolutionDetail: of.ProviderResolutionDetail{}, + }) + + isolator := isolateProvider(provider, nil) + result := isolator.BooleanEvaluation(context.Background(), "test-flag", false, of.FlattenedContext{"targetingKey": "anon"}) + assert.False(t, result.Value) +} diff --git a/openfeature/multi/multiprovider.go b/openfeature/multi/multiprovider.go new file mode 100644 index 00000000..5a6cd926 --- /dev/null +++ b/openfeature/multi/multiprovider.go @@ -0,0 +1,567 @@ +// Package multi is an experimental implementation of a [of.FeatureProvider] that supports evaluating multiple feature flag +// providers together. +package multi + +import ( + "context" + "errors" + "fmt" + "log/slog" + "slices" + "strings" + "sync" + + of "github.com/open-feature/go-sdk/openfeature" + "golang.org/x/sync/errgroup" +) + +// Metadata Keys +const ( + MetadataProviderName = "multiprovider-provider-name" + MetadataProviderType = "multiprovider-provider-type" + MetadataSuccessfulProviderName = "multiprovider-successful-provider-name" + MetadataSuccessfulProviderNames = MetadataSuccessfulProviderName + "s" + MetadataStrategyUsed = "multiprovider-strategy-used" + MetadataFallbackUsed = "multiprovider-fallback-used" + MetadataIsDefaultValue = "multiprovider-is-result-default-value" + MetadataEvaluationError = "multiprovider-evaluation-error" + MetadataComparisonDisagreeingProviders = "multiprovider-comparison-disagreeing-providers" +) + +type ( + // ProviderMap is an alias for a map containing unique names for each included [of.FeatureProvider] + ProviderMap = map[string]of.FeatureProvider + + // Provider is an implementation of [of.FeatureProvider] that can execute multiple providers using various + // strategies. + Provider struct { + providers ProviderMap + metadata of.Metadata + initialized bool + overallStatus of.State + overallStatusLock sync.RWMutex + providerStatus map[string]of.State + providerStatusLock sync.Mutex + strategyName EvaluationStrategy // the name of the strategy used for evaluation + strategyFunc StrategyFn[FlagTypes] // used for evaluating strategies + logger *slog.Logger + outboundEvents chan of.Event + inboundEvents chan namedEvent + workerGroup sync.WaitGroup + shutdownFunc context.CancelFunc + globalHooks []of.Hook + } + + // NamedProvider allows for a unique name to be assigned to a provider during a multi-provider set up. + // The name will be used when reporting errors & results to specify the provider associated with them. + NamedProvider struct { + Name string + of.FeatureProvider + } + + // Option function used for setting configuration via the options pattern + Option func(*configuration) + + // Private Types + namedEvent struct { + of.Event + providerName string + } + + // configuration is the internal configuration of a [multi.Provider] + configuration struct { + useFallback bool + fallbackProvider of.FeatureProvider + customStrategy StrategyConstructor + logger *slog.Logger + hooks []of.Hook + providerHooks map[string][]of.Hook + customComparator Comparator + } +) + +var ( + stateValues map[of.State]int + stateTable [3]of.State + eventTypeToState map[of.EventType]of.State + + // Compile-time interface compliance checks + _ of.FeatureProvider = (*Provider)(nil) + _ of.EventHandler = (*Provider)(nil) + _ of.StateHandler = (*Provider)(nil) +) + +// init Initialize "constants" used for event handling priorities and filtering. +func init() { + // used for mapping provider event types & provider states to comparable values for evaluation + stateValues = map[of.State]int{ + "": -1, // Not a real state, but used for handling provider config changes + of.ReadyState: 0, + of.StaleState: 1, + of.ErrorState: 2, + } + // used for mapping + stateTable = [3]of.State{ + of.ReadyState, // 0 + of.StaleState, // 1 + of.ErrorState, // 2 + } + eventTypeToState = map[of.EventType]of.State{ + of.ProviderConfigChange: "", + of.ProviderReady: of.ReadyState, + of.ProviderStale: of.StaleState, + of.ProviderError: of.ErrorState, + } +} + +// Configuration Options + +// WithLogger sets a logger to be used with slog for internal logging. By default, all logs are discarded unless this is set. +func WithLogger(l *slog.Logger) Option { + return func(conf *configuration) { + conf.logger = l + } +} + +// WithFallbackProvider sets a fallback provider when using the [StrategyComparison] setting. The fallback provider is +// called when providers are not in agreement. If a fallback provider is not set and providers are not agreement, then +// the default result will be returned along with an error value. +func WithFallbackProvider(p of.FeatureProvider) Option { + return func(conf *configuration) { + conf.fallbackProvider = p + conf.useFallback = true + } +} + +// WithCustomComparator sets a custom [Comparator] to use when using [StrategyComparison] when [of.FeatureProvider.ObjectEvaluation] +// is performed. This is required if the returned objects are not comparable, otherwise an error occur.. +func WithCustomComparator(comparator Comparator) Option { + return func(conf *configuration) { + conf.customComparator = comparator + } +} + +// WithCustomStrategy sets a custom strategy function by defining a "constructor" that acts as closure over a slice of +// [NamedProvider] instances with your returned custom strategy function. This must be used in conjunction with [StrategyCustom] +func WithCustomStrategy(s StrategyConstructor) Option { + return func(conf *configuration) { + conf.customStrategy = s + } +} + +// WithGlobalHooks sets the global hooks for the provider. These are [of.Hook] instances that affect ALL [of.FeatureProvider] +// instances. For hooks that target specific providers make sure to attach them to that provider directly, or use the +// [WithProviderHooks] [Option] if that provider does not provide its own hook functionality. +func WithGlobalHooks(hooks ...of.Hook) Option { + return func(conf *configuration) { + conf.hooks = hooks + } +} + +// WithProviderHooks sets [of.Hook] instances that execute only for a specific [of.FeatureProvider]. The providerName +// must match the unique provider name set during [multi.Provider] creation. This should only be used if you need hooks +// that execute around a specific provider, but that provider does not currently accept a way to set hooks. This [Option] +// can be used multiple times using unique provider names. Using a provider name that is not known will cause an error. +func WithProviderHooks(providerName string, hooks ...of.Hook) Option { + return func(conf *configuration) { + conf.providerHooks[providerName] = hooks + } +} + +// Multiprovider Implementation + +// toNamedProviderSlice converts the provided [ProviderMap] into a slice of [NamedProvider] instances +func toNamedProviderSlice(m ProviderMap) []*NamedProvider { + s := make([]*NamedProvider, 0, len(m)) + for name, provider := range m { + s = append(s, &NamedProvider{Name: name, FeatureProvider: provider}) + } + + return s +} + +func buildMetadata(m ProviderMap) of.Metadata { + var separator string + var metaName strings.Builder + metaName.WriteString("MultiProvider {") + names := make([]string, 0, len(m)) + for n := range m { + names = append(names, n) + } + slices.Sort(names) + for _, name := range names { + metaName.WriteString(fmt.Sprintf("%s%s: %s", separator, name, m[name].Metadata().Name)) + if separator == "" { + separator = ", " + } + } + + metaName.WriteRune('}') + return of.Metadata{ + Name: metaName.String(), + } +} + +// NewProvider returns a new [multi.Provider] that acts as a unified interface of multiple providers for interaction. +func NewProvider(providerMap ProviderMap, evaluationStrategy EvaluationStrategy, options ...Option) (*Provider, error) { + if len(providerMap) == 0 { + return nil, errors.New("providerMap cannot be nil or empty") + } + + config := &configuration{ + logger: slog.New(slog.DiscardHandler), + providerHooks: make(map[string][]of.Hook), + } + + for _, opt := range options { + opt(config) + } + + providers := providerMap + collectedHooks := make([]of.Hook, 0, len(providerMap)) + for name, provider := range providerMap { + // Validate Providers + if name == "" { + return nil, errors.New("provider name cannot be the empty string") + } + + if provider == nil { + return nil, fmt.Errorf("provider %s cannot be nil", name) + } + + // Wrap any providers that include hooks + if (len(provider.Hooks()) + len(config.providerHooks[name])) == 0 { + continue + } + + var wrappedProvider of.FeatureProvider + if _, ok := provider.(of.EventHandler); ok { + wrappedProvider = isolateProviderWithEvents(provider, config.providerHooks[name]) + } else { + wrappedProvider = isolateProvider(provider, config.providerHooks[name]) + } + + providers[name] = wrappedProvider + collectedHooks = append(collectedHooks, wrappedProvider.Hooks()...) + } + + multiProvider := &Provider{ + providers: providers, + outboundEvents: make(chan of.Event), + logger: config.logger, + metadata: buildMetadata(providerMap), + overallStatus: of.NotReadyState, + providerStatus: make(map[string]of.State, len(providers)), + globalHooks: append(config.hooks, collectedHooks...), + } + + var strategy StrategyFn[FlagTypes] + switch evaluationStrategy { + case StrategyFirstMatch: + strategy = newFirstMatchStrategy(multiProvider.Providers()) + case StrategyFirstSuccess: + strategy = newFirstSuccessStrategy(multiProvider.Providers()) + case StrategyComparison: + strategy = newComparisonStrategy(multiProvider.Providers(), config.fallbackProvider, config.customComparator) + default: + if config.customStrategy == nil { + return nil, fmt.Errorf("%s is an unknown evaluation strategy", evaluationStrategy) + } + strategy = config.customStrategy(multiProvider.Providers()) + } + multiProvider.strategyFunc = strategy + multiProvider.strategyName = evaluationStrategy + + return multiProvider, nil +} + +// Providers returns slice of providers wrapped in [NamedProvider] structs. +func (p *Provider) Providers() []*NamedProvider { + return toNamedProviderSlice(p.providers) +} + +// ProvidersByName Returns the internal [ProviderMap]. +func (p *Provider) ProvidersByName() ProviderMap { + return p.providers +} + +// EvaluationStrategy The name of the currently set [EvaluationStrategy]. +func (p *Provider) EvaluationStrategy() string { + return p.strategyName +} + +// Metadata provides the name "multiprovider" along with the names and types of each internal [of.FeatureProvider]. +func (p *Provider) Metadata() of.Metadata { + return p.metadata +} + +// Hooks returns a collection [of.Hook] instances configured to the provider using [WithGlobalHooks]. +func (p *Provider) Hooks() []of.Hook { + return p.globalHooks +} + +// BooleanEvaluation evaluates the flag and returns a [of.BoolResolutionDetail]. +func (p *Provider) BooleanEvaluation(ctx context.Context, flag string, defaultValue bool, flatCtx of.FlattenedContext) of.BoolResolutionDetail { + res := p.strategyFunc(ctx, flag, defaultValue, flatCtx) + return of.BoolResolutionDetail{ + Value: res.Value.(bool), + ProviderResolutionDetail: res.ProviderResolutionDetail, + } +} + +// StringEvaluation evaluates the flag and returns a [of.StringResolutionDetail]. +func (p *Provider) StringEvaluation(ctx context.Context, flag string, defaultValue string, flatCtx of.FlattenedContext) of.StringResolutionDetail { + res := p.strategyFunc(ctx, flag, defaultValue, flatCtx) + return of.StringResolutionDetail{ + Value: res.Value.(string), + ProviderResolutionDetail: res.ProviderResolutionDetail, + } +} + +// FloatEvaluation evaluates the flag and returns a [of.FloatResolutionDetail]. +func (p *Provider) FloatEvaluation(ctx context.Context, flag string, defaultValue float64, flatCtx of.FlattenedContext) of.FloatResolutionDetail { + res := p.strategyFunc(ctx, flag, defaultValue, flatCtx) + return of.FloatResolutionDetail{ + Value: res.Value.(float64), + ProviderResolutionDetail: res.ProviderResolutionDetail, + } +} + +// IntEvaluation evaluates the flag and returns an [of.IntResolutionDetail]. +func (p *Provider) IntEvaluation(ctx context.Context, flag string, defaultValue int64, flatCtx of.FlattenedContext) of.IntResolutionDetail { + res := p.strategyFunc(ctx, flag, defaultValue, flatCtx) + return of.IntResolutionDetail{ + Value: res.Value.(int64), + ProviderResolutionDetail: res.ProviderResolutionDetail, + } +} + +// ObjectEvaluation evaluates the flag and returns an [of.InterfaceResolutionDetail]. For the purposes of evaluation +// within strategies, the type of the default value is used as the assumed type of the returned responses from each provider. +// This is especially important when using the [StrategyComparison] configuration as an internal error will occur if this +// is not a comparable type unless the [WithCustomComparator] [Option] is configured. +func (p *Provider) ObjectEvaluation(ctx context.Context, flag string, defaultValue any, flatCtx of.FlattenedContext) of.InterfaceResolutionDetail { + res := p.strategyFunc(ctx, flag, defaultValue, flatCtx) + return of.InterfaceResolutionDetail{ + Value: res.Value, + ProviderResolutionDetail: res.ProviderResolutionDetail, + } +} + +// Init will run the initialize method for all internal [of.FeatureProvider] instances and aggregate any errors. +func (p *Provider) Init(evalCtx of.EvaluationContext) error { + var eg errgroup.Group + // wrapper type used only for initialization of event listener workers + type namedEventHandler struct { + of.EventHandler + name string + } + p.logger.LogAttrs(context.Background(), slog.LevelDebug, "start initialization") + p.inboundEvents = make(chan namedEvent, len(p.providers)) + handlers := make(chan namedEventHandler, len(p.providers)) + for name, provider := range p.providers { + // Initialize each provider to not ready state. No locks required there are no workers running + p.updateProviderState(name, of.NotReadyState) + l := p.logger.With(slog.String(MetadataProviderName, name)) + prov := provider + eg.Go(func() error { + l.LogAttrs(context.Background(), slog.LevelDebug, "starting initialization") + stateHandle, ok := prov.(of.StateHandler) + if !ok { + l.LogAttrs(context.Background(), slog.LevelDebug, "StateHandle not implemented, skipping initialization") + } else if err := stateHandle.Init(evalCtx); err != nil { + l.LogAttrs(context.Background(), slog.LevelError, "initialization failed", slog.Any("error", err)) + return &ProviderError{ + Err: err, + ProviderName: name, + } + } + l.LogAttrs(context.Background(), slog.LevelDebug, "initialization successful") + if eventer, ok := provider.(of.EventHandler); ok { + l.LogAttrs(context.Background(), slog.LevelDebug, "detected EventHandler implementation") + handlers <- namedEventHandler{eventer, name} + } else { + // Do not yet update providers that need event handling + p.updateProviderState(name, of.ReadyState) + } + return nil + }) + } + + if err := eg.Wait(); err != nil { + var pErr *ProviderError + if errors.As(err, &pErr) { + // Update provider status to error, no event needs to be emitted yet + p.updateProviderState(pErr.ProviderName, of.ErrorState) + } else { + pErr = &ProviderError{ + Err: err, + ProviderName: "unknown", + } + p.setStatus(of.ErrorState) + } + + return err + } + close(handlers) + workerCtx, shutdownFunc := context.WithCancel(context.Background()) + for h := range handlers { + go p.startListening(workerCtx, h.name, h.EventHandler, &p.workerGroup) + } + p.shutdownFunc = shutdownFunc + + p.workerGroup.Add(1) + go func() { + workerLogger := p.logger.With(slog.String("multiprovider-worker", "event-forwarder-worker")) + defer p.workerGroup.Done() + for e := range p.inboundEvents { + l := workerLogger.With( + slog.String(MetadataProviderName, e.providerName), + slog.String(MetadataProviderType, e.ProviderName), + ) + l.LogAttrs(context.Background(), slog.LevelDebug, "received event from provider", slog.String("event-type", string(e.EventType))) + if p.updateProviderStateFromEvent(e) { + p.outboundEvents <- e.Event + l.LogAttrs(context.Background(), slog.LevelDebug, "forwarded state update event") + } else { + l.LogAttrs(context.Background(), slog.LevelDebug, "total state not updated, inbound event will not be emitted") + } + } + }() + + p.setStatus(of.ReadyState) + p.initialized = true + return nil +} + +// startListening is intended to be called on a per-provider basis as a goroutine to listen to events from a provider +// implementing [of.EventHandler]. +func (p *Provider) startListening(ctx context.Context, name string, h of.EventHandler, wg *sync.WaitGroup) { + wg.Add(1) + defer wg.Done() + for { + select { + case e := <-h.EventChannel(): + e.EventMetadata[MetadataProviderName] = name + e.EventMetadata[MetadataProviderType] = h.(of.FeatureProvider).Metadata().Name + p.inboundEvents <- namedEvent{ + Event: e, + providerName: name, + } + case <-ctx.Done(): + return + } + } +} + +// updateProviderState Updates the state of an internal provider and then re-evaluates the overall state of the +// multiprovider. If this method returns true the overall state changed. +func (p *Provider) updateProviderState(name string, state of.State) bool { + p.providerStatusLock.Lock() + defer p.providerStatusLock.Unlock() + p.providerStatus[name] = state + evalState := p.evaluateState() + if evalState != p.Status() { + p.setStatus(evalState) + return true + } + + return false +} + +// updateProviderStateFromEvent updates the state of an internal provider from an event emitted from it, and then +// re-evaluates the overall state of the multiprovider. If this method returns true the overall state changed. +func (p *Provider) updateProviderStateFromEvent(e namedEvent) bool { + if e.EventType == of.ProviderConfigChange { + p.logger.LogAttrs(context.Background(), slog.LevelDebug, "ProviderConfigChange event", slog.String("event-message", e.Message)) + } + logProviderState(p.logger, e, p.providerStatus[e.providerName]) + return p.updateProviderState(e.ProviderName, eventTypeToState[e.EventType]) +} + +// evaluateState Determines the overall state of the provider using the weights specified in Appendix A of the +// OpenFeature spec. This method should only be called if the provider state mutex is locked +func (p *Provider) evaluateState() of.State { + maxState := stateValues[of.ReadyState] // initialize to the lowest state value + for _, s := range p.providerStatus { + if stateValues[s] > maxState { + // change in state due to higher priority + maxState = stateValues[s] + } + } + return stateTable[maxState] +} + +func logProviderState(l *slog.Logger, e namedEvent, previousState of.State) { + switch eventTypeToState[e.EventType] { + case of.ReadyState: + if previousState != of.NotReadyState { + l.LogAttrs(context.Background(), slog.LevelInfo, "provider has returned to ready state", + slog.String(MetadataProviderName, e.providerName), slog.String("previous-state", string(previousState))) + return + } + l.LogAttrs(context.Background(), slog.LevelDebug, "provider is ready", slog.String(MetadataProviderName, e.providerName)) + case of.StaleState: + l.LogAttrs(context.Background(), slog.LevelWarn, "provider is stale", + slog.String(MetadataProviderName, e.providerName), slog.String("event-message", e.Message)) + case of.ErrorState: + l.LogAttrs(context.Background(), slog.LevelError, "provider is in an error state", + slog.String(MetadataProviderName, e.providerName), slog.String("event-message", e.Message)) + } +} + +// Shutdown Shuts down all internal [of.FeatureProvider] instances and internal event listeners +func (p *Provider) Shutdown() { + if !p.initialized { + // Don't do anything if we were never initialized + return + } + // Stop all event listener workers, shutdown events should not affect overall state + p.shutdownFunc() + // Stop forwarding worker + close(p.inboundEvents) + p.logger.LogAttrs(context.Background(), slog.LevelDebug, "triggered worker shutdown") + // Wait for workers to stop + p.workerGroup.Wait() + p.logger.LogAttrs(context.Background(), slog.LevelDebug, "worker shutdown completed") + p.logger.LogAttrs(context.Background(), slog.LevelDebug, "starting provider shutdown") + var wg sync.WaitGroup + for _, provider := range p.providers { + wg.Add(1) + + go func(p of.FeatureProvider) { + defer wg.Done() + if stateHandle, ok := p.(of.StateHandler); ok { + stateHandle.Shutdown() + } + }(provider) + } + + p.logger.LogAttrs(context.Background(), slog.LevelDebug, "waiting for provider shutdown completion") + wg.Wait() + p.logger.LogAttrs(context.Background(), slog.LevelDebug, "provider shutdown completed") + p.setStatus(of.NotReadyState) + close(p.outboundEvents) + p.outboundEvents = nil + p.inboundEvents = nil + p.initialized = false +} + +// Status provides the current state of the [multi.Provider]. +func (p *Provider) Status() of.State { + p.overallStatusLock.RLock() + defer p.overallStatusLock.RUnlock() + return p.overallStatus +} + +func (p *Provider) setStatus(state of.State) { + p.overallStatusLock.Lock() + defer p.overallStatusLock.Unlock() + p.overallStatus = state + p.logger.LogAttrs(context.Background(), slog.LevelDebug, "state updated", slog.String("state", string(state))) +} + +// EventChannel is the channel that all events are emitted on. +func (p *Provider) EventChannel() <-chan of.Event { + return p.outboundEvents +} diff --git a/openfeature/multi/multiprovider_test.go b/openfeature/multi/multiprovider_test.go new file mode 100644 index 00000000..1f129cd5 --- /dev/null +++ b/openfeature/multi/multiprovider_test.go @@ -0,0 +1,346 @@ +package multi + +import ( + "context" + "errors" + "regexp" + "testing" + + of "github.com/open-feature/go-sdk/openfeature" + imp "github.com/open-feature/go-sdk/openfeature/memprovider" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +func TestMultiProvider_ProvidersMethod(t *testing.T) { + testProvider1 := imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) + testProvider2 := imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) + + providers := make(ProviderMap) + providers["provider1"] = testProvider1 + providers["provider2"] = testProvider2 + + mp, err := NewProvider(providers, StrategyFirstSuccess) + require.NoError(t, err) + + p := mp.Providers() + assert.Len(t, p, 2) + assert.Regexp(t, regexp.MustCompile("provider[1-2]"), p[0].Name) + assert.NotNil(t, p[0].FeatureProvider) + assert.Regexp(t, regexp.MustCompile("provider[1-2]"), p[1].Name) + assert.NotNil(t, p[1].FeatureProvider) +} + +func TestMultiProvider_NewMultiProvider(t *testing.T) { + t.Run("nil providerMap returns an error", func(t *testing.T) { + _, err := NewProvider(nil, StrategyFirstMatch) + require.Errorf(t, err, "providerMap cannot be nil or empty") + }) + + t.Run("naming a provider the empty string returns an error", func(t *testing.T) { + providers := make(ProviderMap) + providers[""] = imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) + _, err := NewProvider(providers, StrategyFirstMatch) + require.Errorf(t, err, "provider name cannot be the empty string") + }) + + t.Run("nil provider within map returns an error", func(t *testing.T) { + providers := make(ProviderMap) + providers["provider1"] = nil + _, err := NewProvider(providers, StrategyFirstMatch) + require.Errorf(t, err, "provider provider1 cannot be nil") + }) + + t.Run("unknown evaluation strategyFunc returns an error", func(t *testing.T) { + providers := make(ProviderMap) + providers["provider1"] = imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) + _, err := NewProvider(providers, "unknown") + require.Errorf(t, err, "unknown is an unknown evaluation strategyFunc") + }) + + t.Run("setting custom strategyFunc without custom strategyFunc option returns error", func(t *testing.T) { + providers := make(ProviderMap) + providers["provider1"] = imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) + _, err := NewProvider(providers, StrategyCustom) + require.Errorf(t, err, "A custom strategyFunc must be set via an option if StrategyCustom is set") + }) + + t.Run("success", func(t *testing.T) { + providers := make(ProviderMap) + providers["provider1"] = imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) + mp, err := NewProvider(providers, StrategyComparison) + require.NoError(t, err) + assert.NotZero(t, mp) + }) + + t.Run("success with custom provider", func(t *testing.T) { + providers := make(ProviderMap) + providers["provider1"] = imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) + mp, err := NewProvider(providers, StrategyCustom, WithCustomStrategy(func(providers []*NamedProvider) StrategyFn[FlagTypes] { + return func(ctx context.Context, flag string, defaultValue FlagTypes, evalCtx of.FlattenedContext) of.GenericResolutionDetail[FlagTypes] { + return of.GenericResolutionDetail[FlagTypes]{ + Value: defaultValue, + ProviderResolutionDetail: of.ProviderResolutionDetail{Reason: of.UnknownReason}, + } + } + })) + require.NoError(t, err) + assert.NotZero(t, mp) + }) +} + +func TestMultiProvider_ProvidersByNamesMethod(t *testing.T) { + testProvider1 := imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) + testProvider2 := imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) + + providers := make(ProviderMap) + providers["provider1"] = testProvider1 + providers["provider2"] = testProvider2 + + mp, err := NewProvider(providers, StrategyFirstMatch) + require.NoError(t, err) + + p := mp.ProvidersByName() + + assert.Len(t, p, 2) + require.Contains(t, p, "provider1") + assert.Equal(t, p["provider1"], testProvider1) + require.Contains(t, p, "provider2") + assert.Equal(t, p["provider2"], testProvider2) +} + +func TestMultiProvider_MetaData(t *testing.T) { + t.Run("two providers", func(t *testing.T) { + testProvider1 := imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) + ctrl := gomock.NewController(t) + testProvider2 := of.NewMockFeatureProvider(ctrl) + testProvider2.EXPECT().Metadata().Return(of.Metadata{ + Name: "MockProvider", + }) + testProvider2.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) + + providers := make(ProviderMap) + providers["provider1"] = testProvider1 + providers["provider2"] = testProvider2 + + mp, err := NewProvider(providers, StrategyFirstSuccess) + require.NoError(t, err) + + metadata := mp.Metadata() + require.NotZero(t, metadata) + assert.Equal(t, "MultiProvider {provider1: InMemoryProvider, provider2: MockProvider}", metadata.Name) + }) + + t.Run("three providers", func(t *testing.T) { + testProvider1 := imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) + ctrl := gomock.NewController(t) + testProvider2 := of.NewMockFeatureProvider(ctrl) + testProvider2.EXPECT().Metadata().Return(of.Metadata{ + Name: "MockProvider", + }) + testProvider2.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) + testProvider3 := of.NewMockFeatureProvider(ctrl) + testProvider3.EXPECT().Metadata().Return(of.Metadata{ + Name: "MockProvider", + }) + testProvider3.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) + + providers := make(ProviderMap) + providers["provider1"] = testProvider1 + providers["provider2"] = testProvider2 + providers["provider3"] = testProvider3 + + mp, err := NewProvider(providers, StrategyFirstSuccess) + require.NoError(t, err) + + metadata := mp.Metadata() + require.NotZero(t, metadata) + assert.Equal(t, "MultiProvider {provider1: InMemoryProvider, provider2: MockProvider, provider3: MockProvider}", metadata.Name) + }) +} + +func TestMultiProvider_Init(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode") + } + ctrl := gomock.NewController(t) + + testProvider1 := of.NewMockFeatureProvider(ctrl) + testProvider1.EXPECT().Metadata().Return(of.Metadata{Name: "MockProvider"}) + testProvider1.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) + initProvider := of.NewMockFeatureProvider(ctrl) + initProvider.EXPECT().Metadata().Return(of.Metadata{Name: "MockProvider"}) + initProvider.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) + initHandler := of.NewMockStateHandler(ctrl) + initHandler.EXPECT().Init(gomock.Any()).Return(nil) + initHandler.EXPECT().Shutdown().MaxTimes(1) + testProvider2 := struct { + of.FeatureProvider + of.StateHandler + }{ + initProvider, + initHandler, + } + testProvider3 := of.NewMockFeatureProvider(ctrl) + testProvider3.EXPECT().Metadata().Return(of.Metadata{Name: "MockProvider"}) + testProvider3.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) + + providers := make(ProviderMap) + providers["provider1"] = testProvider1 + providers["provider2"] = testProvider2 + providers["provider3"] = testProvider3 + + mp, err := NewProvider(providers, StrategyFirstMatch) + require.NoError(t, err) + + t.Cleanup(func() { + mp.Shutdown() + }) + + attributes := map[string]any{ + "foo": "bar", + } + evalCtx := of.NewTargetlessEvaluationContext(attributes) + err = mp.Init(evalCtx) + require.NoError(t, err) + assert.Equal(t, of.ReadyState, mp.Status()) +} + +func TestMultiProvider_InitErrorWithProvider(t *testing.T) { + ctrl := gomock.NewController(t) + errProvider := of.NewMockFeatureProvider(ctrl) + errProvider.EXPECT().Metadata().Return(of.Metadata{Name: "MockProvider"}) + errProvider.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) + errHandler := of.NewMockStateHandler(ctrl) + errHandler.EXPECT().Init(gomock.Any()).Return(errors.New("test error")) + testProvider3 := struct { + of.FeatureProvider + of.StateHandler + }{ + errProvider, + errHandler, + } + + testProvider1 := of.NewMockFeatureProvider(ctrl) + testProvider1.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) + testProvider1.EXPECT().Metadata().Return(of.Metadata{Name: "MockProvider"}) + testProvider2 := imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) + + providers := make(ProviderMap) + providers["provider1"] = testProvider1 + providers["provider2"] = testProvider2 + providers["provider3"] = testProvider3 + + mp, err := NewProvider(providers, StrategyFirstMatch) + require.NoError(t, err) + + attributes := map[string]any{ + "foo": "bar", + } + evalCtx := of.NewTargetlessEvaluationContext(attributes) + err = mp.Init(evalCtx) + require.Errorf(t, err, "Provider provider3: test error") + assert.Equal(t, of.ErrorState, mp.overallStatus) +} + +func TestMultiProvider_Shutdown_WithoutInit(t *testing.T) { + ctrl := gomock.NewController(t) + + testProvider1 := of.NewMockFeatureProvider(ctrl) + testProvider1.EXPECT().Metadata().Return(of.Metadata{Name: "MockProvider"}) + testProvider1.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) + testProvider2 := imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) + testProvider3 := of.NewMockFeatureProvider(ctrl) + testProvider3.EXPECT().Metadata().Return(of.Metadata{Name: "MockProvider"}) + testProvider3.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) + + providers := make(ProviderMap) + providers["provider1"] = testProvider1 + providers["provider2"] = testProvider2 + providers["provider3"] = testProvider3 + mp, err := NewProvider(providers, StrategyFirstMatch) + require.NoError(t, err) + + mp.Shutdown() +} + +func TestMultiProvider_Shutdown_WithInit(t *testing.T) { + ctrl := gomock.NewController(t) + + testProvider1 := of.NewMockFeatureProvider(ctrl) + testProvider1.EXPECT().Metadata().Return(of.Metadata{Name: "MockProvider"}) + testProvider1.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) + testProvider2 := imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) + handlingProvider := of.NewMockFeatureProvider(ctrl) + handlingProvider.EXPECT().Metadata().Return(of.Metadata{Name: "MockProvider"}) + handlingProvider.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) + handledHandler := of.NewMockStateHandler(ctrl) + handledHandler.EXPECT().Init(gomock.Any()).Return(nil) + handledHandler.EXPECT().Shutdown() + testProvider3 := struct { + of.FeatureProvider + of.StateHandler + }{ + handlingProvider, + handledHandler, + } + + providers := make(ProviderMap) + providers["provider1"] = testProvider1 + providers["provider2"] = testProvider2 + providers["provider3"] = testProvider3 + mp, err := NewProvider(providers, StrategyFirstMatch) + require.NoError(t, err) + evalCtx := of.NewTargetlessEvaluationContext(map[string]any{ + "foo": "bar", + }) + eventChan := make(chan of.Event) + ctx, cancel := context.WithCancel(context.Background()) + go func() { + select { + case e := <-mp.EventChannel(): + eventChan <- e + case <-ctx.Done(): + return + } + }() + err = mp.Init(evalCtx) + require.NoError(t, err) + assert.Equal(t, of.ReadyState, mp.Status()) + cancel() + mp.Shutdown() + assert.Equal(t, of.NotReadyState, mp.Status()) +} + +func TestMultiProvider_statusEvaluation(t *testing.T) { + multiProvider := &Provider{ + overallStatus: of.NotReadyState, + providerStatus: make(map[string]of.State), + } + + t.Run("empty state is ready", func(t *testing.T) { + assert.Equal(t, of.ReadyState, multiProvider.evaluateState()) + }) + + t.Run("all states ready is ready", func(t *testing.T) { + multiProvider.providerStatus["provider1"] = of.ReadyState + multiProvider.providerStatus["provider2"] = of.ReadyState + multiProvider.providerStatus["provider3"] = of.ReadyState + assert.Equal(t, of.ReadyState, multiProvider.evaluateState()) + }) + + t.Run("one state stale is stale", func(t *testing.T) { + multiProvider.providerStatus["provider1"] = of.ReadyState + multiProvider.providerStatus["provider2"] = of.ReadyState + multiProvider.providerStatus["provider3"] = of.StaleState + assert.Equal(t, of.StaleState, multiProvider.evaluateState()) + }) + + t.Run("one state error is error", func(t *testing.T) { + multiProvider.providerStatus["provider1"] = of.ReadyState + multiProvider.providerStatus["provider2"] = of.StaleState + multiProvider.providerStatus["provider3"] = of.ErrorState + assert.Equal(t, of.ErrorState, multiProvider.evaluateState()) + }) +} diff --git a/openfeature/multi/strategies.go b/openfeature/multi/strategies.go new file mode 100644 index 00000000..ee00b15b --- /dev/null +++ b/openfeature/multi/strategies.go @@ -0,0 +1,172 @@ +package multi + +import ( + "context" + "maps" + "regexp" + "strings" + + of "github.com/open-feature/go-sdk/openfeature" +) + +// EvaluationStrategy options +const ( + // StrategyFirstMatch returns the result of the first [of.FeatureProvider] whose response is not [of.FlagNotFoundCode]. + // This is executed sequentially, and not in parallel. Any returned errors from a provider other than flag not found + // will result in a default response with a set error. + StrategyFirstMatch EvaluationStrategy = "strategy-first-match" + // StrategyFirstSuccess returns the result of the First [of.FeatureProvider] whose response that is not an error. + // This is very similar to [StrategyFirstMatch], but does not raise errors. This executed sequentially. + StrategyFirstSuccess EvaluationStrategy = "strategy-first-success" + // StrategyComparison returns a response of all [of.FeatureProvider] instances in agreement. All providers are + // called in parallel and then the results of each non-error result are compared to each other. If all responses + // agree, then that value will be returned. Otherwise, the value from the designated fallback [of.FeatureProvider] + // instance's response will be returned. The fallback provider will be assigned to the first provider registered if + // the [WithFallbackProvider] Option is not explicitly set. + StrategyComparison EvaluationStrategy = "strategy-comparison" + // StrategyCustom allows for using a custom [StrategyFn] implementation. If this is set you MUST use the WithCustomStrategy + // Option to set it + StrategyCustom EvaluationStrategy = "strategy-custom" +) + +// Additional [of.Reason] options +const ( + // ReasonAggregated - the resolved value was the agreement of all providers in the multi.Provider using the + // [StrategyComparison] strategy + ReasonAggregated of.Reason = "AGGREGATED" + // ReasonAggregatedFallback ReasonAggregated - the resolved value was result of the fallback provider because the + // providers in multi.Provider were not in agreement using the [StrategyComparison] strategy. + ReasonAggregatedFallback of.Reason = "AGGREGATED_FALLBACK" +) + +// errAggregationNotAllowedText sentinel returned if [of.FeatureProvider.ObjectEvaluation] is called without a set custom +// strategy when response objects are not comparable. +const errAggregationNotAllowedText = "object evaluation not allowed for non-comparable types without custom comparable func" + +type ( + // EvaluationStrategy Defines a strategy to use for resolving the result from multiple providers. + EvaluationStrategy = string + + // FlagTypes defines the types that can be used for flag values. + FlagTypes interface { + int64 | float64 | string | bool | any + } + // StrategyFn defines the signature for a strategy function. + StrategyFn[T FlagTypes] func(ctx context.Context, flag string, defaultValue T, flatCtx of.FlattenedContext) of.GenericResolutionDetail[T] + // StrategyConstructor defines the signature for the function that will be called to retrieve the closure that acts + // as the custom strategy implementation. This function should return a [StrategyFn] + StrategyConstructor func(providers []*NamedProvider) StrategyFn[FlagTypes] +) + +// Common Components + +// setFlagMetadata sets common metadata for evaluations. +func setFlagMetadata(strategyUsed EvaluationStrategy, successProviderName string, metadata of.FlagMetadata) of.FlagMetadata { + if metadata == nil { + metadata = make(of.FlagMetadata) + } + metadata[MetadataSuccessfulProviderName] = successProviderName + metadata[MetadataStrategyUsed] = strategyUsed + return metadata +} + +// cleanErrorMessage removes prefixes from error messages. +func cleanErrorMessage(msg string) string { + codeRegex := strings.Join([]string{ + string(of.ProviderNotReadyCode), + string(of.ProviderFatalCode), + string(of.FlagNotFoundCode), + string(of.ParseErrorCode), + string(of.TypeMismatchCode), + string(of.TargetingKeyMissingCode), + string(of.GeneralCode), + }, "|") + re := regexp.MustCompile("(?:" + codeRegex + "): (.*)") + matches := re.FindSubmatch([]byte(msg)) + matchCount := len(matches) + switch matchCount { + case 0, 1: + return msg + default: + return strings.TrimSpace(string(matches[1])) + } +} + +// mergeFlagMeta merges flag metadata together into a single [of.FlagMetadata] instance by performing a shallow merge. +func mergeFlagMeta(tags ...of.FlagMetadata) of.FlagMetadata { + size := len(tags) + switch size { + case 0: + return make(of.FlagMetadata) + case 1: + return tags[0] + default: + merged := make(of.FlagMetadata) + for _, t := range tags { + maps.Copy(merged, t) + } + return merged + } +} + +// BuildDefaultResult should be called when a [StrategyFn] is in a failure state and needs to return a default value. +// This method will build a resolution detail with the internal provided error set. This method is exported for those +// writing their own custom [StrategyFn]. +func BuildDefaultResult[R FlagTypes](strategy EvaluationStrategy, defaultValue R, err error) of.GenericResolutionDetail[R] { + var rErr of.ResolutionError + var reason of.Reason + if err != nil { + rErr = of.NewGeneralResolutionError(cleanErrorMessage(err.Error())) + reason = of.ErrorReason + } else { + rErr = of.NewFlagNotFoundResolutionError("not found in any provider") + reason = of.DefaultReason + } + + return of.GenericResolutionDetail[R]{ + Value: defaultValue, + ProviderResolutionDetail: of.ProviderResolutionDetail{ + ResolutionError: rErr, + Reason: reason, + FlagMetadata: of.FlagMetadata{MetadataSuccessfulProviderName: "none", MetadataStrategyUsed: strategy}, + }, + } +} + +// Evaluate is a generic method used to resolve a flag from a single [NamedProvider] without losing type information. +// This method is exported for those writing their own custom [StrategyFn]. Since any is an allowed [FlagTypes] this can +// be set to any type, but this should be done with care outside the specified primitive [FlagTypes] +func Evaluate[T FlagTypes](ctx context.Context, provider *NamedProvider, flag string, defaultVal T, flatCtx of.FlattenedContext) of.GenericResolutionDetail[T] { + var resolution of.GenericResolutionDetail[T] + switch v := any(defaultVal).(type) { + case bool: + res := provider.BooleanEvaluation(ctx, flag, v, flatCtx) + resolution.ProviderResolutionDetail = res.ProviderResolutionDetail + resolution.Value = any(res.Value).(T) + case string: + res := provider.StringEvaluation(ctx, flag, v, flatCtx) + resolution.ProviderResolutionDetail = res.ProviderResolutionDetail + resolution.Value = any(res.Value).(T) + case float64: + res := provider.FloatEvaluation(ctx, flag, v, flatCtx) + resolution.ProviderResolutionDetail = res.ProviderResolutionDetail + resolution.Value = any(res.Value).(T) + case int64: + res := provider.IntEvaluation(ctx, flag, v, flatCtx) + resolution.ProviderResolutionDetail = res.ProviderResolutionDetail + resolution.Value = any(res.Value).(T) + default: + res := provider.ObjectEvaluation(ctx, flag, defaultVal, flatCtx) + resolution.ProviderResolutionDetail = res.ProviderResolutionDetail + resolution.Value = res.Value.(T) + } + + if resolution.FlagMetadata == nil { + resolution.FlagMetadata = make(of.FlagMetadata, 2) + } + + resolution.FlagMetadata[MetadataProviderName] = provider.Name + resolution.FlagMetadata[MetadataProviderType] = provider.Metadata().Name + + return resolution +} diff --git a/openfeature/multi/strategies_test.go b/openfeature/multi/strategies_test.go new file mode 100644 index 00000000..0b31c843 --- /dev/null +++ b/openfeature/multi/strategies_test.go @@ -0,0 +1,24 @@ +package multi + +import ( + of "github.com/open-feature/go-sdk/openfeature" + "go.uber.org/mock/gomock" +) + +func createMockProviders(ctrl *gomock.Controller, count int) []*of.MockFeatureProvider { + providerMocks := make([]*of.MockFeatureProvider, 0, count) + for range count { + provider := of.NewMockFeatureProvider(ctrl) + providerMocks = append(providerMocks, provider) + } + + return providerMocks +} + +const testFlag = "test-flag" + +const ( + TestErrorNone = 0 + TestErrorNotFound = 1 + TestErrorError = 2 +) diff --git a/openfeature/openfeature_test.go b/openfeature/openfeature_test.go index 65e7ba5d..868063c2 100644 --- a/openfeature/openfeature_test.go +++ b/openfeature/openfeature_test.go @@ -3,7 +3,6 @@ package openfeature import ( "context" "errors" - "fmt" "reflect" "testing" "time" @@ -811,7 +810,6 @@ func TestLateBindingOfDefaultProvider(t *testing.T) { if strResult != expectedResultFromLateDefaultProvider { t.Errorf("expected %s, but got %s", expectedResultFromLateDefaultProvider, strResult) } - fmt.Println(strResult) } // Nil providers are not accepted for default and named providers