Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,52 @@ See the [examples](examples/) directory:
- [optional](examples/optional/) - Optional dependencies with fallbacks
- [parallel](examples/parallel/) - Parallel startup/shutdown

## Choosing a Scope

| Scope | Lifetime | Use When |
|-------|----------|----------|
| **Singleton** (default) | One instance for the container lifetime | Stateful services: DB pools, config, caches, loggers |
| **Transient** | New instance every resolution | Stateless handlers, commands, lightweight value objects |
| **Request** | One instance per `WithRequestScope(ctx)` | Per-HTTP-request state: request loggers, auth context, transaction managers |
| **Pooled** | Reusable instances from a fixed-size pool | Expensive-to-create, stateless-between-uses resources: gRPC connections, worker objects |

```go
needle.Provide(c, NewService) // Singleton (default)
needle.Provide(c, NewHandler, needle.WithScope(needle.Transient))
needle.Provide(c, NewRequestLogger, needle.WithScope(needle.Request))
needle.Provide(c, NewWorker, needle.WithPoolSize(10)) // Pooled with 10 slots
```

Pooled services must be released by the caller via `c.Release(key, instance)`. If the pool is full, the instance is dropped and a warning is logged.

## Replacing Services

Replace services at runtime without restarting the container. Useful for feature flags, A/B testing, test doubles, or configuration updates.

```go
// Replace with a new value
needle.ReplaceValue(c, &Config{Port: 9090})

// Replace with a new provider
needle.Replace(c, func(ctx context.Context, r needle.Resolver) (*Server, error) {
return &Server{Config: needle.MustInvoke[*Config](c)}, nil
})

// Replace with auto-wired constructor
needle.ReplaceFunc[*Service](c, NewService)

// Replace with struct injection
needle.ReplaceStruct[*Service](c)

// Named variants
needle.ReplaceNamedValue(c, "primary", &Config{Port: 5432})
needle.ReplaceNamed(c, "primary", provider)
```

All Replace functions accept the same options as Provide (`WithScope`, `WithOnStart`, `WithOnStop`, `WithLazy`, `WithPoolSize`). If the service does not exist yet, Replace creates it. If it does exist, the old entry is removed from both the registry and the dependency graph before re-registering.

`Must` variants (`MustReplace`, `MustReplaceValue`, `MustReplaceFunc`, `MustReplaceStruct`) panic on error.

## Benchmarks

Needle wins benchmark categories against uber/fx, samber/do, and uber/dig.
Expand Down
218 changes: 218 additions & 0 deletions concurrent_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
package needle

import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"

"github.com/danpasecinic/needle/internal/reflect"
)

func TestConcurrentSingletonResolve(t *testing.T) {
t.Parallel()

c := New()
_ = ProvideValue(c, &testCounter{id: 42})

const n = 100
results := make([]*testCounter, n)

var wg sync.WaitGroup
wg.Add(n)
for i := range n {
go func(idx int) {
defer wg.Done()
val, err := Invoke[*testCounter](c)
if err != nil {
t.Errorf("goroutine %d: %v", idx, err)
return
}
results[idx] = val
}(i)
}
wg.Wait()

for i := 1; i < n; i++ {
if results[i] != results[0] {
t.Fatal("singleton must return same instance across goroutines")
}
}
}

func TestConcurrentNamedProvideAndInvoke(t *testing.T) {
t.Parallel()

c := New()
const n = 50

var wg sync.WaitGroup
wg.Add(n)
for i := range n {
go func(idx int) {
defer wg.Done()
_ = ProvideNamedValue(c, fmt.Sprintf("s%d", idx), &concService{id: idx})
}(i)
}
wg.Wait()

wg.Add(n)
for i := range n {
go func(idx int) {
defer wg.Done()
val, err := InvokeNamed[*concService](c, fmt.Sprintf("s%d", idx))
if err != nil {
t.Errorf("invoke s%d: %v", idx, err)
return
}
if val.id != idx {
t.Errorf("s%d: expected id %d, got %d", idx, idx, val.id)
}
}(i)
}
wg.Wait()
}

func TestConcurrentPoolAcquireRelease(t *testing.T) {
t.Parallel()

c := New()
var created atomic.Int32

_ = Provide(c, func(_ context.Context, _ Resolver) (*testCounter, error) {
return &testCounter{id: int(created.Add(1))}, nil
}, WithPoolSize(3))

key := reflect.TypeKey[*testCounter]()

// Pre-fill: create 3 instances, then release all to pool
instances := make([]*testCounter, 3)
for i := range 3 {
inst, err := Invoke[*testCounter](c)
if err != nil {
t.Fatalf("pre-fill %d: %v", i, err)
}
instances[i] = inst
}
for _, inst := range instances {
c.Release(key, inst)
}

// Concurrent acquire-release cycles from the pre-filled pool
const n = 20
var wg sync.WaitGroup
wg.Add(n)
for range n {
go func() {
defer wg.Done()
inst, err := Invoke[*testCounter](c)
if err != nil {
return
}
c.Release(key, inst)
}()
}
wg.Wait()

if created.Load() < 3 {
t.Errorf("expected at least 3 provider calls, got %d", created.Load())
}
}

func TestConcurrentTransientDifferentKeys(t *testing.T) {
t.Parallel()

c := New()
const n = 50

for i := range n {
idx := i
_ = ProvideNamed(c, fmt.Sprintf("t%d", idx), func(_ context.Context, _ Resolver) (*concService, error) {
return &concService{id: idx}, nil
}, WithScope(Transient))
}

var wg sync.WaitGroup
wg.Add(n)
for i := range n {
go func(idx int) {
defer wg.Done()
val, err := InvokeNamed[*concService](c, fmt.Sprintf("t%d", idx))
if err != nil {
t.Errorf("t%d: %v", idx, err)
return
}
if val == nil {
t.Errorf("t%d: got nil", idx)
}
}(i)
}
wg.Wait()
}

func TestConcurrentRequestScopeIsolation(t *testing.T) {
t.Parallel()

c := New()
var created atomic.Int32

_ = Provide(c, func(_ context.Context, _ Resolver) (*testCounter, error) {
return &testCounter{id: int(created.Add(1))}, nil
}, WithScope(Request))

const numContexts = 10
const resolvesPerCtx = 5

distinct := make(map[*testCounter]bool)

for ci := range numContexts {
ctx := WithRequestScope(context.Background())
first, err := InvokeCtx[*testCounter](ctx, c)
if err != nil {
t.Fatalf("ctx %d: %v", ci, err)
}
for ri := 1; ri < resolvesPerCtx; ri++ {
val, err := InvokeCtx[*testCounter](ctx, c)
if err != nil {
t.Fatalf("ctx %d resolve %d: %v", ci, ri, err)
}
if val != first {
t.Errorf("ctx %d: resolve %d returned different instance", ci, ri)
}
}
distinct[first] = true
}

if len(distinct) != numContexts {
t.Errorf("expected %d distinct instances, got %d", numContexts, len(distinct))
}
}

func TestConcurrentReplaceNoRace(t *testing.T) {
t.Parallel()

c := New()
_ = ProvideValue(c, &testCounter{id: 0})

const n = 50
var wg sync.WaitGroup
wg.Add(n)
for i := range n {
go func(idx int) {
defer wg.Done()
if idx%2 == 0 {
_ = ReplaceValue(c, &testCounter{id: idx})
} else {
// Invoke may fail due to concurrent replace,
// we're verifying no panics or data races.
_, _ = Invoke[*testCounter](c)
}
}(i)
}
wg.Wait()
}

type concService struct {
id int
}
11 changes: 8 additions & 3 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,21 @@
//
// Use Optional for dependencies that may or may not be registered:
//
// opt := needle.InvokeOptional[*Cache](c)
// opt, err := needle.InvokeOptional[*Cache](c)
// if err != nil {
// // registered but resolution failed
// }
// if opt.Present() {
// cache := opt.Value()
// }
//
// // Or use OrElse for default values
// cache := needle.InvokeOptional[*Cache](c).OrElse(defaultCache)
// opt, _ := needle.InvokeOptional[*Cache](c)
// cache := opt.OrElse(defaultCache)
//
// // OrElseFunc for lazy defaults
// cache := needle.InvokeOptional[*Cache](c).OrElseFunc(func() *Cache {
// opt, _ := needle.InvokeOptional[*Cache](c)
// cache := opt.OrElseFunc(func() *Cache {
// return NewDefaultCache()
// })
//
Expand Down
30 changes: 3 additions & 27 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,30 +112,14 @@ func newError(code ErrorCode, message string, cause error) *Error {
}
}

func errServiceNotFound(serviceType string) *Error { //nolint:unused // reserved for future use
func errServiceNotFound(serviceType string) *Error {
return newError(
ErrCodeServiceNotFound,
fmt.Sprintf("no provider registered for type %s", serviceType),
nil,
).WithService(serviceType)
}

func errCircularDependency(chain []string) *Error { //nolint:unused // reserved for future use
return newError(
ErrCodeCircularDependency,
fmt.Sprintf("circular dependency detected: %s", strings.Join(chain, " -> ")),
nil,
).WithStack(chain)
}

func errDuplicateService(serviceType string) *Error { //nolint:unused // reserved for future use
return newError(
ErrCodeDuplicateService,
fmt.Sprintf("provider already registered for type %s", serviceType),
nil,
).WithService(serviceType)
}

func errResolutionFailed(serviceType string, cause error) *Error {
return newError(
ErrCodeResolutionFailed,
Expand All @@ -144,23 +128,15 @@ func errResolutionFailed(serviceType string, cause error) *Error {
).WithService(serviceType)
}

func errProviderFailed(serviceType string, cause error) *Error { //nolint:unused // reserved for future use
return newError(
ErrCodeProviderFailed,
fmt.Sprintf("provider for %s returned error", serviceType),
cause,
).WithService(serviceType)
}

func errStartupFailed(serviceType string, cause error) *Error { //nolint:unused // reserved for future use
func errStartupFailed(serviceType string, cause error) *Error {
return newError(
ErrCodeStartupFailed,
fmt.Sprintf("failed to start %s", serviceType),
cause,
).WithService(serviceType)
}

func errShutdownFailed(serviceType string, cause error) *Error { //nolint:unused // reserved for future use
func errShutdownFailed(serviceType string, cause error) *Error {
return newError(
ErrCodeShutdownFailed,
fmt.Sprintf("failed to stop %s", serviceType),
Expand Down
Loading
Loading