Skip to content

Commit 718bcbb

Browse files
committed
Fix deadlock when progressive invocation handler errors (gammazero#319)
1 parent 5cfa511 commit 718bcbb

File tree

6 files changed

+285
-29
lines changed

6 files changed

+285
-29
lines changed

client/client.go

+83-23
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ type Client struct {
6969
closed bool
7070

7171
routerGoodbye *wamp.Goodbye
72-
idGen *wamp.SyncIDGen
7372
}
7473

7574
// InvokeResult represents the result of invoking a procedure.
@@ -256,7 +255,6 @@ func NewClient(p wamp.Peer, cfg Config) (*Client, error) {
256255
log: cfg.Logger,
257256
debug: cfg.Debug,
258257
cancelMode: wamp.CancelModeKillNoWait,
259-
idGen: new(wamp.SyncIDGen),
260258
}
261259
c.ctx, c.cancel = context.WithCancel(context.Background())
262260
go c.run() // start the core goroutine
@@ -313,7 +311,7 @@ func (c *Client) Subscribe(topic string, fn EventHandler, options wamp.Dict) err
313311
if options == nil {
314312
options = wamp.Dict{}
315313
}
316-
id := c.idGen.Next()
314+
id := c.sess.IdGen.Next()
317315
c.expectReply(id)
318316
c.sess.Send() <- &wamp.Subscribe{
319317
Request: id,
@@ -384,7 +382,7 @@ func (c *Client) Unsubscribe(topic string) error {
384382
return ErrNotConn
385383
}
386384

387-
id := c.idGen.Next()
385+
id := c.sess.IdGen.Next()
388386
c.expectReply(id)
389387
c.sess.Send() <- &wamp.Unsubscribe{
390388
Request: id,
@@ -453,7 +451,7 @@ func (c *Client) Publish(topic string, options wamp.Dict, args wamp.List, kwargs
453451
return ErrNotConn
454452
}
455453

456-
id := c.idGen.Next()
454+
id := c.sess.IdGen.Next()
457455

458456
var pubAck bool
459457
if options == nil {
@@ -570,7 +568,7 @@ func (c *Client) Register(procedure string, fn InvocationHandler, options wamp.D
570568
if !c.Connected() {
571569
return ErrNotConn
572570
}
573-
id := c.idGen.Next()
571+
id := c.sess.IdGen.Next()
574572
c.expectReply(id)
575573
if options == nil {
576574
options = wamp.Dict{}
@@ -636,7 +634,7 @@ func (c *Client) Unregister(procedure string) error {
636634
return ErrNotConn
637635
}
638636

639-
id := c.idGen.Next()
637+
id := c.sess.IdGen.Next()
640638
c.expectReply(id)
641639
c.sess.Send() <- &wamp.Unregister{
642640
Request: id,
@@ -758,7 +756,7 @@ func (c *Client) Call(ctx context.Context, procedure string, options wamp.Dict,
758756
}()
759757
}
760758

761-
id := c.idGen.Next()
759+
id := c.sess.IdGen.Next()
762760
c.expectReply(id)
763761
message := &wamp.Call{
764762
Request: id,
@@ -857,7 +855,7 @@ func (c *Client) CallProgressive(ctx context.Context, procedure string, sendProg
857855
}()
858856
}
859857

860-
id := c.idGen.Next()
858+
id := c.sess.IdGen.Next()
861859
c.expectReply(id)
862860
message := &wamp.Call{
863861
Request: id,
@@ -1445,14 +1443,17 @@ func (c *Client) runReceiveFromRouter(msg wamp.Message) bool {
14451443
if c.debug {
14461444
c.log.Println("Client", c.sess, "received", msg.MessageType())
14471445
}
1446+
14481447
switch msg := msg.(type) {
14491448
case *wamp.Event:
14501449
c.runHandleEvent(msg)
14511450

14521451
case *wamp.Invocation:
14531452
c.runHandleInvocation(msg)
1453+
c.sess.UpdateLastRecvID(msg.Request)
14541454
case *wamp.Interrupt:
14551455
c.runHandleInterrupt(msg)
1456+
c.sess.UpdateLastRecvID(msg.Request)
14561457

14571458
case *wamp.Registered:
14581459
c.runSignalReply(msg, msg.Request)
@@ -1535,6 +1536,38 @@ func (c *Client) runHandleEvent(msg *wamp.Event) {
15351536
handler(msg)
15361537
}
15371538

1539+
func (c *Client) cleanupInvHandlersQueue(cliInvocation clientInvocation) {
1540+
c.sess.Lock()
1541+
defer c.sess.Unlock()
1542+
1543+
handlerQueue, _ := c.invHandlersQueues[cliInvocation]
1544+
delete(c.invHandlersQueues, cliInvocation)
1545+
delete(c.invHandlersCtxs, cliInvocation)
1546+
1547+
if nil == handlerQueue {
1548+
return
1549+
}
1550+
if c.debug {
1551+
c.log.Println("Running cleanupInvHandlersQueue cleanup for", cliInvocation)
1552+
}
1553+
1554+
// Try to get any remaining values off the chan (in case anyone is blocked)
1555+
for {
1556+
select {
1557+
case msg := <-handlerQueue:
1558+
if nil == msg {
1559+
return // chan closed
1560+
}
1561+
continue
1562+
default:
1563+
return
1564+
}
1565+
}
1566+
// Don't bother closing in the very rare event that someone still has
1567+
// a reference and tries to send
1568+
// close(c.invHandlersQueues[cliInvocation])
1569+
}
1570+
15381571
// runHandleInvocation processes an INVOCATION message from the router
15391572
// requesting a call to a registered RPC procedure.
15401573
func (c *Client) runHandleInvocation(msg *wamp.Invocation) {
@@ -1616,6 +1649,19 @@ func (c *Client) runHandleInvocation(msg *wamp.Invocation) {
16161649
handlerQueue, queueExists := c.invHandlersQueues[cliInvocation]
16171650
ctx := c.invHandlersCtxs[cliInvocation]
16181651
if !queueExists {
1652+
// Only create the queue if this is a new request
1653+
if !c.sess.UpdateLastRecvIDCallerHasSessionLock(reqID) {
1654+
c.sess.Unlock()
1655+
if c.debug {
1656+
c.log.Println("Ignoring Invocation with expired reqID=", reqID)
1657+
}
1658+
// discard silently
1659+
return
1660+
}
1661+
1662+
if c.debug {
1663+
c.log.Println("Creating new handlerQueue reqID=", reqID)
1664+
}
16191665
handlerQueue = make(chan *wamp.Invocation, 1)
16201666
c.invHandlersQueues[cliInvocation] = handlerQueue
16211667

@@ -1638,7 +1684,6 @@ func (c *Client) runHandleInvocation(msg *wamp.Invocation) {
16381684
ctx = context.WithValue(ctx, invocationIDCtxKey{}, reqID)
16391685
}
16401686
}
1641-
16421687
c.sess.Unlock()
16431688

16441689
handlerQueue <- msg
@@ -1652,21 +1697,36 @@ func (c *Client) runHandleInvocation(msg *wamp.Invocation) {
16521697
// Otherwise, canceling the call will leak the goroutine that is
16531698
// blocked forever waiting to send the result to the channel.
16541699
resChan := make(chan InvokeResult, 1)
1700+
1701+
// Start goroutine to process inbound Invocation messages
16551702
go func() {
1656-
for msg := range handlerQueue {
1657-
1658-
if isInProgress, _ := msg.Details[wamp.OptProgress].(bool); !isInProgress {
1659-
c.sess.Lock()
1660-
close(c.invHandlersQueues[cliInvocation])
1661-
delete(c.invHandlersQueues, cliInvocation)
1662-
delete(c.invHandlersCtxs, cliInvocation)
1663-
c.sess.Unlock()
1664-
}
1703+
defer c.cleanupInvHandlersQueue(cliInvocation)
1704+
1705+
processMessages := true
1706+
for processMessages {
1707+
select {
1708+
case msg := <-handlerQueue:
1709+
if msg == nil { // chan closed
1710+
return
1711+
}
1712+
if isInProgress, _ := msg.Details[wamp.OptProgress].(bool); !isInProgress {
1713+
c.cleanupInvHandlersQueue(cliInvocation)
1714+
processMessages = false
1715+
}
16651716

1666-
// The Context is passed into the handler to tell the client
1667-
// application to stop whatever it is doing if it cares to pay
1668-
// attention.
1669-
resChan <- handler(ctx, msg)
1717+
// The Context is passed into the handler to tell the client
1718+
// application to stop whatever it is doing if it cares to pay
1719+
// attention.
1720+
result := handler(ctx, msg)
1721+
resChan <- result
1722+
if result.Err != "" && result.Err != wamp.InternalProgressiveOmitResult {
1723+
processMessages = false
1724+
}
1725+
case <-c.Done():
1726+
return
1727+
case <-ctx.Done():
1728+
return
1729+
}
16701730
}
16711731
}()
16721732

client/client_test.go

+96-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ const (
3030

3131
testTopic = "test.topic1"
3232
testTopic2 = "test.topic2"
33+
34+
debugClientEnv = "TEST_DEBUG_CLIENT"
35+
debugRouterEnv = "TEST_DEBUG_ROUTER"
3336
)
3437

3538
var logger stdlog.StdLog
@@ -47,6 +50,7 @@ func checkGoLeaks(t *testing.T) {
4750
func getTestRouter(t *testing.T, realmConfig *router.RealmConfig) router.Router {
4851
config := &router.Config{
4952
RealmConfigs: []*router.RealmConfig{realmConfig},
53+
Debug: os.Getenv(debugRouterEnv) != "",
5054
}
5155
r, err := router.NewRouter(config, logger)
5256
require.NoError(t, err)
@@ -95,7 +99,7 @@ func newTestClientConfig(realmName string, fns ...clientConfigMutator) *Config {
9599
Realm: realmName,
96100
ResponseTimeout: 500 * time.Millisecond,
97101
Logger: logger,
98-
Debug: false,
102+
Debug: os.Getenv(debugClientEnv) != "",
99103
}
100104
for _, fn := range fns {
101105
fn(clientConfig)
@@ -704,6 +708,97 @@ func TestProgressiveCallInvocations(t *testing.T) {
704708
require.NoError(t, err)
705709
}
706710

711+
// Tests that a callee can return error while caller IsInProgress
712+
// Test for #319 to ensure messages in-flight do not re-open the handlerQueue
713+
// or call the handler function after the callee has errored.
714+
func TestProgressiveCallInvocationCalleeError(t *testing.T) {
715+
// Connect two clients to the same server
716+
// t.Setenv(debugRouterEnv, "1")
717+
t.Setenv(debugClientEnv, "1")
718+
callee, caller, rooter := connectedTestClients(t)
719+
720+
const forcedError = wamp.URI("error.forced")
721+
moreArgsSent := make(chan struct{})
722+
errorRaised := false
723+
724+
invocationHandler := func(ctx context.Context, inv *wamp.Invocation) InvokeResult {
725+
switch inv.Arguments[0].(int) {
726+
case 1:
727+
// Eat the first arg
728+
t.Log("n=1 Returning OmitResult")
729+
return InvokeResult{Err: wamp.InternalProgressiveOmitResult}
730+
case 2:
731+
t.Log("n=2 Waiting for moreArgsSent")
732+
// Wait till the 4th arg is sent which means 3 should already
733+
// be waiting
734+
<-moreArgsSent
735+
time.Sleep(100 * time.Millisecond)
736+
errorRaised = true
737+
t.Log("n=2 Returning error (as expected)")
738+
// Error
739+
return InvokeResult{Err: forcedError}
740+
default:
741+
// BUG: The handler function should never be called again
742+
t.Error("Handler should not have been called after error returned")
743+
return InvokeResult{Err: wamp.ErrInvalidArgument}
744+
}
745+
}
746+
747+
const procName = "nexus.test.progprocerr"
748+
749+
// Register procedure
750+
err := callee.Register(procName, invocationHandler, nil)
751+
require.NoError(t, err)
752+
753+
// Test calling the procedure.
754+
callArgs := [...]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
755+
ctx := context.Background()
756+
757+
sendCount := 0
758+
sendProgDataCb := func(ctx context.Context) (options wamp.Dict, args wamp.List, kwargs wamp.Dict, err error) {
759+
options = wamp.Dict{
760+
wamp.OptProgress: sendCount < (len(callArgs) - 1),
761+
}
762+
763+
args = wamp.List{callArgs[sendCount]}
764+
sendCount++
765+
766+
// signal the handler should return its error
767+
if 4 == sendCount {
768+
close(moreArgsSent)
769+
}
770+
t.Logf("Sending n=%v", sendCount)
771+
772+
return options, args, nil, nil
773+
}
774+
775+
result, err := caller.CallProgressive(ctx, procName, sendProgDataCb, nil)
776+
require.Error(t, err, "Expected call to return an error")
777+
require.Nil(t, result, "Expected call to return no result")
778+
var rErr RPCError
779+
if errors.As(err, &rErr) {
780+
require.Equal(t, forcedError, rErr.Err.Error, "Unexpected error URI")
781+
} else {
782+
t.Error("Unexpected error type")
783+
}
784+
require.GreaterOrEqual(t, sendCount, 4)
785+
require.True(t, errorRaised, "Error was never raised in handler")
786+
787+
// #319: Show deadlock
788+
t.Log("Closing rooter")
789+
rooter.Close()
790+
t.Log("Closing caller")
791+
require.NoError(t, caller.Close())
792+
goleak.VerifyNone(t)
793+
t.Log("Closing callee")
794+
// #319: We never get past here
795+
require.NoError(t, callee.Close())
796+
797+
t.Log("All closed")
798+
goleak.VerifyNone(t)
799+
t.Log("Done")
800+
}
801+
707802
func TestProgressiveCallsAndResults(t *testing.T) {
708803
// Connect two clients to the same server
709804
callee, caller, _ := connectedTestClients(t)

wamp/idgen.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@ import (
66
"time"
77
)
88

9-
const maxID int64 = 1 << 53
9+
const MaxID uint64 = 1 << 53
1010

1111
func init() {
1212
rand.Seed(time.Now().UnixNano())
1313
}
1414

1515
// NewID generates a random WAMP ID.
1616
func GlobalID() ID {
17-
return ID(rand.Int63n(maxID)) //nolint:gosec
17+
return ID(rand.Int63n(int64(MaxID))) //nolint:gosec
1818
}
1919

2020
// IDGen is generator for WAMP request IDs. Create with new(IDGen).
@@ -29,13 +29,13 @@ func GlobalID() ID {
2929
//
3030
// See https://github.com/wamp-proto/wamp-proto/blob/master/spec/basic.md#ids
3131
type IDGen struct {
32-
next int64
32+
next uint64
3333
}
3434

3535
// Next returns next ID.
3636
func (g *IDGen) Next() ID {
3737
g.next++
38-
if g.next > maxID {
38+
if g.next > MaxID {
3939
g.next = 1
4040
}
4141
return ID(g.next)

wamp/idgen_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func TestIDGen(t *testing.T) {
2626
require.Equal(t, ID(2), id2, errMsg)
2727
require.Equal(t, ID(3), id3, errMsg)
2828

29-
idgen.next = int64(1) << 53
29+
idgen.next = uint64(1) << 53
3030
id1 = idgen.Next()
3131
require.Equal(t, ID(1), id1, "Sequential IDs should wrap at 1 << 53")
3232
}

0 commit comments

Comments
 (0)