From b9c71e60d9619f91df33f21da8dfd9d7f1876694 Mon Sep 17 00:00:00 2001 From: Robin Diddams Date: Wed, 23 Apr 2025 12:06:32 -0500 Subject: [PATCH] Fix dev mode for new sdk --- internal/dev/websocket.go | 90 ++++++++++++++++++++++++++++----------- 1 file changed, 66 insertions(+), 24 deletions(-) diff --git a/internal/dev/websocket.go b/internal/dev/websocket.go index 34d75abc..cec78c50 100644 --- a/internal/dev/websocket.go +++ b/internal/dev/websocket.go @@ -29,23 +29,24 @@ import ( var propagator propagation.TraceContext type Websocket struct { - webSocketId string - conn *websocket.Conn - OtelToken string - OtelUrl string - Project project.ProjectContext - orgId string - done chan struct{} - apiKey string - websocketUrl string - maxRetries int - retryCount int - parentCtx context.Context - ctx context.Context - logger logger.Logger - cleanup func() - tracer trace.Tracer - version string + webSocketId string + conn *websocket.Conn + OtelToken string + OtelUrl string + Project project.ProjectContext + orgId string + done chan struct{} + apiKey string + websocketUrl string + maxRetries int + retryCount int + parentCtx context.Context + ctx context.Context + logger logger.Logger + cleanup func() + tracer trace.Tracer + version string + binaryProtocol bool } type OutputPayload struct { @@ -78,6 +79,29 @@ func (c *Websocket) Done() <-chan struct{} { return c.done } +func (c *Websocket) getAgentProtocol(ctx context.Context, port int) (bool, error) { + url := fmt.Sprintf("http://localhost:%d/_health", port) + for i := 0; i < 5; i++ { + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return false, err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + if strings.Contains(err.Error(), "connection refused") { + time.Sleep(time.Millisecond * time.Duration(100*i+1)) + continue + } + return false, err + } + defer resp.Body.Close() + if resp.StatusCode == 200 { + return resp.Header.Get("x-agentuity-binary") == "true", nil + } + } + return false, fmt.Errorf("failed to inspect agents after 5 attempts") +} + func (c *Websocket) getAgentWelcome(ctx context.Context, port int) (map[string]Welcome, error) { url := fmt.Sprintf("http://localhost:%d/welcome", port) for i := 0; i < 5; i++ { @@ -107,6 +131,13 @@ func (c *Websocket) getAgentWelcome(ctx context.Context, port int) (map[string]W } func (c *Websocket) StartReadingMessages(ctx context.Context, logger logger.Logger, port int) { + var err error + c.binaryProtocol, err = c.getAgentProtocol(ctx, port) + if err != nil { + logger.Error("failed to healthcheck agents: %s", err) + return + } + go func() { defer close(c.done) for { @@ -554,16 +585,27 @@ func processInputMessage(plogger logger.Logger, c *Websocket, m []byte, port int logger.Debug("response: %s (status code: %d)", string(body), resp.StatusCode) - output, lerr := isOutputPayload(body) - if lerr != nil { - err = fmt.Errorf("the Agent produced an error") - return + var trigger string + var contentType string + if c.binaryProtocol { + trigger = resp.Header.Get("x-agentuity-trigger") + contentType = resp.Header.Get("content-type") + } else { + // TODO: remove this once were all off the old protocol + output, lerr := isOutputPayload(body) + if lerr != nil { + err = fmt.Errorf("the Agent produced an error") + return + } + trigger = output.Trigger + contentType = output.ContentType + body = output.Payload } msg := NewOutputMessage(inputMsg.ID, c.Project.Project.ProjectId, OutputPayload{ - ContentType: output.ContentType, - Payload: output.Payload, - Trigger: output.Trigger, + ContentType: contentType, + Payload: body, + Trigger: trigger, }) outputMessage = &msg }