Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions broker/api/sse_broker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
package api

import (
"encoding/json"
"fmt"
"github.com/indexdata/crosslink/broker/common"
"github.com/indexdata/crosslink/broker/events"
pr_db "github.com/indexdata/crosslink/broker/patron_request/db"
prservice "github.com/indexdata/crosslink/broker/patron_request/service"
"github.com/indexdata/crosslink/iso18626"
"net/http"
"sync"
)

type SseBroker struct {
input chan SseMessage
clients map[string]map[chan string]bool
mu sync.Mutex
ctx common.ExtendedContext
tenant common.Tenant
}

func NewSseBroker(ctx common.ExtendedContext, tenant common.Tenant) (broker *SseBroker) {
broker = &SseBroker{
input: make(chan SseMessage),
clients: make(map[string]map[chan string]bool),
ctx: ctx,
tenant: tenant,
}

// Start the single broadcaster goroutine
go broker.run()
return broker
}
func (b *SseBroker) run() {
b.ctx.Logger().Debug("SseBroker running...")
for {
// Wait for an event from the application logic
event := <-b.input

b.mu.Lock()
for clientChannel := range b.clients[event.receiver] {
select {
case clientChannel <- event.message:
// Successfully sent
default:
// Client is slow or disconnected, remove them to prevent memory leak
b.removeClient(event.receiver, clientChannel)
}
}
b.mu.Unlock()
}
}

func (b *SseBroker) removeClient(receiver string, clientChannel chan string) {
clients := b.clients[receiver]
if clients != nil {
delete(clients, clientChannel)
if len(clients) == 0 {
delete(b.clients, receiver)
}
}
close(clientChannel)
b.ctx.Logger().Debug("Client channel closed and removed.")
}

// ServeHTTP implements the http.Handler interface for the SSE endpoint.
func (b *SseBroker) ServeHTTP(w http.ResponseWriter, r *http.Request) {
clientChannel := make(chan string, 10)
tenant := r.Header.Get("X-Okapi-Tenant")
var symbol string
if b.tenant.IsSpecified() && tenant != "" {
symbol = b.tenant.GetSymbol(tenant)
} else {
symbol = r.URL.Query().Get("symbol")
}

side := r.URL.Query().Get("side")
if side == "" || symbol == "" {
http.Error(w, "query parameter 'side' and 'symbol' must be specified", http.StatusBadRequest)
return
}
if side != string(prservice.SideBorrowing) && side != string(prservice.SideLending) {
http.Error(w, fmt.Sprintf("query parameter 'side' must be %s or %s", prservice.SideBorrowing, prservice.SideLending), http.StatusBadRequest)
return
}
b.mu.Lock()
receiver := side + symbol
clients := b.clients[receiver]
if clients != nil {
clients[clientChannel] = true
} else {
b.clients[receiver] = map[chan string]bool{clientChannel: true}
}
b.mu.Unlock()
b.ctx.Logger().Debug(fmt.Sprintf("new client registered: %s", receiver))

defer func() {
b.mu.Lock()
defer b.mu.Unlock()
b.removeClient(receiver, clientChannel)
}()

// Set SSE Headers and get Flusher
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Access-Control-Allow-Origin", "*")

flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "Streaming unsupported!", http.StatusInternalServerError)
return
}

// Context for connection status check
ctx := r.Context()
for {
select {
case <-ctx.Done():
// Client connection closed
return

case event := <-clientChannel:
if _, err := fmt.Fprintf(w, "data: %s\n\n", event); err != nil {
return
}
flusher.Flush()
}
}
}

func (b *SseBroker) SubmitMessageToChannels(message SseMessage) {
b.input <- message
}

type SseMessage struct {
receiver string
message string
}

type SseIsoMessageEvent struct {
Event events.EventName `json:"event,omitempty"`
Data iso18626.ISO18626Message `json:"data,omitempty"`
}

func (b *SseBroker) IncomingIsoMessage(ctx common.ExtendedContext, event events.Event) {
if event.ResultData.OutgoingMessage != nil {
sseEvent := SseIsoMessageEvent{
Data: *event.ResultData.OutgoingMessage,
Event: event.EventName,
}
symbol := ""
var side pr_db.PatronRequestSide
if event.ResultData.OutgoingMessage.RequestingAgencyMessage != nil {
side = prservice.SideLending
symbol = getSymbol(event.ResultData.OutgoingMessage.RequestingAgencyMessage.Header.SupplyingAgencyId)
} else if event.ResultData.OutgoingMessage.SupplyingAgencyMessage != nil {
side = prservice.SideBorrowing
symbol = getSymbol(event.ResultData.OutgoingMessage.SupplyingAgencyMessage.Header.RequestingAgencyId)
} else {
return
}
updateMessageBytes, err := json.Marshal(sseEvent)
if err != nil {
ctx.Logger().Error("failed to parse event data", "error", err)
return
}
b.SubmitMessageToChannels(SseMessage{receiver: string(side) + symbol, message: string(updateMessageBytes)})
}
}

func getSymbol(agencyId iso18626.TypeAgencyId) string {
return agencyId.AgencyIdType.Text + ":" + agencyId.AgencyIdValue
}
13 changes: 11 additions & 2 deletions broker/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ type Context struct {
DirAdapter adapter.DirectoryLookupAdapter
PrRepo pr_db.PrRepo
PrApiHandler prapi.PatronRequestApiHandler
SseBroker *api.SseBroker
}

func configLog() slog.Handler {
Expand Down Expand Up @@ -163,7 +164,9 @@ func Init(ctx context.Context) (Context, error) {
prActionService := prservice.CreatePatronRequestActionService(prRepo, eventBus, &iso18626Handler, lmsCreator)
prApiHandler := prapi.NewApiHandler(prRepo, eventBus, common.NewTenant(TENANT_TO_SYMBOL))

AddDefaultHandlers(eventBus, iso18626Client, supplierLocator, workflowManager, iso18626Handler, prActionService, prApiHandler, prMessageHandler)
sseBroker := api.NewSseBroker(appCtx, common.NewTenant(TENANT_TO_SYMBOL))

AddDefaultHandlers(eventBus, iso18626Client, supplierLocator, workflowManager, iso18626Handler, prActionService, prApiHandler, sseBroker)
err = StartEventBus(ctx, eventBus)
if err != nil {
return Context{}, err
Expand All @@ -175,6 +178,7 @@ func Init(ctx context.Context) (Context, error) {
DirAdapter: dirAdapter,
PrRepo: prRepo,
PrApiHandler: prApiHandler,
SseBroker: sseBroker,
}, nil
}

Expand Down Expand Up @@ -209,6 +213,9 @@ func StartServer(ctx Context) error {
proapi.HandlerFromMux(&ctx.PrApiHandler, ServeMux)
// TODO: proapi.HandlerFromMuxWithBaseURL(&ctx.PrApiHandler, ServeMux, "/broker")

// SSE Incoming message handler
ServeMux.HandleFunc("/sse/events", ctx.SseBroker.ServeHTTP)

signatureHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Server", vcs.GetSignature())
ServeMux.ServeHTTP(w, r)
Expand Down Expand Up @@ -278,7 +285,7 @@ func CreateEventBus(eventRepo events.EventRepo) events.EventBus {

func AddDefaultHandlers(eventBus events.EventBus, iso18626Client client.Iso18626Client,
supplierLocator service.SupplierLocator, workflowManager service.WorkflowManager, iso18626Handler handler.Iso18626Handler,
prActionService prservice.PatronRequestActionService, prApiHandler prapi.PatronRequestApiHandler, prMessageHandler prservice.PatronRequestMessageHandler) {
prActionService prservice.PatronRequestActionService, prApiHandler prapi.PatronRequestApiHandler, sseBroker *api.SseBroker) {
eventBus.HandleEventCreated(events.EventNameMessageSupplier, iso18626Client.MessageSupplier)
eventBus.HandleEventCreated(events.EventNameMessageRequester, iso18626Client.MessageRequester)
eventBus.HandleEventCreated(events.EventNameConfirmRequesterMsg, iso18626Handler.ConfirmRequesterMsg)
Expand All @@ -294,6 +301,8 @@ func AddDefaultHandlers(eventBus events.EventBus, iso18626Client client.Iso18626
eventBus.HandleTaskCompleted(events.EventNameSelectSupplier, workflowManager.OnSelectSupplierComplete)
eventBus.HandleTaskCompleted(events.EventNameMessageSupplier, workflowManager.OnMessageSupplierComplete)
eventBus.HandleTaskCompleted(events.EventNameMessageRequester, workflowManager.OnMessageRequesterComplete)
eventBus.HandleTaskCompleted(events.EventNameMessageSupplier, sseBroker.IncomingIsoMessage)
eventBus.HandleTaskCompleted(events.EventNameMessageRequester, sseBroker.IncomingIsoMessage)

eventBus.HandleEventCreated(events.EventNameInvokeAction, prActionService.InvokeAction)
eventBus.HandleTaskCompleted(events.EventNameInvokeAction, prApiHandler.ConfirmActionProcess)
Expand Down
7 changes: 5 additions & 2 deletions broker/test/api/api-handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ import (
"github.com/testcontainers/testcontainers-go/wait"
)

var eventBus events.EventBus
var illRepo ill_db.IllRepo
var eventRepo events.EventRepo
var sseBroker *api.SseBroker
var mockIllRepoError = new(mocks.MockIllRepositoryError)
var mockEventRepoError = new(mocks.MockEventRepositoryError)
var handlerMock = api.NewApiHandler(mockEventRepoError, mockIllRepoError, common.NewTenant(""), api.LIMIT_DEFAULT)
Expand All @@ -67,7 +67,10 @@ func TestMain(m *testing.M) {
app.HTTP_PORT = utils.Must(test.GetFreePort())

ctx, cancel := context.WithCancel(context.Background())
eventBus, illRepo, eventRepo, _ = apptest.StartApp(ctx)
appContext := apptest.StartAppReturnContext(ctx)
illRepo = appContext.IllRepo
eventRepo = appContext.EventRepo
sseBroker = appContext.SseBroker
test.WaitForServiceUp(app.HTTP_PORT)

defer cancel()
Expand Down
137 changes: 137 additions & 0 deletions broker/test/api/sse_broker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package api

import (
"bufio"
"context"
"errors"
"fmt"
"github.com/indexdata/crosslink/broker/common"
"github.com/indexdata/crosslink/broker/events"
"github.com/indexdata/crosslink/iso18626"
"github.com/stretchr/testify/assert"
"io"
"net/http"
"strings"
"testing"
"time"
)

func TestSseEndpointSuccess(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
go sendMessages(ctx) //Send messages every 5 milliseconds
done := make(chan bool)
inErr := make(chan error)
go func() {
resp, err := http.Get(getLocalhostWithPort() + "/sse/events?side=borrowing&symbol=ISIL:REQ")
if err != nil {
inErr <- err
return
}
defer resp.Body.Close()

// Verify headers
if contentType := resp.Header.Get("Content-Type"); contentType != "text/event-stream" {
inErr <- errors.New("Expected text/event-stream, got " + contentType)
}

results := make(chan string, 1)
errChan := make(chan error, 1)
go func() {
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "data: ") {
results <- strings.TrimPrefix(line, "data: ")
return // Exit after receiving the first event for this test
}
}
if err := scanner.Err(); err != nil {
errChan <- err
}
}()

select {
case data := <-results:
if data == "" {
t.Error("Received empty data from SSE")
}
t.Logf("Successfully received: %s", data)
assert.True(t, strings.Contains(data, "{\"event\":\"message-requester\",\"data\":{\"supplyingAgencyMessage\":"))
case err := <-errChan:
inErr <- err
}
cancel()
done <- true
}()

select {
case err := <-inErr:
assert.NoError(t, err)
default:
// No errors
}

select {
case <-done:
// Test finished successfully
case <-time.After(1 * time.Second):
t.Fatal("Test timed out")
}
}

func sendMessages(ctx context.Context) {
ticker := time.NewTicker(5 * time.Millisecond)
defer ticker.Stop()

for {
select {
case <-ctx.Done():
fmt.Println("Shutting down sendMessages...")
return
case t := <-ticker.C:
executeTask(t)
}
}
}

func TestSseEndpointNoSide(t *testing.T) {
resp, err := http.Get(getLocalhostWithPort() + "/sse/events?symbol=ISIL:REQ")
assert.NoError(t, err)
bodyBytes, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, 400, resp.StatusCode)
assert.Equal(t, "query parameter 'side' and 'symbol' must be specified\n", string(bodyBytes))
}

func TestSseEndpointNoSymbol(t *testing.T) {
resp, err := http.Get(getLocalhostWithPort() + "/sse/events?side=borrowing")
assert.NoError(t, err)
bodyBytes, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, 400, resp.StatusCode)
assert.Equal(t, "query parameter 'side' and 'symbol' must be specified\n", string(bodyBytes))
}

func executeTask(t time.Time) {
ctx := common.CreateExtCtxWithArgs(context.Background(), nil)
sseBroker.IncomingIsoMessage(ctx, events.Event{EventName: events.EventNameMessageRequester,
ResultData: events.EventResult{
CommonEventData: events.CommonEventData{
OutgoingMessage: &iso18626.ISO18626Message{
SupplyingAgencyMessage: &iso18626.SupplyingAgencyMessage{
Header: iso18626.Header{
RequestingAgencyId: iso18626.TypeAgencyId{
AgencyIdType: iso18626.TypeSchemeValuePair{
Text: "ISIL",
},
AgencyIdValue: "REQ",
},
},
MessageInfo: iso18626.MessageInfo{
Note: t.String(),
},
},
},
},
}})
}
Loading
Loading