@@ -18,13 +18,15 @@ import (
1818 "bytes"
1919 "errors"
2020 "fmt"
21+ "sync"
2122)
2223
2324// callSet represents a set of expected calls, indexed by receiver and method
2425// name.
2526type callSet struct {
2627 // Calls that are still expected.
27- expected map [callSetKey ][]* Call
28+ expected map [callSetKey ][]* Call
29+ expectedMu * sync.Mutex
2830 // Calls that have been exhausted.
2931 exhausted map [callSetKey ][]* Call
3032 // when set to true,
@@ -39,14 +41,16 @@ type callSetKey struct {
3941
4042func newCallSet () * callSet {
4143 return & callSet {
42- expected : make (map [callSetKey ][]* Call ),
43- exhausted : make (map [callSetKey ][]* Call ),
44+ expected : make (map [callSetKey ][]* Call ),
45+ expectedMu : & sync.Mutex {},
46+ exhausted : make (map [callSetKey ][]* Call ),
4447 }
4548}
4649
4750func newOverridableCallSet () * callSet {
4851 return & callSet {
4952 expected : make (map [callSetKey ][]* Call ),
53+ expectedMu : & sync.Mutex {},
5054 exhausted : make (map [callSetKey ][]* Call ),
5155 allowOverride : true ,
5256 }
@@ -55,6 +59,10 @@ func newOverridableCallSet() *callSet {
5559// Add adds a new expected call.
5660func (cs callSet ) Add (call * Call ) {
5761 key := callSetKey {call .receiver , call .method }
62+
63+ cs .expectedMu .Lock ()
64+ defer cs .expectedMu .Unlock ()
65+
5866 m := cs .expected
5967 if call .exhausted () {
6068 m = cs .exhausted
@@ -70,6 +78,10 @@ func (cs callSet) Add(call *Call) {
7078// Remove removes an expected call.
7179func (cs callSet ) Remove (call * Call ) {
7280 key := callSetKey {call .receiver , call .method }
81+
82+ cs .expectedMu .Lock ()
83+ defer cs .expectedMu .Unlock ()
84+
7385 calls := cs .expected [key ]
7486 for i , c := range calls {
7587 if c == call {
@@ -85,6 +97,9 @@ func (cs callSet) Remove(call *Call) {
8597func (cs callSet ) FindMatch (receiver interface {}, method string , args []interface {}) (* Call , error ) {
8698 key := callSetKey {receiver , method }
8799
100+ cs .expectedMu .Lock ()
101+ defer cs .expectedMu .Unlock ()
102+
88103 // Search through the expected calls.
89104 expected := cs .expected [key ]
90105 var callsErrors bytes.Buffer
@@ -119,6 +134,9 @@ func (cs callSet) FindMatch(receiver interface{}, method string, args []interfac
119134
120135// Failures returns the calls that are not satisfied.
121136func (cs callSet ) Failures () []* Call {
137+ cs .expectedMu .Lock ()
138+ defer cs .expectedMu .Unlock ()
139+
122140 failures := make ([]* Call , 0 , len (cs .expected ))
123141 for _ , calls := range cs .expected {
124142 for _ , call := range calls {
@@ -129,3 +147,19 @@ func (cs callSet) Failures() []*Call {
129147 }
130148 return failures
131149}
150+
151+ // Satisfied returns true in case all expected calls in this callSet are satisfied.
152+ func (cs callSet ) Satisfied () bool {
153+ cs .expectedMu .Lock ()
154+ defer cs .expectedMu .Unlock ()
155+
156+ for _ , calls := range cs .expected {
157+ for _ , call := range calls {
158+ if ! call .satisfied () {
159+ return false
160+ }
161+ }
162+ }
163+
164+ return true
165+ }
0 commit comments