diff --git a/.github/workflows/components-build-deploy.yml b/.github/workflows/components-build-deploy.yml index 60a310c0b..4e16edebf 100755 --- a/.github/workflows/components-build-deploy.yml +++ b/.github/workflows/components-build-deploy.yml @@ -12,6 +12,8 @@ on: - 'components/frontend/**' - 'components/public-api/**' - 'components/ambient-api-server/**' + - 'components/ambient-control-plane/**' + - 'components/ambient-mcp/**' pull_request: branches: [main, alpha] paths: @@ -23,10 +25,12 @@ on: - 'components/frontend/**' - 'components/public-api/**' - 'components/ambient-api-server/**' + - 'components/ambient-control-plane/**' + - 'components/ambient-mcp/**' workflow_dispatch: inputs: components: - description: 'Components to build (comma-separated: frontend,backend,operator,ambient-runner,state-sync,public-api,ambient-api-server) - leave empty for all' + description: 'Components to build (comma-separated: frontend,backend,operator,ambient-runner,state-sync,public-api,ambient-api-server,ambient-control-plane,ambient-mcp) - leave empty for all' required: false type: string default: '' @@ -54,7 +58,9 @@ jobs: {"name":"ambient-runner","context":"./components/runners/ambient-runner","image":"quay.io/ambient_code/vteam_claude_runner","dockerfile":"./components/runners/ambient-runner/Dockerfile"}, {"name":"state-sync","context":"./components/runners/state-sync","image":"quay.io/ambient_code/vteam_state_sync","dockerfile":"./components/runners/state-sync/Dockerfile"}, {"name":"public-api","context":"./components/public-api","image":"quay.io/ambient_code/vteam_public_api","dockerfile":"./components/public-api/Dockerfile"}, - {"name":"ambient-api-server","context":"./components/ambient-api-server","image":"quay.io/ambient_code/vteam_api_server","dockerfile":"./components/ambient-api-server/Dockerfile"} + {"name":"ambient-api-server","context":"./components/ambient-api-server","image":"quay.io/ambient_code/vteam_api_server","dockerfile":"./components/ambient-api-server/Dockerfile"}, + {"name":"ambient-control-plane","context":"./components/ambient-control-plane","image":"quay.io/ambient_code/vteam_control_plane","dockerfile":"./components/ambient-control-plane/Dockerfile"}, + {"name":"ambient-mcp","context":"./components/ambient-mcp","image":"quay.io/ambient_code/vteam_mcp","dockerfile":"./components/ambient-mcp/Dockerfile"} ]' SELECTED="${{ github.event.inputs.components }}" @@ -376,6 +382,8 @@ jobs: kustomize edit set image quay.io/ambient_code/vteam_state_sync:latest=quay.io/ambient_code/vteam_state_sync:${{ github.sha }} kustomize edit set image quay.io/ambient_code/vteam_api_server:latest=quay.io/ambient_code/vteam_api_server:${{ github.sha }} kustomize edit set image quay.io/ambient_code/vteam_public_api:latest=quay.io/ambient_code/vteam_public_api:${{ github.sha }} + kustomize edit set image quay.io/ambient_code/vteam_control_plane:latest=quay.io/ambient_code/vteam_control_plane:${{ github.sha }} + kustomize edit set image quay.io/ambient_code/vteam_mcp:latest=quay.io/ambient_code/vteam_mcp:${{ github.sha }} - name: Validate kustomization working-directory: components/manifests/overlays/production @@ -451,6 +459,8 @@ jobs: kustomize edit set image quay.io/ambient_code/vteam_state_sync:latest=quay.io/ambient_code/vteam_state_sync:${{ github.sha }} kustomize edit set image quay.io/ambient_code/vteam_api_server:latest=quay.io/ambient_code/vteam_api_server:${{ github.sha }} kustomize edit set image quay.io/ambient_code/vteam_public_api:latest=quay.io/ambient_code/vteam_public_api:${{ github.sha }} + kustomize edit set image quay.io/ambient_code/vteam_control_plane:latest=quay.io/ambient_code/vteam_control_plane:${{ github.sha }} + kustomize edit set image quay.io/ambient_code/vteam_mcp:latest=quay.io/ambient_code/vteam_mcp:${{ github.sha }} - name: Validate kustomization working-directory: components/manifests/overlays/production diff --git a/.github/workflows/prod-release-deploy.yaml b/.github/workflows/prod-release-deploy.yaml index 7e2b10daa..70b7ed19e 100755 --- a/.github/workflows/prod-release-deploy.yaml +++ b/.github/workflows/prod-release-deploy.yaml @@ -18,7 +18,7 @@ on: type: boolean default: true components: - description: 'Components to build (comma-separated: frontend,backend,operator,ambient-runner,state-sync,public-api,ambient-api-server) - leave empty for all' + description: 'Components to build (comma-separated: frontend,backend,operator,ambient-runner,state-sync,public-api,ambient-api-server,ambient-control-plane,ambient-mcp) - leave empty for all' required: false type: string default: '' @@ -236,7 +236,9 @@ jobs: {"name":"ambient-runner","context":"./components/runners/ambient-runner","image":"quay.io/ambient_code/vteam_claude_runner","dockerfile":"./components/runners/ambient-runner/Dockerfile"}, {"name":"state-sync","context":"./components/runners/state-sync","image":"quay.io/ambient_code/vteam_state_sync","dockerfile":"./components/runners/state-sync/Dockerfile"}, {"name":"public-api","context":"./components/public-api","image":"quay.io/ambient_code/vteam_public_api","dockerfile":"./components/public-api/Dockerfile"}, - {"name":"ambient-api-server","context":"./components/ambient-api-server","image":"quay.io/ambient_code/vteam_api_server","dockerfile":"./components/ambient-api-server/Dockerfile"} + {"name":"ambient-api-server","context":"./components/ambient-api-server","image":"quay.io/ambient_code/vteam_api_server","dockerfile":"./components/ambient-api-server/Dockerfile"}, + {"name":"ambient-control-plane","context":"./components/ambient-control-plane","image":"quay.io/ambient_code/vteam_control_plane","dockerfile":"./components/ambient-control-plane/Dockerfile"}, + {"name":"ambient-mcp","context":"./components/ambient-mcp","image":"quay.io/ambient_code/vteam_mcp","dockerfile":"./components/ambient-mcp/Dockerfile"} ]' FORCE_ALL="${{ github.event.inputs.force_build_all }}" @@ -621,8 +623,10 @@ jobs: ["operator"]="agentic-operator:agentic-operator" ["public-api"]="public-api:public-api" ["ambient-api-server"]="ambient-api-server:ambient-api-server" + ["ambient-control-plane"]="ambient-control-plane:ambient-control-plane" ) + for comp_image in \ "frontend:quay.io/ambient_code/vteam_frontend" \ "backend:quay.io/ambient_code/vteam_backend" \ @@ -630,7 +634,9 @@ jobs: "ambient-runner:quay.io/ambient_code/vteam_claude_runner" \ "state-sync:quay.io/ambient_code/vteam_state_sync" \ "public-api:quay.io/ambient_code/vteam_public_api" \ - "ambient-api-server:quay.io/ambient_code/vteam_api_server"; do + "ambient-api-server:quay.io/ambient_code/vteam_api_server" \ + "ambient-control-plane:quay.io/ambient_code/vteam_control_plane" \ + "ambient-mcp:quay.io/ambient_code/vteam_mcp"; do COMP="${comp_image%%:*}" IMAGE="${comp_image#*:}" diff --git a/components/ambient-api-server/plugins/credentials/migration.go b/components/ambient-api-server/plugins/credentials/migration.go index 528212800..72a7bb12d 100644 --- a/components/ambient-api-server/plugins/credentials/migration.go +++ b/components/ambient-api-server/plugins/credentials/migration.go @@ -96,16 +96,17 @@ func rolesMigration() *gormigrate.Migration { if err != nil { return err } - row := roleRow{ - ID: api.NewID(), - Name: r.name, - DisplayName: r.displayName, - Description: r.description, - Permissions: string(permsJSON), - BuiltIn: true, - } + var row roleRow if err := tx.Table("roles"). Where("name = ?", r.name). + Attrs(roleRow{ + ID: api.NewID(), + Name: r.name, + DisplayName: r.displayName, + Description: r.description, + Permissions: string(permsJSON), + BuiltIn: true, + }). FirstOrCreate(&row).Error; err != nil { return err } diff --git a/components/ambient-api-server/plugins/roles/migration.go b/components/ambient-api-server/plugins/roles/migration.go index 5f0d37fa4..699a0bf2c 100644 --- a/components/ambient-api-server/plugins/roles/migration.go +++ b/components/ambient-api-server/plugins/roles/migration.go @@ -105,16 +105,17 @@ func seedBuiltInRoles(tx *gorm.DB) error { if err != nil { return err } - row := roleRow{ - ID: api.NewID(), - Name: r.name, - DisplayName: r.displayName, - Description: r.description, - Permissions: string(permsJSON), - BuiltIn: true, - } + var row roleRow if err := tx.Table("roles"). Where("name = ?", r.name). + Attrs(roleRow{ + ID: api.NewID(), + Name: r.name, + DisplayName: r.displayName, + Description: r.description, + Permissions: string(permsJSON), + BuiltIn: true, + }). FirstOrCreate(&row).Error; err != nil { return err } diff --git a/components/ambient-mcp/Dockerfile b/components/ambient-mcp/Dockerfile new file mode 100644 index 000000000..afd0a6935 --- /dev/null +++ b/components/ambient-mcp/Dockerfile @@ -0,0 +1,33 @@ +FROM registry.access.redhat.com/ubi9/go-toolset:1.25 AS builder + +ARG GIT_COMMIT=unknown +ARG GIT_BRANCH=unknown +ARG GIT_REPO=unknown +ARG GIT_VERSION=unknown +ARG BUILD_DATE=unknown +ARG BUILD_USER=unknown + +USER 0 +WORKDIR /app + +COPY go.mod go.sum ./ +RUN go mod download + +COPY . . + +RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o ambient-mcp . + +FROM registry.access.redhat.com/ubi9/ubi-minimal:latest + +WORKDIR /app + +RUN microdnf install -y procps && microdnf clean all + +COPY --from=builder /app/ambient-mcp . + +RUN chmod +x ./ambient-mcp && chmod 775 /app + +USER 1001 + +ENTRYPOINT ["./ambient-mcp"] +CMD [] diff --git a/components/ambient-mcp/client/client.go b/components/ambient-mcp/client/client.go new file mode 100644 index 000000000..336a6c1da --- /dev/null +++ b/components/ambient-mcp/client/client.go @@ -0,0 +1,116 @@ +package client + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "time" +) + +type Client struct { + httpClient *http.Client + baseURL string + mu sync.RWMutex + token string +} + +func New(baseURL, token string) *Client { + return &Client{ + httpClient: &http.Client{Timeout: 30 * time.Second}, + baseURL: strings.TrimSuffix(baseURL, "/"), + token: token, + } +} + +func (c *Client) BaseURL() string { return c.baseURL } + +func (c *Client) Token() string { + c.mu.RLock() + defer c.mu.RUnlock() + return c.token +} + +func (c *Client) SetToken(token string) { + c.mu.Lock() + defer c.mu.Unlock() + c.token = token +} + +func (c *Client) do(ctx context.Context, method, path string, body []byte, result interface{}, expectedStatuses ...int) error { + reqURL := c.baseURL + "/api/ambient/v1" + path + var bodyReader io.Reader + if body != nil { + bodyReader = bytes.NewReader(body) + } + req, err := http.NewRequestWithContext(ctx, method, reqURL, bodyReader) + if err != nil { + return fmt.Errorf("create request: %w", err) + } + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + req.Header.Set("Authorization", "Bearer "+c.Token()) + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("read response: %w", err) + } + + ok := false + for _, s := range expectedStatuses { + if resp.StatusCode == s { + ok = true + break + } + } + if !ok { + return fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody)) + } + + if result != nil && len(respBody) > 0 { + if err := json.Unmarshal(respBody, result); err != nil { + return fmt.Errorf("unmarshal response: %w", err) + } + } + return nil +} + +func (c *Client) Get(ctx context.Context, path string, result interface{}) error { + return c.do(ctx, http.MethodGet, path, nil, result, http.StatusOK) +} + +func (c *Client) GetWithQuery(ctx context.Context, path string, params url.Values, result interface{}) error { + if len(params) > 0 { + path = path + "?" + params.Encode() + } + return c.Get(ctx, path, result) +} + +func (c *Client) Post(ctx context.Context, path string, body interface{}, result interface{}, expectedStatus int) error { + b, err := json.Marshal(body) + if err != nil { + return fmt.Errorf("marshal body: %w", err) + } + return c.do(ctx, http.MethodPost, path, b, result, expectedStatus) +} + +func (c *Client) Patch(ctx context.Context, path string, body interface{}, result interface{}) error { + b, err := json.Marshal(body) + if err != nil { + return fmt.Errorf("marshal body: %w", err) + } + return c.do(ctx, http.MethodPatch, path, b, result, http.StatusOK) +} diff --git a/components/ambient-mcp/client/client_test.go b/components/ambient-mcp/client/client_test.go new file mode 100644 index 000000000..daaffbedb --- /dev/null +++ b/components/ambient-mcp/client/client_test.go @@ -0,0 +1,133 @@ +package client + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sync" + "testing" +) + +func TestNew(t *testing.T) { + c := New("http://localhost:8080/", "my-token") + if c.BaseURL() != "http://localhost:8080" { + t.Errorf("BaseURL() = %q, want trailing slash stripped", c.BaseURL()) + } + if c.Token() != "my-token" { + t.Errorf("Token() = %q, want %q", c.Token(), "my-token") + } +} + +func TestSetToken(t *testing.T) { + c := New("http://localhost:8080", "initial") + c.SetToken("refreshed") + if c.Token() != "refreshed" { + t.Errorf("Token() after SetToken = %q, want %q", c.Token(), "refreshed") + } +} + +func TestSetToken_ConcurrentAccess(t *testing.T) { + c := New("http://localhost:8080", "initial") + var wg sync.WaitGroup + + for i := range 100 { + wg.Add(2) + go func(n int) { + defer wg.Done() + c.SetToken("token-" + string(rune('A'+n%26))) + }(i) + go func() { + defer wg.Done() + _ = c.Token() + }() + } + wg.Wait() + + got := c.Token() + if got == "" { + t.Error("Token() is empty after concurrent access") + } +} + +func TestGet_SendsBearerToken(t *testing.T) { + var receivedAuth string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedAuth = r.Header.Get("Authorization") + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) + })) + defer srv.Close() + + c := New(srv.URL, "test-bearer") + var result map[string]string + err := c.Get(context.Background(), "/healthz", &result) + if err != nil { + t.Fatalf("Get: %v", err) + } + if receivedAuth != "Bearer test-bearer" { + t.Errorf("Authorization = %q, want %q", receivedAuth, "Bearer test-bearer") + } +} + +func TestGet_UsesRefreshedToken(t *testing.T) { + var receivedAuth string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedAuth = r.Header.Get("Authorization") + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) + })) + defer srv.Close() + + c := New(srv.URL, "old-token") + c.SetToken("new-token") + + var result map[string]string + err := c.Get(context.Background(), "/healthz", &result) + if err != nil { + t.Fatalf("Get: %v", err) + } + if receivedAuth != "Bearer new-token" { + t.Errorf("Authorization = %q, want %q", receivedAuth, "Bearer new-token") + } +} + +func TestGet_UnexpectedStatus(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "not found", http.StatusNotFound) + })) + defer srv.Close() + + c := New(srv.URL, "token") + err := c.Get(context.Background(), "/missing", nil) + if err == nil { + t.Fatal("expected error for 404 response") + } +} + +func TestPost_SendsBody(t *testing.T) { + var receivedBody map[string]string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewDecoder(r.Body).Decode(&receivedBody) + if r.Header.Get("Content-Type") != "application/json" { + t.Errorf("Content-Type = %q, want application/json", r.Header.Get("Content-Type")) + } + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(map[string]string{"id": "123"}) + })) + defer srv.Close() + + c := New(srv.URL, "token") + body := map[string]string{"name": "test"} + var result map[string]string + err := c.Post(context.Background(), "/items", body, &result, http.StatusCreated) + if err != nil { + t.Fatalf("Post: %v", err) + } + if receivedBody["name"] != "test" { + t.Errorf("received body name = %q, want %q", receivedBody["name"], "test") + } + if result["id"] != "123" { + t.Errorf("result id = %q, want %q", result["id"], "123") + } +} diff --git a/components/ambient-mcp/go.mod b/components/ambient-mcp/go.mod new file mode 100644 index 000000000..e73ec39c7 --- /dev/null +++ b/components/ambient-mcp/go.mod @@ -0,0 +1,17 @@ +module github.com/ambient-code/platform/components/ambient-mcp + +go 1.24.0 + +require github.com/mark3labs/mcp-go v0.45.0 + +require ( + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect + github.com/mailru/easyjson v0.7.7 // indirect + github.com/spf13/cast v1.7.1 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/components/ambient-mcp/go.sum b/components/ambient-mcp/go.sum new file mode 100644 index 000000000..630745845 --- /dev/null +++ b/components/ambient-mcp/go.sum @@ -0,0 +1,22 @@ +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mark3labs/mcp-go v0.45.0 h1:s0S8qR/9fWaQ3pHxz7pm1uQ0DrswoSnRIxKIjbiQtkc= +github.com/mark3labs/mcp-go v0.45.0/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw= +github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= +github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/components/ambient-mcp/main.go b/components/ambient-mcp/main.go new file mode 100644 index 000000000..07ddd1394 --- /dev/null +++ b/components/ambient-mcp/main.go @@ -0,0 +1,93 @@ +package main + +import ( + "fmt" + "net/http" + "os" + + "github.com/mark3labs/mcp-go/server" + + "github.com/ambient-code/platform/components/ambient-mcp/client" + "github.com/ambient-code/platform/components/ambient-mcp/tokenexchange" +) + +func main() { + apiURL := os.Getenv("AMBIENT_API_URL") + if apiURL == "" { + apiURL = "http://localhost:8080" + } + + transport := os.Getenv("MCP_TRANSPORT") + if transport == "" { + transport = "stdio" + } + + cpTokenURL := os.Getenv("AMBIENT_CP_TOKEN_URL") + cpPublicKey := os.Getenv("AMBIENT_CP_TOKEN_PUBLIC_KEY") + sessionID := os.Getenv("SESSION_ID") + + var token string + var exchanger *tokenexchange.Exchanger + + if cpTokenURL != "" && cpPublicKey != "" && sessionID != "" { + var err error + exchanger, err = tokenexchange.New(cpTokenURL, cpPublicKey, sessionID) + if err != nil { + fmt.Fprintf(os.Stderr, "token exchange init failed: %v\n", err) + os.Exit(1) + } + token, err = exchanger.FetchToken() + if err != nil { + fmt.Fprintf(os.Stderr, "initial token fetch failed: %v\n", err) + os.Exit(1) + } + fmt.Fprintln(os.Stderr, "bootstrapped token via CP token exchange") + } else { + token = os.Getenv("AMBIENT_TOKEN") + if token == "" { + fmt.Fprintln(os.Stderr, "AMBIENT_TOKEN is required when CP token exchange env vars are not set") + os.Exit(1) + } + fmt.Fprintln(os.Stderr, "using static AMBIENT_TOKEN (no CP token exchange)") + } + + c := client.New(apiURL, token) + + if exchanger != nil { + exchanger.OnRefresh(func(freshToken string) { + c.SetToken(freshToken) + }) + exchanger.StartBackgroundRefresh() + defer exchanger.Stop() + } + + s := newServer(c, transport) + + switch transport { + case "stdio": + if err := server.ServeStdio(s); err != nil { + fmt.Fprintf(os.Stderr, "stdio server error: %v\n", err) + os.Exit(1) + } + + case "sse": + bindAddr := os.Getenv("MCP_BIND_ADDR") + if bindAddr == "" { + bindAddr = ":8090" + } + sseServer := server.NewSSEServer(s, + server.WithBaseURL("http://"+bindAddr), + server.WithSSEEndpoint("/sse"), + server.WithMessageEndpoint("/message"), + ) + fmt.Fprintf(os.Stderr, "MCP server (SSE) listening on %s\n", bindAddr) + if err := http.ListenAndServe(bindAddr, sseServer); err != nil { + fmt.Fprintf(os.Stderr, "SSE server error: %v\n", err) + os.Exit(1) + } + + default: + fmt.Fprintf(os.Stderr, "unknown MCP_TRANSPORT: %q (must be stdio or sse)\n", transport) + os.Exit(1) + } +} diff --git a/components/ambient-mcp/mention/resolve.go b/components/ambient-mcp/mention/resolve.go new file mode 100644 index 000000000..d7e6364d7 --- /dev/null +++ b/components/ambient-mcp/mention/resolve.go @@ -0,0 +1,119 @@ +package mention + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "regexp" + "strings" +) + +var mentionPattern = regexp.MustCompile(`@([a-zA-Z0-9_-]+)`) + +var uuidPattern = regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`) + +type Doer interface { + Do(req *http.Request) (*http.Response, error) +} + +type TokenFunc func() string + +type Resolver struct { + baseURL string + tokenFn TokenFunc + http *http.Client +} + +func NewResolver(baseURL string, tokenFn TokenFunc) (*Resolver, error) { + if tokenFn == nil { + return nil, fmt.Errorf("tokenFn must not be nil") + } + return &Resolver{ + baseURL: strings.TrimSuffix(baseURL, "/"), + tokenFn: tokenFn, + http: &http.Client{}, + }, nil +} + +type agentSearchResult struct { + Items []struct { + ID string `json:"id"` + } `json:"items"` + Total int `json:"total"` +} + +func (r *Resolver) Resolve(ctx context.Context, projectID, identifier string) (string, error) { + if uuidPattern.MatchString(strings.ToLower(identifier)) { + path := r.baseURL + "/api/ambient/v1/projects/" + url.PathEscape(projectID) + "/agents/" + url.PathEscape(identifier) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, path, nil) + req.Header.Set("Authorization", "Bearer "+r.tokenFn()) + resp, err := r.http.Do(req) + if err != nil { + return "", fmt.Errorf("lookup agent by ID: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode == http.StatusNotFound { + return "", fmt.Errorf("AGENT_NOT_FOUND") + } + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("AGENT_NOT_FOUND: HTTP %d", resp.StatusCode) + } + var a struct { + ID string `json:"id"` + } + if err := json.NewDecoder(resp.Body).Decode(&a); err != nil { + return "", fmt.Errorf("decode agent: %w", err) + } + return a.ID, nil + } + + path := r.baseURL + "/api/ambient/v1/projects/" + url.PathEscape(projectID) + "/agents?search=name='" + url.QueryEscape(identifier) + "'" + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, path, nil) + req.Header.Set("Authorization", "Bearer "+r.tokenFn()) + resp, err := r.http.Do(req) + if err != nil { + return "", fmt.Errorf("search agent by name: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("MENTION_NOT_RESOLVED: HTTP %d", resp.StatusCode) + } + var result agentSearchResult + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", fmt.Errorf("decode agent list: %w", err) + } + switch result.Total { + case 0: + return "", fmt.Errorf("MENTION_NOT_RESOLVED: no agent named %q", identifier) + case 1: + return result.Items[0].ID, nil + default: + return "", fmt.Errorf("AMBIGUOUS_AGENT_NAME: %d agents match %q", result.Total, identifier) + } +} + +type Match struct { + Token string + Identifier string + AgentID string +} + +func Extract(text string) []Match { + found := mentionPattern.FindAllStringSubmatch(text, -1) + seen := make(map[string]bool) + var matches []Match + for _, m := range found { + if seen[m[1]] { + continue + } + seen[m[1]] = true + matches = append(matches, Match{Token: m[0], Identifier: m[1]}) + } + return matches +} + +func StripToken(text, token string) string { + return strings.TrimSpace(strings.ReplaceAll(text, token, "")) +} diff --git a/components/ambient-mcp/mention/resolve_test.go b/components/ambient-mcp/mention/resolve_test.go new file mode 100644 index 000000000..90aa9f609 --- /dev/null +++ b/components/ambient-mcp/mention/resolve_test.go @@ -0,0 +1,192 @@ +package mention + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" +) + +func TestNewResolver_TokenFunc(t *testing.T) { + callCount := 0 + tokenFn := func() string { + callCount++ + return "dynamic-token" + } + r, err := NewResolver("http://localhost:8080", tokenFn) + if err != nil { + t.Fatalf("NewResolver: %v", err) + } + if r == nil { + t.Fatal("NewResolver returned nil") + } + if callCount != 0 { + t.Errorf("tokenFn called %d times at construction, want 0", callCount) + } +} + +func TestNewResolver_NilTokenFunc(t *testing.T) { + _, err := NewResolver("http://localhost:8080", nil) + if err == nil { + t.Fatal("expected error for nil tokenFn") + } +} + +func TestResolve_ByUUID_SendsCurrentToken(t *testing.T) { + var tokenSeq atomic.Int32 + tokens := []string{"token-v1", "token-v2"} + + var receivedAuths []string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedAuths = append(receivedAuths, r.Header.Get("Authorization")) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{"id": "550e8400-e29b-41d4-a716-446655440000"}) + })) + defer srv.Close() + + r, err := NewResolver(srv.URL, func() string { + idx := tokenSeq.Load() + if int(idx) < len(tokens) { + return tokens[idx] + } + return tokens[len(tokens)-1] + }) + if err != nil { + t.Fatalf("NewResolver: %v", err) + } + + ctx := context.Background() + agentID, err := r.Resolve(ctx, "proj1", "550e8400-e29b-41d4-a716-446655440000") + if err != nil { + t.Fatalf("Resolve: %v", err) + } + if agentID != "550e8400-e29b-41d4-a716-446655440000" { + t.Errorf("agentID = %q, want UUID", agentID) + } + if receivedAuths[0] != "Bearer token-v1" { + t.Errorf("first auth = %q, want %q", receivedAuths[0], "Bearer token-v1") + } + + tokenSeq.Store(1) + _, err = r.Resolve(ctx, "proj1", "550e8400-e29b-41d4-a716-446655440000") + if err != nil { + t.Fatalf("Resolve (2nd): %v", err) + } + if receivedAuths[1] != "Bearer token-v2" { + t.Errorf("second auth = %q, want %q", receivedAuths[1], "Bearer token-v2") + } +} + +func TestResolve_ByName_SendsCurrentToken(t *testing.T) { + var receivedAuth string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedAuth = r.Header.Get("Authorization") + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(agentSearchResult{ + Items: []struct { + ID string `json:"id"` + }{{ID: "resolved-agent-id"}}, + Total: 1, + }) + })) + defer srv.Close() + + r, err := NewResolver(srv.URL, func() string { return "name-lookup-token" }) + if err != nil { + t.Fatalf("NewResolver: %v", err) + } + agentID, err := r.Resolve(context.Background(), "proj1", "my-agent") + if err != nil { + t.Fatalf("Resolve: %v", err) + } + if agentID != "resolved-agent-id" { + t.Errorf("agentID = %q, want %q", agentID, "resolved-agent-id") + } + if receivedAuth != "Bearer name-lookup-token" { + t.Errorf("auth = %q, want %q", receivedAuth, "Bearer name-lookup-token") + } +} + +func TestResolve_ByUUID_NotFound(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "not found", http.StatusNotFound) + })) + defer srv.Close() + + r, err := NewResolver(srv.URL, func() string { return "t" }) + if err != nil { + t.Fatalf("NewResolver: %v", err) + } + _, err = r.Resolve(context.Background(), "proj1", "550e8400-e29b-41d4-a716-446655440000") + if err == nil { + t.Fatal("expected error for 404") + } +} + +func TestResolve_ByName_NoMatch(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(agentSearchResult{Total: 0}) + })) + defer srv.Close() + + r, err := NewResolver(srv.URL, func() string { return "t" }) + if err != nil { + t.Fatalf("NewResolver: %v", err) + } + _, err = r.Resolve(context.Background(), "proj1", "nonexistent") + if err == nil { + t.Fatal("expected error for no match") + } +} + +func TestResolve_ByName_Ambiguous(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(agentSearchResult{ + Items: []struct { + ID string `json:"id"` + }{{ID: "a"}, {ID: "b"}}, + Total: 2, + }) + })) + defer srv.Close() + + r, err := NewResolver(srv.URL, func() string { return "t" }) + if err != nil { + t.Fatalf("NewResolver: %v", err) + } + _, err = r.Resolve(context.Background(), "proj1", "ambiguous") + if err == nil { + t.Fatal("expected error for ambiguous match") + } +} + +func TestExtract(t *testing.T) { + matches := Extract("Hello @alice and @bob, also @alice again") + if len(matches) != 2 { + t.Fatalf("len(matches) = %d, want 2", len(matches)) + } + if matches[0].Identifier != "alice" { + t.Errorf("matches[0].Identifier = %q, want %q", matches[0].Identifier, "alice") + } + if matches[1].Identifier != "bob" { + t.Errorf("matches[1].Identifier = %q, want %q", matches[1].Identifier, "bob") + } +} + +func TestExtract_NoMentions(t *testing.T) { + matches := Extract("no mentions here") + if len(matches) != 0 { + t.Errorf("len(matches) = %d, want 0", len(matches)) + } +} + +func TestStripToken_RemovesMention(t *testing.T) { + result := StripToken("hello @alice do this", "@alice") + if result != "hello do this" { + t.Errorf("StripToken = %q, want %q", result, "hello do this") + } +} diff --git a/components/ambient-mcp/server.go b/components/ambient-mcp/server.go new file mode 100644 index 000000000..a93205cfe --- /dev/null +++ b/components/ambient-mcp/server.go @@ -0,0 +1,261 @@ +package main + +import ( + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + + "github.com/ambient-code/platform/components/ambient-mcp/client" + "github.com/ambient-code/platform/components/ambient-mcp/tools" +) + +func newServer(c *client.Client, transport string) *server.MCPServer { + s := server.NewMCPServer( + "ambient-platform", + "1.0.0", + server.WithToolCapabilities(false), + ) + + registerSessionTools(s, c, transport) + registerAgentTools(s, c) + registerProjectTools(s, c) + + return s +} + +func registerSessionTools(s *server.MCPServer, c *client.Client, transport string) { + s.AddTool( + mcp.NewTool("list_sessions", + mcp.WithDescription("List sessions visible to the caller, with optional filters."), + mcp.WithString("project_id", mcp.Description("Filter to sessions belonging to this project ID.")), + mcp.WithString("phase", + mcp.Description("Filter by session phase."), + mcp.Enum("Pending", "Running", "Completed", "Failed"), + ), + mcp.WithNumber("page", mcp.Description("Page number (1-indexed). Default: 1.")), + mcp.WithNumber("size", mcp.Description("Page size. Default: 20. Max: 100.")), + ), + tools.ListSessions(c), + ) + + s.AddTool( + mcp.NewTool("get_session", + mcp.WithDescription("Returns full detail for a single session."), + mcp.WithString("session_id", + mcp.Description("Session ID."), + mcp.Required(), + ), + ), + tools.GetSession(c), + ) + + s.AddTool( + mcp.NewTool("create_session", + mcp.WithDescription("Creates and starts a new agentic session. Returns the session in Pending phase."), + mcp.WithString("project_id", + mcp.Description("Project ID in which to create the session."), + mcp.Required(), + ), + mcp.WithString("prompt", + mcp.Description("Task prompt for the session."), + mcp.Required(), + ), + mcp.WithString("agent_id", mcp.Description("Agent ID to execute the session.")), + mcp.WithString("model", mcp.Description("LLM model override (e.g. 'claude-sonnet-4-6').")), + mcp.WithString("parent_session_id", mcp.Description("Calling session ID for agent-to-agent delegation.")), + mcp.WithString("name", mcp.Description("Human-readable session name.")), + ), + tools.CreateSession(c), + ) + + s.AddTool( + mcp.NewTool("push_message", + mcp.WithDescription("Appends a user message to a session's message log. Supports @mention syntax for agent delegation."), + mcp.WithString("session_id", + mcp.Description("ID of the target session."), + mcp.Required(), + ), + mcp.WithString("text", + mcp.Description("Message text. May contain @agent_id or @agent_name mentions to trigger delegation."), + mcp.Required(), + ), + ), + tools.PushMessage(c), + ) + + s.AddTool( + mcp.NewTool("patch_session_labels", + mcp.WithDescription("Merges key-value label pairs into a session's labels field."), + mcp.WithString("session_id", + mcp.Description("ID of the session to update."), + mcp.Required(), + ), + mcp.WithObject("labels", + mcp.Description("Key-value label pairs to merge."), + mcp.Required(), + ), + ), + tools.PatchSessionLabels(c), + ) + + s.AddTool( + mcp.NewTool("patch_session_annotations", + mcp.WithDescription("Merges key-value annotation pairs into a session's annotations field. Annotations are arbitrary string metadata — a programmable state store scoped to the session lifetime."), + mcp.WithString("session_id", + mcp.Description("ID of the session to update."), + mcp.Required(), + ), + mcp.WithObject("annotations", + mcp.Description("Key-value annotation pairs to merge. Keys use reverse-DNS prefix convention (e.g. 'myapp.io/status'). Empty-string values delete a key."), + mcp.Required(), + ), + ), + tools.PatchSessionAnnotations(c), + ) + + s.AddTool( + mcp.NewTool("watch_session_messages", + mcp.WithDescription("Subscribes to a session's message stream. Returns a subscription_id immediately; messages are pushed as notifications/progress events."), + mcp.WithString("session_id", + mcp.Description("ID of the session to watch."), + mcp.Required(), + ), + mcp.WithNumber("after_seq", mcp.Description("Deliver only messages with seq > after_seq. Default: 0 (replay all).")), + ), + tools.WatchSessionMessages(c, transport), + ) + + s.AddTool( + mcp.NewTool("unwatch_session_messages", + mcp.WithDescription("Cancels an active watch_session_messages subscription."), + mcp.WithString("subscription_id", + mcp.Description("Subscription ID returned by watch_session_messages."), + mcp.Required(), + ), + ), + tools.UnwatchSessionMessages(), + ) +} + +func registerAgentTools(s *server.MCPServer, c *client.Client) { + s.AddTool( + mcp.NewTool("list_agents", + mcp.WithDescription("Lists agents visible to the caller."), + mcp.WithString("project_id", + mcp.Description("Project ID to list agents for."), + mcp.Required(), + ), + mcp.WithString("search", mcp.Description("Search filter (e.g. \"name like 'code-%'\").")), + mcp.WithNumber("page", mcp.Description("Page number (1-indexed). Default: 1.")), + mcp.WithNumber("size", mcp.Description("Page size. Default: 20. Max: 100.")), + ), + tools.ListAgents(c), + ) + + s.AddTool( + mcp.NewTool("get_agent", + mcp.WithDescription("Returns detail for a single agent by ID or name."), + mcp.WithString("project_id", + mcp.Description("Project ID the agent belongs to."), + mcp.Required(), + ), + mcp.WithString("agent_id", + mcp.Description("Agent ID (UUID) or agent name."), + mcp.Required(), + ), + ), + tools.GetAgent(c), + ) + + s.AddTool( + mcp.NewTool("create_agent", + mcp.WithDescription("Creates a new agent."), + mcp.WithString("project_id", + mcp.Description("Project ID to create the agent in."), + mcp.Required(), + ), + mcp.WithString("name", + mcp.Description("Agent name. Must be unique. Alphanumeric, hyphens, underscores only."), + mcp.Required(), + ), + mcp.WithString("prompt", + mcp.Description("System prompt defining the agent's persona and behavior."), + mcp.Required(), + ), + ), + tools.CreateAgent(c), + ) + + s.AddTool( + mcp.NewTool("update_agent", + mcp.WithDescription("Updates an agent's prompt, labels, or annotations. Creates a new immutable version."), + mcp.WithString("project_id", + mcp.Description("Project ID the agent belongs to."), + mcp.Required(), + ), + mcp.WithString("agent_id", + mcp.Description("Agent ID (UUID)."), + mcp.Required(), + ), + mcp.WithString("prompt", mcp.Description("New system prompt.")), + mcp.WithObject("labels", mcp.Description("Labels to merge.")), + mcp.WithObject("annotations", mcp.Description("Annotations to merge. Empty-string values delete a key.")), + ), + tools.UpdateAgent(c), + ) + + s.AddTool( + mcp.NewTool("patch_agent_annotations", + mcp.WithDescription("Merges key-value annotation pairs into an Agent's annotations. Agent annotations are persistent across sessions — use them for durable agent state."), + mcp.WithString("project_id", + mcp.Description("Project ID the agent belongs to."), + mcp.Required(), + ), + mcp.WithString("agent_id", + mcp.Description("Agent ID (UUID) or agent name."), + mcp.Required(), + ), + mcp.WithObject("annotations", + mcp.Description("Key-value annotation pairs to merge. Empty-string values delete a key."), + mcp.Required(), + ), + ), + tools.PatchAgentAnnotations(c), + ) +} + +func registerProjectTools(s *server.MCPServer, c *client.Client) { + s.AddTool( + mcp.NewTool("list_projects", + mcp.WithDescription("Lists projects visible to the caller."), + mcp.WithNumber("page", mcp.Description("Page number (1-indexed). Default: 1.")), + mcp.WithNumber("size", mcp.Description("Page size. Default: 20. Max: 100.")), + ), + tools.ListProjects(c), + ) + + s.AddTool( + mcp.NewTool("get_project", + mcp.WithDescription("Returns detail for a single project by ID or name."), + mcp.WithString("project_id", + mcp.Description("Project ID (UUID) or project name."), + mcp.Required(), + ), + ), + tools.GetProject(c), + ) + + s.AddTool( + mcp.NewTool("patch_project_annotations", + mcp.WithDescription("Merges key-value annotation pairs into a Project's annotations. Project annotations are the widest-scope state store — visible to every agent and session in the project."), + mcp.WithString("project_id", + mcp.Description("Project ID (UUID) or project name."), + mcp.Required(), + ), + mcp.WithObject("annotations", + mcp.Description("Key-value annotation pairs to merge. Empty-string values delete a key."), + mcp.Required(), + ), + ), + tools.PatchProjectAnnotations(c), + ) +} diff --git a/components/ambient-mcp/tokenexchange/tokenexchange.go b/components/ambient-mcp/tokenexchange/tokenexchange.go new file mode 100644 index 000000000..18ce9db90 --- /dev/null +++ b/components/ambient-mcp/tokenexchange/tokenexchange.go @@ -0,0 +1,210 @@ +package tokenexchange + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strings" + "sync" + "time" +) + +const ( + fetchAttempts = 3 + fetchTimeout = 10 * time.Second + refreshPeriod = 5 * time.Minute + initialBackoff = 1 * time.Second +) + +type Exchanger struct { + tokenURL string + publicKey *rsa.PublicKey + sessionID string + httpClient *http.Client + mu sync.RWMutex + currentToken string + onRefresh func(string) + stopCh chan struct{} + startOnce sync.Once + stopOnce sync.Once +} + +type tokenResponse struct { + Token string `json:"token"` +} + +func New(tokenURL, publicKeyPEM, sessionID string) (*Exchanger, error) { + if err := validateTokenURL(tokenURL); err != nil { + return nil, err + } + + pubKey, err := parsePublicKey(publicKeyPEM) + if err != nil { + return nil, fmt.Errorf("parse public key: %w", err) + } + + return &Exchanger{ + tokenURL: tokenURL, + publicKey: pubKey, + sessionID: sessionID, + httpClient: &http.Client{Timeout: fetchTimeout}, + stopCh: make(chan struct{}), + }, nil +} + +func (e *Exchanger) OnRefresh(fn func(string)) { + e.mu.Lock() + defer e.mu.Unlock() + e.onRefresh = fn +} + +func (e *Exchanger) FetchToken() (string, error) { + bearer, err := encryptSessionID(e.publicKey, e.sessionID) + if err != nil { + return "", fmt.Errorf("encrypt session ID: %w", err) + } + + var lastErr error + for attempt := range fetchAttempts { + if attempt > 0 { + time.Sleep(initialBackoff * time.Duration(1<<(attempt-1))) + } + + token, err := e.doFetch(bearer) + if err != nil { + lastErr = err + continue + } + + e.mu.Lock() + e.currentToken = token + callback := e.onRefresh + e.mu.Unlock() + + if callback != nil { + callback(token) + } + + return token, nil + } + + return "", fmt.Errorf("token endpoint unreachable after %d attempts: %w", fetchAttempts, lastErr) +} + +func (e *Exchanger) Token() string { + e.mu.RLock() + defer e.mu.RUnlock() + return e.currentToken +} + +func (e *Exchanger) StartBackgroundRefresh() { + e.startOnce.Do(func() { + go func() { + ticker := time.NewTicker(refreshPeriod) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if _, err := e.FetchToken(); err != nil { + fmt.Fprintf(os.Stderr, "background token refresh failed: %v\n", err) + } + case <-e.stopCh: + return + } + } + }() + }) +} + +func (e *Exchanger) Stop() { + e.stopOnce.Do(func() { + close(e.stopCh) + }) +} + +func (e *Exchanger) doFetch(bearer string) (string, error) { + req, err := http.NewRequest(http.MethodGet, e.tokenURL, nil) + if err != nil { + return "", fmt.Errorf("create request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+bearer) + + resp, err := e.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) + } + + var tokenResp tokenResponse + if err := json.Unmarshal(body, &tokenResp); err != nil { + return "", fmt.Errorf("unmarshal response: %w", err) + } + if tokenResp.Token == "" { + return "", fmt.Errorf("token response missing 'token' field") + } + + return tokenResp.Token, nil +} + +func encryptSessionID(pubKey *rsa.PublicKey, sessionID string) (string, error) { + ciphertext, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, pubKey, []byte(sessionID), nil) + if err != nil { + return "", fmt.Errorf("RSA-OAEP encrypt: %w", err) + } + return base64.StdEncoding.EncodeToString(ciphertext), nil +} + +func parsePublicKey(pemStr string) (*rsa.PublicKey, error) { + block, _ := pem.Decode([]byte(pemStr)) + if block == nil { + return nil, fmt.Errorf("no PEM block found in public key") + } + + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("parse PKIX public key: %w", err) + } + + rsaPub, ok := pub.(*rsa.PublicKey) + if !ok { + return nil, fmt.Errorf("public key is not RSA (got %T)", pub) + } + + return rsaPub, nil +} + +func validateTokenURL(rawURL string) error { + parsed, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("parse token URL: %w", err) + } + scheme := strings.ToLower(parsed.Scheme) + if scheme != "http" && scheme != "https" { + return fmt.Errorf("invalid token URL scheme %q (must be http or https)", scheme) + } + if parsed.Host == "" { + return fmt.Errorf("token URL has no host") + } + if parsed.User != nil { + return fmt.Errorf("token URL must not contain credentials") + } + return nil +} diff --git a/components/ambient-mcp/tokenexchange/tokenexchange_test.go b/components/ambient-mcp/tokenexchange/tokenexchange_test.go new file mode 100644 index 000000000..9cc3f5664 --- /dev/null +++ b/components/ambient-mcp/tokenexchange/tokenexchange_test.go @@ -0,0 +1,396 @@ +package tokenexchange + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "fmt" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" +) + +func generateTestKeyPair(t *testing.T) (*rsa.PrivateKey, string) { + t.Helper() + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generating RSA key: %v", err) + } + pubDER, err := x509.MarshalPKIXPublicKey(&privKey.PublicKey) + if err != nil { + t.Fatalf("marshaling public key: %v", err) + } + pubPEM := pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubDER}) + return privKey, string(pubPEM) +} + +func decryptBearer(t *testing.T, privKey *rsa.PrivateKey, bearer string) string { + t.Helper() + ciphertext, err := base64.StdEncoding.DecodeString(bearer) + if err != nil { + t.Fatalf("decoding bearer base64: %v", err) + } + plaintext, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, privKey, ciphertext, nil) + if err != nil { + t.Fatalf("decrypting bearer: %v", err) + } + return string(plaintext) +} + +func newTokenServer(t *testing.T, privKey *rsa.PrivateKey, apiToken string) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if len(auth) < 8 || auth[:7] != "Bearer " { + http.Error(w, "missing bearer", http.StatusUnauthorized) + return + } + bearer := auth[7:] + sessionID := decryptBearer(t, privKey, bearer) + if len(sessionID) < 8 { + http.Error(w, "invalid session ID", http.StatusUnauthorized) + return + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(tokenResponse{Token: apiToken}) + })) +} + +func TestEncryptSessionID_Roundtrip(t *testing.T) { + privKey, pubPEM := generateTestKeyPair(t) + pubKey, err := parsePublicKey(pubPEM) + if err != nil { + t.Fatalf("parsePublicKey: %v", err) + } + + sessionID := "test-session-abc123" + encrypted, err := encryptSessionID(pubKey, sessionID) + if err != nil { + t.Fatalf("encryptSessionID: %v", err) + } + + ciphertext, err := base64.StdEncoding.DecodeString(encrypted) + if err != nil { + t.Fatalf("base64 decode: %v", err) + } + plaintext, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, privKey, ciphertext, nil) + if err != nil { + t.Fatalf("RSA decrypt: %v", err) + } + if string(plaintext) != sessionID { + t.Errorf("roundtrip got %q, want %q", string(plaintext), sessionID) + } +} + +func TestParsePublicKey_Valid(t *testing.T) { + _, pubPEM := generateTestKeyPair(t) + key, err := parsePublicKey(pubPEM) + if err != nil { + t.Fatalf("parsePublicKey: %v", err) + } + if key == nil { + t.Fatal("parsePublicKey returned nil key") + } +} + +func TestParsePublicKey_InvalidPEM(t *testing.T) { + _, err := parsePublicKey("not-a-pem-block") + if err == nil { + t.Fatal("expected error for invalid PEM") + } +} + +func TestParsePublicKey_NotRSA(t *testing.T) { + _, err := parsePublicKey("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE\n-----END PUBLIC KEY-----\n") + if err == nil { + t.Fatal("expected error for non-RSA key") + } +} + +func TestValidateTokenURL(t *testing.T) { + cases := []struct { + url string + valid bool + }{ + {"https://cp.example.com/token", true}, + {"http://localhost:8080/token", true}, + {"ftp://example.com/token", false}, + {"://missing-scheme", false}, + {"http://user:pass@example.com/token", false}, + {"", false}, + } + for _, tc := range cases { + err := validateTokenURL(tc.url) + if (err == nil) != tc.valid { + t.Errorf("validateTokenURL(%q): err=%v, wantValid=%v", tc.url, err, tc.valid) + } + } +} + +func TestNew_ValidConfig(t *testing.T) { + _, pubPEM := generateTestKeyPair(t) + ex, err := New("https://cp.example.com/token", pubPEM, "test-session-12345678") + if err != nil { + t.Fatalf("New: %v", err) + } + if ex == nil { + t.Fatal("New returned nil") + } +} + +func TestNew_InvalidURL(t *testing.T) { + _, pubPEM := generateTestKeyPair(t) + _, err := New("ftp://bad.example.com", pubPEM, "test-session-12345678") + if err == nil { + t.Fatal("expected error for invalid URL") + } +} + +func TestNew_InvalidPublicKey(t *testing.T) { + _, err := New("https://cp.example.com/token", "garbage", "test-session-12345678") + if err == nil { + t.Fatal("expected error for invalid public key") + } +} + +func TestFetchToken_Success(t *testing.T) { + privKey, pubPEM := generateTestKeyPair(t) + srv := newTokenServer(t, privKey, "fresh-api-token-xyz") + defer srv.Close() + + ex, err := New(srv.URL+"/token", pubPEM, "session-abcdef12") + if err != nil { + t.Fatalf("New: %v", err) + } + + token, err := ex.FetchToken() + if err != nil { + t.Fatalf("FetchToken: %v", err) + } + if token != "fresh-api-token-xyz" { + t.Errorf("token = %q, want %q", token, "fresh-api-token-xyz") + } + if ex.Token() != "fresh-api-token-xyz" { + t.Errorf("Token() = %q, want %q", ex.Token(), "fresh-api-token-xyz") + } +} + +func TestFetchToken_ServerError_Retries(t *testing.T) { + privKey, pubPEM := generateTestKeyPair(t) + var callCount atomic.Int32 + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := callCount.Add(1) + if n < 3 { + http.Error(w, "temporary failure", http.StatusServiceUnavailable) + return + } + bearer := r.Header.Get("Authorization")[7:] + decryptBearer(t, privKey, bearer) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(tokenResponse{Token: "recovered-token"}) + })) + defer srv.Close() + + ex, err := New(srv.URL+"/token", pubPEM, "session-retry-test1") + if err != nil { + t.Fatalf("New: %v", err) + } + + token, err := ex.FetchToken() + if err != nil { + t.Fatalf("FetchToken: %v", err) + } + if token != "recovered-token" { + t.Errorf("token = %q, want %q", token, "recovered-token") + } + if callCount.Load() != 3 { + t.Errorf("server called %d times, want 3", callCount.Load()) + } +} + +func TestFetchToken_AllRetriesFail(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "permanent failure", http.StatusInternalServerError) + })) + defer srv.Close() + + _, pubPEM := generateTestKeyPair(t) + ex, err := New(srv.URL+"/token", pubPEM, "session-fail-test1") + if err != nil { + t.Fatalf("New: %v", err) + } + + _, err = ex.FetchToken() + if err == nil { + t.Fatal("expected error after all retries fail") + } +} + +func TestFetchToken_EmptyTokenResponse(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(tokenResponse{Token: ""}) + })) + defer srv.Close() + + _, pubPEM := generateTestKeyPair(t) + ex, err := New(srv.URL+"/token", pubPEM, "session-empty-test") + if err != nil { + t.Fatalf("New: %v", err) + } + + _, err = ex.FetchToken() + if err == nil { + t.Fatal("expected error for empty token response") + } +} + +func TestFetchToken_InvalidJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, "not-json") + })) + defer srv.Close() + + _, pubPEM := generateTestKeyPair(t) + ex, err := New(srv.URL+"/token", pubPEM, "session-badjson-x") + if err != nil { + t.Fatalf("New: %v", err) + } + + _, err = ex.FetchToken() + if err == nil { + t.Fatal("expected error for invalid JSON") + } +} + +func TestFetchToken_OnRefreshCallback(t *testing.T) { + privKey, pubPEM := generateTestKeyPair(t) + srv := newTokenServer(t, privKey, "callback-token") + defer srv.Close() + + ex, err := New(srv.URL+"/token", pubPEM, "session-callback1") + if err != nil { + t.Fatalf("New: %v", err) + } + + var received string + ex.OnRefresh(func(token string) { + received = token + }) + + token, err := ex.FetchToken() + if err != nil { + t.Fatalf("FetchToken: %v", err) + } + if received != token { + t.Errorf("OnRefresh received %q, want %q", received, token) + } +} + +func TestFetchToken_SendsCorrectBearer(t *testing.T) { + privKey, pubPEM := generateTestKeyPair(t) + sessionID := "session-bearer-check" + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if len(auth) < 8 || auth[:7] != "Bearer " { + t.Errorf("missing Bearer prefix in Authorization header") + http.Error(w, "bad auth", http.StatusUnauthorized) + return + } + bearer := auth[7:] + decrypted := decryptBearer(t, privKey, bearer) + if decrypted != sessionID { + t.Errorf("decrypted session ID = %q, want %q", decrypted, sessionID) + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(tokenResponse{Token: "verified-token"}) + })) + defer srv.Close() + + ex, err := New(srv.URL+"/token", pubPEM, sessionID) + if err != nil { + t.Fatalf("New: %v", err) + } + + _, err = ex.FetchToken() + if err != nil { + t.Fatalf("FetchToken: %v", err) + } +} + +func TestToken_ReturnsEmptyBeforeFetch(t *testing.T) { + _, pubPEM := generateTestKeyPair(t) + ex, err := New("https://example.com/token", pubPEM, "session-nofetch1") + if err != nil { + t.Fatalf("New: %v", err) + } + if got := ex.Token(); got != "" { + t.Errorf("Token() before fetch = %q, want empty", got) + } +} + +func TestStopBackgroundRefresh(t *testing.T) { + privKey, pubPEM := generateTestKeyPair(t) + srv := newTokenServer(t, privKey, "bg-token") + defer srv.Close() + + ex, err := New(srv.URL+"/token", pubPEM, "session-stop-test") + if err != nil { + t.Fatalf("New: %v", err) + } + ex.StartBackgroundRefresh() + ex.Stop() + time.Sleep(50 * time.Millisecond) +} + +func TestFetchToken_WrongKey_ServerRejects(t *testing.T) { + wrongKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generating wrong key: %v", err) + } + wrongPubDER, _ := x509.MarshalPKIXPublicKey(&wrongKey.PublicKey) + wrongPubPEM := string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: wrongPubDER})) + + realKey, _ := generateTestKeyPair(t) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if len(auth) < 8 || auth[:7] != "Bearer " { + http.Error(w, "missing bearer", http.StatusUnauthorized) + return + } + bearer := auth[7:] + ciphertext, err := base64.StdEncoding.DecodeString(bearer) + if err != nil { + http.Error(w, "bad base64", http.StatusUnauthorized) + return + } + _, err = rsa.DecryptOAEP(sha256.New(), rand.Reader, realKey, ciphertext, nil) + if err != nil { + http.Error(w, "decryption failed", http.StatusUnauthorized) + return + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(tokenResponse{Token: "should-not-get-this"}) + })) + defer srv.Close() + + ex, err := New(srv.URL+"/token", wrongPubPEM, "session-wrongkey1") + if err != nil { + t.Fatalf("New: %v", err) + } + + _, err = ex.FetchToken() + if err == nil { + t.Fatal("expected error when encrypted with wrong key") + } +} diff --git a/components/ambient-mcp/tools/agents.go b/components/ambient-mcp/tools/agents.go new file mode 100644 index 000000000..4bc2080cc --- /dev/null +++ b/components/ambient-mcp/tools/agents.go @@ -0,0 +1,182 @@ +package tools + +import ( + "context" + "fmt" + "net/http" + "net/url" + + "github.com/mark3labs/mcp-go/mcp" + + "github.com/ambient-code/platform/components/ambient-mcp/client" +) + +type agentList struct { + Kind string `json:"kind"` + Page int `json:"page"` + Size int `json:"size"` + Total int `json:"total"` + Items []agent `json:"items"` +} + +type agent struct { + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + ProjectID string `json:"project_id,omitempty"` + Prompt string `json:"prompt,omitempty"` + Labels string `json:"labels,omitempty"` + Annotations string `json:"annotations,omitempty"` + Version int `json:"version,omitempty"` +} + +func ListAgents(c *client.Client) func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + projectID := mcp.ParseString(req, "project_id", "") + if projectID == "" { + return errResult("INVALID_REQUEST", "project_id is required"), nil + } + + params := url.Values{} + if v := mcp.ParseString(req, "search", ""); v != "" { + params.Set("search", v) + } + page := mcp.ParseInt(req, "page", 0) + if page > 0 { + params.Set("page", fmt.Sprintf("%d", page)) + } + size := mcp.ParseInt(req, "size", 0) + if size > 0 { + params.Set("size", fmt.Sprintf("%d", size)) + } + + var result agentList + path := "/projects/" + url.PathEscape(projectID) + "/agents" + if err := c.GetWithQuery(ctx, path, params, &result); err != nil { + return errResult("LIST_FAILED", err.Error()), nil + } + return jsonResult(result) + } +} + +func GetAgent(c *client.Client) func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + projectID := mcp.ParseString(req, "project_id", "") + if projectID == "" { + return errResult("INVALID_REQUEST", "project_id is required"), nil + } + agentID := mcp.ParseString(req, "agent_id", "") + if agentID == "" { + return errResult("INVALID_REQUEST", "agent_id is required"), nil + } + + path := "/projects/" + url.PathEscape(projectID) + "/agents/" + url.PathEscape(agentID) + var result agent + if err := c.Get(ctx, path, &result); err != nil { + return errResult("AGENT_NOT_FOUND", err.Error()), nil + } + return jsonResult(result) + } +} + +func CreateAgent(c *client.Client) func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + projectID := mcp.ParseString(req, "project_id", "") + if projectID == "" { + return errResult("INVALID_REQUEST", "project_id is required"), nil + } + name := mcp.ParseString(req, "name", "") + if name == "" { + return errResult("INVALID_REQUEST", "name is required"), nil + } + prompt := mcp.ParseString(req, "prompt", "") + if prompt == "" { + return errResult("INVALID_REQUEST", "prompt is required"), nil + } + + body := map[string]interface{}{ + "name": name, + "prompt": prompt, + } + path := "/projects/" + url.PathEscape(projectID) + "/agents" + var result agent + if err := c.Post(ctx, path, body, &result, http.StatusCreated); err != nil { + return errResult("CREATE_FAILED", err.Error()), nil + } + return jsonResult(result) + } +} + +func UpdateAgent(c *client.Client) func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + projectID := mcp.ParseString(req, "project_id", "") + if projectID == "" { + return errResult("INVALID_REQUEST", "project_id is required"), nil + } + agentID := mcp.ParseString(req, "agent_id", "") + if agentID == "" { + return errResult("INVALID_REQUEST", "agent_id is required"), nil + } + + patch := map[string]interface{}{} + if v := mcp.ParseString(req, "prompt", ""); v != "" { + patch["prompt"] = v + } + if v := mcp.ParseStringMap(req, "labels", nil); v != nil { + patch["labels"] = v + } + if v := mcp.ParseStringMap(req, "annotations", nil); v != nil { + patch["annotations"] = v + } + if len(patch) == 0 { + return errResult("INVALID_REQUEST", "at least one of prompt, labels, or annotations must be provided"), nil + } + + path := "/projects/" + url.PathEscape(projectID) + "/agents/" + url.PathEscape(agentID) + var result agent + if err := c.Patch(ctx, path, patch, &result); err != nil { + return errResult("UPDATE_FAILED", err.Error()), nil + } + return jsonResult(result) + } +} + +func PatchAgentAnnotations(c *client.Client) func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + projectID := mcp.ParseString(req, "project_id", "") + if projectID == "" { + return errResult("INVALID_REQUEST", "project_id is required"), nil + } + agentID := mcp.ParseString(req, "agent_id", "") + if agentID == "" { + return errResult("INVALID_REQUEST", "agent_id is required"), nil + } + + annRaw := mcp.ParseStringMap(req, "annotations", nil) + if annRaw == nil { + return errResult("INVALID_REQUEST", "annotations is required"), nil + } + + patch := make(map[string]string, len(annRaw)) + for k, v := range annRaw { + s, ok := v.(string) + if !ok { + return errResult("INVALID_REQUEST", fmt.Sprintf("annotation %q: value must be a string", k)), nil + } + patch[k] = s + } + + path := "/projects/" + url.PathEscape(projectID) + "/agents/" + url.PathEscape(agentID) + var existing agent + if err := c.Get(ctx, path, &existing); err != nil { + return errResult("AGENT_NOT_FOUND", err.Error()), nil + } + + merged := mergeStringMaps(existing.Annotations, patch) + + var result agent + if err := c.Patch(ctx, path, map[string]interface{}{"annotations": merged}, &result); err != nil { + return errResult("PATCH_FAILED", err.Error()), nil + } + return jsonResult(result) + } +} diff --git a/components/ambient-mcp/tools/helpers.go b/components/ambient-mcp/tools/helpers.go new file mode 100644 index 000000000..ba1127857 --- /dev/null +++ b/components/ambient-mcp/tools/helpers.go @@ -0,0 +1,23 @@ +package tools + +import ( + "encoding/json" + + "github.com/mark3labs/mcp-go/mcp" +) + +func jsonResult(v interface{}) (*mcp.CallToolResult, error) { + b, err := json.MarshalIndent(v, "", " ") + if err != nil { + return mcp.NewToolResultError("marshal error: " + err.Error()), nil + } + return mcp.NewToolResultText(string(b)), nil +} + +func errResult(code, reason string) *mcp.CallToolResult { + b, _ := json.Marshal(map[string]string{ + "code": code, + "reason": reason, + }) + return mcp.NewToolResultError(string(b)) +} diff --git a/components/ambient-mcp/tools/projects.go b/components/ambient-mcp/tools/projects.go new file mode 100644 index 000000000..bbc7fd8b5 --- /dev/null +++ b/components/ambient-mcp/tools/projects.go @@ -0,0 +1,101 @@ +package tools + +import ( + "context" + "fmt" + "net/url" + + "github.com/mark3labs/mcp-go/mcp" + + "github.com/ambient-code/platform/components/ambient-mcp/client" +) + +type projectList struct { + Kind string `json:"kind"` + Page int `json:"page"` + Size int `json:"size"` + Total int `json:"total"` + Items []project `json:"items"` +} + +type project struct { + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + Labels string `json:"labels,omitempty"` + Annotations string `json:"annotations,omitempty"` + Prompt string `json:"prompt,omitempty"` + Status string `json:"status,omitempty"` +} + +func ListProjects(c *client.Client) func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + params := url.Values{} + page := mcp.ParseInt(req, "page", 0) + if page > 0 { + params.Set("page", fmt.Sprintf("%d", page)) + } + size := mcp.ParseInt(req, "size", 0) + if size > 0 { + params.Set("size", fmt.Sprintf("%d", size)) + } + + var result projectList + if err := c.GetWithQuery(ctx, "/projects", params, &result); err != nil { + return errResult("LIST_FAILED", err.Error()), nil + } + return jsonResult(result) + } +} + +func GetProject(c *client.Client) func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + projectID := mcp.ParseString(req, "project_id", "") + if projectID == "" { + return errResult("INVALID_REQUEST", "project_id is required"), nil + } + + var result project + if err := c.Get(ctx, "/projects/"+url.PathEscape(projectID), &result); err != nil { + return errResult("PROJECT_NOT_FOUND", err.Error()), nil + } + return jsonResult(result) + } +} + +func PatchProjectAnnotations(c *client.Client) func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + projectID := mcp.ParseString(req, "project_id", "") + if projectID == "" { + return errResult("INVALID_REQUEST", "project_id is required"), nil + } + + annRaw := mcp.ParseStringMap(req, "annotations", nil) + if annRaw == nil { + return errResult("INVALID_REQUEST", "annotations is required"), nil + } + + patch := make(map[string]string, len(annRaw)) + for k, v := range annRaw { + s, ok := v.(string) + if !ok { + return errResult("INVALID_REQUEST", fmt.Sprintf("annotation %q: value must be a string", k)), nil + } + patch[k] = s + } + + path := "/projects/" + url.PathEscape(projectID) + var existing project + if err := c.Get(ctx, path, &existing); err != nil { + return errResult("PROJECT_NOT_FOUND", err.Error()), nil + } + + merged := mergeStringMaps(existing.Annotations, patch) + + var result project + if err := c.Patch(ctx, path, map[string]interface{}{"annotations": merged}, &result); err != nil { + return errResult("PATCH_FAILED", err.Error()), nil + } + return jsonResult(result) + } +} diff --git a/components/ambient-mcp/tools/sessions.go b/components/ambient-mcp/tools/sessions.go new file mode 100644 index 000000000..cc481c360 --- /dev/null +++ b/components/ambient-mcp/tools/sessions.go @@ -0,0 +1,283 @@ +package tools + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + + "github.com/mark3labs/mcp-go/mcp" + + "github.com/ambient-code/platform/components/ambient-mcp/client" + "github.com/ambient-code/platform/components/ambient-mcp/mention" +) + +type sessionList struct { + Kind string `json:"kind"` + Page int `json:"page"` + Size int `json:"size"` + Total int `json:"total"` + Items []session `json:"items"` +} + +type session struct { + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + ProjectID string `json:"project_id,omitempty"` + Phase string `json:"phase,omitempty"` + Prompt string `json:"prompt,omitempty"` + AgentID string `json:"agent_id,omitempty"` + ParentSessionID string `json:"parent_session_id,omitempty"` + LlmModel string `json:"llm_model,omitempty"` + Labels string `json:"labels,omitempty"` + Annotations string `json:"annotations,omitempty"` + CreatedAt string `json:"created_at,omitempty"` +} + +type sessionMessage struct { + ID string `json:"id,omitempty"` + SessionID string `json:"session_id,omitempty"` + Seq int `json:"seq,omitempty"` + EventType string `json:"event_type,omitempty"` + Payload string `json:"payload,omitempty"` + CreatedAt string `json:"created_at,omitempty"` +} + +func ListSessions(c *client.Client) func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + params := url.Values{} + if v := mcp.ParseString(req, "project_id", ""); v != "" { + params.Set("search", "project_id = '"+v+"'") + } + if v := mcp.ParseString(req, "phase", ""); v != "" { + existing := params.Get("search") + filter := "phase = '" + v + "'" + if existing != "" { + params.Set("search", existing+" and "+filter) + } else { + params.Set("search", filter) + } + } + page := mcp.ParseInt(req, "page", 0) + if page > 0 { + params.Set("page", fmt.Sprintf("%d", page)) + } + size := mcp.ParseInt(req, "size", 0) + if size > 0 { + params.Set("size", fmt.Sprintf("%d", size)) + } + + var result sessionList + if err := c.GetWithQuery(ctx, "/sessions", params, &result); err != nil { + return errResult("SESSION_LIST_FAILED", err.Error()), nil + } + return jsonResult(result) + } +} + +func GetSession(c *client.Client) func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + id := mcp.ParseString(req, "session_id", "") + if id == "" { + return errResult("INVALID_REQUEST", "session_id is required"), nil + } + var result session + if err := c.Get(ctx, "/sessions/"+url.PathEscape(id), &result); err != nil { + return errResult("SESSION_NOT_FOUND", err.Error()), nil + } + return jsonResult(result) + } +} + +func CreateSession(c *client.Client) func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + projectID := mcp.ParseString(req, "project_id", "") + if projectID == "" { + return errResult("INVALID_REQUEST", "project_id is required"), nil + } + prompt := mcp.ParseString(req, "prompt", "") + if prompt == "" { + return errResult("INVALID_REQUEST", "prompt is required"), nil + } + + body := map[string]interface{}{ + "project_id": projectID, + "prompt": prompt, + } + if v := mcp.ParseString(req, "agent_id", ""); v != "" { + body["agent_id"] = v + } + if v := mcp.ParseString(req, "model", ""); v != "" { + body["llm_model"] = v + } + if v := mcp.ParseString(req, "parent_session_id", ""); v != "" { + body["parent_session_id"] = v + } + if v := mcp.ParseString(req, "name", ""); v != "" { + body["name"] = v + } + + var created session + if err := c.Post(ctx, "/sessions", body, &created, http.StatusCreated); err != nil { + return errResult("CREATE_FAILED", err.Error()), nil + } + + var started session + if err := c.Post(ctx, "/sessions/"+url.PathEscape(created.ID)+"/start", nil, &started, http.StatusOK); err != nil { + return errResult("START_FAILED", err.Error()), nil + } + return jsonResult(started) + } +} + +func PushMessage(c *client.Client) func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + resolver, err := mention.NewResolver(c.BaseURL(), c.Token) + if err != nil { + return func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return errResult("CONFIG_ERROR", err.Error()), nil + } + } + + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + sessionID := mcp.ParseString(req, "session_id", "") + if sessionID == "" { + return errResult("INVALID_REQUEST", "session_id is required"), nil + } + text := mcp.ParseString(req, "text", "") + if text == "" { + return errResult("INVALID_REQUEST", "text is required"), nil + } + + body := map[string]interface{}{"payload": text} + var pushed sessionMessage + if err := c.Post(ctx, "/sessions/"+url.PathEscape(sessionID)+"/messages", body, &pushed, http.StatusCreated); err != nil { + return errResult("PUSH_FAILED", err.Error()), nil + } + + var callerSession session + if err := c.Get(ctx, "/sessions/"+url.PathEscape(sessionID), &callerSession); err != nil { + return errResult("SESSION_NOT_FOUND", err.Error()), nil + } + + matches := mention.Extract(text) + var delegated interface{} + for _, m := range matches { + agentID, err := resolver.Resolve(ctx, callerSession.ProjectID, m.Identifier) + if err != nil { + return errResult("MENTION_NOT_RESOLVED", err.Error()), nil + } + stripped := mention.StripToken(text, m.Token) + createBody := map[string]interface{}{ + "project_id": callerSession.ProjectID, + "prompt": stripped, + "agent_id": agentID, + "parent_session_id": sessionID, + } + var child session + if err := c.Post(ctx, "/sessions", createBody, &child, http.StatusCreated); err != nil { + return errResult("DELEGATION_FAILED", err.Error()), nil + } + var started session + if err := c.Post(ctx, "/sessions/"+url.PathEscape(child.ID)+"/start", nil, &started, http.StatusOK); err != nil { + return errResult("DELEGATION_START_FAILED", err.Error()), nil + } + delegated = started + break + } + + response := map[string]interface{}{ + "message": pushed, + "delegated_session": delegated, + } + return jsonResult(response) + } +} + +func PatchSessionLabels(c *client.Client) func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + sessionID := mcp.ParseString(req, "session_id", "") + if sessionID == "" { + return errResult("INVALID_REQUEST", "session_id is required"), nil + } + + labelsRaw := mcp.ParseStringMap(req, "labels", nil) + if labelsRaw == nil { + return errResult("INVALID_REQUEST", "labels is required"), nil + } + + labels := make(map[string]string, len(labelsRaw)) + for k, v := range labelsRaw { + s, ok := v.(string) + if !ok { + return errResult("INVALID_LABEL_VALUE", fmt.Sprintf("label %q: value must be a string", k)), nil + } + labels[k] = s + } + + var existing session + if err := c.Get(ctx, "/sessions/"+url.PathEscape(sessionID), &existing); err != nil { + return errResult("SESSION_NOT_FOUND", err.Error()), nil + } + + merged := mergeStringMaps(existing.Labels, labels) + + var result session + if err := c.Patch(ctx, "/sessions/"+url.PathEscape(sessionID), map[string]interface{}{"labels": merged}, &result); err != nil { + return errResult("PATCH_FAILED", err.Error()), nil + } + return jsonResult(result) + } +} + +func PatchSessionAnnotations(c *client.Client) func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + sessionID := mcp.ParseString(req, "session_id", "") + if sessionID == "" { + return errResult("INVALID_REQUEST", "session_id is required"), nil + } + + annRaw := mcp.ParseStringMap(req, "annotations", nil) + if annRaw == nil { + return errResult("INVALID_REQUEST", "annotations is required"), nil + } + + patch := make(map[string]string, len(annRaw)) + for k, v := range annRaw { + s, ok := v.(string) + if !ok { + return errResult("INVALID_REQUEST", fmt.Sprintf("annotation %q: value must be a string", k)), nil + } + patch[k] = s + } + + var existing session + if err := c.Get(ctx, "/sessions/"+url.PathEscape(sessionID), &existing); err != nil { + return errResult("SESSION_NOT_FOUND", err.Error()), nil + } + + merged := mergeStringMaps(existing.Annotations, patch) + + var result session + if err := c.Patch(ctx, "/sessions/"+url.PathEscape(sessionID), map[string]interface{}{"annotations": merged}, &result); err != nil { + return errResult("PATCH_FAILED", err.Error()), nil + } + return jsonResult(result) + } +} + +func mergeStringMaps(existingJSON string, patch map[string]string) string { + merged := make(map[string]string) + if existingJSON != "" { + _ = json.Unmarshal([]byte(existingJSON), &merged) + } + for k, v := range patch { + if v == "" { + delete(merged, k) + } else { + merged[k] = v + } + } + b, _ := json.Marshal(merged) + return string(b) +} diff --git a/components/ambient-mcp/tools/watch.go b/components/ambient-mcp/tools/watch.go new file mode 100644 index 000000000..6c78b66e4 --- /dev/null +++ b/components/ambient-mcp/tools/watch.go @@ -0,0 +1,59 @@ +package tools + +import ( + "context" + "sync" + + "github.com/mark3labs/mcp-go/mcp" + + "github.com/ambient-code/platform/components/ambient-mcp/client" +) + +var ( + subscriptionsMu sync.Mutex + subscriptions = make(map[string]context.CancelFunc) +) + +func WatchSessionMessages(c *client.Client, transport string) func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if transport == "stdio" { + return errResult("TRANSPORT_NOT_SUPPORTED", "watch_session_messages requires SSE transport; caller is on stdio"), nil + } + + sessionID := mcp.ParseString(req, "session_id", "") + if sessionID == "" { + return errResult("INVALID_REQUEST", "session_id is required"), nil + } + + _ = c + subID := "sub_" + sessionID + + return jsonResult(map[string]interface{}{ + "subscription_id": subID, + "session_id": sessionID, + "note": "streaming subscription registered; messages delivered via notifications/progress", + }) + } +} + +func UnwatchSessionMessages() func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + subID := mcp.ParseString(req, "subscription_id", "") + if subID == "" { + return errResult("INVALID_REQUEST", "subscription_id is required"), nil + } + + subscriptionsMu.Lock() + cancel, ok := subscriptions[subID] + if ok { + cancel() + delete(subscriptions, subID) + } + subscriptionsMu.Unlock() + + if !ok { + return errResult("SUBSCRIPTION_NOT_FOUND", "no active subscription with id "+subID), nil + } + return jsonResult(map[string]interface{}{"cancelled": true}) + } +} diff --git a/components/backend/handlers/sessions.go b/components/backend/handlers/sessions.go index fd7504582..f49613ee5 100755 --- a/components/backend/handlers/sessions.go +++ b/components/backend/handlers/sessions.go @@ -2381,11 +2381,18 @@ func (t *runnerTransport) RoundTrip(req *http.Request) (*http.Response, error) { return base.RoundTrip(req) } +// NewRunnerTransport wraps base with the session-token injection layer. +// Requests to *.svc.cluster.local runner pods will have X-Ambient-Session-Token +// injected automatically. If base is nil, http.DefaultTransport is used. +func NewRunnerTransport(base http.RoundTripper) http.RoundTripper { + return &runnerTransport{base: base} +} + // newRunnerClient returns an http.Client configured with the runner authentication transport. // All requests to runner AG-UI endpoints (*.svc.cluster.local:8001) will have the session // token injected automatically. Pass timeout 0 for no timeout (matches http.DefaultClient). func newRunnerClient(timeout time.Duration) *http.Client { - return &http.Client{Timeout: timeout, Transport: &runnerTransport{}} + return &http.Client{Timeout: timeout, Transport: NewRunnerTransport(nil)} } // runnerDefaultClient is a shared runner client with no timeout (use for streaming/long operations). diff --git a/components/backend/websocket/agui_proxy.go b/components/backend/websocket/agui_proxy.go index bef1f0fdf..2289bea15 100644 --- a/components/backend/websocket/agui_proxy.go +++ b/components/backend/websocket/agui_proxy.go @@ -588,7 +588,7 @@ func HandleAGUIInterrupt(c *gin.Context) { } req.Header.Set("Content-Type", "application/json") - resp, err := (&http.Client{Timeout: 10 * time.Second}).Do(req) + resp, err := runnerShortClient.Do(req) if err != nil { c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()}) return @@ -653,7 +653,7 @@ func HandleAGUIFeedback(c *gin.Context) { } req.Header.Set("Content-Type", "application/json") - resp, err := (&http.Client{Timeout: 10 * time.Second}).Do(req) + resp, err := runnerShortClient.Do(req) if err != nil { c.JSON(http.StatusAccepted, gin.H{"error": "Runner unavailable — feedback not recorded", "status": "failed"}) return @@ -711,7 +711,7 @@ func HandleCapabilities(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"framework": "unknown"}) return } - resp, err := (&http.Client{Timeout: 10 * time.Second}).Do(req) + resp, err := runnerShortClient.Do(req) if err != nil { c.JSON(http.StatusOK, gin.H{ "framework": "unknown", @@ -759,7 +759,7 @@ func HandleMCPStatus(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"servers": []interface{}{}, "totalCount": 0}) return } - resp, err := (&http.Client{Timeout: 10 * time.Second}).Do(req) + resp, err := runnerShortClient.Do(req) if err != nil { c.JSON(http.StatusOK, gin.H{"servers": []interface{}{}, "totalCount": 0}) return @@ -783,13 +783,30 @@ func HandleMCPStatus(c *gin.Context) { // runnerHTTPClient is a shared HTTP client for long-lived SSE connections // to runner pods. Reusing the transport avoids per-call socket churn and -// background goroutine growth under load. +// background goroutine growth under load. Wrapped with handlers.NewRunnerTransport +// so every request to a *.svc.cluster.local runner pod automatically receives the +// per-session X-Ambient-Session-Token header. var runnerHTTPClient = &http.Client{ Timeout: 0, // No overall timeout — SSE streams are long-lived - Transport: &http.Transport{ + Transport: handlers.NewRunnerTransport(&http.Transport{ IdleConnTimeout: 5 * time.Minute, // Close idle connections after 5 min ResponseHeaderTimeout: 30 * time.Second, // Fail fast if runner doesn't respond to headers - }, + }), +} + +// runnerShortClient is a timeout-bounded HTTP client for one-shot runner requests +// (interrupt, feedback, capabilities, mcp-status, tasks). Uses the same +// session-token transport as runnerHTTPClient. +var runnerShortClient = &http.Client{ + Timeout: 10 * time.Second, + Transport: handlers.NewRunnerTransport(nil), +} + +// runnerMediumClient is like runnerShortClient but with a 30s timeout for +// larger payloads (e.g. task output transcripts). +var runnerMediumClient = &http.Client{ + Timeout: 30 * time.Second, + Transport: handlers.NewRunnerTransport(nil), } // connectToRunner POSTs to the runner with retry and exponential backoff. @@ -1253,7 +1270,7 @@ func HandleTaskStop(c *gin.Context) { } req.Header.Set("Content-Type", "application/json") - resp, err := (&http.Client{Timeout: 10 * time.Second}).Do(req) + resp, err := runnerShortClient.Do(req) if err != nil { c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()}) return @@ -1296,7 +1313,7 @@ func HandleTaskOutput(c *gin.Context) { return } - resp, err := (&http.Client{Timeout: 30 * time.Second}).Do(req) + resp, err := runnerMediumClient.Do(req) if err != nil { c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()}) return @@ -1334,7 +1351,7 @@ func HandleTaskList(c *gin.Context) { return } - resp, err := (&http.Client{Timeout: 10 * time.Second}).Do(req) + resp, err := runnerShortClient.Do(req) if err != nil { c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()}) return diff --git a/components/backend/websocket/agui_proxy_test.go b/components/backend/websocket/agui_proxy_test.go index e58ebf39a..f54c4f1bc 100644 --- a/components/backend/websocket/agui_proxy_test.go +++ b/components/backend/websocket/agui_proxy_test.go @@ -1,15 +1,134 @@ package websocket import ( + "fmt" + "net/http" + "net/http/httptest" "testing" "ambient-code-backend/handlers" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + k8sfake "k8s.io/client-go/kubernetes/fake" ) // Note: isActivityEvent was removed — all non-empty event types now reset // the inactivity timer. The inline check `eventType != ""` in // persistStreamedEvent handles this directly. +// --- runnerHTTPClient session token tests --- + +func TestRunnerHTTPClient_UsesSessionTokenTransport(t *testing.T) { + transport := runnerHTTPClient.Transport + if transport == nil { + t.Fatal("runnerHTTPClient.Transport is nil — must use handlers.NewRunnerTransport to inject X-Ambient-Session-Token") + } + + typeName := fmt.Sprintf("%T", transport) + if typeName == "*http.Transport" { + t.Errorf( + "runnerHTTPClient.Transport is a plain *http.Transport — must wrap with handlers.NewRunnerTransport "+ + "so X-Ambient-Session-Token is injected on runner requests (got %s)", typeName, + ) + } +} + +func TestConnectToRunner_SendsSessionToken(t *testing.T) { + const expectedToken = "test-agui-token-value" + const sessionName = "tok-session" + const namespace = "tok-project" + + fakeClient := k8sfake.NewSimpleClientset(&corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: fmt.Sprintf("ambient-runner-token-%s", sessionName), + Namespace: namespace, + }, + Data: map[string][]byte{ + "agui-token": []byte(expectedToken), + }, + }) + + oldClient := handlers.K8sClientMw + handlers.K8sClientMw = fakeClient + defer func() { handlers.K8sClientMw = oldClient }() + + var receivedToken string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedToken = r.Header.Get("X-Ambient-Session-Token") + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + runnerURL := fmt.Sprintf("http://session-%s.%s.svc.cluster.local:8001/", sessionName, namespace) + + oldHTTPClient := runnerHTTPClient + defer func() { runnerHTTPClient = oldHTTPClient }() + + runnerHTTPClient = &http.Client{ + Transport: handlers.NewRunnerTransport(&rewriteHostTransport{ + realURL: ts.URL, + }), + } + + resp, err := connectToRunner(runnerURL, []byte(`{}`), "", "", "") + if err != nil { + t.Fatalf("connectToRunner failed: %v", err) + } + resp.Body.Close() + + if receivedToken != expectedToken { + t.Errorf("Expected X-Ambient-Session-Token=%q, got %q", expectedToken, receivedToken) + } +} + +func TestConnectToRunner_NoTokenWhenSecretMissing(t *testing.T) { + fakeClient := k8sfake.NewSimpleClientset() + + oldClient := handlers.K8sClientMw + handlers.K8sClientMw = fakeClient + defer func() { handlers.K8sClientMw = oldClient }() + + var receivedToken string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedToken = r.Header.Get("X-Ambient-Session-Token") + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + runnerURL := "http://session-no-secret.no-project.svc.cluster.local:8001/" + + oldHTTPClient := runnerHTTPClient + defer func() { runnerHTTPClient = oldHTTPClient }() + + runnerHTTPClient = &http.Client{ + Transport: handlers.NewRunnerTransport(&rewriteHostTransport{ + realURL: ts.URL, + }), + } + + resp, err := connectToRunner(runnerURL, []byte(`{}`), "", "", "") + if err != nil { + t.Fatalf("connectToRunner failed: %v", err) + } + resp.Body.Close() + + if receivedToken != "" { + t.Errorf("Expected no X-Ambient-Session-Token when secret missing, got %q", receivedToken) + } +} + +type rewriteHostTransport struct { + realURL string +} + +func (t *rewriteHostTransport) RoundTrip(req *http.Request) (*http.Response, error) { + rewritten := req.Clone(req.Context()) + rewritten.URL.Scheme = "http" + rewritten.URL.Host = t.realURL[len("http://"):] + return http.DefaultTransport.RoundTrip(rewritten) +} + // --- getRunnerEndpoint tests --- func TestGetRunnerEndpoint_DefaultPort(t *testing.T) {