Skip to content

Commit 7eed21f

Browse files
committed
feat(shutdown): make shutdown async and parallel (shutdown in reverse invocation order)
1 parent 9170ba6 commit 7eed21f

File tree

12 files changed

+310
-111
lines changed

12 files changed

+310
-111
lines changed

dag.go

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -23,59 +23,82 @@ type EdgeService struct {
2323
// newDAG creates a new DAG (Directed Acyclic Graph) with initialized dependencies and dependents maps.
2424
func newDAG() *DAG {
2525
return &DAG{
26-
dependencies: new(sync.Map),
27-
dependents: new(sync.Map),
26+
mu: sync.RWMutex{},
27+
dependencies: map[EdgeService]map[EdgeService]struct{}{},
28+
dependents: map[EdgeService]map[EdgeService]struct{}{},
2829
}
2930
}
3031

3132
// DAG represents a Directed Acyclic Graph of services, tracking dependencies and dependents.
3233
type DAG struct {
33-
dependencies *sync.Map
34-
dependents *sync.Map
34+
mu sync.RWMutex
35+
dependencies map[EdgeService]map[EdgeService]struct{}
36+
dependents map[EdgeService]map[EdgeService]struct{}
3537
}
3638

3739
// addDependency adds a dependency relationship from one service to another in the DAG.
3840
func (d *DAG) addDependency(fromScopeID, fromScopeName, fromServiceName, toScopeID, toScopeName, toServiceName string) {
3941
from := newEdgeService(fromScopeID, fromScopeName, fromServiceName)
4042
to := newEdgeService(toScopeID, toScopeName, toServiceName)
4143

42-
d.addToMap(d.dependencies, from, to)
43-
d.addToMap(d.dependents, to, from)
44+
d.mu.Lock()
45+
defer d.mu.Unlock()
46+
47+
// from -> to
48+
if _, ok := d.dependencies[from]; !ok {
49+
d.dependencies[from] = map[EdgeService]struct{}{}
50+
}
51+
d.dependencies[from][to] = struct{}{}
52+
53+
// from <- to
54+
if _, ok := d.dependents[to]; !ok {
55+
d.dependents[to] = map[EdgeService]struct{}{}
56+
}
57+
d.dependents[to][from] = struct{}{}
4458
}
4559

46-
// addToMap is a helper function to add a key-value pair to a sync.Map, creating a new sync.Map for the value if necessary.
47-
func (d *DAG) addToMap(dependencyMap *sync.Map, key, value interface{}) {
48-
valueMap := new(sync.Map)
49-
valueMap.Store(value, struct{}{})
60+
// removeService removes a dependency relationship between services in the DAG.
61+
func (d *DAG) removeService(scopeID, scopeName, serviceName string) {
62+
edge := newEdgeService(scopeID, scopeName, serviceName)
63+
64+
d.mu.Lock()
65+
defer d.mu.Unlock()
5066

51-
if actual, loaded := dependencyMap.LoadOrStore(key, valueMap); loaded {
52-
actual.(*sync.Map).Store(value, struct{}{})
67+
dependencies, dependents := d.explainServiceImplem(edge)
68+
69+
for _, dependency := range dependencies {
70+
delete(d.dependents[dependency], edge)
5371
}
72+
73+
// should be empty, because we remove dependencies in the inverse invocation order
74+
for _, dependent := range dependents {
75+
delete(d.dependencies[dependent], edge)
76+
}
77+
78+
delete(d.dependencies, edge)
79+
delete(d.dependents, edge)
5480
}
5581

5682
// explainService provides information about a service's dependencies and dependents in the DAG.
5783
func (d *DAG) explainService(scopeID, scopeName, serviceName string) (dependencies, dependents []EdgeService) {
5884
edge := newEdgeService(scopeID, scopeName, serviceName)
5985

60-
dependencies = d.getServicesFromMap(d.dependencies, edge)
61-
dependents = d.getServicesFromMap(d.dependents, edge)
86+
d.mu.RLock()
87+
defer d.mu.RUnlock()
6288

63-
return dependencies, dependents
89+
return d.explainServiceImplem(edge)
6490
}
6591

66-
// getServicesFromMap is a helper function to retrieve services related to a specific key from a sync.Map.
67-
func (d *DAG) getServicesFromMap(serviceMap *sync.Map, edge EdgeService) []EdgeService {
68-
var services []EdgeService
69-
70-
if kv, ok := serviceMap.Load(edge); ok {
71-
kv.(*sync.Map).Range(func(key, value interface{}) bool {
72-
edgeService, ok := key.(EdgeService)
73-
if ok {
74-
services = append(services, edgeService)
75-
}
76-
return ok
77-
})
92+
func (d *DAG) explainServiceImplem(edge EdgeService) (dependencies, dependents []EdgeService) {
93+
dependencies, dependents = []EdgeService{}, []EdgeService{}
94+
95+
if kv, ok := d.dependencies[edge]; ok {
96+
dependencies = keys(kv)
7897
}
7998

80-
return services
99+
if kv, ok := d.dependents[edge]; ok {
100+
dependents = keys(kv)
101+
}
102+
103+
return dependencies, dependents
81104
}

dag_test.go

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package do
22

33
import (
4-
"sync"
54
"testing"
65

76
"github.com/stretchr/testify/assert"
@@ -24,11 +23,11 @@ func TestNewDAG(t *testing.T) {
2423
is := assert.New(t)
2524

2625
dag := newDAG()
27-
expectedDependencies := unSyncMap(new(sync.Map))
28-
expectedDependents := unSyncMap(new(sync.Map))
26+
expectedDependencies := map[EdgeService]map[EdgeService]struct{}{}
27+
expectedDependents := map[EdgeService]map[EdgeService]struct{}{}
2928

30-
is.Equal(expectedDependencies, unSyncMap(dag.dependencies))
31-
is.Equal(expectedDependents, unSyncMap(dag.dependents))
29+
is.Equal(expectedDependencies, dag.dependencies)
30+
is.Equal(expectedDependents, dag.dependents)
3231
}
3332

3433
// TestDAG_addDependency checks the addition of dependencies to the DAG.
@@ -44,19 +43,42 @@ func TestDAG_addDependency(t *testing.T) {
4443

4544
dag.addDependency("scope1", "scope1", "service1", "scope2", "scope2", "service2")
4645

47-
expectedDependencies := map[interface{}]interface{}{edge1: map[interface{}]interface{}{edge2: struct{}{}}}
48-
expectedDependents := map[interface{}]interface{}{edge2: map[interface{}]interface{}{edge1: struct{}{}}}
46+
expectedDependencies := map[EdgeService]map[EdgeService]struct{}{edge1: {edge2: {}}}
47+
expectedDependents := map[EdgeService]map[EdgeService]struct{}{edge2: {edge1: {}}}
4948

50-
is.Equal(expectedDependencies, unSyncMap(dag.dependencies))
51-
is.Equal(expectedDependents, unSyncMap(dag.dependents))
49+
is.Equal(expectedDependencies, dag.dependencies)
50+
is.Equal(expectedDependents, dag.dependents)
5251

5352
dag.addDependency("scope3", "scope3", "service3", "scope2", "scope2", "service2")
5453

55-
expectedDependencies[edge3] = map[interface{}]interface{}{edge2: struct{}{}}
56-
expectedDependents[edge2] = map[interface{}]interface{}{edge1: struct{}{}, edge3: struct{}{}}
54+
expectedDependencies = map[EdgeService]map[EdgeService]struct{}{edge1: {edge2: {}}, edge3: {edge2: {}}}
55+
expectedDependents = map[EdgeService]map[EdgeService]struct{}{edge2: {edge1: {}, edge3: {}}}
5756

58-
is.Equal(expectedDependencies, unSyncMap(dag.dependencies))
59-
is.Equal(expectedDependents, unSyncMap(dag.dependents))
57+
is.Equal(expectedDependencies, dag.dependencies)
58+
is.Equal(expectedDependents, dag.dependents)
59+
}
60+
61+
// TestDAG_removeService checks the removal of dependencies to the DAG.
62+
func TestDAG_removeService(t *testing.T) {
63+
t.Parallel()
64+
is := assert.New(t)
65+
66+
edge1 := newEdgeService("scope1", "scope1", "service1")
67+
// edge2 := newEdgeService("scope2", "scope2", "service2")
68+
edge3 := newEdgeService("scope3", "scope3", "service3")
69+
70+
dag := newDAG()
71+
72+
dag.addDependency("scope1", "scope1", "service1", "scope2", "scope2", "service2")
73+
dag.addDependency("scope3", "scope3", "service3", "scope2", "scope2", "service2")
74+
75+
dag.removeService("scope2", "scope2", "service2")
76+
77+
expectedDependencies := map[EdgeService]map[EdgeService]struct{}{edge1: {}, edge3: {}}
78+
expectedDependents := map[EdgeService]map[EdgeService]struct{}{}
79+
80+
is.Equal(expectedDependencies, dag.dependencies)
81+
is.Equal(expectedDependents, dag.dependents)
6082
}
6183

6284
// TestDAG_explainService checks the explanation of dependencies for a service in the DAG.
@@ -92,19 +114,3 @@ func TestDAG_explainService(t *testing.T) {
92114
is.ElementsMatch([]EdgeService{}, a)
93115
is.ElementsMatch([]EdgeService{}, b)
94116
}
95-
96-
func unSyncMap(syncMap *sync.Map) map[interface{}]interface{} {
97-
result := make(map[interface{}]interface{})
98-
99-
syncMap.Range(func(key, value interface{}) bool {
100-
if vSyncMap, ok := value.(*sync.Map); ok {
101-
result[key] = unSyncMap(vSyncMap)
102-
} else {
103-
result[key] = value
104-
}
105-
106-
return true
107-
})
108-
109-
return result
110-
}

docs/docs/service-lifecycle/shutdowner.md

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ A shutdown can be triggered on a root scope:
1818

1919
```go
2020
// on demand
21-
injector.Shutdown() map[string]error
22-
injector.ShutdownWithContext(context.Context) map[string]error
21+
injector.Shutdown() error
22+
injector.ShutdownWithContext(context.Context) error
2323

2424
// on signal
25-
injector.ShutdownOnSignals(...os.Signal) (os.Signal, map[string]error)
26-
injector.ShutdownOnSignalsWithContext(context.Context, ...os.Signal) (os.Signal, map[string]error)
25+
injector.ShutdownOnSignals(...os.Signal) (os.Signal, error)
26+
injector.ShutdownOnSignalsWithContext(context.Context, ...os.Signal) (os.Signal, error)
2727
```
2828

2929
...on a single service:
@@ -90,9 +90,7 @@ Invoke(i, ...)
9090

9191
ctx := context.WithTimeout(10 * time.Second)
9292
errors := i.ShutdownWithContext(ctx)
93-
for _, err := range errors {
94-
if err != nil {
95-
log.Println("shutdown error:", err)
96-
}
93+
if err != nil {
94+
log.Println("shutdown error:", err)
9795
}
9896
```

errors.go

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,65 @@
11
package do
22

3-
import "errors"
3+
import (
4+
"errors"
5+
"fmt"
6+
"strings"
7+
)
48

59
var ErrServiceNotFound = errors.New("DI: could not find service")
610
var ErrCircularDependency = errors.New("DI: circular dependency detected")
711
var ErrHealthCheckTimeout = errors.New("DI: health check timeout")
12+
13+
func newShutdownErrors() *ShutdownErrors {
14+
return &ShutdownErrors{}
15+
}
16+
17+
type ShutdownErrors map[EdgeService]error
18+
19+
func (e *ShutdownErrors) Add(scopeID string, scopeName string, serviceName string, err error) {
20+
if err != nil {
21+
(*e)[newEdgeService(scopeID, scopeName, serviceName)] = err
22+
}
23+
}
24+
25+
func (e ShutdownErrors) Len() int {
26+
out := 0
27+
for _, v := range e {
28+
if v != nil {
29+
out++
30+
}
31+
}
32+
return out
33+
}
34+
35+
func (e ShutdownErrors) Error() string {
36+
lines := []string{}
37+
for k, v := range e {
38+
if v != nil {
39+
lines = append(lines, fmt.Sprintf(" - %s > %s: %s", k.ScopeName, k.Service, v.Error()))
40+
}
41+
}
42+
43+
if len(lines) == 0 {
44+
return "DI: no shutdown errors"
45+
}
46+
47+
return "DI: shutdown errors:\n" + strings.Join(lines, "\n")
48+
}
49+
50+
func mergeShutdownErrors(ins ...*ShutdownErrors) *ShutdownErrors {
51+
out := newShutdownErrors()
52+
53+
for _, in := range ins {
54+
if in != nil {
55+
se := &ShutdownErrors{}
56+
if ok := errors.As(in, &se); ok {
57+
for k, v := range *se {
58+
(*out)[k] = v
59+
}
60+
}
61+
}
62+
}
63+
64+
return out
65+
}

errors_test.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package do
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
func TestShutdownErrors_Add(t *testing.T) {
10+
is := assert.New(t)
11+
12+
se := newShutdownErrors()
13+
is.Equal(0, len(*se))
14+
is.Equal(0, se.Len())
15+
16+
se.Add("scope-1", "scope-a", "service-a", nil)
17+
is.Equal(0, len(*se))
18+
is.Equal(0, se.Len())
19+
is.EqualValues(&ShutdownErrors{}, se)
20+
21+
se.Add("scope-2", "scope-b", "service-b", assert.AnError)
22+
is.Equal(1, len(*se))
23+
is.Equal(1, se.Len())
24+
is.EqualValues(&ShutdownErrors{
25+
{ScopeID: "scope-2", ScopeName: "scope-b", Service: "service-b"}: assert.AnError,
26+
}, se)
27+
}
28+
29+
func TestShutdownErrors_Error(t *testing.T) {
30+
is := assert.New(t)
31+
32+
se := newShutdownErrors()
33+
is.Equal(0, len(*se))
34+
is.Equal(0, se.Len())
35+
is.EqualValues("DI: no shutdown errors", se.Error())
36+
37+
se.Add("scope-1", "scope-a", "service-a", nil)
38+
is.Equal(0, len(*se))
39+
is.Equal(0, se.Len())
40+
is.EqualValues("DI: no shutdown errors", se.Error())
41+
42+
se.Add("scope-2", "scope-b", "service-b", assert.AnError)
43+
is.Equal(1, len(*se))
44+
is.Equal(1, se.Len())
45+
is.EqualValues("DI: shutdown errors:\n - scope-b > service-b: assert.AnError general error for testing", se.Error())
46+
}
47+
48+
func TestMergeShutdownErrors(t *testing.T) {
49+
is := assert.New(t)
50+
51+
se1 := newShutdownErrors()
52+
se2 := newShutdownErrors()
53+
se3 := newShutdownErrors()
54+
55+
se1.Add("scope-1", "scope-a", "service-a", assert.AnError)
56+
se2.Add("scope-2", "scope-b", "service-b", assert.AnError)
57+
58+
result := mergeShutdownErrors(se1, se2, se3, nil)
59+
is.Equal(2, result.Len())
60+
is.EqualValues(
61+
&ShutdownErrors{
62+
{ScopeID: "scope-1", ScopeName: "scope-a", Service: "service-a"}: assert.AnError,
63+
{ScopeID: "scope-2", ScopeName: "scope-b", Service: "service-b"}: assert.AnError,
64+
},
65+
result,
66+
)
67+
}

examples/http/std/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package main
33
import (
44
"net/http"
55

6-
"github.com/samber/do/http/std"
6+
"github.com/samber/do/http/std/v2"
77
)
88

99
func main() {

0 commit comments

Comments
 (0)