From 1f103d4f739ec788f9c481d3666d076c2ed7d17d Mon Sep 17 00:00:00 2001 From: danpasecinic Date: Sat, 7 Mar 2026 19:17:56 +0100 Subject: [PATCH 1/7] refactor: remove unused error constructors Remove errCircularDependency, errDuplicateService, and errProviderFailed which were defined but never called. Remove //nolint:unused directives from errServiceNotFound, errStartupFailed, errShutdownFailed which are actually used. --- errors.go | 30 +++--------------------------- 1 file changed, 3 insertions(+), 27 deletions(-) diff --git a/errors.go b/errors.go index 903cc61..9c450f4 100644 --- a/errors.go +++ b/errors.go @@ -112,7 +112,7 @@ func newError(code ErrorCode, message string, cause error) *Error { } } -func errServiceNotFound(serviceType string) *Error { //nolint:unused // reserved for future use +func errServiceNotFound(serviceType string) *Error { return newError( ErrCodeServiceNotFound, fmt.Sprintf("no provider registered for type %s", serviceType), @@ -120,22 +120,6 @@ func errServiceNotFound(serviceType string) *Error { //nolint:unused // reserved ).WithService(serviceType) } -func errCircularDependency(chain []string) *Error { //nolint:unused // reserved for future use - return newError( - ErrCodeCircularDependency, - fmt.Sprintf("circular dependency detected: %s", strings.Join(chain, " -> ")), - nil, - ).WithStack(chain) -} - -func errDuplicateService(serviceType string) *Error { //nolint:unused // reserved for future use - return newError( - ErrCodeDuplicateService, - fmt.Sprintf("provider already registered for type %s", serviceType), - nil, - ).WithService(serviceType) -} - func errResolutionFailed(serviceType string, cause error) *Error { return newError( ErrCodeResolutionFailed, @@ -144,15 +128,7 @@ func errResolutionFailed(serviceType string, cause error) *Error { ).WithService(serviceType) } -func errProviderFailed(serviceType string, cause error) *Error { //nolint:unused // reserved for future use - return newError( - ErrCodeProviderFailed, - fmt.Sprintf("provider for %s returned error", serviceType), - cause, - ).WithService(serviceType) -} - -func errStartupFailed(serviceType string, cause error) *Error { //nolint:unused // reserved for future use +func errStartupFailed(serviceType string, cause error) *Error { return newError( ErrCodeStartupFailed, fmt.Sprintf("failed to start %s", serviceType), @@ -160,7 +136,7 @@ func errStartupFailed(serviceType string, cause error) *Error { //nolint:unused ).WithService(serviceType) } -func errShutdownFailed(serviceType string, cause error) *Error { //nolint:unused // reserved for future use +func errShutdownFailed(serviceType string, cause error) *Error { return newError( ErrCodeShutdownFailed, fmt.Sprintf("failed to stop %s", serviceType), From 59ca5c769d46ae44b12214e0bc1b41b637e81486 Mon Sep 17 00:00:00 2001 From: danpasecinic Date: Sat, 7 Mar 2026 19:18:24 +0100 Subject: [PATCH 2/7] refactor: extract lock helpers in container internals Extract registerLocked/registerValueLocked from Register/RegisterValue to use defer-based unlock instead of manual multi-exit-point unlocks. Extract markResolving/unmarkResolving from resolveSlow for the same reason. Add comment to GetInstanceFast documenting why defer is intentionally skipped on that hot path. --- internal/container/container.go | 34 +++++++++++++++++++++------------ internal/container/registry.go | 1 + internal/container/resolve.go | 2 +- 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/internal/container/container.go b/internal/container/container.go index 665f56b..bdec067 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -75,10 +75,22 @@ func New(cfg *Config) *Container { } func (c *Container) Register(key string, provider ProviderFunc, dependencies []string) error { + if err := c.registerLocked(key, provider, dependencies); err != nil { + return err + } + + for _, hook := range c.onProvide { + hook(key) + } + + return nil +} + +func (c *Container) registerLocked(key string, provider ProviderFunc, dependencies []string) error { c.mu.Lock() + defer c.mu.Unlock() if c.registry.HasUnsafe(key) { - c.mu.Unlock() return fmt.Errorf("service already registered: %s", key) } @@ -88,11 +100,16 @@ func (c *Container) Register(key string, provider ProviderFunc, dependencies []s if len(dependencies) > 0 && c.graph.HasCycleUnsafe() { c.registry.RemoveUnsafe(key) c.graph.RemoveNodeUnsafe(key) - c.mu.Unlock() return fmt.Errorf("circular dependency detected for: %s", key) } - c.mu.Unlock() + return nil +} + +func (c *Container) RegisterValue(key string, value any) error { + if err := c.registerValueLocked(key, value); err != nil { + return err + } for _, hook := range c.onProvide { hook(key) @@ -101,23 +118,16 @@ func (c *Container) Register(key string, provider ProviderFunc, dependencies []s return nil } -func (c *Container) RegisterValue(key string, value any) error { +func (c *Container) registerValueLocked(key string, value any) error { c.mu.Lock() + defer c.mu.Unlock() if c.registry.HasUnsafe(key) { - c.mu.Unlock() return fmt.Errorf("service already registered: %s", key) } c.registry.RegisterValueUnsafe(key, value) c.graph.AddNodeUnsafe(key, nil) - - c.mu.Unlock() - - for _, hook := range c.onProvide { - hook(key) - } - return nil } diff --git a/internal/container/registry.go b/internal/container/registry.go index 5ad86e5..dd2930c 100644 --- a/internal/container/registry.go +++ b/internal/container/registry.go @@ -113,6 +113,7 @@ func (r *Registry) GetInstance(key string) (any, bool) { return entry.Instance, true } +// GetInstanceFast avoids defer for performance -- this is a hot path called on every Resolve. func (r *Registry) GetInstanceFast(key string) (any, bool) { r.mu.RLock() entry, exists := r.services[key] diff --git a/internal/container/resolve.go b/internal/container/resolve.go index a790238..2b927b9 100644 --- a/internal/container/resolve.go +++ b/internal/container/resolve.go @@ -45,7 +45,7 @@ func (c *Container) resolveSlow(ctx context.Context, key string) (any, error) { err := fmt.Errorf("circular resolution detected for: %s", key) c.callResolveHooks(key, time.Since(start), err) return nil, err - } + } dd399ff (refactor: extract lock helpers in container internals) c.mu.RLock() entry, exists := c.registry.Get(key) From a7951830ee7722012068fce99ae201b259422c41 Mon Sep 17 00:00:00 2001 From: danpasecinic Date: Sat, 7 Mar 2026 19:18:30 +0100 Subject: [PATCH 3/7] feat: add pool overflow warning on release Log a warning via slog when a pooled instance is dropped because the pool channel is full. Logging lives in Container.Release rather than Registry to keep the data layer free of observability concerns. --- internal/container/container.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/internal/container/container.go b/internal/container/container.go index bdec067..96dc3b6 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -190,7 +190,11 @@ func (c *Container) State() State { } func (c *Container) Release(key string, instance any) bool { - return c.registry.ReleaseToPool(key, instance) + released := c.registry.ReleaseToPool(key, instance) + if !released { + c.logger.Warn("pool overflow: instance dropped", "service", key) + } + return released } func (c *Container) AddOnStart(key string, hook Hook) { From cc0b0061229fe77023725d26138ebaa3e5ce246c Mon Sep 17 00:00:00 2001 From: danpasecinic Date: Sat, 7 Mar 2026 19:18:39 +0100 Subject: [PATCH 4/7] feat: return error from InvokeOptional to distinguish not-found from resolution failure Change InvokeOptional/InvokeOptionalNamed return type from Optional[T] to (Optional[T], error). Not-registered returns (None, nil); a broken provider returns (None, error). Extract shared resolveOptional helper to deduplicate the Ctx and NamedCtx variants. Remove TOCTOU Has()+Resolve() double-lock pattern in favor of a single Resolve() call with error inspection. --- doc.go | 11 +++-- examples/optional/main.go | 25 +++++++----- needle_test.go | 85 ++++++++++++++++++++++++++++++++++----- resolver.go | 48 +++++++++------------- 4 files changed, 118 insertions(+), 51 deletions(-) diff --git a/doc.go b/doc.go index 3a5b0eb..b7db138 100644 --- a/doc.go +++ b/doc.go @@ -64,16 +64,21 @@ // // Use Optional for dependencies that may or may not be registered: // -// opt := needle.InvokeOptional[*Cache](c) +// opt, err := needle.InvokeOptional[*Cache](c) +// if err != nil { +// // registered but resolution failed +// } // if opt.Present() { // cache := opt.Value() // } // // // Or use OrElse for default values -// cache := needle.InvokeOptional[*Cache](c).OrElse(defaultCache) +// opt, _ := needle.InvokeOptional[*Cache](c) +// cache := opt.OrElse(defaultCache) // // // OrElseFunc for lazy defaults -// cache := needle.InvokeOptional[*Cache](c).OrElseFunc(func() *Cache { +// opt, _ := needle.InvokeOptional[*Cache](c) +// cache := opt.OrElseFunc(func() *Cache { // return NewDefaultCache() // }) // diff --git a/examples/optional/main.go b/examples/optional/main.go index 47e0fc4..545afa3 100644 --- a/examples/optional/main.go +++ b/examples/optional/main.go @@ -82,6 +82,13 @@ func (s *UserService) GetUser(id int) string { return user } +func mustOptional[T any](opt needle.Optional[T], err error) needle.Optional[T] { + if err != nil { + panic(err) + } + return opt +} + func main() { fmt.Println("=== Scenario 1: All dependencies available ===") runWithAllDeps() @@ -112,12 +119,12 @@ func runWithAllDeps() { _ = needle.Provide( c, func(_ context.Context, _ needle.Resolver) (*UserService, error) { - cache := needle.InvokeOptional[Cache](c).OrElseFunc( + cache := mustOptional(needle.InvokeOptional[Cache](c)).OrElseFunc( func() Cache { return NewInMemoryCache() }, ) - metrics := needle.InvokeOptional[Metrics](c).OrElse(&NoOpMetrics{}) + metrics := mustOptional(needle.InvokeOptional[Metrics](c)).OrElse(&NoOpMetrics{}) return &UserService{cache: cache, metrics: metrics}, nil }, ) @@ -132,12 +139,12 @@ func runWithoutOptionalDeps() { _ = needle.Provide( c, func(_ context.Context, _ needle.Resolver) (*UserService, error) { - cache := needle.InvokeOptional[Cache](c).OrElseFunc( + cache := mustOptional(needle.InvokeOptional[Cache](c)).OrElseFunc( func() Cache { return NewInMemoryCache() }, ) - metrics := needle.InvokeOptional[Metrics](c).OrElse(&NoOpMetrics{}) + metrics := mustOptional(needle.InvokeOptional[Metrics](c)).OrElse(&NoOpMetrics{}) return &UserService{cache: cache, metrics: metrics}, nil }, ) @@ -154,7 +161,7 @@ func demonstrateOptionalAPI() { _ = needle.Bind[Cache, *RedisCache](c) fmt.Println("--- Present() and Value() ---") - opt := needle.InvokeOptional[Cache](c) + opt := mustOptional(needle.InvokeOptional[Cache](c)) if opt.Present() { cache := opt.Value() cache.Set("foo", "bar") @@ -168,11 +175,11 @@ func demonstrateOptionalAPI() { } fmt.Println("\n--- OrElse() ---") - cache := needle.InvokeOptional[Cache](c).OrElse(NewInMemoryCache()) + cache := mustOptional(needle.InvokeOptional[Cache](c)).OrElse(NewInMemoryCache()) fmt.Printf("Cache type: %T\n", cache) fmt.Println("\n--- OrElseFunc() (lazy) ---") - cache = needle.InvokeOptional[Cache](c).OrElseFunc( + cache = mustOptional(needle.InvokeOptional[Cache](c)).OrElseFunc( func() Cache { fmt.Println("This won't print because cache exists") return NewInMemoryCache() @@ -181,12 +188,12 @@ func demonstrateOptionalAPI() { fmt.Printf("Cache type: %T\n", cache) fmt.Println("\n--- Missing dependency ---") - optMetrics := needle.InvokeOptional[Metrics](c) + optMetrics := mustOptional(needle.InvokeOptional[Metrics](c)) fmt.Printf("Metrics present: %v\n", optMetrics.Present()) metrics := optMetrics.OrElse(&NoOpMetrics{}) fmt.Printf("Metrics type: %T\n", metrics) fmt.Println("\n--- Named optional ---") - optNamed := needle.InvokeOptionalNamed[Cache](c, "session") + optNamed := mustOptional(needle.InvokeOptionalNamed[Cache](c, "session")) fmt.Printf("Named cache present: %v\n", optNamed.Present()) } diff --git a/needle_test.go b/needle_test.go index 7429bd2..ced1d17 100644 --- a/needle_test.go +++ b/needle_test.go @@ -391,7 +391,10 @@ func TestOptionalPresent(t *testing.T) { c := needle.New() _ = needle.ProvideValue(c, &Config{Port: 8080, Host: "localhost"}) - opt := needle.InvokeOptional[*Config](c) + opt, err := needle.InvokeOptional[*Config](c) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } if !opt.Present() { t.Error("expected optional to be present") @@ -415,7 +418,10 @@ func TestOptionalNotPresent(t *testing.T) { c := needle.New() - opt := needle.InvokeOptional[*Config](c) + opt, err := needle.InvokeOptional[*Config](c) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } if opt.Present() { t.Error("expected optional to not be present") @@ -435,7 +441,10 @@ func TestOptionalOrElse(t *testing.T) { c := needle.New() - opt := needle.InvokeOptional[*Config](c) + opt, err := needle.InvokeOptional[*Config](c) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } defaultCfg := &Config{Port: 3000} result := opt.OrElse(defaultCfg) @@ -444,7 +453,10 @@ func TestOptionalOrElse(t *testing.T) { } _ = needle.ProvideValue(c, &Config{Port: 8080}) - opt2 := needle.InvokeOptional[*Config](c) + opt2, err := needle.InvokeOptional[*Config](c) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } result2 := opt2.OrElse(defaultCfg) if result2.Port != 8080 { @@ -458,7 +470,10 @@ func TestOptionalOrElseFunc(t *testing.T) { c := needle.New() callCount := 0 - opt := needle.InvokeOptional[*Config](c) + opt, err := needle.InvokeOptional[*Config](c) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } result := opt.OrElseFunc(func() *Config { callCount++ return &Config{Port: 9000} @@ -472,7 +487,10 @@ func TestOptionalOrElseFunc(t *testing.T) { } _ = needle.ProvideValue(c, &Config{Port: 8080}) - opt2 := needle.InvokeOptional[*Config](c) + opt2, err := needle.InvokeOptional[*Config](c) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } result2 := opt2.OrElseFunc(func() *Config { callCount++ return &Config{Port: 9000} @@ -492,7 +510,10 @@ func TestOptionalNamed(t *testing.T) { c := needle.New() _ = needle.ProvideNamedValue(c, "primary", &Config{Port: 5432}) - opt := needle.InvokeOptionalNamed[*Config](c, "primary") + opt, err := needle.InvokeOptionalNamed[*Config](c, "primary") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } if !opt.Present() { t.Error("expected primary config to be present") } @@ -500,7 +521,10 @@ func TestOptionalNamed(t *testing.T) { t.Errorf("expected port 5432, got %d", opt.Value().Port) } - optMissing := needle.InvokeOptionalNamed[*Config](c, "replica") + optMissing, err := needle.InvokeOptionalNamed[*Config](c, "replica") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } if optMissing.Present() { t.Error("expected replica config to not be present") } @@ -520,7 +544,10 @@ func TestOptionalInProvider(t *testing.T) { } _ = needle.Provide(c, func(ctx context.Context, r needle.Resolver) (*Service, error) { - cacheOpt := needle.InvokeOptional[*Cache](c) + cacheOpt, err := needle.InvokeOptional[*Cache](c) + if err != nil { + return nil, err + } return &Service{ Cache: cacheOpt.OrElse(nil), }, nil @@ -547,7 +574,10 @@ func TestOptionalInProviderWithValue(t *testing.T) { _ = needle.ProvideValue(c, &Cache{Enabled: true}) _ = needle.Provide(c, func(ctx context.Context, r needle.Resolver) (*Service, error) { - cacheOpt := needle.InvokeOptional[*Cache](c) + cacheOpt, err := needle.InvokeOptional[*Cache](c) + if err != nil { + return nil, err + } return &Service{ Cache: cacheOpt.OrElse(nil), }, nil @@ -562,6 +592,41 @@ func TestOptionalInProviderWithValue(t *testing.T) { } } +func TestOptionalResolutionError(t *testing.T) { + t.Parallel() + + c := needle.New() + + _ = needle.Provide(c, func(_ context.Context, _ needle.Resolver) (*Config, error) { + return nil, errors.New("provider broken") + }) + + opt, err := needle.InvokeOptional[*Config](c) + if err == nil { + t.Fatal("expected error for broken provider") + } + if opt.Present() { + t.Error("expected optional to not be present on error") + } + if !needle.IsResolutionFailed(err) { + t.Errorf("expected resolution failed error, got: %v", err) + } +} + +func TestOptionalNotRegisteredNoError(t *testing.T) { + t.Parallel() + + c := needle.New() + + opt, err := needle.InvokeOptional[*Config](c) + if err != nil { + t.Fatalf("expected no error for unregistered service, got: %v", err) + } + if opt.Present() { + t.Error("expected optional to not be present") + } +} + func TestSomeNone(t *testing.T) { t.Parallel() diff --git a/resolver.go b/resolver.go index 1eacb4e..75aca54 100644 --- a/resolver.go +++ b/resolver.go @@ -2,6 +2,7 @@ package needle import ( "context" + "strings" "github.com/danpasecinic/needle/internal/reflect" ) @@ -156,50 +157,39 @@ func None[T any]() Optional[T] { return Optional[T]{} } -func InvokeOptional[T any](c *Container) Optional[T] { +func InvokeOptional[T any](c *Container) (Optional[T], error) { return InvokeOptionalCtx[T](context.Background(), c) } -func InvokeOptionalCtx[T any](ctx context.Context, c *Container) Optional[T] { - key := reflect.TypeKey[T]() - - if !c.internal.Has(key) { - return None[T]() - } - - instance, err := c.internal.Resolve(ctx, key) - if err != nil { - return None[T]() - } - - typed, ok := instance.(T) - if !ok { - return None[T]() - } - - return Some(typed) +func InvokeOptionalCtx[T any](ctx context.Context, c *Container) (Optional[T], error) { + return resolveOptional[T](ctx, c, reflect.TypeKey[T](), reflect.TypeName[T]()) } -func InvokeOptionalNamed[T any](c *Container, name string) Optional[T] { +func InvokeOptionalNamed[T any](c *Container, name string) (Optional[T], error) { return InvokeOptionalNamedCtx[T](context.Background(), c, name) } -func InvokeOptionalNamedCtx[T any](ctx context.Context, c *Container, name string) Optional[T] { - key := reflect.TypeKeyNamed[T](name) - - if !c.internal.Has(key) { - return None[T]() - } +func InvokeOptionalNamedCtx[T any](ctx context.Context, c *Container, name string) (Optional[T], error) { + return resolveOptional[T](ctx, c, reflect.TypeKeyNamed[T](name), reflect.TypeName[T]()+"#"+name) +} +func resolveOptional[T any](ctx context.Context, c *Container, key, displayName string) (Optional[T], error) { instance, err := c.internal.Resolve(ctx, key) if err != nil { - return None[T]() + if isServiceNotFound(err) { + return None[T](), nil + } + return None[T](), errResolutionFailed(displayName, err) } typed, ok := instance.(T) if !ok { - return None[T]() + return None[T](), errResolutionFailed(displayName, nil) } - return Some(typed) + return Some(typed), nil +} + +func isServiceNotFound(err error) bool { + return err != nil && strings.Contains(err.Error(), "service not found:") } From ad332de730910d6a1f993b1e9f2f45cf4c5bf033 Mon Sep 17 00:00:00 2001 From: danpasecinic Date: Sat, 7 Mar 2026 19:18:46 +0100 Subject: [PATCH 5/7] test: add concurrent stress tests Add 6 tests covering parallel singleton resolve, named provide+invoke, pool acquire/release cycles, transient with different keys, request scope isolation, and concurrent replace with race detection. --- concurrent_test.go | 218 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 218 insertions(+) create mode 100644 concurrent_test.go diff --git a/concurrent_test.go b/concurrent_test.go new file mode 100644 index 0000000..9bbd0ba --- /dev/null +++ b/concurrent_test.go @@ -0,0 +1,218 @@ +package needle + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "testing" + + "github.com/danpasecinic/needle/internal/reflect" +) + +func TestConcurrentSingletonResolve(t *testing.T) { + t.Parallel() + + c := New() + _ = ProvideValue(c, &testCounter{id: 42}) + + const n = 100 + results := make([]*testCounter, n) + + var wg sync.WaitGroup + wg.Add(n) + for i := range n { + go func(idx int) { + defer wg.Done() + val, err := Invoke[*testCounter](c) + if err != nil { + t.Errorf("goroutine %d: %v", idx, err) + return + } + results[idx] = val + }(i) + } + wg.Wait() + + for i := 1; i < n; i++ { + if results[i] != results[0] { + t.Fatal("singleton must return same instance across goroutines") + } + } +} + +func TestConcurrentNamedProvideAndInvoke(t *testing.T) { + t.Parallel() + + c := New() + const n = 50 + + var wg sync.WaitGroup + wg.Add(n) + for i := range n { + go func(idx int) { + defer wg.Done() + _ = ProvideNamedValue(c, fmt.Sprintf("s%d", idx), &concService{id: idx}) + }(i) + } + wg.Wait() + + wg.Add(n) + for i := range n { + go func(idx int) { + defer wg.Done() + val, err := InvokeNamed[*concService](c, fmt.Sprintf("s%d", idx)) + if err != nil { + t.Errorf("invoke s%d: %v", idx, err) + return + } + if val.id != idx { + t.Errorf("s%d: expected id %d, got %d", idx, idx, val.id) + } + }(i) + } + wg.Wait() +} + +func TestConcurrentPoolAcquireRelease(t *testing.T) { + t.Parallel() + + c := New() + var created atomic.Int32 + + _ = Provide(c, func(_ context.Context, _ Resolver) (*testCounter, error) { + return &testCounter{id: int(created.Add(1))}, nil + }, WithPoolSize(3)) + + key := reflect.TypeKey[*testCounter]() + + // Pre-fill: create 3 instances, then release all to pool + instances := make([]*testCounter, 3) + for i := range 3 { + inst, err := Invoke[*testCounter](c) + if err != nil { + t.Fatalf("pre-fill %d: %v", i, err) + } + instances[i] = inst + } + for _, inst := range instances { + c.Release(key, inst) + } + + // Concurrent acquire-release cycles from the pre-filled pool + const n = 20 + var wg sync.WaitGroup + wg.Add(n) + for range n { + go func() { + defer wg.Done() + inst, err := Invoke[*testCounter](c) + if err != nil { + return + } + c.Release(key, inst) + }() + } + wg.Wait() + + if created.Load() < 3 { + t.Errorf("expected at least 3 provider calls, got %d", created.Load()) + } +} + +func TestConcurrentTransientDifferentKeys(t *testing.T) { + t.Parallel() + + c := New() + const n = 50 + + for i := range n { + idx := i + _ = ProvideNamed(c, fmt.Sprintf("t%d", idx), func(_ context.Context, _ Resolver) (*concService, error) { + return &concService{id: idx}, nil + }, WithScope(Transient)) + } + + var wg sync.WaitGroup + wg.Add(n) + for i := range n { + go func(idx int) { + defer wg.Done() + val, err := InvokeNamed[*concService](c, fmt.Sprintf("t%d", idx)) + if err != nil { + t.Errorf("t%d: %v", idx, err) + return + } + if val == nil { + t.Errorf("t%d: got nil", idx) + } + }(i) + } + wg.Wait() +} + +func TestConcurrentRequestScopeIsolation(t *testing.T) { + t.Parallel() + + c := New() + var created atomic.Int32 + + _ = Provide(c, func(_ context.Context, _ Resolver) (*testCounter, error) { + return &testCounter{id: int(created.Add(1))}, nil + }, WithScope(Request)) + + const numContexts = 10 + const resolvesPerCtx = 5 + + distinct := make(map[*testCounter]bool) + + for ci := range numContexts { + ctx := WithRequestScope(context.Background()) + first, err := InvokeCtx[*testCounter](ctx, c) + if err != nil { + t.Fatalf("ctx %d: %v", ci, err) + } + for ri := 1; ri < resolvesPerCtx; ri++ { + val, err := InvokeCtx[*testCounter](ctx, c) + if err != nil { + t.Fatalf("ctx %d resolve %d: %v", ci, ri, err) + } + if val != first { + t.Errorf("ctx %d: resolve %d returned different instance", ci, ri) + } + } + distinct[first] = true + } + + if len(distinct) != numContexts { + t.Errorf("expected %d distinct instances, got %d", numContexts, len(distinct)) + } +} + +func TestConcurrentReplaceNoRace(t *testing.T) { + t.Parallel() + + c := New() + _ = ProvideValue(c, &testCounter{id: 0}) + + const n = 50 + var wg sync.WaitGroup + wg.Add(n) + for i := range n { + go func(idx int) { + defer wg.Done() + if idx%2 == 0 { + _ = ReplaceValue(c, &testCounter{id: idx}) + } else { + // Invoke may fail due to concurrent replace, + // we're verifying no panics or data races. + _, _ = Invoke[*testCounter](c) + } + }(i) + } + wg.Wait() +} + +type concService struct { + id int +} From 5405bdbe424e3d5c77dc87a6c45f1c4828d9067e Mon Sep 17 00:00:00 2001 From: danpasecinic Date: Sat, 7 Mar 2026 19:18:52 +0100 Subject: [PATCH 6/7] docs: add scope selection guide and Replace API documentation Add a Choosing a Scope decision table and a Replacing Services section with usage examples to the README. --- README.md | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/README.md b/README.md index e1e2615..749e33f 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,52 @@ See the [examples](examples/) directory: - [optional](examples/optional/) - Optional dependencies with fallbacks - [parallel](examples/parallel/) - Parallel startup/shutdown +## Choosing a Scope + +| Scope | Lifetime | Use When | +|-------|----------|----------| +| **Singleton** (default) | One instance for the container lifetime | Stateful services: DB pools, config, caches, loggers | +| **Transient** | New instance every resolution | Stateless handlers, commands, lightweight value objects | +| **Request** | One instance per `WithRequestScope(ctx)` | Per-HTTP-request state: request loggers, auth context, transaction managers | +| **Pooled** | Reusable instances from a fixed-size pool | Expensive-to-create, stateless-between-uses resources: gRPC connections, worker objects | + +```go +needle.Provide(c, NewService) // Singleton (default) +needle.Provide(c, NewHandler, needle.WithScope(needle.Transient)) +needle.Provide(c, NewRequestLogger, needle.WithScope(needle.Request)) +needle.Provide(c, NewWorker, needle.WithPoolSize(10)) // Pooled with 10 slots +``` + +Pooled services must be released by the caller via `c.Release(key, instance)`. If the pool is full, the instance is dropped and a warning is logged. + +## Replacing Services + +Replace services at runtime without restarting the container. Useful for feature flags, A/B testing, test doubles, or configuration updates. + +```go +// Replace with a new value +needle.ReplaceValue(c, &Config{Port: 9090}) + +// Replace with a new provider +needle.Replace(c, func(ctx context.Context, r needle.Resolver) (*Server, error) { + return &Server{Config: needle.MustInvoke[*Config](c)}, nil +}) + +// Replace with auto-wired constructor +needle.ReplaceFunc[*Service](c, NewService) + +// Replace with struct injection +needle.ReplaceStruct[*Service](c) + +// Named variants +needle.ReplaceNamedValue(c, "primary", &Config{Port: 5432}) +needle.ReplaceNamed(c, "primary", provider) +``` + +All Replace functions accept the same options as Provide (`WithScope`, `WithOnStart`, `WithOnStop`, `WithLazy`, `WithPoolSize`). If the service does not exist yet, Replace creates it. If it does exist, the old entry is removed from both the registry and the dependency graph before re-registering. + +`Must` variants (`MustReplace`, `MustReplaceValue`, `MustReplaceFunc`, `MustReplaceStruct`) panic on error. + ## Benchmarks Needle wins benchmark categories against uber/fx, samber/do, and uber/dig. From e02b5940fd0454a0c56ffde93513fcf26da1d545 Mon Sep 17 00:00:00 2001 From: danpasecinic Date: Sat, 7 Mar 2026 19:29:17 +0100 Subject: [PATCH 7/7] fix: use locked registry methods in Replace to prevent data races Replace and ReplaceValue were calling Unsafe registry methods (no registry lock) while holding only the container mutex. Concurrent Resolve calls via GetInstanceFast hold only the registry mutex, creating a race on the services map. Switch to locked registry methods so both paths coordinate through the same lock. Also remove leftover rebase conflict marker in resolve.go. --- internal/container/replace.go | 10 +++++----- internal/container/resolve.go | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/internal/container/replace.go b/internal/container/replace.go index 7c0573f..481d6ab 100644 --- a/internal/container/replace.go +++ b/internal/container/replace.go @@ -6,14 +6,14 @@ func (c *Container) Replace(key string, provider ProviderFunc, dependencies []st c.mu.Lock() defer c.mu.Unlock() - c.registry.RemoveUnsafe(key) + c.registry.Remove(key) c.graph.RemoveNodeUnsafe(key) - c.registry.RegisterUnsafe(key, provider, dependencies) + _ = c.registry.Register(key, provider, dependencies) c.graph.AddNodeUnsafe(key, dependencies) if len(dependencies) > 0 && c.graph.HasCycleUnsafe() { - c.registry.RemoveUnsafe(key) + c.registry.Remove(key) c.graph.RemoveNodeUnsafe(key) cyclePath := c.graph.FindCyclePathUnsafe(key) return fmt.Errorf("circular dependency detected: %v", cyclePath) @@ -26,10 +26,10 @@ func (c *Container) ReplaceValue(key string, value any) error { c.mu.Lock() defer c.mu.Unlock() - c.registry.RemoveUnsafe(key) + c.registry.Remove(key) c.graph.RemoveNodeUnsafe(key) - c.registry.RegisterValueUnsafe(key, value) + _ = c.registry.RegisterValue(key, value) c.graph.AddNodeUnsafe(key, nil) return nil } diff --git a/internal/container/resolve.go b/internal/container/resolve.go index 2b927b9..a790238 100644 --- a/internal/container/resolve.go +++ b/internal/container/resolve.go @@ -45,7 +45,7 @@ func (c *Container) resolveSlow(ctx context.Context, key string) (any, error) { err := fmt.Errorf("circular resolution detected for: %s", key) c.callResolveHooks(key, time.Since(start), err) return nil, err - } dd399ff (refactor: extract lock helpers in container internals) + } c.mu.RLock() entry, exists := c.registry.Get(key)