diff --git a/event_bus.go b/event_bus.go index d4cf80f..64bcf3f 100644 --- a/event_bus.go +++ b/event_bus.go @@ -1,6 +1,7 @@ package EventBus import ( + "context" "fmt" "reflect" "sync" @@ -10,6 +11,7 @@ import ( type BusSubscriber interface { Subscribe(topic string, fn interface{}) error SubscribeAsync(topic string, fn interface{}, transactional bool) error + SubscribeAsyncWithNewContext(topic string, fn interface{}, transactional bool) error SubscribeOnce(topic string, fn interface{}) error SubscribeOnceAsync(topic string, fn interface{}) error Unsubscribe(topic string, handler interface{}) error @@ -46,6 +48,7 @@ type eventHandler struct { async bool transactional bool sync.Mutex // lock for an event handler - useful for running async callbacks serially + withNewCtx bool } // New returns new EventBus with empty handlers. @@ -73,7 +76,7 @@ func (bus *EventBus) doSubscribe(topic string, fn interface{}, handler *eventHan // Returns error if `fn` is not a function. func (bus *EventBus) Subscribe(topic string, fn interface{}) error { return bus.doSubscribe(topic, fn, &eventHandler{ - reflect.ValueOf(fn), false, false, false, sync.Mutex{}, + reflect.ValueOf(fn), false, false, false, sync.Mutex{}, false, }) } @@ -83,7 +86,13 @@ func (bus *EventBus) Subscribe(topic string, fn interface{}) error { // Returns error if `fn` is not a function. func (bus *EventBus) SubscribeAsync(topic string, fn interface{}, transactional bool) error { return bus.doSubscribe(topic, fn, &eventHandler{ - reflect.ValueOf(fn), false, true, transactional, sync.Mutex{}, + reflect.ValueOf(fn), false, true, transactional, sync.Mutex{}, false, + }) +} + +func (bus *EventBus) SubscribeAsyncWithNewContext(topic string, fn interface{}, transactional bool) error { + return bus.doSubscribe(topic, fn, &eventHandler{ + reflect.ValueOf(fn), false, true, transactional, sync.Mutex{}, true, }) } @@ -91,7 +100,7 @@ func (bus *EventBus) SubscribeAsync(topic string, fn interface{}, transactional // Returns error if `fn` is not a function. func (bus *EventBus) SubscribeOnce(topic string, fn interface{}) error { return bus.doSubscribe(topic, fn, &eventHandler{ - reflect.ValueOf(fn), true, false, false, sync.Mutex{}, + reflect.ValueOf(fn), true, false, false, sync.Mutex{}, false, }) } @@ -100,7 +109,7 @@ func (bus *EventBus) SubscribeOnce(topic string, fn interface{}) error { // Returns error if `fn` is not a function. func (bus *EventBus) SubscribeOnceAsync(topic string, fn interface{}) error { return bus.doSubscribe(topic, fn, &eventHandler{ - reflect.ValueOf(fn), true, true, false, sync.Mutex{}, + reflect.ValueOf(fn), true, true, false, sync.Mutex{}, false, }) } @@ -147,7 +156,17 @@ func (bus *EventBus) Publish(topic string, args ...interface{}) { if handler.transactional { handler.Lock() } - go bus.doPublishAsync(handler, topic, args...) + var asyncArgs []interface{} + if handler.withNewCtx { + if len(args) == 0 { + panic("context expected, got no args") + } + asyncArgs = append(asyncArgs, context.Background()) + asyncArgs = append(asyncArgs, args[1:]...) + } else { + asyncArgs = args + } + go bus.doPublishAsync(handler, topic, asyncArgs...) } } } diff --git a/event_bus_test.go b/event_bus_test.go index 0cf196d..ede640b 100644 --- a/event_bus_test.go +++ b/event_bus_test.go @@ -1,6 +1,7 @@ package EventBus import ( + "context" "testing" "time" ) @@ -154,3 +155,85 @@ func TestSubscribeAsync(t *testing.T) { t.Fail() } } + +func TestSubscribeAsyncContextCancelled(t *testing.T) { + var isCancelledDone bool + isNotCancelledDone := true + + bus := New() + bus.SubscribeAsync("topic", func(ctx context.Context, done *bool) { + select { + case <-time.NewTimer(10*time.Millisecond).C: + *done = false + case <-ctx.Done(): + *done = true + } + }, false) + + notCancelled, notCancel := context.WithCancel(context.Background()) + cancelled, cancel := context.WithCancel(context.Background()) + bus.Publish("topic", notCancelled, &isNotCancelledDone) + bus.Publish("topic", cancelled, &isCancelledDone) + defer notCancel() + cancel() + + bus.WaitAsync() + + if isNotCancelledDone { + t.Fail() + } + if !isCancelledDone { + t.Fail() + } +} + +func TestSubscribeAsyncWithNewContextNoArgs(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if r != "context expected, got no args" { + t.Fail() + } + } + }() + bus := New() + bus.SubscribeAsyncWithNewContext("topic", func(ctx context.Context) {}, false) + bus.Publish("topic") +} + +func TestSubscribeAsyncWithNewContext(t *testing.T) { + isCancelledDone := true + isNotCancelledDone := true + + bus := New() + bus.SubscribeAsyncWithNewContext("topic", func(ctx context.Context, done *bool) { + select { + case <-time.NewTimer(10*time.Millisecond).C: + *done = false + case <-ctx.Done(): + *done = true + } + }, false) + + notCancelled, notCancel := context.WithCancel(context.Background()) + cancelled, cancel := context.WithCancel(context.Background()) + bus.Publish("topic", notCancelled, &isNotCancelledDone) + bus.Publish("topic", cancelled, &isCancelledDone) + defer notCancel() + cancel() + + bus.WaitAsync() + + if isCancelledDone { + t.Fail() + } + if isNotCancelledDone { + t.Fail() + } +} + +//func TestSubscribeAsyncNewContext(t *testing.T) { +// bus := New() +// bus.SubscribeAsyncWithNewContext("topic", func(context.Context, i int, k string) { +// +// }, false) +//} \ No newline at end of file