diff --git a/providers/multi-provider/Makefile b/providers/multi-provider/Makefile index 941e647cb..fea2a850f 100644 --- a/providers/multi-provider/Makefile +++ b/providers/multi-provider/Makefile @@ -1,10 +1,14 @@ -.PHONY: generate test +.PHONY: generate test lint GOPATH_LOC = ${GOPATH} +OF_SDK_DIR := $(shell go list -f '{{.Dir}}' github.com/open-feature/go-sdk/openfeature) + generate: go generate ./... - go mod download - mockgen -source=${GOPATH}/pkg/mod/github.com/open-feature/go-sdk@v1.13.1/openfeature/provider.go -package=mocks -destination=./internal/mocks/provider_mock.go + +lint: + go install -v github.com/golangci/golangci-lint/cmd/golangci-lint@v1.64.5 + ${GOPATH}/bin/golangci-lint run ./... test: go test ./... \ No newline at end of file diff --git a/providers/multi-provider/README.md b/providers/multi-provider/README.md index b2bab36e1..4d9488e10 100644 --- a/providers/multi-provider/README.md +++ b/providers/multi-provider/README.md @@ -42,7 +42,14 @@ openfeature.SetProvider(provider) - `WithTimeout` - the duration is used for the total timeout across parallel operations. If none is set it will default to 5 seconds. This is not supported for `FirstMatch` yet, which executes sequentially - `WithFallbackProvider` - Used for setting a fallback provider for the `Comparison` strategy -- `WithLogger` - Provides slog support +- `WithLogger` - Provides slog support using the specified logger +- `WithLoggerDefault` - Default setting. Uses the slog default logger +- `WithoutLogging` - Disables internal logging of the multiprovider +- `WithCustomStrategy` - Allows for passing in an instance of a custom `Strategy` implementation. Must be used in +conjunction with the `StrategyCustom` `EvaluationStrategy` parameter. +- `WithGlobalHooks` - Sets any hooks that should be executed globally across all internal providers. For hooks targeting +specific providers they should either be attached directly to the provider or use `WithProviderHooks` +- `WithProviderHooks` - Sets any hooks that should be executed only for a specific named provider # Strategies @@ -77,8 +84,12 @@ returned. If a provider returns `FLAG_NOT_FOUND` that is not included in the com return not found then the default value is returned. Finally, if any provider returns an error other than `FLAG_NOT_FOUND` the evaluation immediately stops and that error result is returned. This strategy does NOT support `ObjectEvaluation` +## Custom + +Users can opt to write their own strategy by implementing the interface if they have needs that the three built-in +strategies cannot meet. When setting the `StrategyCustom` strategy make sure to pass in an instance of your `Strategy` +implementation using the `WithCustomStrategy` option. + # Not Yet Implemented -- Hooks support -- Event support - Full slog support \ No newline at end of file diff --git a/providers/multi-provider/internal/logger/logger.go b/providers/multi-provider/internal/logger/logger.go new file mode 100644 index 000000000..88c6a8888 --- /dev/null +++ b/providers/multi-provider/internal/logger/logger.go @@ -0,0 +1,62 @@ +package logger + +import ( + "context" + "log/slog" +) + +// ConditionalLogger Logger instance that may be empty so the caller does not need to worry about checking +// if logging is enabled or not. This type should be treated as immutable +type ConditionalLogger struct { + l *slog.Logger +} + +// NewConditionalLogger Creates a new ConditionalLogger. If a nil value is provided no logging will be performed and all +// methods will act as no-ops. The state of a ConditionalLogger should be treated as immutable +func NewConditionalLogger(l *slog.Logger) *ConditionalLogger { + return &ConditionalLogger{l} +} + +// enabled Checks to determine if logging should be performed. Also acts as an internal nil check +func (cl *ConditionalLogger) enabled() bool { + return cl.l != nil +} + +// LogError Log a message at the error level +func (cl *ConditionalLogger) LogError(ctx context.Context, msg string, attr ...slog.Attr) { + if cl.enabled() { + cl.l.LogAttrs(ctx, slog.LevelError, msg, attr...) + } +} + +// LogWarn Log a message at the warn level +func (cl *ConditionalLogger) LogWarn(ctx context.Context, msg string, attr ...slog.Attr) { + if cl.enabled() { + cl.l.LogAttrs(ctx, slog.LevelWarn, msg, attr...) + } +} + +// LogInfo Log a message at the info level (should be used sparingly) +func (cl *ConditionalLogger) LogInfo(ctx context.Context, msg string, attr ...slog.Attr) { + if cl.enabled() { + cl.l.LogAttrs(ctx, slog.LevelInfo, msg, attr...) + } +} + +// LogDebug Log a message at the debug level +func (cl *ConditionalLogger) LogDebug(ctx context.Context, msg string, attr ...slog.Attr) { + if cl.enabled() { + cl.l.LogAttrs(ctx, slog.LevelDebug, msg, attr...) + } +} + +// With Creates and returns a child logger with the provided attributes set. If the current logger is disabled by having +// the same disabled logger will be returned and this acts as a no-op. +func (cl *ConditionalLogger) With(attr ...any) *ConditionalLogger { + if cl.enabled() { + return &ConditionalLogger{l: cl.l.With(attr...)} + } + + // Don't bother creating a child logger since there's no difference + return cl +} diff --git a/providers/multi-provider/internal/mocks/mocks.go b/providers/multi-provider/internal/mocks/mocks.go new file mode 100644 index 000000000..5bbbec202 --- /dev/null +++ b/providers/multi-provider/internal/mocks/mocks.go @@ -0,0 +1,2 @@ +//go:generate go run go.uber.org/mock/mockgen -destination=../../internal/mocks/openfeature_mocks.go -package=mocks "github.com/open-feature/go-sdk/openfeature" FeatureProvider,Hook,StateHandler,EventHandler +package mocks diff --git a/providers/multi-provider/internal/mocks/provider_mock.go b/providers/multi-provider/internal/mocks/openfeature_mocks.go similarity index 73% rename from providers/multi-provider/internal/mocks/provider_mock.go rename to providers/multi-provider/internal/mocks/openfeature_mocks.go index 2f0675dd3..c4a65c8a7 100644 --- a/providers/multi-provider/internal/mocks/provider_mock.go +++ b/providers/multi-provider/internal/mocks/openfeature_mocks.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: /Users/jordanblacker/go/pkg/mod/github.com/open-feature/go-sdk@v1.13.1/openfeature/provider.go +// Source: github.com/open-feature/go-sdk/openfeature (interfaces: FeatureProvider,Hook,StateHandler,EventHandler) // // Generated by this command: // -// mockgen -source=/Users/jordanblacker/go/pkg/mod/github.com/open-feature/go-sdk@v1.13.1/openfeature/provider.go -package=mocks -destination=./internal/mocks/provider_mock.go +// mockgen -destination=../../internal/mocks/openfeature_mocks.go -package=mocks github.com/open-feature/go-sdk/openfeature FeatureProvider,Hook,StateHandler,EventHandler // // Package mocks is a generated GoMock package. @@ -139,6 +139,83 @@ func (mr *MockFeatureProviderMockRecorder) StringEvaluation(ctx, flag, defaultVa return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StringEvaluation", reflect.TypeOf((*MockFeatureProvider)(nil).StringEvaluation), ctx, flag, defaultValue, evalCtx) } +// MockHook is a mock of Hook interface. +type MockHook struct { + ctrl *gomock.Controller + recorder *MockHookMockRecorder + isgomock struct{} +} + +// MockHookMockRecorder is the mock recorder for MockHook. +type MockHookMockRecorder struct { + mock *MockHook +} + +// NewMockHook creates a new mock instance. +func NewMockHook(ctrl *gomock.Controller) *MockHook { + mock := &MockHook{ctrl: ctrl} + mock.recorder = &MockHookMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockHook) EXPECT() *MockHookMockRecorder { + return m.recorder +} + +// After mocks base method. +func (m *MockHook) After(ctx context.Context, hookContext openfeature.HookContext, flagEvaluationDetails openfeature.InterfaceEvaluationDetails, hookHints openfeature.HookHints) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "After", ctx, hookContext, flagEvaluationDetails, hookHints) + ret0, _ := ret[0].(error) + return ret0 +} + +// After indicates an expected call of After. +func (mr *MockHookMockRecorder) After(ctx, hookContext, flagEvaluationDetails, hookHints any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "After", reflect.TypeOf((*MockHook)(nil).After), ctx, hookContext, flagEvaluationDetails, hookHints) +} + +// Before mocks base method. +func (m *MockHook) Before(ctx context.Context, hookContext openfeature.HookContext, hookHints openfeature.HookHints) (*openfeature.EvaluationContext, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Before", ctx, hookContext, hookHints) + ret0, _ := ret[0].(*openfeature.EvaluationContext) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Before indicates an expected call of Before. +func (mr *MockHookMockRecorder) Before(ctx, hookContext, hookHints any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Before", reflect.TypeOf((*MockHook)(nil).Before), ctx, hookContext, hookHints) +} + +// Error mocks base method. +func (m *MockHook) Error(ctx context.Context, hookContext openfeature.HookContext, err error, hookHints openfeature.HookHints) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Error", ctx, hookContext, err, hookHints) +} + +// Error indicates an expected call of Error. +func (mr *MockHookMockRecorder) Error(ctx, hookContext, err, hookHints any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockHook)(nil).Error), ctx, hookContext, err, hookHints) +} + +// Finally mocks base method. +func (m *MockHook) Finally(ctx context.Context, hookContext openfeature.HookContext, hookHints openfeature.HookHints) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Finally", ctx, hookContext, hookHints) +} + +// Finally indicates an expected call of Finally. +func (mr *MockHookMockRecorder) Finally(ctx, hookContext, hookHints any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Finally", reflect.TypeOf((*MockHook)(nil).Finally), ctx, hookContext, hookHints) +} + // MockStateHandler is a mock of StateHandler interface. type MockStateHandler struct { ctrl *gomock.Controller diff --git a/providers/multi-provider/internal/wrappers/hook_isolator.go b/providers/multi-provider/internal/wrappers/hook_isolator.go new file mode 100644 index 000000000..5ab83772d --- /dev/null +++ b/providers/multi-provider/internal/wrappers/hook_isolator.go @@ -0,0 +1,342 @@ +package wrappers + +import ( + "context" + "fmt" + of "github.com/open-feature/go-sdk/openfeature" + "slices" + "sync" +) + +type ( + // HookIsolator used as a wrapper around a provider that prevents context changes from leaking between providers + // during evaluation + HookIsolator struct { + mu sync.Mutex + of.FeatureProvider + hooks []of.Hook + capturedContext of.HookContext + capturedHints of.HookHints + } + + // EventHandlingHookIsolator is equivalent to HookIsolator, but also implements [of.EventHandler] + EventHandlingHookIsolator struct { + HookIsolator + } +) + +var ( + _ of.FeatureProvider = (*HookIsolator)(nil) + _ of.Hook = (*HookIsolator)(nil) + _ of.EventHandler = (*EventHandlingHookIsolator)(nil) +) + +func IsolateProvider(provider of.FeatureProvider, extraHooks []of.Hook) *HookIsolator { + return &HookIsolator{ + FeatureProvider: provider, + hooks: slices.Concat(provider.Hooks(), extraHooks), + } +} + +func IsolateProviderWithEvents(provider of.FeatureProvider, extraHooks []of.Hook) *EventHandlingHookIsolator { + return &EventHandlingHookIsolator{*IsolateProvider(provider, extraHooks)} +} + +func (h *EventHandlingHookIsolator) EventChannel() <-chan of.Event { + return h.FeatureProvider.(of.EventHandler).EventChannel() +} + +func (h *HookIsolator) Before(ctx context.Context, hookContext of.HookContext, hookHints of.HookHints) (*of.EvaluationContext, error) { + // Used for capturing the context and hints + h.mu.Lock() + defer h.mu.Unlock() + h.capturedContext = hookContext + h.capturedHints = hookHints + // Return copy of original evaluation context so any changes are isolated to each provider's hooks + evalCtx := h.capturedContext.EvaluationContext() + return &evalCtx, nil +} + +func (h *HookIsolator) After(ctx context.Context, hookContext of.HookContext, flagEvaluationDetails of.InterfaceEvaluationDetails, hookHints of.HookHints) error { + // Purposely left as a no-op + return nil +} + +func (h *HookIsolator) Error(ctx context.Context, hookContext of.HookContext, err error, hookHints of.HookHints) { + // Purposely left as a no-op +} + +func (h *HookIsolator) Finally(ctx context.Context, hookContext of.HookContext, hookHints of.HookHints) { + // Purposely left as a no-op +} + +func (h *HookIsolator) Metadata() of.Metadata { + return h.FeatureProvider.Metadata() +} + +func (h *HookIsolator) BooleanEvaluation(ctx context.Context, flag string, defaultValue bool, evalCtx of.FlattenedContext) of.BoolResolutionDetail { + completeEval := h.evaluate(ctx, flag, of.Boolean, defaultValue, evalCtx) + + return of.BoolResolutionDetail{ + Value: completeEval.Value.(bool), + ProviderResolutionDetail: toProviderResolutionDetail(completeEval), + } +} + +func (h *HookIsolator) StringEvaluation(ctx context.Context, flag string, defaultValue string, evalCtx of.FlattenedContext) of.StringResolutionDetail { + completeEval := h.evaluate(ctx, flag, of.String, defaultValue, evalCtx) + + return of.StringResolutionDetail{ + Value: completeEval.Value.(string), + ProviderResolutionDetail: toProviderResolutionDetail(completeEval), + } +} + +func (h *HookIsolator) FloatEvaluation(ctx context.Context, flag string, defaultValue float64, evalCtx of.FlattenedContext) of.FloatResolutionDetail { + completeEval := h.evaluate(ctx, flag, of.Float, defaultValue, evalCtx) + + return of.FloatResolutionDetail{ + Value: completeEval.Value.(float64), + ProviderResolutionDetail: toProviderResolutionDetail(completeEval), + } +} + +func (h *HookIsolator) IntEvaluation(ctx context.Context, flag string, defaultValue int64, evalCtx of.FlattenedContext) of.IntResolutionDetail { + completeEval := h.evaluate(ctx, flag, of.Int, defaultValue, evalCtx) + + return of.IntResolutionDetail{ + Value: completeEval.Value.(int64), + ProviderResolutionDetail: toProviderResolutionDetail(completeEval), + } +} + +func (h *HookIsolator) ObjectEvaluation(ctx context.Context, flag string, defaultValue interface{}, evalCtx of.FlattenedContext) of.InterfaceResolutionDetail { + completeEval := h.evaluate(ctx, flag, of.Object, defaultValue, evalCtx) + + return of.InterfaceResolutionDetail{ + Value: completeEval.Value, + ProviderResolutionDetail: toProviderResolutionDetail(completeEval), + } +} + +func toProviderResolutionDetail(evalDetails of.InterfaceEvaluationDetails) of.ProviderResolutionDetail { + var resolutionErr of.ResolutionError + var reason of.Reason + switch evalDetails.ErrorCode { + case of.GeneralCode: + resolutionErr = of.NewGeneralResolutionError(evalDetails.ErrorMessage) + reason = of.ErrorReason + case of.FlagNotFoundCode: + resolutionErr = of.NewFlagNotFoundResolutionError(evalDetails.ErrorMessage) + reason = of.DefaultReason + case of.TargetingKeyMissingCode: + resolutionErr = of.NewTargetingKeyMissingResolutionError(evalDetails.ErrorMessage) + reason = of.TargetingMatchReason + case of.TypeMismatchCode: + resolutionErr = of.NewTypeMismatchResolutionError(evalDetails.ErrorMessage) + reason = of.ErrorReason + case of.ParseErrorCode: + resolutionErr = of.NewParseErrorResolutionError(evalDetails.ErrorMessage) + reason = of.ErrorReason + case of.InvalidContextCode: + resolutionErr = of.NewInvalidContextResolutionError(evalDetails.ErrorMessage) + reason = of.ErrorReason + } + return of.ProviderResolutionDetail{ + ResolutionError: resolutionErr, + Reason: reason, + Variant: evalDetails.Variant, + FlagMetadata: evalDetails.FlagMetadata, + } +} + +func (h *HookIsolator) Hooks() []of.Hook { + // return self as hook to capture contexts + return []of.Hook{h} +} + +func (h *HookIsolator) evaluate(ctx context.Context, flag string, flagType of.Type, defaultValue interface{}, flatCtx of.FlattenedContext) of.InterfaceEvaluationDetails { + evalDetails := of.InterfaceEvaluationDetails{ + Value: defaultValue, + EvaluationDetails: of.EvaluationDetails{ + FlagKey: flag, + FlagType: flagType, + }, + } + + defer func() { + h.finallyHooks(ctx) + }() + + evalCtx, err := h.beforeHooks(ctx) + // Update hook context unconditionally + h.updateEvalContext(evalCtx) + if err != nil { + //h.logger.Error( + // err, "before hook", "flag", flag, "defaultValue", defaultValue, + // "evaluationContext", flatCtx, "evaluationOptions", options, "type", flagType.String(), + //) + err = fmt.Errorf("before hook: %w", err) + h.errorHooks(ctx, err) + evalDetails.ResolutionDetail = of.ResolutionDetail{ + Reason: of.ErrorReason, + ErrorCode: of.GeneralCode, + ErrorMessage: err.Error(), + FlagMetadata: nil, + } + return evalDetails + } + + // Merge together the passed in flat context and the captured evaluation context and transform back into a flattened + // context for evaluation + flatCtx = flattenContext(mergeContexts(h.capturedContext.EvaluationContext(), deepenContext(flatCtx))) + + var resolution of.InterfaceResolutionDetail + switch flagType { + case of.Object: + resolution = h.FeatureProvider.ObjectEvaluation(ctx, flag, defaultValue, flatCtx) + case of.Boolean: + defValue := defaultValue.(bool) + res := h.FeatureProvider.BooleanEvaluation(ctx, flag, defValue, flatCtx) + resolution.ProviderResolutionDetail = res.ProviderResolutionDetail + resolution.Value = res.Value + case of.String: + defValue := defaultValue.(string) + res := h.FeatureProvider.StringEvaluation(ctx, flag, defValue, flatCtx) + resolution.ProviderResolutionDetail = res.ProviderResolutionDetail + resolution.Value = res.Value + case of.Float: + defValue := defaultValue.(float64) + res := h.FeatureProvider.FloatEvaluation(ctx, flag, defValue, flatCtx) + resolution.ProviderResolutionDetail = res.ProviderResolutionDetail + resolution.Value = res.Value + case of.Int: + defValue := defaultValue.(int64) + res := h.FeatureProvider.IntEvaluation(ctx, flag, defValue, flatCtx) + resolution.ProviderResolutionDetail = res.ProviderResolutionDetail + resolution.Value = res.Value + } + + err = resolution.Error() + if err != nil { + //h.logger.Error( + // err, "flag resolution", "flag", flag, "defaultValue", defaultValue, + // "evaluationContext", flatCtx, "evaluationOptions", options, "type", flagType.String(), "errorCode", err, + // "errMessage", resolution.ResolutionError.message, + //) + err = fmt.Errorf("error code: %w", err) + h.errorHooks(ctx, err) + evalDetails.ResolutionDetail = resolution.ResolutionDetail() + evalDetails.Reason = of.ErrorReason + return evalDetails + } + evalDetails.Value = resolution.Value + evalDetails.ResolutionDetail = resolution.ResolutionDetail() + + if err := h.afterHooks(ctx, evalDetails); err != nil { + //h.logger.Error( + // err, "after hook", "flag", flag, "defaultValue", defaultValue, + // "evaluationContext", flatCtx, "evaluationOptions", options, "type", flagType.String(), + //) + err = fmt.Errorf("after hook: %w", err) + h.errorHooks(ctx, err) + return evalDetails + } + + return evalDetails +} + +func (h *HookIsolator) beforeHooks(ctx context.Context) (of.EvaluationContext, error) { + contexts := []of.EvaluationContext{h.capturedContext.EvaluationContext()} + for _, hook := range h.hooks { + resultEvalCtx, err := hook.Before(ctx, h.capturedContext, h.capturedHints) + if resultEvalCtx != nil { + contexts = append(contexts, *resultEvalCtx) + } + if err != nil { + return mergeContexts(contexts...), err + } + } + + return mergeContexts(contexts...), nil +} + +func (h *HookIsolator) afterHooks(ctx context.Context, evalDetails of.InterfaceEvaluationDetails) error { + for _, hook := range h.hooks { + if err := hook.After(ctx, h.capturedContext, evalDetails, h.capturedHints); err != nil { + return err + } + } + + return nil +} + +func (h *HookIsolator) errorHooks(ctx context.Context, err error) { + for _, hook := range h.hooks { + hook.Error(ctx, h.capturedContext, err, h.capturedHints) + } +} + +func (h *HookIsolator) finallyHooks(ctx context.Context) { + for _, hook := range h.hooks { + hook.Finally(ctx, h.capturedContext, h.capturedHints) + } +} + +// updateEvalContext Returns a new [of.HookContext] with an updated [of.EvaluationContext] value. [of.HookContext] is +// immutable, and this returns a new [of.HookContext] with all other values copied +func (h *HookIsolator) updateEvalContext(evalCtx of.EvaluationContext) { + hookCtx := of.NewHookContext( + h.capturedContext.FlagKey(), + h.capturedContext.FlagType(), + h.capturedContext.DefaultValue(), + h.capturedContext.ClientMetadata(), + h.capturedContext.ProviderMetadata(), + evalCtx, + ) + h.mu.Lock() + defer h.mu.Unlock() + h.capturedContext = hookCtx +} + +func deepenContext(flatCtx of.FlattenedContext) of.EvaluationContext { + noTargetingKey := make(map[string]interface{}) + for k, v := range flatCtx { + if k != "targetingKey" { + noTargetingKey[k] = v + } + } + return of.NewEvaluationContext(flatCtx["targetingKey"].(string), noTargetingKey) +} + +func flattenContext(evalCtx of.EvaluationContext) of.FlattenedContext { + flatCtx := evalCtx.Attributes() + flatCtx["targetingKey"] = evalCtx.TargetingKey() + return flatCtx +} + +// merges attributes from the given EvaluationContexts with the nth EvaluationContext taking precedence in case +// of any conflicts with the (n+1)th EvaluationContext +func mergeContexts(evaluationContexts ...of.EvaluationContext) of.EvaluationContext { + if len(evaluationContexts) == 0 { + return of.EvaluationContext{} + } + // create initial values + attributes := evaluationContexts[0].Attributes() + targetingKey := evaluationContexts[0].TargetingKey() + + for i := 1; i < len(evaluationContexts); i++ { + if targetingKey == "" && evaluationContexts[i].TargetingKey() != "" { + targetingKey = evaluationContexts[i].TargetingKey() + } + + for k, v := range evaluationContexts[i].Attributes() { + _, ok := attributes[k] + if !ok { + attributes[k] = v + } + } + } + + return of.NewEvaluationContext(targetingKey, attributes) +} diff --git a/providers/multi-provider/internal/wrappers/hook_isolator_test.go b/providers/multi-provider/internal/wrappers/hook_isolator_test.go new file mode 100644 index 000000000..cad52374d --- /dev/null +++ b/providers/multi-provider/internal/wrappers/hook_isolator_test.go @@ -0,0 +1,104 @@ +package wrappers + +import ( + "context" + "errors" + "github.com/open-feature/go-sdk-contrib/providers/multi-provider/internal/mocks" + of "github.com/open-feature/go-sdk/openfeature" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "testing" +) + +func Test_HookIsolator_BeforeCapturesData(t *testing.T) { + hookCtx := of.NewHookContext( + "test-key", + of.Boolean, + false, + of.ClientMetadata{}, + of.Metadata{}, + of.NewEvaluationContext("target", map[string]interface{}{}), + ) + require.NotZero(t, hookCtx) + hookHints := of.NewHookHints(map[string]interface{}{"foo": "bar"}) + require.NotZero(t, hookHints) + ctrl := gomock.NewController(t) + provider := mocks.NewMockFeatureProvider(ctrl) + provider.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) + isolator := IsolateProvider(provider, []of.Hook{}) + assert.Zero(t, isolator.capturedContext) + assert.Zero(t, isolator.capturedHints) + evalCtx, err := isolator.Before(context.Background(), hookCtx, hookHints) + require.NoError(t, err) + assert.NotNil(t, evalCtx) + assert.Equal(t, hookCtx, isolator.capturedContext) + assert.Equal(t, hookHints, isolator.capturedHints) +} + +func Test_HookIsolator_Hooks_ReturnsSelf(t *testing.T) { + ctrl := gomock.NewController(t) + provider := mocks.NewMockFeatureProvider(ctrl) + provider.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) + isolator := IsolateProvider(provider, []of.Hook{}) + hooks := isolator.Hooks() + assert.NotEmpty(t, hooks) + assert.Same(t, isolator, hooks[0]) +} + +func Test_HookIsolator_ExecutesHooksDuringEvaluation_NoError(t *testing.T) { + ctrl := gomock.NewController(t) + testHook := mocks.NewMockHook(ctrl) + testHook.EXPECT().Before(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) + testHook.EXPECT().After(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + testHook.EXPECT().Finally(gomock.Any(), gomock.Any(), gomock.Any()) + testHook.EXPECT().Error(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0) + + provider := mocks.NewMockFeatureProvider(ctrl) + provider.EXPECT().Hooks().Return([]of.Hook{testHook}) + provider.EXPECT().BooleanEvaluation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(of.BoolResolutionDetail{ + Value: true, + ProviderResolutionDetail: of.ProviderResolutionDetail{}, + }) + + isolator := IsolateProvider(provider, nil) + result := isolator.BooleanEvaluation(context.Background(), "test-flag", false, of.FlattenedContext{"targetingKey": "anon"}) + assert.True(t, result.Value) +} + +func Test_HookIsolator_ExecutesHooksDuringEvaluation_BeforeErrorAbortsExecution(t *testing.T) { + ctrl := gomock.NewController(t) + testHook := mocks.NewMockHook(ctrl) + testHook.EXPECT().Before(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("test error")) + testHook.EXPECT().After(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0) + testHook.EXPECT().Finally(gomock.Any(), gomock.Any(), gomock.Any()) + testHook.EXPECT().Error(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + + provider := mocks.NewMockFeatureProvider(ctrl) + provider.EXPECT().Hooks().Return([]of.Hook{testHook}) + + isolator := IsolateProvider(provider, nil) + result := isolator.BooleanEvaluation(context.Background(), "test-flag", false, of.FlattenedContext{"targetingKey": "anon"}) + assert.False(t, result.Value) +} + +func Test_HookIsolator_ExecutesHooksDuringEvaluation_WithAfterError(t *testing.T) { + ctrl := gomock.NewController(t) + testHook := mocks.NewMockHook(ctrl) + testHook.EXPECT().Before(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) + testHook.EXPECT().After(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("test error")) + testHook.EXPECT().Finally(gomock.Any(), gomock.Any(), gomock.Any()) + testHook.EXPECT().Error(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + + provider := mocks.NewMockFeatureProvider(ctrl) + provider.EXPECT().Hooks().Return([]of.Hook{testHook}) + provider.EXPECT().BooleanEvaluation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(of.BoolResolutionDetail{ + Value: false, + ProviderResolutionDetail: of.ProviderResolutionDetail{}, + }) + + isolator := IsolateProvider(provider, nil) + result := isolator.BooleanEvaluation(context.Background(), "test-flag", false, of.FlattenedContext{"targetingKey": "anon"}) + assert.False(t, result.Value) +} diff --git a/providers/multi-provider/pkg/options.go b/providers/multi-provider/pkg/options.go index bbc38154b..1017a529c 100644 --- a/providers/multi-provider/pkg/options.go +++ b/providers/multi-provider/pkg/options.go @@ -14,6 +14,21 @@ func WithLogger(l *slog.Logger) Option { } } +// WithLoggerDefault Uses the default [slog.Logger] (this is the default setting) +// use WithoutLogging to disable logging completely +func WithLoggerDefault() Option { + return func(conf *Configuration) { + conf.logger = slog.Default() + } +} + +// WithoutLogging Disables logging functionality +func WithoutLogging() Option { + return func(conf *Configuration) { + conf.logger = nil + } +} + // WithTimeout Set a timeout for the total runtime for evaluation of parallel strategies func WithTimeout(d time.Duration) Option { return func(conf *Configuration) { @@ -36,16 +51,21 @@ func WithCustomStrategy(s strategies.Strategy) Option { } } -// WithEventPublishing Enables event publishing (Not Yet Implemented) -func WithEventPublishing() Option { +// WithGlobalHooks sets the global hooks for the provider. These are hooks that affect ALL providers. For hooks that +// target specific providers make sure to attach them to that provider directly, or use the WithProviderHook Option if +// that provider does not provide its own hook functionality +func WithGlobalHooks(hooks ...of.Hook) Option { return func(conf *Configuration) { - conf.publishEvents = true + conf.hooks = hooks } } -// WithoutEventPublishing Disables event publishing (this is the default, but included for explicit usage) -func WithoutEventPublishing() Option { +// WithProviderHooks sets hooks that execute only for a specific provider. The providerName must match the unique provider +// name set during MultiProvider creation. This should only be used if you need hooks that execute around a specific +// provider, but that provider does not currently accept a way to set hooks. This option can be used multiple times using +// unique provider names. Using a provider name that is not known will cause an error. +func WithProviderHooks(providerName string, hooks ...of.Hook) Option { return func(conf *Configuration) { - conf.publishEvents = false + conf.providerHooks[providerName] = hooks } } diff --git a/providers/multi-provider/pkg/providers.go b/providers/multi-provider/pkg/providers.go index c40703981..63e94b030 100644 --- a/providers/multi-provider/pkg/providers.go +++ b/providers/multi-provider/pkg/providers.go @@ -4,6 +4,8 @@ import ( "context" "errors" "fmt" + "github.com/open-feature/go-sdk-contrib/providers/multi-provider/internal/logger" + "github.com/open-feature/go-sdk-contrib/providers/multi-provider/internal/wrappers" "github.com/open-feature/go-sdk-contrib/providers/multi-provider/pkg/strategies" "golang.org/x/sync/errgroup" "log/slog" @@ -20,13 +22,20 @@ import ( type ( // MultiProvider Provider used for combining multiple providers MultiProvider struct { - providers ProviderMap - metadata of.Metadata - events chan of.Event - status of.State - mu sync.RWMutex - strategy strategies.Strategy - logger *slog.Logger + providers ProviderMap + metadata of.Metadata + initialized bool + totalStatus of.State + totalStatusLock sync.RWMutex + providerStatus map[string]of.State + providerStatusLock sync.Mutex + strategy strategies.Strategy + logger *logger.ConditionalLogger + outboundEvents chan of.Event + inboundEvents chan namedEvent + workerGroup sync.WaitGroup + shutdownFunc context.CancelFunc + globalHooks []of.Hook } // Configuration MultiProvider's internal configuration @@ -35,10 +44,9 @@ type ( fallbackProvider of.FeatureProvider customStrategy strategies.Strategy logger *slog.Logger - publishEvents bool - metadata *of.Metadata timeout time.Duration - hooks []of.Hook // Not implemented yet + hooks []of.Hook + providerHooks map[string][]of.Hook } // EvaluationStrategy Defines a strategy to use for resolving the result from multiple providers @@ -47,6 +55,12 @@ type ( ProviderMap map[string]of.FeatureProvider // Option Function used for setting Configuration via the options pattern Option func(*Configuration) + + // Private Types + namedEvent struct { + of.Event + providerName string + } ) const ( @@ -62,9 +76,42 @@ const ( // StrategyCustom allows for using a custom Strategy implementation. If this is set you MUST use the WithCustomStrategy // option to set it StrategyCustom EvaluationStrategy = "strategy-custom" + + MetadataProviderName = "multiprovider-provider-name" + MetadataProviderType = "multiprovider-provider-type" + MetadataInternalError = "multiprovider-internal-error" +) + +var ( + _ of.FeatureProvider = (*MultiProvider)(nil) + _ of.EventHandler = (*MultiProvider)(nil) + _ of.StateHandler = (*MultiProvider)(nil) + stateValues map[of.State]int + stateTable [3]of.State + eventTypeToState map[of.EventType]of.State ) -var _ of.FeatureProvider = (*MultiProvider)(nil) +func init() { + // used for mapping provider event types & provider states to comparable values for evaluation + stateValues = map[of.State]int{ + "": -1, // Not a real state, but used for handling provider config changes + of.ErrorState: 0, + of.StaleState: 1, + of.ReadyState: 2, + } + // used for mapping + stateTable = [3]of.State{ + of.ReadyState, // 0 + of.StaleState, // 1 + of.ErrorState, // 2 + } + eventTypeToState = map[of.EventType]of.State{ + of.ProviderConfigChange: "", + of.ProviderReady: of.ReadyState, + of.ProviderStale: of.StaleState, + of.ProviderError: of.ErrorState, + } +} // AsNamedProviderSlice Converts the map into a slice of NamedProvider instances func (m ProviderMap) AsNamedProviderSlice() []*strategies.NamedProvider { @@ -103,8 +150,20 @@ func NewMultiProvider(providerMap ProviderMap, evaluationStrategy EvaluationStra if len(providerMap) == 0 { return nil, errors.New("providerMap cannot be nil or empty") } - // Validate Providers + + config := &Configuration{ + logger: slog.Default(), // Logging enabled by default using default slog logger + providerHooks: make(map[string][]of.Hook), + } + + for _, opt := range options { + opt(config) + } + + providers := providerMap + collectedHooks := make([]of.Hook, 0, len(providerMap)) for name, provider := range providerMap { + // Validate Providers if name == "" { return nil, errors.New("provider name cannot be the empty string") } @@ -112,31 +171,31 @@ func NewMultiProvider(providerMap ProviderMap, evaluationStrategy EvaluationStra if provider == nil { return nil, fmt.Errorf("provider %s cannot be nil", name) } - } - - config := &Configuration{ - logger: slog.Default(), - } - for _, opt := range options { - opt(config) - } + // Wrap any providers that include hooks + if (len(provider.Hooks()) + len(config.providerHooks[name])) == 0 { + continue + } - var eventChannel chan of.Event - if config.publishEvents { - eventChannel = make(chan of.Event) - } + var wrappedProvider of.FeatureProvider + if _, ok := provider.(of.EventHandler); ok { + wrappedProvider = wrappers.IsolateProviderWithEvents(provider, config.providerHooks[name]) + } else { + wrappedProvider = wrappers.IsolateProvider(provider, config.providerHooks[name]) + } - logger := config.logger - if logger == nil { - logger = slog.Default() + providers[name] = wrappedProvider + collectedHooks = append(collectedHooks, wrappedProvider.Hooks()...) } multiProvider := &MultiProvider{ - providers: providerMap, - events: eventChannel, - logger: logger, - metadata: providerMap.buildMetadata(), + providers: providers, + outboundEvents: make(chan of.Event), + logger: logger.NewConditionalLogger(config.logger), + metadata: providerMap.buildMetadata(), + totalStatus: of.NotReadyState, + providerStatus: make(map[string]of.State), + globalHooks: slices.Concat(config.hooks, collectedHooks), } var zeroDuration time.Duration @@ -220,47 +279,184 @@ func (mp *MultiProvider) ObjectEvaluation(ctx context.Context, flag string, defa // Init will run the initialize method for all of provides and aggregate the errors. func (mp *MultiProvider) Init(evalCtx of.EvaluationContext) error { var eg errgroup.Group - + // wrapper type used only for initialization of event listener workers + type namedEventHandler struct { + of.EventHandler + name string + } + mp.logger.LogDebug(context.Background(), "start initialization") + mp.inboundEvents = make(chan namedEvent, len(mp.providers)) + handlers := make(chan namedEventHandler) for name, provider := range mp.providers { + // Initialize each provider to not ready state. No locks required there are no workers running + mp.providerStatus[name] = of.NotReadyState + l := mp.logger.With(slog.String("multiprovider-provider-name", name)) + eg.Go(func() error { + l.LogDebug(context.Background(), "starting initialization") stateHandle, ok := provider.(of.StateHandler) if !ok { - return nil - } - if err := stateHandle.Init(evalCtx); err != nil { + l.LogDebug(context.Background(), "StateHandle not implemented, skipping initialization") + } else if err := stateHandle.Init(evalCtx); err != nil { + l.LogError(context.Background(), "initialization failed", slog.Any("error", err)) return &mperr.ProviderError{ Err: err, ProviderName: name, } } - + l.LogDebug(context.Background(), "initialization successful") + if eventer, ok := provider.(of.EventHandler); ok { + l.LogDebug(context.Background(), "detected EventHandler implementation") + handlers <- namedEventHandler{eventer, name} + } else { + // Do not yet update providers that need event handling + mp.providerStatusLock.Lock() + defer mp.providerStatusLock.Unlock() + mp.providerStatus[name] = of.ReadyState + } return nil }) } if err := eg.Wait(); err != nil { - mp.mu.Lock() - defer mp.mu.Unlock() - mp.status = of.ErrorState - + mp.setStatus(of.ErrorState) + var pErr *mperr.ProviderError + if errors.As(err, &pErr) { + // Update provider status to error, no event needs to be emitted. + // No locks needed as no workers are active at this point + mp.providerStatus[pErr.ProviderName] = of.ErrorState + } else { + pErr = &mperr.ProviderError{ + Err: err, + ProviderName: "unknown", + } + } + mp.outboundEvents <- of.Event{ + ProviderName: mp.Metadata().Name, + EventType: of.ProviderError, + ProviderEventDetails: of.ProviderEventDetails{ + Message: fmt.Sprintf("internal provider %s encountered an error during initialization: %+v", pErr.ProviderName, pErr.Err), + FlagChanges: nil, + EventMetadata: map[string]interface{}{ + MetadataProviderName: pErr.ProviderName, + MetadataInternalError: pErr.Error(), + }, + }, + } return err } - - mp.mu.Lock() - defer mp.mu.Unlock() - mp.status = of.ReadyState + close(handlers) + workerCtx, shutdownFunc := context.WithCancel(context.Background()) + for h := range handlers { + go mp.startListening(workerCtx, h.name, h.EventHandler, &mp.workerGroup) + } + mp.shutdownFunc = shutdownFunc + + go func() { + workerLogger := mp.logger.With(slog.String("multiprovider-worker", "event-forwarder-worker")) + mp.workerGroup.Add(1) + defer mp.workerGroup.Done() + for e := range mp.inboundEvents { + l := workerLogger.With( + slog.String("multiprovider-provider-name", e.providerName), + slog.String("multiprovider-provider-type", e.ProviderName), + ) + l.LogDebug(context.Background(), fmt.Sprintf("received %s event from provider", e.EventType)) + state := mp.updateProviderStateAndEvaluateTotalState(e, l) + if state != mp.Status() { + mp.setStatus(state) + mp.outboundEvents <- e.Event + l.LogDebug(context.Background(), "forwarded state update event") + } else { + l.LogDebug(context.Background(), "total state not updated, inbound event will not be emitted") + } + } + }() + + mp.setStatus(of.ReadyState) + mp.outboundEvents <- of.Event{ + ProviderName: mp.Metadata().Name, + EventType: of.ProviderReady, + ProviderEventDetails: of.ProviderEventDetails{ + Message: "all internal providers initialized successfully", + FlagChanges: nil, + EventMetadata: map[string]interface{}{ + MetadataProviderName: "all", + }, + }, + } + mp.initialized = true return nil } -// Status the current status of the MultiProvider -func (mp *MultiProvider) Status() of.State { - mp.mu.RLock() - defer mp.mu.RUnlock() - return mp.status +// startListening is intended to be +func (mp *MultiProvider) startListening(ctx context.Context, name string, h of.EventHandler, wg *sync.WaitGroup) { + wg.Add(1) + defer wg.Done() + for { + select { + case e := <-h.EventChannel(): + e.EventMetadata[MetadataProviderName] = name + e.EventMetadata[MetadataProviderType] = h.(of.FeatureProvider).Metadata().Name + mp.inboundEvents <- namedEvent{ + Event: e, + providerName: name, + } + case <-ctx.Done(): + return + } + } +} + +func (mp *MultiProvider) updateProviderStateAndEvaluateTotalState(e namedEvent, l *logger.ConditionalLogger) of.State { + if e.EventType == of.ProviderConfigChange { + l.LogDebug(context.Background(), fmt.Sprintf("ProviderConfigChange event: %s", e.Message)) + return mp.Status() + } + mp.providerStatusLock.Lock() + defer mp.providerStatusLock.Unlock() + logProviderState(l, e, mp.providerStatus[e.providerName]) + mp.providerStatus[e.providerName] = eventTypeToState[e.EventType] + maxState := stateValues[of.ReadyState] // initialize to the lowest state value + for _, s := range mp.providerStatus { + if stateValues[s] > maxState { + // change in state due to higher priority + maxState = stateValues[s] + } + } + return stateTable[maxState] +} + +func logProviderState(l *logger.ConditionalLogger, e namedEvent, previousState of.State) { + switch eventTypeToState[e.EventType] { + case of.ReadyState: + if previousState != of.NotReadyState { + l.LogInfo(context.Background(), fmt.Sprintf("provider %s has returned to ready state from %s", e.providerName, previousState)) + return + } + l.LogDebug(context.Background(), fmt.Sprintf("provider %s is ready", e.providerName)) + case of.StaleState: + l.LogWarn(context.Background(), fmt.Sprintf("provider %s is stale: %s", e.providerName, e.Message)) + case of.ErrorState: + l.LogError(context.Background(), fmt.Sprintf("provider %s is in an error state: %s", e.providerName, e.Message)) + } } // Shutdown Shuts down all internal providers func (mp *MultiProvider) Shutdown() { + if !mp.initialized { + // Don't do anything if we were never initialized + return + } + // Stop all event listener workers, shutdown events should not affect overall state + mp.shutdownFunc() + // Stop forwarding worker + close(mp.inboundEvents) + mp.logger.LogDebug(context.Background(), "triggered worker shutdown") + // Wait for workers to stop + mp.workerGroup.Wait() + mp.logger.LogDebug(context.Background(), "worker shutdown completed") + mp.logger.LogDebug(context.Background(), "starting provider shutdown") var wg sync.WaitGroup for _, provider := range mp.providers { wg.Add(1) @@ -272,10 +468,31 @@ func (mp *MultiProvider) Shutdown() { }(provider) } + mp.logger.LogDebug(context.Background(), "waiting for provider shutdown completion") wg.Wait() + mp.logger.LogDebug(context.Background(), "provider shutdown completed") + mp.setStatus(of.NotReadyState) + close(mp.outboundEvents) + mp.outboundEvents = nil + mp.inboundEvents = nil + mp.initialized = false +} + +// Status the current state of the MultiProvider +func (mp *MultiProvider) Status() of.State { + mp.totalStatusLock.RLock() + defer mp.totalStatusLock.RUnlock() + return mp.totalStatus +} + +func (mp *MultiProvider) setStatus(state of.State) { + mp.totalStatusLock.Lock() + defer mp.totalStatusLock.Unlock() + mp.totalStatus = state + mp.logger.LogDebug(context.Background(), "state updated", slog.String("state", string(state))) } -// EventChannel the channel events are emitted on (Not Yet Implemented) +// EventChannel the channel events are emitted on func (mp *MultiProvider) EventChannel() <-chan of.Event { - return mp.events + return mp.outboundEvents } diff --git a/providers/multi-provider/pkg/providers_test.go b/providers/multi-provider/pkg/providers_test.go index d4efdf02a..d51b22e19 100644 --- a/providers/multi-provider/pkg/providers_test.go +++ b/providers/multi-provider/pkg/providers_test.go @@ -1,6 +1,7 @@ package multiprovider import ( + "context" "errors" "github.com/open-feature/go-sdk-contrib/providers/multi-provider/internal/mocks" "github.com/open-feature/go-sdk-contrib/providers/multi-provider/pkg/strategies" @@ -113,6 +114,7 @@ func TestMultiProvider_MetaData(t *testing.T) { testProvider2.EXPECT().Metadata().Return(of.Metadata{ Name: "MockProvider", }) + testProvider2.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) providers := make(ProviderMap) providers["provider1"] = testProvider1 @@ -131,9 +133,11 @@ func TestMultiProvider_Init(t *testing.T) { testProvider1 := mocks.NewMockFeatureProvider(ctrl) testProvider1.EXPECT().Metadata().Return(of.Metadata{Name: "MockProvider"}) + testProvider1.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) testProvider2 := imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) testProvider3 := mocks.NewMockFeatureProvider(ctrl) testProvider3.EXPECT().Metadata().Return(of.Metadata{Name: "MockProvider"}) + testProvider3.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) providers := make(ProviderMap) providers["provider1"] = testProvider1 @@ -147,16 +151,41 @@ func TestMultiProvider_Init(t *testing.T) { "foo": "bar", } evalCtx := openfeature.NewTargetlessEvaluationContext(attributes) - + eventChan := make(chan of.Event) + ctx, cancel := context.WithCancel(context.Background()) + go func() { + select { + case e := <-mp.EventChannel(): + eventChan <- e + case <-ctx.Done(): + return + } + }() err = mp.Init(evalCtx) require.NoError(t, err) - assert.Equal(t, of.ReadyState, mp.status) + assert.Equal(t, of.ReadyState, mp.Status()) + cancel() + event := <-eventChan + assert.NotZero(t, event) + assert.Equal(t, mp.Metadata().Name, event.ProviderName) + assert.Equal(t, of.ProviderReady, event.EventType) + assert.Equal(t, of.ProviderEventDetails{ + Message: "all internal providers initialized successfully", + FlagChanges: nil, + EventMetadata: map[string]interface{}{ + MetadataProviderName: "all", + }, + }, event.ProviderEventDetails) + t.Cleanup(func() { + mp.Shutdown() + }) } func TestMultiProvider_InitErrorWithProvider(t *testing.T) { ctrl := gomock.NewController(t) errProvider := mocks.NewMockFeatureProvider(ctrl) errProvider.EXPECT().Metadata().Return(of.Metadata{Name: "MockProvider"}) + errProvider.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) errHandler := mocks.NewMockStateHandler(ctrl) errHandler.EXPECT().Init(gomock.Any()).Return(errors.New("test error")) testProvider3 := struct { @@ -168,6 +197,7 @@ func TestMultiProvider_InitErrorWithProvider(t *testing.T) { } testProvider1 := mocks.NewMockFeatureProvider(ctrl) + testProvider1.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) testProvider1.EXPECT().Metadata().Return(of.Metadata{Name: "MockProvider"}) testProvider2 := imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) @@ -183,20 +213,44 @@ func TestMultiProvider_InitErrorWithProvider(t *testing.T) { "foo": "bar", } evalCtx := openfeature.NewTargetlessEvaluationContext(attributes) - + eventChan := make(chan of.Event) + ctx, cancel := context.WithCancel(context.Background()) + go func() { + select { + case e := <-mp.EventChannel(): + eventChan <- e + case <-ctx.Done(): + return + } + }() err = mp.Init(evalCtx) - require.Errorf(t, err, "Provider provider1: test error") - assert.Equal(t, of.ErrorState, mp.status) + require.Errorf(t, err, "Provider provider3: test error") + assert.Equal(t, of.ErrorState, mp.totalStatus) + cancel() + event := <-eventChan + assert.NotZero(t, event) + assert.Equal(t, mp.Metadata().Name, event.ProviderName) + assert.Equal(t, of.ProviderError, event.EventType) + assert.Equal(t, of.ProviderEventDetails{ + Message: "internal provider provider3 encountered an error during initialization: test error", + FlagChanges: nil, + EventMetadata: map[string]interface{}{ + MetadataProviderName: "provider3", + MetadataInternalError: "Provider provider3: test error", + }, + }, event.ProviderEventDetails) } -func TestMultiProvider_Shutdown(t *testing.T) { +func TestMultiProvider_Shutdown_WithoutInit(t *testing.T) { ctrl := gomock.NewController(t) testProvider1 := mocks.NewMockFeatureProvider(ctrl) testProvider1.EXPECT().Metadata().Return(of.Metadata{Name: "MockProvider"}) + testProvider1.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) testProvider2 := imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) testProvider3 := mocks.NewMockFeatureProvider(ctrl) testProvider3.EXPECT().Metadata().Return(of.Metadata{Name: "MockProvider"}) + testProvider3.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) providers := make(ProviderMap) providers["provider1"] = testProvider1 @@ -207,3 +261,51 @@ func TestMultiProvider_Shutdown(t *testing.T) { mp.Shutdown() } + +func TestMultiProvider_Shutdown_WithInit(t *testing.T) { + ctrl := gomock.NewController(t) + + testProvider1 := mocks.NewMockFeatureProvider(ctrl) + testProvider1.EXPECT().Metadata().Return(of.Metadata{Name: "MockProvider"}) + testProvider1.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) + testProvider2 := imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) + handlingProvider := mocks.NewMockFeatureProvider(ctrl) + handlingProvider.EXPECT().Metadata().Return(of.Metadata{Name: "MockProvider"}) + handlingProvider.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) + handledHandler := mocks.NewMockStateHandler(ctrl) + handledHandler.EXPECT().Init(gomock.Any()).Return(nil) + handledHandler.EXPECT().Shutdown() + testProvider3 := struct { + of.FeatureProvider + of.StateHandler + }{ + handlingProvider, + handledHandler, + } + + providers := make(ProviderMap) + providers["provider1"] = testProvider1 + providers["provider2"] = testProvider2 + providers["provider3"] = testProvider3 + mp, err := NewMultiProvider(providers, strategies.StrategyFirstMatch) + require.NoError(t, err) + evalCtx := openfeature.NewTargetlessEvaluationContext(map[string]interface{}{ + "foo": "bar", + }) + eventChan := make(chan of.Event) + ctx, cancel := context.WithCancel(context.Background()) + go func() { + select { + case e := <-mp.EventChannel(): + eventChan <- e + case <-ctx.Done(): + return + } + }() + err = mp.Init(evalCtx) + require.NoError(t, err) + assert.Equal(t, of.ReadyState, mp.Status()) + cancel() + mp.Shutdown() + assert.Equal(t, of.NotReadyState, mp.Status()) +} diff --git a/providers/multi-provider/pkg/strategies/strategies.go b/providers/multi-provider/pkg/strategies/strategies.go index 4670acfc6..f03aeaf0c 100644 --- a/providers/multi-provider/pkg/strategies/strategies.go +++ b/providers/multi-provider/pkg/strategies/strategies.go @@ -1,6 +1,6 @@ // Package strategies Resolution strategies are defined within this package // -//go:generate go run go.uber.org/mock/mockgen -source=strategies.go -destination=../../pkg/strategies/strategy_mock.go -package=strategies +//go:generate go run go.uber.org/mock/mockgen -destination=../../pkg/strategies/strategy_mock.go -package=strategies -write_source_comment=false . Strategy package strategies import ( diff --git a/providers/multi-provider/pkg/strategies/strategy_mock.go b/providers/multi-provider/pkg/strategies/strategy_mock.go index 2bf1ad1a5..2f32b5ed2 100644 --- a/providers/multi-provider/pkg/strategies/strategy_mock.go +++ b/providers/multi-provider/pkg/strategies/strategy_mock.go @@ -1,9 +1,8 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: strategies.go // // Generated by this command: // -// mockgen -source=strategies.go -destination=../../pkg/strategies/strategy_mock.go -package=strategies +// mockgen -destination=../../pkg/strategies/strategy_mock.go -package=strategies -write_source_comment=false . Strategy // // Package strategies is a generated GoMock package.