From 61e173f252337b531b4515a2e2d9c5b08e318823 Mon Sep 17 00:00:00 2001 From: Dan Pasecinic Date: Sat, 21 Feb 2026 21:48:39 +0100 Subject: [PATCH 1/4] chore: upgrade go version to 1.26.0 (#16) --- .github/workflows/ci.yml | 10 +++++----- README.md | 21 ++++++--------------- autowire.go | 4 ++-- benchmark/go.mod | 2 +- errors.go | 7 +++---- go.mod | 2 +- internal/reflect/types.go | 16 ++++++++-------- internal/reflect/types_test.go | 4 ++-- module.go | 2 +- needle_test.go | 6 ++---- 10 files changed, 31 insertions(+), 43 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1b06a8c..092d27d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,16 +2,16 @@ name: CI on: push: - branches: [ main, development ] + branches: [main, development] pull_request: - branches: [ main, development ] + branches: [main, development] jobs: build: runs-on: ubuntu-latest strategy: matrix: - go-version: [ '1.25' ] + go-version: ["1.26"] steps: - uses: actions/checkout@v4 @@ -32,7 +32,7 @@ jobs: - name: Upload coverage uses: codecov/codecov-action@v4 - if: matrix.go-version == '1.25' + if: matrix.go-version == '1.26' with: files: coverage.out fail_ci_if_error: false @@ -45,7 +45,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: '1.25' + go-version: "1.26" - name: Install golangci-lint run: go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest diff --git a/README.md b/README.md index 941cefe..e1e2615 100644 --- a/README.md +++ b/README.md @@ -1,26 +1,17 @@ # Needle -A modern, type-safe dependency injection framework for Go 1.25+. +A modern, type-safe dependency injection framework for Go. [![Go Reference](https://pkg.go.dev/badge/github.com/danpasecinic/needle.svg)](https://pkg.go.dev/github.com/danpasecinic/needle) [![Go Report Card](https://goreportcard.com/badge/github.com/danpasecinic/needle)](https://goreportcard.com/report/github.com/danpasecinic/needle) ## Features -- **Type-safe generics** - Compile-time type checking with `Provide[T]` and `Invoke[T]` -- **Auto-wiring** - Constructor injection and struct tag injection -- **Hot reload** - Replace services at runtime without restart -- **Zero dependencies** - Only Go standard library -- **Cycle detection** - Automatically detects circular dependencies -- **Multiple scopes** - Singleton, Transient, Request, Pooled -- **Lifecycle management** - OnStart/OnStop hooks with ordering -- **Lazy providers** - Defer instantiation until first use -- **Parallel startup** - Start independent services concurrently -- **Modules** - Group related providers -- **Interface binding** - Bind interfaces to implementations -- **Decorators** - Wrap services with cross-cutting concerns -- **Health checks** - Liveness and readiness probes -- **Optional dependencies** - Type-safe optional resolution +Needle uses Go generics for compile-time type safety (`Provide[T]`, `Invoke[T]`) and has zero external dependencies. + +It supports constructor auto-wiring, struct tag injection, multiple scopes (singleton, transient, request, pooled), and lifecycle hooks that run in dependency order. Services can start in parallel, be lazily initialized, or be replaced at runtime without restarting the container. + +You can group providers into modules, bind interfaces to implementations, wrap services with decorators, and resolve optional dependencies with a built-in `Optional[T]` type. Health and readiness checks are supported out of the box. ## Installation diff --git a/autowire.go b/autowire.go index bc8bacb..bd0d043 100644 --- a/autowire.go +++ b/autowire.go @@ -18,7 +18,7 @@ func InvokeStructCtx[T any](ctx context.Context, c *Container) (T, error) { var zero T t := reflectPkg.TypeOf(zero) - isPtr := t.Kind() == reflectPkg.Ptr + isPtr := t.Kind() == reflectPkg.Pointer if isPtr { t = t.Elem() } @@ -100,7 +100,7 @@ func ProvideFunc[T any](c *Container, constructor any, opts ...ProviderOption) e fnVal := reflectPkg.ValueOf(constructor) fnType := fnVal.Type() - hasError := fnType.NumOut() == 2 && fnType.Out(1).Implements(reflectPkg.TypeOf((*error)(nil)).Elem()) + hasError := fnType.NumOut() == 2 && fnType.Out(1).Implements(reflectPkg.TypeFor[error]()) deps := make([]string, len(params)) for i, p := range params { diff --git a/benchmark/go.mod b/benchmark/go.mod index e5d891c..fd1c87b 100644 --- a/benchmark/go.mod +++ b/benchmark/go.mod @@ -1,6 +1,6 @@ module benchmark -go 1.25.4 +go 1.26.0 replace github.com/danpasecinic/needle => ../ diff --git a/errors.go b/errors.go index 2c53c8d..ef93c6b 100644 --- a/errors.go +++ b/errors.go @@ -65,10 +65,10 @@ type Error struct { func (e *Error) Error() string { var b strings.Builder - b.WriteString(fmt.Sprintf("[%s]", e.Code)) + fmt.Fprintf(&b, "[%s]", e.Code) if e.Service != "" { - b.WriteString(fmt.Sprintf(" service=%q:", e.Service)) + fmt.Fprintf(&b, " service=%q:", e.Service) } b.WriteString(" ") @@ -87,8 +87,7 @@ func (e *Error) Unwrap() error { } func (e *Error) Is(target error) bool { - var t *Error - if errors.As(target, &t) { + if t, ok := errors.AsType[*Error](target); ok { return e.Code == t.Code } return false diff --git a/go.mod b/go.mod index 09a173a..79917c0 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/danpasecinic/needle -go 1.25.4 +go 1.26.0 diff --git a/internal/reflect/types.go b/internal/reflect/types.go index 8411de3..a16ac03 100644 --- a/internal/reflect/types.go +++ b/internal/reflect/types.go @@ -12,7 +12,7 @@ func TypeKey[T any]() string { var zero T t := reflect.TypeOf(zero) if t == nil { - t = reflect.TypeOf((*T)(nil)).Elem() + t = reflect.TypeFor[T]() } return typeKeyFromReflect(t) } @@ -76,7 +76,7 @@ func TypeKeyNamed[T any](name string) string { var zero T t := reflect.TypeOf(zero) if t == nil { - t = reflect.TypeOf((*T)(nil)).Elem() + t = reflect.TypeFor[T]() } key := namedKey{t: t, name: name} @@ -100,7 +100,7 @@ func IsNil(v any) bool { rv := reflect.ValueOf(v) switch rv.Kind() { - case reflect.Ptr, reflect.Interface, reflect.Map, reflect.Slice, reflect.Chan, reflect.Func: + case reflect.Pointer, reflect.Interface, reflect.Map, reflect.Slice, reflect.Chan, reflect.Func: return rv.IsNil() default: return false @@ -111,13 +111,13 @@ func TypeName[T any]() string { var zero T t := reflect.TypeOf(zero) if t == nil { - t = reflect.TypeOf((*T)(nil)).Elem() + t = reflect.TypeFor[T]() } return t.String() } func IsInterface[T any]() bool { - t := reflect.TypeOf((*T)(nil)).Elem() + t := reflect.TypeFor[T]() return t.Kind() == reflect.Interface } @@ -125,7 +125,7 @@ func Implements[T any](v any) bool { if v == nil { return false } - t := reflect.TypeOf((*T)(nil)).Elem() + t := reflect.TypeFor[T]() return reflect.TypeOf(v).Implements(t) } @@ -138,8 +138,8 @@ type FieldInfo struct { } func StructFields[T any](tagKey string) ([]FieldInfo, error) { - t := reflect.TypeOf((*T)(nil)).Elem() - if t.Kind() == reflect.Ptr { + t := reflect.TypeFor[T]() + if t.Kind() == reflect.Pointer { t = t.Elem() } if t.Kind() != reflect.Struct { diff --git a/internal/reflect/types_test.go b/internal/reflect/types_test.go index cf186cf..1e3afac 100644 --- a/internal/reflect/types_test.go +++ b/internal/reflect/types_test.go @@ -212,14 +212,14 @@ func TestImplements(t *testing.T) { func BenchmarkTypeKey(b *testing.B) { b.ReportAllocs() - for i := 0; i < b.N; i++ { + for b.Loop() { _ = TypeKey[*testStruct]() } } func BenchmarkTypeKeyNamed(b *testing.B) { b.ReportAllocs() - for i := 0; i < b.N; i++ { + for b.Loop() { _ = TypeKeyNamed[*testStruct]("primary") } } diff --git a/module.go b/module.go index a1fdfc5..6e44976 100644 --- a/module.go +++ b/module.go @@ -178,7 +178,7 @@ func errModuleApplyFailed(moduleName string, cause error) *Error { ) } -func errModuleInvalidProvider(provider any) *Error { +func errModuleInvalidProvider(_ any) *Error { return newError( ErrCodeModuleInvalidProvider, "invalid provider type in module", diff --git a/needle_test.go b/needle_test.go index 3aa9a13..7429bd2 100644 --- a/needle_test.go +++ b/needle_test.go @@ -369,9 +369,8 @@ func BenchmarkProvideAndInvoke(b *testing.B) { c := needle.New() _ = needle.ProvideValue(c, &Config{Port: 8080}) - b.ResetTimer() b.ReportAllocs() - for i := 0; i < b.N; i++ { + for b.Loop() { _, _ = needle.Invoke[*Config](c) } } @@ -380,9 +379,8 @@ func BenchmarkMustInvoke(b *testing.B) { c := needle.New() _ = needle.ProvideValue(c, &Config{Port: 8080}) - b.ResetTimer() b.ReportAllocs() - for i := 0; i < b.N; i++ { + for b.Loop() { _ = needle.MustInvoke[*Config](c) } } From 8b752b731881f704e8d2e942864503143707a556 Mon Sep 17 00:00:00 2001 From: Dan Pasecinic Date: Sat, 21 Feb 2026 22:11:40 +0100 Subject: [PATCH 2/4] fix: enhance & address minor issues (#17) --- autowire.go | 42 +++++--- container.go | 12 +-- errors.go | 7 +- errors_test.go | 53 +++++++++++ internal/container/container.go | 6 +- internal/container/container_test.go | 62 ++++++++++++ internal/container/lifecycle.go | 19 ++-- internal/container/lifecycle_test.go | 137 +++++++++++++++++++++++++++ internal/container/registry.go | 30 ++++++ internal/container/replace.go | 30 +++--- internal/container/resolve.go | 75 +++++++++------ internal/graph/cycle.go | 31 ++++-- internal/graph/graph_test.go | 56 +++++++++++ internal/reflect/types.go | 5 +- internal/reflect/types_test.go | 19 ++++ observability.go | 65 ++++--------- replace.go | 63 +----------- 17 files changed, 513 insertions(+), 199 deletions(-) create mode 100644 errors_test.go create mode 100644 internal/container/lifecycle_test.go diff --git a/autowire.go b/autowire.go index bd0d043..5a43039 100644 --- a/autowire.go +++ b/autowire.go @@ -82,19 +82,19 @@ func InvokeStructCtx[T any](ctx context.Context, c *Container) (T, error) { return structVal.Interface().(T), nil } -func ProvideFunc[T any](c *Container, constructor any, opts ...ProviderOption) error { +func buildFuncProvider[T any](c *Container, constructor any) (Provider[T], []ProviderOption, error) { params, returnType, err := reflect.FuncParams(constructor) if err != nil { - return err + return nil, nil, err } if returnType == nil { - return fmt.Errorf("constructor must return at least one value") + return nil, nil, fmt.Errorf("constructor must return at least one value") } expectedType := reflectPkg.TypeOf((*T)(nil)).Elem() if !returnType.AssignableTo(expectedType) { - return fmt.Errorf("constructor returns %s, expected %s", returnType, expectedType) + return nil, nil, fmt.Errorf("constructor returns %s, expected %s", returnType, expectedType) } fnVal := reflectPkg.ValueOf(constructor) @@ -128,17 +128,10 @@ func ProvideFunc[T any](c *Container, constructor any, opts ...ProviderOption) e return results[0].Interface().(T), nil } - opts = append([]ProviderOption{WithDependencies(deps...)}, opts...) - return Provide(c, provider, opts...) -} - -func MustProvideFunc[T any](c *Container, constructor any, opts ...ProviderOption) { - if err := ProvideFunc[T](c, constructor, opts...); err != nil { - panic(err) - } + return provider, []ProviderOption{WithDependencies(deps...)}, nil } -func ProvideStruct[T any](c *Container, opts ...ProviderOption) error { +func buildStructProvider[T any](c *Container) (Provider[T], []ProviderOption) { provider := func(ctx context.Context, r Resolver) (T, error) { return InvokeStructCtx[T](ctx, c) } @@ -155,7 +148,28 @@ func ProvideStruct[T any](c *Container, opts ...ProviderOption) error { } } - opts = append([]ProviderOption{WithDependencies(deps...)}, opts...) + return provider, []ProviderOption{WithDependencies(deps...)} +} + +func ProvideFunc[T any](c *Container, constructor any, opts ...ProviderOption) error { + provider, depOpts, err := buildFuncProvider[T](c, constructor) + if err != nil { + return err + } + + opts = append(depOpts, opts...) + return Provide(c, provider, opts...) +} + +func MustProvideFunc[T any](c *Container, constructor any, opts ...ProviderOption) { + if err := ProvideFunc[T](c, constructor, opts...); err != nil { + panic(err) + } +} + +func ProvideStruct[T any](c *Container, opts ...ProviderOption) error { + provider, depOpts := buildStructProvider[T](c) + opts = append(depOpts, opts...) return Provide(c, provider, opts...) } diff --git a/container.go b/container.go index 504db81..8902d98 100644 --- a/container.go +++ b/container.go @@ -42,20 +42,16 @@ func newContainer(opts ...Option) *Container { } for _, h := range cfg.onResolve { - hook := h - internalCfg.OnResolve = append(internalCfg.OnResolve, container.ResolveHook(hook)) + internalCfg.OnResolve = append(internalCfg.OnResolve, container.ResolveHook(h)) } for _, h := range cfg.onProvide { - hook := h - internalCfg.OnProvide = append(internalCfg.OnProvide, container.ProvideHook(hook)) + internalCfg.OnProvide = append(internalCfg.OnProvide, container.ProvideHook(h)) } for _, h := range cfg.onStart { - hook := h - internalCfg.OnStart = append(internalCfg.OnStart, container.StartHook(hook)) + internalCfg.OnStart = append(internalCfg.OnStart, container.StartHook(h)) } for _, h := range cfg.onStop { - hook := h - internalCfg.OnStop = append(internalCfg.OnStop, container.StopHook(hook)) + internalCfg.OnStop = append(internalCfg.OnStop, container.StopHook(h)) } c := &Container{ diff --git a/errors.go b/errors.go index ef93c6b..903cc61 100644 --- a/errors.go +++ b/errors.go @@ -87,10 +87,11 @@ func (e *Error) Unwrap() error { } func (e *Error) Is(target error) bool { - if t, ok := errors.AsType[*Error](target); ok { - return e.Code == t.Code + t, ok := target.(*Error) + if !ok { + return false } - return false + return e.Code == t.Code } func (e *Error) WithService(service string) *Error { diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 0000000..c19cce1 --- /dev/null +++ b/errors_test.go @@ -0,0 +1,53 @@ +package needle + +import ( + "errors" + "fmt" + "testing" +) + +func TestError_Is_SameCode(t *testing.T) { + t.Parallel() + + err1 := newError(ErrCodeServiceNotFound, "service A not found", nil) + err2 := newError(ErrCodeServiceNotFound, "service B not found", nil) + + if !errors.Is(err1, err2) { + t.Error("errors with same code should match via Is") + } +} + +func TestError_Is_DifferentCode(t *testing.T) { + t.Parallel() + + err1 := newError(ErrCodeServiceNotFound, "not found", nil) + err2 := newError(ErrCodeCircularDependency, "cycle", nil) + + if errors.Is(err1, err2) { + t.Error("errors with different codes should not match via Is") + } +} + +func TestError_Is_DoesNotTraverseTargetChain(t *testing.T) { + t.Parallel() + + inner := newError(ErrCodeServiceNotFound, "inner", nil) + wrapper := fmt.Errorf("wrapped: %w", inner) + check := newError(ErrCodeServiceNotFound, "check", nil) + + if errors.Is(check, wrapper) { + t.Error("Is should not traverse target's chain, only direct type assertion on target") + } +} + +func TestError_Is_WrappedSource(t *testing.T) { + t.Parallel() + + inner := newError(ErrCodeServiceNotFound, "inner", nil) + wrapper := fmt.Errorf("wrapped: %w", inner) + target := newError(ErrCodeServiceNotFound, "target", nil) + + if !errors.Is(wrapper, target) { + t.Error("errors.Is should find inner *Error via Unwrap chain of source") + } +} diff --git a/internal/container/container.go b/internal/container/container.go index 1fd8726..665f56b 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -30,9 +30,6 @@ type Container struct { logger *slog.Logger state State - resolving map[string]bool - resolvingMu sync.Mutex - decorators map[string][]DecoratorFunc decoratorsMu sync.RWMutex @@ -68,7 +65,6 @@ func New(cfg *Config) *Container { registry: NewRegistry(), graph: graph.New(), logger: logger, - resolving: make(map[string]bool), decorators: make(map[string][]DecoratorFunc), onResolve: cfg.OnResolve, onProvide: cfg.OnProvide, @@ -89,7 +85,7 @@ func (c *Container) Register(key string, provider ProviderFunc, dependencies []s c.registry.RegisterUnsafe(key, provider, dependencies) c.graph.AddNodeUnsafe(key, dependencies) - if len(dependencies) > 0 && c.graph.HasCycle() { + if len(dependencies) > 0 && c.graph.HasCycleUnsafe() { c.registry.RemoveUnsafe(key) c.graph.RemoveNodeUnsafe(key) c.mu.Unlock() diff --git a/internal/container/container_test.go b/internal/container/container_test.go index d1b8333..06d2de5 100644 --- a/internal/container/container_test.go +++ b/internal/container/container_test.go @@ -3,6 +3,8 @@ package container import ( "context" "errors" + "sync" + "sync/atomic" "testing" ) @@ -297,6 +299,66 @@ func TestContainer_ContextCancellation(t *testing.T) { } } +func TestContainer_ConcurrentResolve_NoFalseCycle(t *testing.T) { + t.Parallel() + + c := New(&Config{}) + + _ = c.RegisterValue("dep", "dependency") + _ = c.Register("svc", func(ctx context.Context, r Resolver) (any, error) { + _, _ = r.Resolve(ctx, "dep") + return "service", nil + }, []string{"dep"}) + + var wg sync.WaitGroup + errs := make(chan error, 50) + + for range 50 { + wg.Add(1) + go func() { + defer wg.Done() + _, err := c.Resolve(context.Background(), "svc") + if err != nil { + errs <- err + } + }() + } + + wg.Wait() + close(errs) + + for err := range errs { + t.Errorf("unexpected error during concurrent resolve: %v", err) + } +} + +func TestContainer_SingletonCalledOnce(t *testing.T) { + t.Parallel() + + c := New(&Config{}) + + var callCount atomic.Int64 + _ = c.Register("singleton", func(ctx context.Context, r Resolver) (any, error) { + callCount.Add(1) + return "instance", nil + }, nil) + + var wg sync.WaitGroup + for range 50 { + wg.Add(1) + go func() { + defer wg.Done() + _, _ = c.Resolve(context.Background(), "singleton") + }() + } + + wg.Wait() + + if count := callCount.Load(); count != 1 { + t.Errorf("singleton provider called %d times, expected 1", count) + } +} + func BenchmarkContainer_Resolve(b *testing.B) { c := New(&Config{}) diff --git a/internal/container/lifecycle.go b/internal/container/lifecycle.go index 1dfb74c..4b8714b 100644 --- a/internal/container/lifecycle.go +++ b/internal/container/lifecycle.go @@ -2,6 +2,7 @@ package container import ( "context" + "errors" "fmt" "sync" "time" @@ -110,13 +111,9 @@ func (c *Container) startService(ctx context.Context, key string) error { return fmt.Errorf("failed to resolve %s during startup: %w", key, err) } - entry, exists := c.registry.GetEntry(key) - if !exists { - return nil - } - var startErr error - for _, hook := range entry.OnStart { + hooks := c.registry.GetOnStartHooks(key) + for _, hook := range hooks { c.logger.Debug("running OnStart hook", "service", key) if err := hook(ctx); err != nil { startErr = fmt.Errorf("OnStart hook failed for %s: %w", key, err) @@ -240,15 +237,17 @@ func (c *Container) stopService(ctx context.Context, key string) error { } start := time.Now() - var stopErr error + var errs []error - for i := len(entry.OnStop) - 1; i >= 0; i-- { + hooks := c.registry.GetOnStopHooks(key) + for i := len(hooks) - 1; i >= 0; i-- { c.logger.Debug("running OnStop hook", "service", key) - if err := entry.OnStop[i](ctx); err != nil { - stopErr = fmt.Errorf("OnStop hook failed for %s: %w", key, err) + if err := hooks[i](ctx); err != nil { + errs = append(errs, fmt.Errorf("OnStop hook failed for %s: %w", key, err)) } } + stopErr := errors.Join(errs...) c.callStopHooks(key, time.Since(start), stopErr) return stopErr } diff --git a/internal/container/lifecycle_test.go b/internal/container/lifecycle_test.go new file mode 100644 index 0000000..4322980 --- /dev/null +++ b/internal/container/lifecycle_test.go @@ -0,0 +1,137 @@ +package container + +import ( + "context" + "errors" + "fmt" + "strings" + "testing" +) + +func TestStopService_CollectsAllErrors(t *testing.T) { + t.Parallel() + + c := New(&Config{}) + + err1 := errors.New("hook1 failed") + err2 := errors.New("hook2 failed") + + _ = c.Register("svc", func(ctx context.Context, r Resolver) (any, error) { + return "instance", nil + }, nil) + + c.registry.AddOnStop("svc", func(ctx context.Context) error { + return err1 + }) + c.registry.AddOnStop("svc", func(ctx context.Context) error { + return err2 + }) + + ctx := context.Background() + _, _ = c.Resolve(ctx, "svc") + + stopErr := c.stopService(ctx, "svc") + if stopErr == nil { + t.Fatal("expected error from stopService") + } + + msg := stopErr.Error() + if !strings.Contains(msg, "hook1 failed") { + t.Errorf("expected error to contain 'hook1 failed', got: %s", msg) + } + if !strings.Contains(msg, "hook2 failed") { + t.Errorf("expected error to contain 'hook2 failed', got: %s", msg) + } +} + +func TestStopService_NoErrorWhenHooksSucceed(t *testing.T) { + t.Parallel() + + c := New(&Config{}) + + _ = c.Register("svc", func(ctx context.Context, r Resolver) (any, error) { + return "instance", nil + }, nil) + + c.registry.AddOnStop("svc", func(ctx context.Context) error { + return nil + }) + + ctx := context.Background() + _, _ = c.Resolve(ctx, "svc") + + stopErr := c.stopService(ctx, "svc") + if stopErr != nil { + t.Errorf("expected no error, got: %v", stopErr) + } +} + +func TestStartAndStop_Integration(t *testing.T) { + t.Parallel() + + c := New(&Config{}) + + var order []string + + _ = c.Register("svc", func(ctx context.Context, r Resolver) (any, error) { + return "instance", nil + }, nil) + + c.registry.AddOnStart("svc", func(ctx context.Context) error { + order = append(order, "started") + return nil + }) + c.registry.AddOnStop("svc", func(ctx context.Context) error { + order = append(order, "stopped") + return nil + }) + + ctx := context.Background() + if err := c.Start(ctx); err != nil { + t.Fatalf("Start failed: %v", err) + } + + if len(order) != 1 || order[0] != "started" { + t.Errorf("expected [started], got %v", order) + } + + if err := c.Stop(ctx); err != nil { + t.Fatalf("Stop failed: %v", err) + } + + if len(order) != 2 || order[1] != "stopped" { + t.Errorf("expected [started, stopped], got %v", order) + } +} + +func TestStopService_MultipleFailingHooks_BothPresent(t *testing.T) { + t.Parallel() + + c := New(&Config{}) + + _ = c.Register("svc", func(ctx context.Context, r Resolver) (any, error) { + return "instance", nil + }, nil) + + c.registry.AddOnStop("svc", func(ctx context.Context) error { + return fmt.Errorf("first error") + }) + c.registry.AddOnStop("svc", func(ctx context.Context) error { + return fmt.Errorf("second error") + }) + + ctx := context.Background() + _, _ = c.Resolve(ctx, "svc") + + stopErr := c.stopService(ctx, "svc") + if stopErr == nil { + t.Fatal("expected combined error") + } + + if !strings.Contains(stopErr.Error(), "first error") { + t.Error("missing first error") + } + if !strings.Contains(stopErr.Error(), "second error") { + t.Error("missing second error") + } +} diff --git a/internal/container/registry.go b/internal/container/registry.go index 50f05c8..5ad86e5 100644 --- a/internal/container/registry.go +++ b/internal/container/registry.go @@ -29,6 +29,8 @@ type ServiceEntry struct { pool chan any Lazy bool StartRan bool + once sync.Once + initErr error } type Registry struct { @@ -310,6 +312,34 @@ func (r *Registry) IsLazy(key string) bool { return false } +func (r *Registry) GetOnStartHooks(key string) []Hook { + r.mu.RLock() + defer r.mu.RUnlock() + + entry, exists := r.services[key] + if !exists { + return nil + } + + hooks := make([]Hook, len(entry.OnStart)) + copy(hooks, entry.OnStart) + return hooks +} + +func (r *Registry) GetOnStopHooks(key string) []Hook { + r.mu.RLock() + defer r.mu.RUnlock() + + entry, exists := r.services[key] + if !exists { + return nil + } + + hooks := make([]Hook, len(entry.OnStop)) + copy(hooks, entry.OnStop) + return hooks +} + func (r *Registry) SetStartRan(key string) { r.mu.Lock() defer r.mu.Unlock() diff --git a/internal/container/replace.go b/internal/container/replace.go index b32de28..7c0573f 100644 --- a/internal/container/replace.go +++ b/internal/container/replace.go @@ -6,19 +6,16 @@ func (c *Container) Replace(key string, provider ProviderFunc, dependencies []st c.mu.Lock() defer c.mu.Unlock() - c.registry.Remove(key) - c.graph.RemoveNode(key) + c.registry.RemoveUnsafe(key) + c.graph.RemoveNodeUnsafe(key) - if err := c.registry.Register(key, provider, dependencies); err != nil { - return err - } - - c.graph.AddNode(key, dependencies) + c.registry.RegisterUnsafe(key, provider, dependencies) + c.graph.AddNodeUnsafe(key, dependencies) - if c.graph.HasCycle() { - c.registry.Remove(key) - c.graph.RemoveNode(key) - cyclePath := c.graph.FindCyclePath(key) + if len(dependencies) > 0 && c.graph.HasCycleUnsafe() { + c.registry.RemoveUnsafe(key) + c.graph.RemoveNodeUnsafe(key) + cyclePath := c.graph.FindCyclePathUnsafe(key) return fmt.Errorf("circular dependency detected: %v", cyclePath) } @@ -29,13 +26,10 @@ func (c *Container) ReplaceValue(key string, value any) error { c.mu.Lock() defer c.mu.Unlock() - c.registry.Remove(key) - c.graph.RemoveNode(key) - - if err := c.registry.RegisterValue(key, value); err != nil { - return err - } + c.registry.RemoveUnsafe(key) + c.graph.RemoveNodeUnsafe(key) - c.graph.AddNode(key, nil) + c.registry.RegisterValueUnsafe(key, value) + c.graph.AddNodeUnsafe(key, nil) return nil } diff --git a/internal/container/resolve.go b/internal/container/resolve.go index 17d0fda..a790238 100644 --- a/internal/container/resolve.go +++ b/internal/container/resolve.go @@ -9,6 +9,24 @@ import ( "github.com/danpasecinic/needle/internal/scope" ) +type resolvingKey struct{} + +func withResolving(ctx context.Context, key string) (context.Context, bool) { + set, _ := ctx.Value(resolvingKey{}).(map[string]bool) + if set != nil && set[key] { + return ctx, false + } + if set == nil { + set = make(map[string]bool) + } + next := make(map[string]bool, len(set)+1) + for k := range set { + next[k] = true + } + next[key] = true + return context.WithValue(ctx, resolvingKey{}, next), true +} + func (c *Container) Resolve(ctx context.Context, key string) (any, error) { if len(c.onResolve) == 0 { if instance, ok := c.registry.GetInstanceFast(key); ok { @@ -22,21 +40,12 @@ func (c *Container) Resolve(ctx context.Context, key string) (any, error) { func (c *Container) resolveSlow(ctx context.Context, key string) (any, error) { start := time.Now() - c.resolvingMu.Lock() - if c.resolving[key] { - c.resolvingMu.Unlock() + ctx, ok := withResolving(ctx, key) + if !ok { err := fmt.Errorf("circular resolution detected for: %s", key) c.callResolveHooks(key, time.Since(start), err) return nil, err } - c.resolving[key] = true - c.resolvingMu.Unlock() - - defer func() { - c.resolvingMu.Lock() - delete(c.resolving, key) - c.resolvingMu.Unlock() - }() c.mu.RLock() entry, exists := c.registry.Get(key) @@ -75,27 +84,36 @@ func (c *Container) resolveWithScope(ctx context.Context, key string, entry *Ser } func (c *Container) resolveSingleton(ctx context.Context, key string, entry *ServiceEntry) (any, error) { - if entry.Instantiated { + if entry.Provider == nil { return entry.Instance, nil } - for _, dep := range entry.Dependencies { - if _, err := c.Resolve(ctx, dep); err != nil { - return nil, fmt.Errorf("failed to resolve dependency %s for %s: %w", dep, key, err) + entry.once.Do(func() { + for _, dep := range entry.Dependencies { + if _, err := c.Resolve(ctx, dep); err != nil { + entry.initErr = fmt.Errorf("failed to resolve dependency %s for %s: %w", dep, key, err) + return + } } - } - instance, err := entry.Provider(ctx, c) - if err != nil { - return nil, fmt.Errorf("provider failed for %s: %w", key, err) - } + inst, err := entry.Provider(ctx, c) + if err != nil { + entry.initErr = fmt.Errorf("provider failed for %s: %w", key, err) + return + } - instance, err = c.applyDecorators(ctx, key, instance) - if err != nil { - return nil, err - } + inst, err = c.applyDecorators(ctx, key, inst) + if err != nil { + entry.initErr = err + return + } - c.registry.SetInstance(key, instance) + c.registry.SetInstance(key, inst) + }) + + if entry.initErr != nil { + return nil, entry.initErr + } if entry.Lazy && !entry.StartRan && c.state == StateRunning { if err := c.runLazyStart(ctx, key, entry); err != nil { @@ -103,14 +121,15 @@ func (c *Container) resolveSingleton(ctx context.Context, key string, entry *Ser } } - return instance, nil + return entry.Instance, nil } -func (c *Container) runLazyStart(ctx context.Context, key string, entry *ServiceEntry) error { +func (c *Container) runLazyStart(ctx context.Context, key string, _ *ServiceEntry) error { start := time.Now() var startErr error - for _, hook := range entry.OnStart { + hooks := c.registry.GetOnStartHooks(key) + for _, hook := range hooks { c.logger.Debug("running lazy OnStart hook", "service", key) if err := hook(ctx); err != nil { startErr = fmt.Errorf("OnStart hook failed for %s: %w", key, err) diff --git a/internal/graph/cycle.go b/internal/graph/cycle.go index b662469..470cd24 100644 --- a/internal/graph/cycle.go +++ b/internal/graph/cycle.go @@ -14,6 +14,10 @@ func (g *Graph) DetectCycles() [][]string { g.mu.RLock() defer g.mu.RUnlock() + return g.detectCyclesUnsafe() +} + +func (g *Graph) detectCyclesUnsafe() [][]string { detector := &CycleDetector{ graph: g, index: 0, @@ -105,6 +109,15 @@ func (g *Graph) HasCycle() bool { return g.hasCycle } +func (g *Graph) HasCycleUnsafe() bool { + if g.cycleValid { + return g.hasCycle + } + g.hasCycle = g.hasCycleUnsafe() + g.cycleValid = true + return g.hasCycle +} + func (g *Graph) hasCycleUnsafe() bool { white := make(map[string]bool, len(g.nodes)) gray := make(map[string]bool, len(g.nodes)) @@ -113,7 +126,6 @@ func (g *Graph) hasCycleUnsafe() bool { white[id] = true } - var hasCycle bool var dfs func(id string) bool dfs = func(id string) bool { white[id] = false @@ -138,19 +150,26 @@ func (g *Graph) hasCycleUnsafe() bool { for id := range g.nodes { if white[id] { if dfs(id) { - hasCycle = true - break + return true } } } - return hasCycle + return false } func (g *Graph) FindCyclePath(start string) []string { g.mu.RLock() defer g.mu.RUnlock() + return g.findCyclePathUnsafe(start) +} + +func (g *Graph) FindCyclePathUnsafe(start string) []string { + return g.findCyclePathUnsafe(start) +} + +func (g *Graph) findCyclePathUnsafe(start string) []string { visited := make(map[string]bool) path := make([]string, 0) inPath := make(map[string]bool) @@ -201,7 +220,7 @@ func (g *Graph) GetAllCyclePaths() [][]string { g.mu.RLock() defer g.mu.RUnlock() - cycles := g.DetectCycles() + cycles := g.detectCyclesUnsafe() if len(cycles) == 0 { return nil } @@ -209,7 +228,7 @@ func (g *Graph) GetAllCyclePaths() [][]string { var allPaths [][]string for _, scc := range cycles { if len(scc) > 0 { - path := g.FindCyclePath(scc[0]) + path := g.findCyclePathUnsafe(scc[0]) if path != nil { allPaths = append(allPaths, path) } diff --git a/internal/graph/graph_test.go b/internal/graph/graph_test.go index 6b23ce7..71f03d9 100644 --- a/internal/graph/graph_test.go +++ b/internal/graph/graph_test.go @@ -337,6 +337,62 @@ func TestGraph_ParallelStartupGroups(t *testing.T) { } } +func TestGraph_GetAllCyclePaths(t *testing.T) { + t.Parallel() + + g := New() + g.AddNode("A", []string{"B"}) + g.AddNode("B", []string{"C"}) + g.AddNode("C", []string{"A"}) + + paths := g.GetAllCyclePaths() + if len(paths) == 0 { + t.Fatal("expected at least one cycle path") + } + + path := paths[0] + if path[0] != path[len(path)-1] { + t.Error("cycle path should start and end with same node") + } +} + +func TestGraph_GetAllCyclePaths_NoCycle(t *testing.T) { + t.Parallel() + + g := New() + g.AddNode("A", []string{"B"}) + g.AddNode("B", nil) + + paths := g.GetAllCyclePaths() + if paths != nil { + t.Errorf("expected nil, got %v", paths) + } +} + +func TestGraph_GetAllCyclePaths_ConcurrentNoDeadlock(t *testing.T) { + t.Parallel() + + g := New() + g.AddNode("A", []string{"B"}) + g.AddNode("B", []string{"C"}) + g.AddNode("C", []string{"A"}) + + done := make(chan struct{}) + go func() { + defer close(done) + for range 100 { + g.GetAllCyclePaths() + } + }() + + for range 100 { + g.AddNode("D", nil) + g.RemoveNode("D") + } + + <-done +} + func BenchmarkGraph_DetectCycles(b *testing.B) { g := New() for i := 0; i < 100; i++ { diff --git a/internal/reflect/types.go b/internal/reflect/types.go index a16ac03..4d76841 100644 --- a/internal/reflect/types.go +++ b/internal/reflect/types.go @@ -2,6 +2,7 @@ package reflect import ( "reflect" + "strconv" "sync" ) @@ -38,12 +39,12 @@ func buildTypeKey(t reflect.Type) string { } switch t.Kind() { - case reflect.Ptr: + case reflect.Pointer: return "*" + buildTypeKey(t.Elem()) case reflect.Slice: return "[]" + buildTypeKey(t.Elem()) case reflect.Array: - return "[" + string(rune(t.Len())) + "]" + buildTypeKey(t.Elem()) + return "[" + strconv.Itoa(t.Len()) + "]" + buildTypeKey(t.Elem()) case reflect.Map: return "map[" + buildTypeKey(t.Key()) + "]" + buildTypeKey(t.Elem()) case reflect.Chan: diff --git a/internal/reflect/types_test.go b/internal/reflect/types_test.go index 1e3afac..97dc56b 100644 --- a/internal/reflect/types_test.go +++ b/internal/reflect/types_test.go @@ -73,6 +73,25 @@ func TestTypeKey(t *testing.T) { } } +func TestTypeKeyArray(t *testing.T) { + t.Parallel() + + key := TypeKey[[3]int]() + if key != "[3]int" { + t.Errorf("expected [3]int, got %s", key) + } + + key10 := TypeKey[[10]string]() + if key10 != "[10]string" { + t.Errorf("expected [10]string, got %s", key10) + } + + key256 := TypeKey[[256]byte]() + if key256 != "[256]uint8" { + t.Errorf("expected [256]uint8, got %s", key256) + } +} + func TestTypeKeyUnique(t *testing.T) { t.Parallel() diff --git a/observability.go b/observability.go index 15cdfea..056002a 100644 --- a/observability.go +++ b/observability.go @@ -62,53 +62,24 @@ func (c *Container) Health(ctx context.Context) []HealthReport { } func (c *Container) checkHealth(ctx context.Context) []HealthReport { - keys := c.internal.Keys() - var reports []HealthReport - var mu sync.Mutex - var wg sync.WaitGroup - - for _, key := range keys { - instance, ok := c.internal.GetInstance(key) - if !ok { - continue + return c.runChecks(ctx, func(instance any) func(context.Context) error { + if hc, ok := instance.(HealthChecker); ok { + return hc.HealthCheck } + return nil + }) +} - checker, ok := instance.(HealthChecker) - if !ok { - continue +func (c *Container) checkReadiness(ctx context.Context) []HealthReport { + return c.runChecks(ctx, func(instance any) func(context.Context) error { + if rc, ok := instance.(ReadinessChecker); ok { + return rc.ReadinessCheck } - - wg.Add(1) - go func(k string, hc HealthChecker) { - defer wg.Done() - - start := time.Now() - err := hc.HealthCheck(ctx) - latency := time.Since(start) - - report := HealthReport{ - Name: k, - Latency: latency, - } - - if err != nil { - report.Status = HealthStatusDown - report.Error = err - } else { - report.Status = HealthStatusUp - } - - mu.Lock() - reports = append(reports, report) - mu.Unlock() - }(key, checker) - } - - wg.Wait() - return reports + return nil + }) } -func (c *Container) checkReadiness(ctx context.Context) []HealthReport { +func (c *Container) runChecks(ctx context.Context, extractCheck func(any) func(context.Context) error) []HealthReport { keys := c.internal.Keys() var reports []HealthReport var mu sync.Mutex @@ -120,17 +91,17 @@ func (c *Container) checkReadiness(ctx context.Context) []HealthReport { continue } - checker, ok := instance.(ReadinessChecker) - if !ok { + check := extractCheck(instance) + if check == nil { continue } wg.Add(1) - go func(k string, rc ReadinessChecker) { + go func(k string, fn func(context.Context) error) { defer wg.Done() start := time.Now() - err := rc.ReadinessCheck(ctx) + err := fn(ctx) latency := time.Since(start) report := HealthReport{ @@ -148,7 +119,7 @@ func (c *Container) checkReadiness(ctx context.Context) []HealthReport { mu.Lock() reports = append(reports, report) mu.Unlock() - }(key, checker) + }(key, check) } wg.Wait() diff --git a/replace.go b/replace.go index 38a0d10..0af4102 100644 --- a/replace.go +++ b/replace.go @@ -2,8 +2,6 @@ package needle import ( "context" - "fmt" - reflectPkg "reflect" "github.com/danpasecinic/needle/internal/container" "github.com/danpasecinic/needle/internal/reflect" @@ -20,8 +18,8 @@ func Replace[T any](c *Container, provider Provider[T], opts ...ProviderOption) key = reflect.TypeKeyNamed[T](cfg.name) } + resolver := c.resolver wrappedProvider := func(ctx context.Context, r container.Resolver) (any, error) { - resolver := &resolverAdapter{container: c} return provider(ctx, resolver) } @@ -97,48 +95,12 @@ func MustReplaceValue[T any](c *Container, value T, opts ...ProviderOption) { } func ReplaceFunc[T any](c *Container, constructor any, opts ...ProviderOption) error { - params, returnType, err := reflect.FuncParams(constructor) + provider, depOpts, err := buildFuncProvider[T](c, constructor) if err != nil { return err } - if returnType == nil { - return fmt.Errorf("constructor must return at least one value") - } - - fnVal := reflectPkg.ValueOf(constructor) - fnType := fnVal.Type() - - hasError := fnType.NumOut() == 2 && - fnType.Out(1).Implements(reflectPkg.TypeOf((*error)(nil)).Elem()) - - deps := make([]string, len(params)) - for i, p := range params { - deps[i] = p.TypeKey - } - - provider := func(ctx context.Context, r Resolver) (T, error) { - var zero T - - args := make([]reflectPkg.Value, len(params)) - for i, p := range params { - instance, err := c.internal.Resolve(ctx, p.TypeKey) - if err != nil { - return zero, fmt.Errorf("failed to resolve parameter %d (%s): %w", i, p.TypeKey, err) - } - args[i] = reflectPkg.ValueOf(instance) - } - - results := fnVal.Call(args) - - if hasError && len(results) == 2 && !results[1].IsNil() { - return zero, results[1].Interface().(error) - } - - return results[0].Interface().(T), nil - } - - opts = append([]ProviderOption{WithDependencies(deps...)}, opts...) + opts = append(depOpts, opts...) return Replace(c, provider, opts...) } @@ -149,23 +111,8 @@ func MustReplaceFunc[T any](c *Container, constructor any, opts ...ProviderOptio } func ReplaceStruct[T any](c *Container, opts ...ProviderOption) error { - provider := func(ctx context.Context, r Resolver) (T, error) { - return InvokeStructCtx[T](ctx, c) - } - - fields, _ := reflect.StructFields[T](TagKey) - deps := make([]string, 0, len(fields)) - for _, f := range fields { - if !f.Optional { - if f.Named != "" { - deps = append(deps, f.TypeKey+"#"+f.Named) - } else { - deps = append(deps, f.TypeKey) - } - } - } - - opts = append([]ProviderOption{WithDependencies(deps...)}, opts...) + provider, depOpts := buildStructProvider[T](c) + opts = append(depOpts, opts...) return Replace(c, provider, opts...) } From f365af44b18decfd800e8ec8d34270bcd3ef5654 Mon Sep 17 00:00:00 2001 From: Dan Pasecinic Date: Sat, 7 Mar 2026 19:33:44 +0100 Subject: [PATCH 3/4] refactor: adjust error handling, improve locking & logging (#19) --- README.md | 46 +++++++ concurrent_test.go | 218 ++++++++++++++++++++++++++++++++ doc.go | 11 +- errors.go | 30 +---- examples/optional/main.go | 25 ++-- internal/container/container.go | 40 ++++-- internal/container/registry.go | 1 + internal/container/replace.go | 10 +- needle_test.go | 85 +++++++++++-- resolver.go | 48 +++---- 10 files changed, 418 insertions(+), 96 deletions(-) create mode 100644 concurrent_test.go diff --git a/README.md b/README.md index e1e2615..749e33f 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,52 @@ See the [examples](examples/) directory: - [optional](examples/optional/) - Optional dependencies with fallbacks - [parallel](examples/parallel/) - Parallel startup/shutdown +## Choosing a Scope + +| Scope | Lifetime | Use When | +|-------|----------|----------| +| **Singleton** (default) | One instance for the container lifetime | Stateful services: DB pools, config, caches, loggers | +| **Transient** | New instance every resolution | Stateless handlers, commands, lightweight value objects | +| **Request** | One instance per `WithRequestScope(ctx)` | Per-HTTP-request state: request loggers, auth context, transaction managers | +| **Pooled** | Reusable instances from a fixed-size pool | Expensive-to-create, stateless-between-uses resources: gRPC connections, worker objects | + +```go +needle.Provide(c, NewService) // Singleton (default) +needle.Provide(c, NewHandler, needle.WithScope(needle.Transient)) +needle.Provide(c, NewRequestLogger, needle.WithScope(needle.Request)) +needle.Provide(c, NewWorker, needle.WithPoolSize(10)) // Pooled with 10 slots +``` + +Pooled services must be released by the caller via `c.Release(key, instance)`. If the pool is full, the instance is dropped and a warning is logged. + +## Replacing Services + +Replace services at runtime without restarting the container. Useful for feature flags, A/B testing, test doubles, or configuration updates. + +```go +// Replace with a new value +needle.ReplaceValue(c, &Config{Port: 9090}) + +// Replace with a new provider +needle.Replace(c, func(ctx context.Context, r needle.Resolver) (*Server, error) { + return &Server{Config: needle.MustInvoke[*Config](c)}, nil +}) + +// Replace with auto-wired constructor +needle.ReplaceFunc[*Service](c, NewService) + +// Replace with struct injection +needle.ReplaceStruct[*Service](c) + +// Named variants +needle.ReplaceNamedValue(c, "primary", &Config{Port: 5432}) +needle.ReplaceNamed(c, "primary", provider) +``` + +All Replace functions accept the same options as Provide (`WithScope`, `WithOnStart`, `WithOnStop`, `WithLazy`, `WithPoolSize`). If the service does not exist yet, Replace creates it. If it does exist, the old entry is removed from both the registry and the dependency graph before re-registering. + +`Must` variants (`MustReplace`, `MustReplaceValue`, `MustReplaceFunc`, `MustReplaceStruct`) panic on error. + ## Benchmarks Needle wins benchmark categories against uber/fx, samber/do, and uber/dig. diff --git a/concurrent_test.go b/concurrent_test.go new file mode 100644 index 0000000..9bbd0ba --- /dev/null +++ b/concurrent_test.go @@ -0,0 +1,218 @@ +package needle + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "testing" + + "github.com/danpasecinic/needle/internal/reflect" +) + +func TestConcurrentSingletonResolve(t *testing.T) { + t.Parallel() + + c := New() + _ = ProvideValue(c, &testCounter{id: 42}) + + const n = 100 + results := make([]*testCounter, n) + + var wg sync.WaitGroup + wg.Add(n) + for i := range n { + go func(idx int) { + defer wg.Done() + val, err := Invoke[*testCounter](c) + if err != nil { + t.Errorf("goroutine %d: %v", idx, err) + return + } + results[idx] = val + }(i) + } + wg.Wait() + + for i := 1; i < n; i++ { + if results[i] != results[0] { + t.Fatal("singleton must return same instance across goroutines") + } + } +} + +func TestConcurrentNamedProvideAndInvoke(t *testing.T) { + t.Parallel() + + c := New() + const n = 50 + + var wg sync.WaitGroup + wg.Add(n) + for i := range n { + go func(idx int) { + defer wg.Done() + _ = ProvideNamedValue(c, fmt.Sprintf("s%d", idx), &concService{id: idx}) + }(i) + } + wg.Wait() + + wg.Add(n) + for i := range n { + go func(idx int) { + defer wg.Done() + val, err := InvokeNamed[*concService](c, fmt.Sprintf("s%d", idx)) + if err != nil { + t.Errorf("invoke s%d: %v", idx, err) + return + } + if val.id != idx { + t.Errorf("s%d: expected id %d, got %d", idx, idx, val.id) + } + }(i) + } + wg.Wait() +} + +func TestConcurrentPoolAcquireRelease(t *testing.T) { + t.Parallel() + + c := New() + var created atomic.Int32 + + _ = Provide(c, func(_ context.Context, _ Resolver) (*testCounter, error) { + return &testCounter{id: int(created.Add(1))}, nil + }, WithPoolSize(3)) + + key := reflect.TypeKey[*testCounter]() + + // Pre-fill: create 3 instances, then release all to pool + instances := make([]*testCounter, 3) + for i := range 3 { + inst, err := Invoke[*testCounter](c) + if err != nil { + t.Fatalf("pre-fill %d: %v", i, err) + } + instances[i] = inst + } + for _, inst := range instances { + c.Release(key, inst) + } + + // Concurrent acquire-release cycles from the pre-filled pool + const n = 20 + var wg sync.WaitGroup + wg.Add(n) + for range n { + go func() { + defer wg.Done() + inst, err := Invoke[*testCounter](c) + if err != nil { + return + } + c.Release(key, inst) + }() + } + wg.Wait() + + if created.Load() < 3 { + t.Errorf("expected at least 3 provider calls, got %d", created.Load()) + } +} + +func TestConcurrentTransientDifferentKeys(t *testing.T) { + t.Parallel() + + c := New() + const n = 50 + + for i := range n { + idx := i + _ = ProvideNamed(c, fmt.Sprintf("t%d", idx), func(_ context.Context, _ Resolver) (*concService, error) { + return &concService{id: idx}, nil + }, WithScope(Transient)) + } + + var wg sync.WaitGroup + wg.Add(n) + for i := range n { + go func(idx int) { + defer wg.Done() + val, err := InvokeNamed[*concService](c, fmt.Sprintf("t%d", idx)) + if err != nil { + t.Errorf("t%d: %v", idx, err) + return + } + if val == nil { + t.Errorf("t%d: got nil", idx) + } + }(i) + } + wg.Wait() +} + +func TestConcurrentRequestScopeIsolation(t *testing.T) { + t.Parallel() + + c := New() + var created atomic.Int32 + + _ = Provide(c, func(_ context.Context, _ Resolver) (*testCounter, error) { + return &testCounter{id: int(created.Add(1))}, nil + }, WithScope(Request)) + + const numContexts = 10 + const resolvesPerCtx = 5 + + distinct := make(map[*testCounter]bool) + + for ci := range numContexts { + ctx := WithRequestScope(context.Background()) + first, err := InvokeCtx[*testCounter](ctx, c) + if err != nil { + t.Fatalf("ctx %d: %v", ci, err) + } + for ri := 1; ri < resolvesPerCtx; ri++ { + val, err := InvokeCtx[*testCounter](ctx, c) + if err != nil { + t.Fatalf("ctx %d resolve %d: %v", ci, ri, err) + } + if val != first { + t.Errorf("ctx %d: resolve %d returned different instance", ci, ri) + } + } + distinct[first] = true + } + + if len(distinct) != numContexts { + t.Errorf("expected %d distinct instances, got %d", numContexts, len(distinct)) + } +} + +func TestConcurrentReplaceNoRace(t *testing.T) { + t.Parallel() + + c := New() + _ = ProvideValue(c, &testCounter{id: 0}) + + const n = 50 + var wg sync.WaitGroup + wg.Add(n) + for i := range n { + go func(idx int) { + defer wg.Done() + if idx%2 == 0 { + _ = ReplaceValue(c, &testCounter{id: idx}) + } else { + // Invoke may fail due to concurrent replace, + // we're verifying no panics or data races. + _, _ = Invoke[*testCounter](c) + } + }(i) + } + wg.Wait() +} + +type concService struct { + id int +} diff --git a/doc.go b/doc.go index 3a5b0eb..b7db138 100644 --- a/doc.go +++ b/doc.go @@ -64,16 +64,21 @@ // // Use Optional for dependencies that may or may not be registered: // -// opt := needle.InvokeOptional[*Cache](c) +// opt, err := needle.InvokeOptional[*Cache](c) +// if err != nil { +// // registered but resolution failed +// } // if opt.Present() { // cache := opt.Value() // } // // // Or use OrElse for default values -// cache := needle.InvokeOptional[*Cache](c).OrElse(defaultCache) +// opt, _ := needle.InvokeOptional[*Cache](c) +// cache := opt.OrElse(defaultCache) // // // OrElseFunc for lazy defaults -// cache := needle.InvokeOptional[*Cache](c).OrElseFunc(func() *Cache { +// opt, _ := needle.InvokeOptional[*Cache](c) +// cache := opt.OrElseFunc(func() *Cache { // return NewDefaultCache() // }) // diff --git a/errors.go b/errors.go index 903cc61..9c450f4 100644 --- a/errors.go +++ b/errors.go @@ -112,7 +112,7 @@ func newError(code ErrorCode, message string, cause error) *Error { } } -func errServiceNotFound(serviceType string) *Error { //nolint:unused // reserved for future use +func errServiceNotFound(serviceType string) *Error { return newError( ErrCodeServiceNotFound, fmt.Sprintf("no provider registered for type %s", serviceType), @@ -120,22 +120,6 @@ func errServiceNotFound(serviceType string) *Error { //nolint:unused // reserved ).WithService(serviceType) } -func errCircularDependency(chain []string) *Error { //nolint:unused // reserved for future use - return newError( - ErrCodeCircularDependency, - fmt.Sprintf("circular dependency detected: %s", strings.Join(chain, " -> ")), - nil, - ).WithStack(chain) -} - -func errDuplicateService(serviceType string) *Error { //nolint:unused // reserved for future use - return newError( - ErrCodeDuplicateService, - fmt.Sprintf("provider already registered for type %s", serviceType), - nil, - ).WithService(serviceType) -} - func errResolutionFailed(serviceType string, cause error) *Error { return newError( ErrCodeResolutionFailed, @@ -144,15 +128,7 @@ func errResolutionFailed(serviceType string, cause error) *Error { ).WithService(serviceType) } -func errProviderFailed(serviceType string, cause error) *Error { //nolint:unused // reserved for future use - return newError( - ErrCodeProviderFailed, - fmt.Sprintf("provider for %s returned error", serviceType), - cause, - ).WithService(serviceType) -} - -func errStartupFailed(serviceType string, cause error) *Error { //nolint:unused // reserved for future use +func errStartupFailed(serviceType string, cause error) *Error { return newError( ErrCodeStartupFailed, fmt.Sprintf("failed to start %s", serviceType), @@ -160,7 +136,7 @@ func errStartupFailed(serviceType string, cause error) *Error { //nolint:unused ).WithService(serviceType) } -func errShutdownFailed(serviceType string, cause error) *Error { //nolint:unused // reserved for future use +func errShutdownFailed(serviceType string, cause error) *Error { return newError( ErrCodeShutdownFailed, fmt.Sprintf("failed to stop %s", serviceType), diff --git a/examples/optional/main.go b/examples/optional/main.go index 47e0fc4..545afa3 100644 --- a/examples/optional/main.go +++ b/examples/optional/main.go @@ -82,6 +82,13 @@ func (s *UserService) GetUser(id int) string { return user } +func mustOptional[T any](opt needle.Optional[T], err error) needle.Optional[T] { + if err != nil { + panic(err) + } + return opt +} + func main() { fmt.Println("=== Scenario 1: All dependencies available ===") runWithAllDeps() @@ -112,12 +119,12 @@ func runWithAllDeps() { _ = needle.Provide( c, func(_ context.Context, _ needle.Resolver) (*UserService, error) { - cache := needle.InvokeOptional[Cache](c).OrElseFunc( + cache := mustOptional(needle.InvokeOptional[Cache](c)).OrElseFunc( func() Cache { return NewInMemoryCache() }, ) - metrics := needle.InvokeOptional[Metrics](c).OrElse(&NoOpMetrics{}) + metrics := mustOptional(needle.InvokeOptional[Metrics](c)).OrElse(&NoOpMetrics{}) return &UserService{cache: cache, metrics: metrics}, nil }, ) @@ -132,12 +139,12 @@ func runWithoutOptionalDeps() { _ = needle.Provide( c, func(_ context.Context, _ needle.Resolver) (*UserService, error) { - cache := needle.InvokeOptional[Cache](c).OrElseFunc( + cache := mustOptional(needle.InvokeOptional[Cache](c)).OrElseFunc( func() Cache { return NewInMemoryCache() }, ) - metrics := needle.InvokeOptional[Metrics](c).OrElse(&NoOpMetrics{}) + metrics := mustOptional(needle.InvokeOptional[Metrics](c)).OrElse(&NoOpMetrics{}) return &UserService{cache: cache, metrics: metrics}, nil }, ) @@ -154,7 +161,7 @@ func demonstrateOptionalAPI() { _ = needle.Bind[Cache, *RedisCache](c) fmt.Println("--- Present() and Value() ---") - opt := needle.InvokeOptional[Cache](c) + opt := mustOptional(needle.InvokeOptional[Cache](c)) if opt.Present() { cache := opt.Value() cache.Set("foo", "bar") @@ -168,11 +175,11 @@ func demonstrateOptionalAPI() { } fmt.Println("\n--- OrElse() ---") - cache := needle.InvokeOptional[Cache](c).OrElse(NewInMemoryCache()) + cache := mustOptional(needle.InvokeOptional[Cache](c)).OrElse(NewInMemoryCache()) fmt.Printf("Cache type: %T\n", cache) fmt.Println("\n--- OrElseFunc() (lazy) ---") - cache = needle.InvokeOptional[Cache](c).OrElseFunc( + cache = mustOptional(needle.InvokeOptional[Cache](c)).OrElseFunc( func() Cache { fmt.Println("This won't print because cache exists") return NewInMemoryCache() @@ -181,12 +188,12 @@ func demonstrateOptionalAPI() { fmt.Printf("Cache type: %T\n", cache) fmt.Println("\n--- Missing dependency ---") - optMetrics := needle.InvokeOptional[Metrics](c) + optMetrics := mustOptional(needle.InvokeOptional[Metrics](c)) fmt.Printf("Metrics present: %v\n", optMetrics.Present()) metrics := optMetrics.OrElse(&NoOpMetrics{}) fmt.Printf("Metrics type: %T\n", metrics) fmt.Println("\n--- Named optional ---") - optNamed := needle.InvokeOptionalNamed[Cache](c, "session") + optNamed := mustOptional(needle.InvokeOptionalNamed[Cache](c, "session")) fmt.Printf("Named cache present: %v\n", optNamed.Present()) } diff --git a/internal/container/container.go b/internal/container/container.go index 665f56b..96dc3b6 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -75,10 +75,22 @@ func New(cfg *Config) *Container { } func (c *Container) Register(key string, provider ProviderFunc, dependencies []string) error { + if err := c.registerLocked(key, provider, dependencies); err != nil { + return err + } + + for _, hook := range c.onProvide { + hook(key) + } + + return nil +} + +func (c *Container) registerLocked(key string, provider ProviderFunc, dependencies []string) error { c.mu.Lock() + defer c.mu.Unlock() if c.registry.HasUnsafe(key) { - c.mu.Unlock() return fmt.Errorf("service already registered: %s", key) } @@ -88,11 +100,16 @@ func (c *Container) Register(key string, provider ProviderFunc, dependencies []s if len(dependencies) > 0 && c.graph.HasCycleUnsafe() { c.registry.RemoveUnsafe(key) c.graph.RemoveNodeUnsafe(key) - c.mu.Unlock() return fmt.Errorf("circular dependency detected for: %s", key) } - c.mu.Unlock() + return nil +} + +func (c *Container) RegisterValue(key string, value any) error { + if err := c.registerValueLocked(key, value); err != nil { + return err + } for _, hook := range c.onProvide { hook(key) @@ -101,23 +118,16 @@ func (c *Container) Register(key string, provider ProviderFunc, dependencies []s return nil } -func (c *Container) RegisterValue(key string, value any) error { +func (c *Container) registerValueLocked(key string, value any) error { c.mu.Lock() + defer c.mu.Unlock() if c.registry.HasUnsafe(key) { - c.mu.Unlock() return fmt.Errorf("service already registered: %s", key) } c.registry.RegisterValueUnsafe(key, value) c.graph.AddNodeUnsafe(key, nil) - - c.mu.Unlock() - - for _, hook := range c.onProvide { - hook(key) - } - return nil } @@ -180,7 +190,11 @@ func (c *Container) State() State { } func (c *Container) Release(key string, instance any) bool { - return c.registry.ReleaseToPool(key, instance) + released := c.registry.ReleaseToPool(key, instance) + if !released { + c.logger.Warn("pool overflow: instance dropped", "service", key) + } + return released } func (c *Container) AddOnStart(key string, hook Hook) { diff --git a/internal/container/registry.go b/internal/container/registry.go index 5ad86e5..dd2930c 100644 --- a/internal/container/registry.go +++ b/internal/container/registry.go @@ -113,6 +113,7 @@ func (r *Registry) GetInstance(key string) (any, bool) { return entry.Instance, true } +// GetInstanceFast avoids defer for performance -- this is a hot path called on every Resolve. func (r *Registry) GetInstanceFast(key string) (any, bool) { r.mu.RLock() entry, exists := r.services[key] diff --git a/internal/container/replace.go b/internal/container/replace.go index 7c0573f..481d6ab 100644 --- a/internal/container/replace.go +++ b/internal/container/replace.go @@ -6,14 +6,14 @@ func (c *Container) Replace(key string, provider ProviderFunc, dependencies []st c.mu.Lock() defer c.mu.Unlock() - c.registry.RemoveUnsafe(key) + c.registry.Remove(key) c.graph.RemoveNodeUnsafe(key) - c.registry.RegisterUnsafe(key, provider, dependencies) + _ = c.registry.Register(key, provider, dependencies) c.graph.AddNodeUnsafe(key, dependencies) if len(dependencies) > 0 && c.graph.HasCycleUnsafe() { - c.registry.RemoveUnsafe(key) + c.registry.Remove(key) c.graph.RemoveNodeUnsafe(key) cyclePath := c.graph.FindCyclePathUnsafe(key) return fmt.Errorf("circular dependency detected: %v", cyclePath) @@ -26,10 +26,10 @@ func (c *Container) ReplaceValue(key string, value any) error { c.mu.Lock() defer c.mu.Unlock() - c.registry.RemoveUnsafe(key) + c.registry.Remove(key) c.graph.RemoveNodeUnsafe(key) - c.registry.RegisterValueUnsafe(key, value) + _ = c.registry.RegisterValue(key, value) c.graph.AddNodeUnsafe(key, nil) return nil } diff --git a/needle_test.go b/needle_test.go index 7429bd2..ced1d17 100644 --- a/needle_test.go +++ b/needle_test.go @@ -391,7 +391,10 @@ func TestOptionalPresent(t *testing.T) { c := needle.New() _ = needle.ProvideValue(c, &Config{Port: 8080, Host: "localhost"}) - opt := needle.InvokeOptional[*Config](c) + opt, err := needle.InvokeOptional[*Config](c) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } if !opt.Present() { t.Error("expected optional to be present") @@ -415,7 +418,10 @@ func TestOptionalNotPresent(t *testing.T) { c := needle.New() - opt := needle.InvokeOptional[*Config](c) + opt, err := needle.InvokeOptional[*Config](c) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } if opt.Present() { t.Error("expected optional to not be present") @@ -435,7 +441,10 @@ func TestOptionalOrElse(t *testing.T) { c := needle.New() - opt := needle.InvokeOptional[*Config](c) + opt, err := needle.InvokeOptional[*Config](c) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } defaultCfg := &Config{Port: 3000} result := opt.OrElse(defaultCfg) @@ -444,7 +453,10 @@ func TestOptionalOrElse(t *testing.T) { } _ = needle.ProvideValue(c, &Config{Port: 8080}) - opt2 := needle.InvokeOptional[*Config](c) + opt2, err := needle.InvokeOptional[*Config](c) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } result2 := opt2.OrElse(defaultCfg) if result2.Port != 8080 { @@ -458,7 +470,10 @@ func TestOptionalOrElseFunc(t *testing.T) { c := needle.New() callCount := 0 - opt := needle.InvokeOptional[*Config](c) + opt, err := needle.InvokeOptional[*Config](c) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } result := opt.OrElseFunc(func() *Config { callCount++ return &Config{Port: 9000} @@ -472,7 +487,10 @@ func TestOptionalOrElseFunc(t *testing.T) { } _ = needle.ProvideValue(c, &Config{Port: 8080}) - opt2 := needle.InvokeOptional[*Config](c) + opt2, err := needle.InvokeOptional[*Config](c) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } result2 := opt2.OrElseFunc(func() *Config { callCount++ return &Config{Port: 9000} @@ -492,7 +510,10 @@ func TestOptionalNamed(t *testing.T) { c := needle.New() _ = needle.ProvideNamedValue(c, "primary", &Config{Port: 5432}) - opt := needle.InvokeOptionalNamed[*Config](c, "primary") + opt, err := needle.InvokeOptionalNamed[*Config](c, "primary") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } if !opt.Present() { t.Error("expected primary config to be present") } @@ -500,7 +521,10 @@ func TestOptionalNamed(t *testing.T) { t.Errorf("expected port 5432, got %d", opt.Value().Port) } - optMissing := needle.InvokeOptionalNamed[*Config](c, "replica") + optMissing, err := needle.InvokeOptionalNamed[*Config](c, "replica") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } if optMissing.Present() { t.Error("expected replica config to not be present") } @@ -520,7 +544,10 @@ func TestOptionalInProvider(t *testing.T) { } _ = needle.Provide(c, func(ctx context.Context, r needle.Resolver) (*Service, error) { - cacheOpt := needle.InvokeOptional[*Cache](c) + cacheOpt, err := needle.InvokeOptional[*Cache](c) + if err != nil { + return nil, err + } return &Service{ Cache: cacheOpt.OrElse(nil), }, nil @@ -547,7 +574,10 @@ func TestOptionalInProviderWithValue(t *testing.T) { _ = needle.ProvideValue(c, &Cache{Enabled: true}) _ = needle.Provide(c, func(ctx context.Context, r needle.Resolver) (*Service, error) { - cacheOpt := needle.InvokeOptional[*Cache](c) + cacheOpt, err := needle.InvokeOptional[*Cache](c) + if err != nil { + return nil, err + } return &Service{ Cache: cacheOpt.OrElse(nil), }, nil @@ -562,6 +592,41 @@ func TestOptionalInProviderWithValue(t *testing.T) { } } +func TestOptionalResolutionError(t *testing.T) { + t.Parallel() + + c := needle.New() + + _ = needle.Provide(c, func(_ context.Context, _ needle.Resolver) (*Config, error) { + return nil, errors.New("provider broken") + }) + + opt, err := needle.InvokeOptional[*Config](c) + if err == nil { + t.Fatal("expected error for broken provider") + } + if opt.Present() { + t.Error("expected optional to not be present on error") + } + if !needle.IsResolutionFailed(err) { + t.Errorf("expected resolution failed error, got: %v", err) + } +} + +func TestOptionalNotRegisteredNoError(t *testing.T) { + t.Parallel() + + c := needle.New() + + opt, err := needle.InvokeOptional[*Config](c) + if err != nil { + t.Fatalf("expected no error for unregistered service, got: %v", err) + } + if opt.Present() { + t.Error("expected optional to not be present") + } +} + func TestSomeNone(t *testing.T) { t.Parallel() diff --git a/resolver.go b/resolver.go index 1eacb4e..75aca54 100644 --- a/resolver.go +++ b/resolver.go @@ -2,6 +2,7 @@ package needle import ( "context" + "strings" "github.com/danpasecinic/needle/internal/reflect" ) @@ -156,50 +157,39 @@ func None[T any]() Optional[T] { return Optional[T]{} } -func InvokeOptional[T any](c *Container) Optional[T] { +func InvokeOptional[T any](c *Container) (Optional[T], error) { return InvokeOptionalCtx[T](context.Background(), c) } -func InvokeOptionalCtx[T any](ctx context.Context, c *Container) Optional[T] { - key := reflect.TypeKey[T]() - - if !c.internal.Has(key) { - return None[T]() - } - - instance, err := c.internal.Resolve(ctx, key) - if err != nil { - return None[T]() - } - - typed, ok := instance.(T) - if !ok { - return None[T]() - } - - return Some(typed) +func InvokeOptionalCtx[T any](ctx context.Context, c *Container) (Optional[T], error) { + return resolveOptional[T](ctx, c, reflect.TypeKey[T](), reflect.TypeName[T]()) } -func InvokeOptionalNamed[T any](c *Container, name string) Optional[T] { +func InvokeOptionalNamed[T any](c *Container, name string) (Optional[T], error) { return InvokeOptionalNamedCtx[T](context.Background(), c, name) } -func InvokeOptionalNamedCtx[T any](ctx context.Context, c *Container, name string) Optional[T] { - key := reflect.TypeKeyNamed[T](name) - - if !c.internal.Has(key) { - return None[T]() - } +func InvokeOptionalNamedCtx[T any](ctx context.Context, c *Container, name string) (Optional[T], error) { + return resolveOptional[T](ctx, c, reflect.TypeKeyNamed[T](name), reflect.TypeName[T]()+"#"+name) +} +func resolveOptional[T any](ctx context.Context, c *Container, key, displayName string) (Optional[T], error) { instance, err := c.internal.Resolve(ctx, key) if err != nil { - return None[T]() + if isServiceNotFound(err) { + return None[T](), nil + } + return None[T](), errResolutionFailed(displayName, err) } typed, ok := instance.(T) if !ok { - return None[T]() + return None[T](), errResolutionFailed(displayName, nil) } - return Some(typed) + return Some(typed), nil +} + +func isServiceNotFound(err error) bool { + return err != nil && strings.Contains(err.Error(), "service not found:") } From f76793c143ae47b3854a1c10403609aadb4dad31 Mon Sep 17 00:00:00 2001 From: danpasecinic Date: Sat, 7 Mar 2026 19:50:14 +0100 Subject: [PATCH 4/4] refactor: replace unsafe lock-free methods with safe locked variants --- README.md | 20 ++++++++++---------- internal/container/container.go | 18 +++++++++--------- internal/container/registry.go | 25 ------------------------- internal/container/replace.go | 14 +++++++------- internal/graph/cycle.go | 13 ------------- internal/graph/graph.go | 8 -------- 6 files changed, 26 insertions(+), 72 deletions(-) diff --git a/README.md b/README.md index 749e33f..94a07a5 100644 --- a/README.md +++ b/README.md @@ -101,21 +101,21 @@ Needle wins benchmark categories against uber/fx, samber/do, and uber/dig. | Framework | Simple | Chain | Memory (Chain) | |------------|--------|-------|----------------| -| **Needle** | 780ns | 1.6μs | 3KB | -| Do | 1.9μs | 5.0μs | 4KB | -| Dig | 13μs | 28μs | 28KB | -| Fx | 42μs | 85μs | 70KB | +| **Needle** | 698ns | 1.5μs | 3KB | +| Do | 1.8μs | 4.4μs | 4KB | +| Dig | 13μs | 26μs | 28KB | +| Fx | 39μs | 78μs | 70KB | -Needle is **50x faster** than Fx for provider registration. +Needle is **56x faster** than Fx for provider registration. ### Service Resolution | Framework | Singleton | Chain | |------------|-----------|-------| | Fx | 0ns* | 0ns* | -| **Needle** | 17ns | 16ns | -| Do | 152ns | 159ns | -| Dig | 591ns | 586ns | +| **Needle** | 15ns | 17ns | +| Do | 150ns | 161ns | +| Dig | 614ns | 622ns | *Fx resolves at startup, not on-demand. @@ -125,8 +125,8 @@ When services have initialization work (database connections, HTTP clients, etc. | Scenario | Sequential | Parallel | Speedup | |-------------------|------------|----------|---------| -| 10 services × 1ms | 23ms | 2.4ms | **10x** | -| 50 services × 1ms | 116ms | 2.5ms | **45x** | +| 10 services × 1ms | 23ms | 2.3ms | **10x** | +| 50 services × 1ms | 113ms | 2.6ms | **44x** | Run benchmarks: `cd benchmark && make run` diff --git a/internal/container/container.go b/internal/container/container.go index 96dc3b6..9fab32b 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -90,16 +90,16 @@ func (c *Container) registerLocked(key string, provider ProviderFunc, dependenci c.mu.Lock() defer c.mu.Unlock() - if c.registry.HasUnsafe(key) { + if c.registry.Has(key) { return fmt.Errorf("service already registered: %s", key) } - c.registry.RegisterUnsafe(key, provider, dependencies) - c.graph.AddNodeUnsafe(key, dependencies) + _ = c.registry.Register(key, provider, dependencies) + c.graph.AddNode(key, dependencies) - if len(dependencies) > 0 && c.graph.HasCycleUnsafe() { - c.registry.RemoveUnsafe(key) - c.graph.RemoveNodeUnsafe(key) + if len(dependencies) > 0 && c.graph.HasCycle() { + c.registry.Remove(key) + c.graph.RemoveNode(key) return fmt.Errorf("circular dependency detected for: %s", key) } @@ -122,12 +122,12 @@ func (c *Container) registerValueLocked(key string, value any) error { c.mu.Lock() defer c.mu.Unlock() - if c.registry.HasUnsafe(key) { + if c.registry.Has(key) { return fmt.Errorf("service already registered: %s", key) } - c.registry.RegisterValueUnsafe(key, value) - c.graph.AddNodeUnsafe(key, nil) + _ = c.registry.RegisterValue(key, value) + c.graph.AddNode(key, nil) return nil } diff --git a/internal/container/registry.go b/internal/container/registry.go index dd2930c..2c8771f 100644 --- a/internal/container/registry.go +++ b/internal/container/registry.go @@ -55,14 +55,6 @@ func (r *Registry) Register(key string, provider ProviderFunc, dependencies []st return nil } -func (r *Registry) RegisterUnsafe(key string, provider ProviderFunc, dependencies []string) { - r.services[key] = &ServiceEntry{ - Key: key, - Provider: provider, - Dependencies: dependencies, - } -} - func (r *Registry) RegisterValue(key string, value any) error { r.mu.Lock() defer r.mu.Unlock() @@ -74,14 +66,6 @@ func (r *Registry) RegisterValue(key string, value any) error { return nil } -func (r *Registry) RegisterValueUnsafe(key string, value any) { - r.services[key] = &ServiceEntry{ - Key: key, - Instance: value, - Instantiated: true, - } -} - func (r *Registry) Has(key string) bool { r.mu.RLock() defer r.mu.RUnlock() @@ -89,11 +73,6 @@ func (r *Registry) Has(key string) bool { return exists } -func (r *Registry) HasUnsafe(key string) bool { - _, exists := r.services[key] - return exists -} - func (r *Registry) Get(key string) (*ServiceEntry, bool) { r.mu.RLock() defer r.mu.RUnlock() @@ -171,10 +150,6 @@ func (r *Registry) Remove(key string) { delete(r.services, key) } -func (r *Registry) RemoveUnsafe(key string) { - delete(r.services, key) -} - func (r *Registry) Dependencies(key string) []string { r.mu.RLock() defer r.mu.RUnlock() diff --git a/internal/container/replace.go b/internal/container/replace.go index 481d6ab..33564e0 100644 --- a/internal/container/replace.go +++ b/internal/container/replace.go @@ -7,15 +7,15 @@ func (c *Container) Replace(key string, provider ProviderFunc, dependencies []st defer c.mu.Unlock() c.registry.Remove(key) - c.graph.RemoveNodeUnsafe(key) + c.graph.RemoveNode(key) _ = c.registry.Register(key, provider, dependencies) - c.graph.AddNodeUnsafe(key, dependencies) + c.graph.AddNode(key, dependencies) - if len(dependencies) > 0 && c.graph.HasCycleUnsafe() { + if len(dependencies) > 0 && c.graph.HasCycle() { c.registry.Remove(key) - c.graph.RemoveNodeUnsafe(key) - cyclePath := c.graph.FindCyclePathUnsafe(key) + c.graph.RemoveNode(key) + cyclePath := c.graph.FindCyclePath(key) return fmt.Errorf("circular dependency detected: %v", cyclePath) } @@ -27,9 +27,9 @@ func (c *Container) ReplaceValue(key string, value any) error { defer c.mu.Unlock() c.registry.Remove(key) - c.graph.RemoveNodeUnsafe(key) + c.graph.RemoveNode(key) _ = c.registry.RegisterValue(key, value) - c.graph.AddNodeUnsafe(key, nil) + c.graph.AddNode(key, nil) return nil } diff --git a/internal/graph/cycle.go b/internal/graph/cycle.go index 470cd24..9f8cf87 100644 --- a/internal/graph/cycle.go +++ b/internal/graph/cycle.go @@ -109,15 +109,6 @@ func (g *Graph) HasCycle() bool { return g.hasCycle } -func (g *Graph) HasCycleUnsafe() bool { - if g.cycleValid { - return g.hasCycle - } - g.hasCycle = g.hasCycleUnsafe() - g.cycleValid = true - return g.hasCycle -} - func (g *Graph) hasCycleUnsafe() bool { white := make(map[string]bool, len(g.nodes)) gray := make(map[string]bool, len(g.nodes)) @@ -165,10 +156,6 @@ func (g *Graph) FindCyclePath(start string) []string { return g.findCyclePathUnsafe(start) } -func (g *Graph) FindCyclePathUnsafe(start string) []string { - return g.findCyclePathUnsafe(start) -} - func (g *Graph) findCyclePathUnsafe(start string) []string { visited := make(map[string]bool) path := make([]string, 0) diff --git a/internal/graph/graph.go b/internal/graph/graph.go index 27fed38..354ab7c 100644 --- a/internal/graph/graph.go +++ b/internal/graph/graph.go @@ -32,10 +32,6 @@ func (g *Graph) AddNode(id string, dependencies []string) { g.addNodeUnsafe(id, dependencies) } -func (g *Graph) AddNodeUnsafe(id string, dependencies []string) { - g.addNodeUnsafe(id, dependencies) -} - func (g *Graph) addNodeUnsafe(id string, dependencies []string) { g.nodes[id] = &Node{ ID: id, @@ -52,10 +48,6 @@ func (g *Graph) RemoveNode(id string) { g.removeNodeUnsafe(id) } -func (g *Graph) RemoveNodeUnsafe(id string) { - g.removeNodeUnsafe(id) -} - func (g *Graph) removeNodeUnsafe(id string) { delete(g.nodes, id) delete(g.edges, id)