diff --git a/broker/api/sse_broker.go b/broker/api/sse_broker.go new file mode 100644 index 00000000..890a7f72 --- /dev/null +++ b/broker/api/sse_broker.go @@ -0,0 +1,175 @@ +package api + +import ( + "encoding/json" + "fmt" + "github.com/indexdata/crosslink/broker/common" + "github.com/indexdata/crosslink/broker/events" + pr_db "github.com/indexdata/crosslink/broker/patron_request/db" + prservice "github.com/indexdata/crosslink/broker/patron_request/service" + "github.com/indexdata/crosslink/iso18626" + "net/http" + "sync" +) + +type SseBroker struct { + input chan SseMessage + clients map[string]map[chan string]bool + mu sync.Mutex + ctx common.ExtendedContext + tenant common.Tenant +} + +func NewSseBroker(ctx common.ExtendedContext, tenant common.Tenant) (broker *SseBroker) { + broker = &SseBroker{ + input: make(chan SseMessage), + clients: make(map[string]map[chan string]bool), + ctx: ctx, + tenant: tenant, + } + + // Start the single broadcaster goroutine + go broker.run() + return broker +} +func (b *SseBroker) run() { + b.ctx.Logger().Debug("SseBroker running...") + for { + // Wait for an event from the application logic + event := <-b.input + + b.mu.Lock() + for clientChannel := range b.clients[event.receiver] { + select { + case clientChannel <- event.message: + // Successfully sent + default: + // Client is slow or disconnected, remove them to prevent memory leak + b.removeClient(event.receiver, clientChannel) + } + } + b.mu.Unlock() + } +} + +func (b *SseBroker) removeClient(receiver string, clientChannel chan string) { + clients := b.clients[receiver] + if clients != nil { + delete(clients, clientChannel) + if len(clients) == 0 { + delete(b.clients, receiver) + } + } + close(clientChannel) + b.ctx.Logger().Debug("Client channel closed and removed.") +} + +// ServeHTTP implements the http.Handler interface for the SSE endpoint. +func (b *SseBroker) ServeHTTP(w http.ResponseWriter, r *http.Request) { + clientChannel := make(chan string, 10) + tenant := r.Header.Get("X-Okapi-Tenant") + var symbol string + if b.tenant.IsSpecified() && tenant != "" { + symbol = b.tenant.GetSymbol(tenant) + } else { + symbol = r.URL.Query().Get("symbol") + } + + side := r.URL.Query().Get("side") + if side == "" || symbol == "" { + http.Error(w, "query parameter 'side' and 'symbol' must be specified", http.StatusBadRequest) + return + } + if side != string(prservice.SideBorrowing) && side != string(prservice.SideLending) { + http.Error(w, fmt.Sprintf("query parameter 'side' must be %s or %s", prservice.SideBorrowing, prservice.SideLending), http.StatusBadRequest) + return + } + b.mu.Lock() + receiver := side + symbol + clients := b.clients[receiver] + if clients != nil { + clients[clientChannel] = true + } else { + b.clients[receiver] = map[chan string]bool{clientChannel: true} + } + b.mu.Unlock() + b.ctx.Logger().Debug(fmt.Sprintf("new client registered: %s", receiver)) + + defer func() { + b.mu.Lock() + defer b.mu.Unlock() + b.removeClient(receiver, clientChannel) + }() + + // Set SSE Headers and get Flusher + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Access-Control-Allow-Origin", "*") + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported!", http.StatusInternalServerError) + return + } + + // Context for connection status check + ctx := r.Context() + for { + select { + case <-ctx.Done(): + // Client connection closed + return + + case event := <-clientChannel: + if _, err := fmt.Fprintf(w, "data: %s\n\n", event); err != nil { + return + } + flusher.Flush() + } + } +} + +func (b *SseBroker) SubmitMessageToChannels(message SseMessage) { + b.input <- message +} + +type SseMessage struct { + receiver string + message string +} + +type SseIsoMessageEvent struct { + Event events.EventName `json:"event,omitempty"` + Data iso18626.ISO18626Message `json:"data,omitempty"` +} + +func (b *SseBroker) IncomingIsoMessage(ctx common.ExtendedContext, event events.Event) { + if event.ResultData.OutgoingMessage != nil { + sseEvent := SseIsoMessageEvent{ + Data: *event.ResultData.OutgoingMessage, + Event: event.EventName, + } + symbol := "" + var side pr_db.PatronRequestSide + if event.ResultData.OutgoingMessage.RequestingAgencyMessage != nil { + side = prservice.SideLending + symbol = getSymbol(event.ResultData.OutgoingMessage.RequestingAgencyMessage.Header.SupplyingAgencyId) + } else if event.ResultData.OutgoingMessage.SupplyingAgencyMessage != nil { + side = prservice.SideBorrowing + symbol = getSymbol(event.ResultData.OutgoingMessage.SupplyingAgencyMessage.Header.RequestingAgencyId) + } else { + return + } + updateMessageBytes, err := json.Marshal(sseEvent) + if err != nil { + ctx.Logger().Error("failed to parse event data", "error", err) + return + } + b.SubmitMessageToChannels(SseMessage{receiver: string(side) + symbol, message: string(updateMessageBytes)}) + } +} + +func getSymbol(agencyId iso18626.TypeAgencyId) string { + return agencyId.AgencyIdType.Text + ":" + agencyId.AgencyIdValue +} diff --git a/broker/app/app.go b/broker/app/app.go index 0510a59f..a0f40101 100644 --- a/broker/app/app.go +++ b/broker/app/app.go @@ -85,6 +85,7 @@ type Context struct { DirAdapter adapter.DirectoryLookupAdapter PrRepo pr_db.PrRepo PrApiHandler prapi.PatronRequestApiHandler + SseBroker *api.SseBroker } func configLog() slog.Handler { @@ -163,7 +164,9 @@ func Init(ctx context.Context) (Context, error) { prActionService := prservice.CreatePatronRequestActionService(prRepo, eventBus, &iso18626Handler, lmsCreator) prApiHandler := prapi.NewApiHandler(prRepo, eventBus, common.NewTenant(TENANT_TO_SYMBOL)) - AddDefaultHandlers(eventBus, iso18626Client, supplierLocator, workflowManager, iso18626Handler, prActionService, prApiHandler, prMessageHandler) + sseBroker := api.NewSseBroker(appCtx, common.NewTenant(TENANT_TO_SYMBOL)) + + AddDefaultHandlers(eventBus, iso18626Client, supplierLocator, workflowManager, iso18626Handler, prActionService, prApiHandler, sseBroker) err = StartEventBus(ctx, eventBus) if err != nil { return Context{}, err @@ -175,6 +178,7 @@ func Init(ctx context.Context) (Context, error) { DirAdapter: dirAdapter, PrRepo: prRepo, PrApiHandler: prApiHandler, + SseBroker: sseBroker, }, nil } @@ -209,6 +213,9 @@ func StartServer(ctx Context) error { proapi.HandlerFromMux(&ctx.PrApiHandler, ServeMux) // TODO: proapi.HandlerFromMuxWithBaseURL(&ctx.PrApiHandler, ServeMux, "/broker") + // SSE Incoming message handler + ServeMux.HandleFunc("/sse/events", ctx.SseBroker.ServeHTTP) + signatureHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Server", vcs.GetSignature()) ServeMux.ServeHTTP(w, r) @@ -278,7 +285,7 @@ func CreateEventBus(eventRepo events.EventRepo) events.EventBus { func AddDefaultHandlers(eventBus events.EventBus, iso18626Client client.Iso18626Client, supplierLocator service.SupplierLocator, workflowManager service.WorkflowManager, iso18626Handler handler.Iso18626Handler, - prActionService prservice.PatronRequestActionService, prApiHandler prapi.PatronRequestApiHandler, prMessageHandler prservice.PatronRequestMessageHandler) { + prActionService prservice.PatronRequestActionService, prApiHandler prapi.PatronRequestApiHandler, sseBroker *api.SseBroker) { eventBus.HandleEventCreated(events.EventNameMessageSupplier, iso18626Client.MessageSupplier) eventBus.HandleEventCreated(events.EventNameMessageRequester, iso18626Client.MessageRequester) eventBus.HandleEventCreated(events.EventNameConfirmRequesterMsg, iso18626Handler.ConfirmRequesterMsg) @@ -294,6 +301,8 @@ func AddDefaultHandlers(eventBus events.EventBus, iso18626Client client.Iso18626 eventBus.HandleTaskCompleted(events.EventNameSelectSupplier, workflowManager.OnSelectSupplierComplete) eventBus.HandleTaskCompleted(events.EventNameMessageSupplier, workflowManager.OnMessageSupplierComplete) eventBus.HandleTaskCompleted(events.EventNameMessageRequester, workflowManager.OnMessageRequesterComplete) + eventBus.HandleTaskCompleted(events.EventNameMessageSupplier, sseBroker.IncomingIsoMessage) + eventBus.HandleTaskCompleted(events.EventNameMessageRequester, sseBroker.IncomingIsoMessage) eventBus.HandleEventCreated(events.EventNameInvokeAction, prActionService.InvokeAction) eventBus.HandleTaskCompleted(events.EventNameInvokeAction, prApiHandler.ConfirmActionProcess) diff --git a/broker/test/api/api-handler_test.go b/broker/test/api/api-handler_test.go index c042c8f6..96da124c 100644 --- a/broker/test/api/api-handler_test.go +++ b/broker/test/api/api-handler_test.go @@ -38,9 +38,9 @@ import ( "github.com/testcontainers/testcontainers-go/wait" ) -var eventBus events.EventBus var illRepo ill_db.IllRepo var eventRepo events.EventRepo +var sseBroker *api.SseBroker var mockIllRepoError = new(mocks.MockIllRepositoryError) var mockEventRepoError = new(mocks.MockEventRepositoryError) var handlerMock = api.NewApiHandler(mockEventRepoError, mockIllRepoError, common.NewTenant(""), api.LIMIT_DEFAULT) @@ -67,7 +67,10 @@ func TestMain(m *testing.M) { app.HTTP_PORT = utils.Must(test.GetFreePort()) ctx, cancel := context.WithCancel(context.Background()) - eventBus, illRepo, eventRepo, _ = apptest.StartApp(ctx) + appContext := apptest.StartAppReturnContext(ctx) + illRepo = appContext.IllRepo + eventRepo = appContext.EventRepo + sseBroker = appContext.SseBroker test.WaitForServiceUp(app.HTTP_PORT) defer cancel() diff --git a/broker/test/api/sse_broker_test.go b/broker/test/api/sse_broker_test.go new file mode 100644 index 00000000..6f8d6a73 --- /dev/null +++ b/broker/test/api/sse_broker_test.go @@ -0,0 +1,137 @@ +package api + +import ( + "bufio" + "context" + "errors" + "fmt" + "github.com/indexdata/crosslink/broker/common" + "github.com/indexdata/crosslink/broker/events" + "github.com/indexdata/crosslink/iso18626" + "github.com/stretchr/testify/assert" + "io" + "net/http" + "strings" + "testing" + "time" +) + +func TestSseEndpointSuccess(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + go sendMessages(ctx) //Send messages every 5 milliseconds + done := make(chan bool) + inErr := make(chan error) + go func() { + resp, err := http.Get(getLocalhostWithPort() + "/sse/events?side=borrowing&symbol=ISIL:REQ") + if err != nil { + inErr <- err + return + } + defer resp.Body.Close() + + // Verify headers + if contentType := resp.Header.Get("Content-Type"); contentType != "text/event-stream" { + inErr <- errors.New("Expected text/event-stream, got " + contentType) + } + + results := make(chan string, 1) + errChan := make(chan error, 1) + go func() { + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "data: ") { + results <- strings.TrimPrefix(line, "data: ") + return // Exit after receiving the first event for this test + } + } + if err := scanner.Err(); err != nil { + errChan <- err + } + }() + + select { + case data := <-results: + if data == "" { + t.Error("Received empty data from SSE") + } + t.Logf("Successfully received: %s", data) + assert.True(t, strings.Contains(data, "{\"event\":\"message-requester\",\"data\":{\"supplyingAgencyMessage\":")) + case err := <-errChan: + inErr <- err + } + cancel() + done <- true + }() + + select { + case err := <-inErr: + assert.NoError(t, err) + default: + // No errors + } + + select { + case <-done: + // Test finished successfully + case <-time.After(1 * time.Second): + t.Fatal("Test timed out") + } +} + +func sendMessages(ctx context.Context) { + ticker := time.NewTicker(5 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + fmt.Println("Shutting down sendMessages...") + return + case t := <-ticker.C: + executeTask(t) + } + } +} + +func TestSseEndpointNoSide(t *testing.T) { + resp, err := http.Get(getLocalhostWithPort() + "/sse/events?symbol=ISIL:REQ") + assert.NoError(t, err) + bodyBytes, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, 400, resp.StatusCode) + assert.Equal(t, "query parameter 'side' and 'symbol' must be specified\n", string(bodyBytes)) +} + +func TestSseEndpointNoSymbol(t *testing.T) { + resp, err := http.Get(getLocalhostWithPort() + "/sse/events?side=borrowing") + assert.NoError(t, err) + bodyBytes, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, 400, resp.StatusCode) + assert.Equal(t, "query parameter 'side' and 'symbol' must be specified\n", string(bodyBytes)) +} + +func executeTask(t time.Time) { + ctx := common.CreateExtCtxWithArgs(context.Background(), nil) + sseBroker.IncomingIsoMessage(ctx, events.Event{EventName: events.EventNameMessageRequester, + ResultData: events.EventResult{ + CommonEventData: events.CommonEventData{ + OutgoingMessage: &iso18626.ISO18626Message{ + SupplyingAgencyMessage: &iso18626.SupplyingAgencyMessage{ + Header: iso18626.Header{ + RequestingAgencyId: iso18626.TypeAgencyId{ + AgencyIdType: iso18626.TypeSchemeValuePair{ + Text: "ISIL", + }, + AgencyIdValue: "REQ", + }, + }, + MessageInfo: iso18626.MessageInfo{ + Note: t.String(), + }, + }, + }, + }, + }}) +} diff --git a/broker/test/apputils/apputils.go b/broker/test/apputils/apputils.go index 29348f3a..7aaa489d 100644 --- a/broker/test/apputils/apputils.go +++ b/broker/test/apputils/apputils.go @@ -24,13 +24,18 @@ import ( const EventRecordFormat = "%v, %v = %v" func StartApp(ctx context.Context) (events.EventBus, ill_db.IllRepo, events.EventRepo, pr_db.PrRepo) { - context, err := app.Init(ctx) + appContext := StartAppReturnContext(ctx) + return appContext.EventBus, appContext.IllRepo, appContext.EventRepo, appContext.PrRepo +} + +func StartAppReturnContext(ctx context.Context) app.Context { + appContext, err := app.Init(ctx) utils.Expect(err, "failed to init app") go func() { - err := app.StartServer(context) + err := app.StartServer(appContext) utils.Expect(err, "failed to start server") }() - return context.EventBus, context.IllRepo, context.EventRepo, context.PrRepo + return appContext } func CreatePgText(value string) pgtype.Text {