diff --git a/components/manifests/base/ambient-control-plane-service.yml b/components/manifests/base/ambient-control-plane-service.yml new file mode 100644 index 000000000..6ed2783c7 --- /dev/null +++ b/components/manifests/base/ambient-control-plane-service.yml @@ -0,0 +1,55 @@ +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: ambient-control-plane + labels: + app: ambient-control-plane +spec: + replicas: 1 + selector: + matchLabels: + app: ambient-control-plane + template: + metadata: + labels: + app: ambient-control-plane + spec: + serviceAccountName: ambient-control-plane + securityContext: + runAsNonRoot: true + seccompProfile: + type: RuntimeDefault + containers: + - name: ambient-control-plane + image: quay.io/ambient_code/vteam_control_plane:latest + imagePullPolicy: Always + securityContext: + allowPrivilegeEscalation: false + readOnlyRootFilesystem: true + capabilities: + drop: ["ALL"] + env: + - name: AMBIENT_API_TOKEN + valueFrom: + secretKeyRef: + name: ambient-control-plane-token + key: token + - name: AMBIENT_API_SERVER_URL + value: "https://ambient-api-server.ambient-code.svc:8000" + - name: AMBIENT_GRPC_SERVER_ADDR + value: "ambient-api-server.ambient-code.svc:9000" + - name: AMBIENT_GRPC_USE_TLS + value: "true" + - name: MODE + value: "kube" + - name: LOG_LEVEL + value: "info" + resources: + requests: + cpu: 50m + memory: 64Mi + limits: + cpu: 200m + memory: 256Mi + restartPolicy: Always diff --git a/components/manifests/base/kustomization.yaml b/components/manifests/base/kustomization.yaml index f9f8a242d..77f667b8d 100644 --- a/components/manifests/base/kustomization.yaml +++ b/components/manifests/base/kustomization.yaml @@ -8,6 +8,7 @@ resources: - core - rbac - platform +- ambient-control-plane-service.yml # Default images (can be overridden by overlays) images: @@ -25,3 +26,7 @@ images: newTag: latest - name: quay.io/ambient_code/vteam_api_server newTag: latest +- name: quay.io/ambient_code/vteam_control_plane + newTag: latest +- name: quay.io/ambient_code/vteam_mcp + newTag: latest diff --git a/components/manifests/base/rbac/control-plane-clusterrole.yaml b/components/manifests/base/rbac/control-plane-clusterrole.yaml new file mode 100644 index 000000000..c2cec298e --- /dev/null +++ b/components/manifests/base/rbac/control-plane-clusterrole.yaml @@ -0,0 +1,27 @@ +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: ambient-control-plane +rules: +# AgenticSession custom resources (full lifecycle management) +- apiGroups: ["vteam.ambient-code"] + resources: ["agenticsessions"] + verbs: ["get", "list", "watch", "create", "update", "patch", "delete"] +- apiGroups: ["vteam.ambient-code"] + resources: ["agenticsessions/status"] + verbs: ["update", "patch"] +# Namespaces (create and label per-project namespaces) +- apiGroups: [""] + resources: ["namespaces"] + verbs: ["get", "list", "watch", "create", "update", "patch"] +# RoleBindings (reconcile group access from ProjectSettings) +- apiGroups: ["rbac.authorization.k8s.io"] + resources: ["rolebindings"] + verbs: ["get", "list", "watch", "create", "update", "patch", "delete"] +# Session runner resources (provision/deprovision per-session workloads in project namespaces) +- apiGroups: [""] + resources: ["secrets", "serviceaccounts", "services", "pods"] + verbs: ["get", "list", "watch", "create", "update", "patch", "delete", "deletecollection"] +- apiGroups: ["batch"] + resources: ["jobs"] + verbs: ["get", "list", "watch", "create", "update", "patch", "delete"] diff --git a/components/manifests/overlays/cluster-reader/cluster-role-binding.yaml b/components/manifests/base/rbac/control-plane-clusterrolebinding.yaml similarity index 52% rename from components/manifests/overlays/cluster-reader/cluster-role-binding.yaml rename to components/manifests/base/rbac/control-plane-clusterrolebinding.yaml index e62d0bdbc..c327e2887 100644 --- a/components/manifests/overlays/cluster-reader/cluster-role-binding.yaml +++ b/components/manifests/base/rbac/control-plane-clusterrolebinding.yaml @@ -1,12 +1,12 @@ apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRoleBinding metadata: - name: readonly-admin-cluster-reader + name: ambient-control-plane roleRef: apiGroup: rbac.authorization.k8s.io kind: ClusterRole - name: cluster-reader + name: ambient-control-plane subjects: - - kind: ServiceAccount - name: readonly-admin - namespace: ambient-code +- kind: ServiceAccount + name: ambient-control-plane + namespace: ambient-code diff --git a/components/manifests/base/rbac/control-plane-sa.yaml b/components/manifests/base/rbac/control-plane-sa.yaml new file mode 100644 index 000000000..6f2368730 --- /dev/null +++ b/components/manifests/base/rbac/control-plane-sa.yaml @@ -0,0 +1,14 @@ +apiVersion: v1 +kind: ServiceAccount +metadata: + name: ambient-control-plane + namespace: ambient-code +--- +apiVersion: v1 +kind: Secret +metadata: + name: ambient-control-plane-token + namespace: ambient-code + annotations: + kubernetes.io/service-account.name: ambient-control-plane +type: kubernetes.io/service-account-token diff --git a/components/manifests/base/rbac/kustomization.yaml b/components/manifests/base/rbac/kustomization.yaml index 7f5a9572d..72b1c2b28 100644 --- a/components/manifests/base/rbac/kustomization.yaml +++ b/components/manifests/base/rbac/kustomization.yaml @@ -14,3 +14,6 @@ resources: - frontend-rbac.yaml - aggregate-agenticsessions-admin.yaml - aggregate-projectsettings-admin.yaml +- control-plane-sa.yaml +- control-plane-clusterrole.yaml +- control-plane-clusterrolebinding.yaml diff --git a/components/manifests/overlays/cluster-reader/README.md b/components/manifests/overlays/cluster-reader/README.md deleted file mode 100644 index 6ce54b03f..000000000 --- a/components/manifests/overlays/cluster-reader/README.md +++ /dev/null @@ -1,35 +0,0 @@ -# Cluster Reader Service Account - -Read-only service account using OpenShift's built-in `cluster-reader` ClusterRole. - -## Use cases - -- CI pipelines that observe cluster state (pod status, deployment health, events) -- Development tooling and dashboards -- Any automation that needs cluster-wide visibility without write access - -## Permissions - -- **Can read**: pods, deployments, nodes, namespaces, configmaps, events, and most other resources across all namespaces -- **Cannot read**: secrets -- **Cannot write**: anything (create, update, delete, patch all denied) - -## Usage - -```bash -# Apply (OpenShift only — cluster-reader does not exist on vanilla K8s/kind) -oc apply -k components/manifests/overlays/cluster-reader/ - -# Override namespace -cd components/manifests/overlays/cluster-reader -NS=my-namespace -kustomize edit set namespace "${NS}" -oc apply -k . - -# Get a token (max 1 year) -oc create token readonly-admin -n "${NS}" --duration=8760h - -# Verify read-only access -oc auth can-i get pods --all-namespaces --as=system:serviceaccount:${NS}:readonly-admin -oc auth can-i delete pods -n "${NS}" --as=system:serviceaccount:${NS}:readonly-admin -``` diff --git a/components/manifests/overlays/cluster-reader/kustomization.yaml b/components/manifests/overlays/cluster-reader/kustomization.yaml deleted file mode 100644 index 205474e2b..000000000 --- a/components/manifests/overlays/cluster-reader/kustomization.yaml +++ /dev/null @@ -1,8 +0,0 @@ -apiVersion: kustomize.config.k8s.io/v1beta1 -kind: Kustomization - -namespace: ambient-code - -resources: - - service-account.yaml - - cluster-role-binding.yaml diff --git a/components/manifests/overlays/cluster-reader/service-account.yaml b/components/manifests/overlays/cluster-reader/service-account.yaml deleted file mode 100644 index e3989cc30..000000000 --- a/components/manifests/overlays/cluster-reader/service-account.yaml +++ /dev/null @@ -1,4 +0,0 @@ -apiVersion: v1 -kind: ServiceAccount -metadata: - name: readonly-admin diff --git a/components/manifests/overlays/kind-local/operator-env-patch.yaml b/components/manifests/overlays/kind-local/operator-env-patch.yaml index 15b203005..0b7a06811 100644 --- a/components/manifests/overlays/kind-local/operator-env-patch.yaml +++ b/components/manifests/overlays/kind-local/operator-env-patch.yaml @@ -19,3 +19,5 @@ spec: value: "IfNotPresent" - name: POD_FSGROUP value: "0" + - name: RUNNER_LOG_LEVEL + value: "debug" diff --git a/components/manifests/overlays/local-dev/ambient-api-server-route.yaml b/components/manifests/overlays/local-dev/ambient-api-server-route.yaml index 1530d558f..1b3c195a9 100644 --- a/components/manifests/overlays/local-dev/ambient-api-server-route.yaml +++ b/components/manifests/overlays/local-dev/ambient-api-server-route.yaml @@ -6,6 +6,8 @@ metadata: labels: app: ambient-api-server component: api + annotations: + haproxy.router.openshift.io/timeout: 10m spec: to: kind: Service @@ -23,6 +25,8 @@ metadata: labels: app: ambient-api-server component: grpc + annotations: + haproxy.router.openshift.io/timeout: 10m spec: to: kind: Service diff --git a/components/manifests/overlays/mpp-openshift/README.md b/components/manifests/overlays/mpp-openshift/README.md new file mode 100644 index 000000000..08eb914bc --- /dev/null +++ b/components/manifests/overlays/mpp-openshift/README.md @@ -0,0 +1,101 @@ +# MPP OpenShift Overlay + +Kustomize overlay for the Managed Platform Plus (MPP) OpenShift environment: `ambient-code--runtime-int`. + +## Apply + +```bash +kubectl apply -k components/manifests/overlays/mpp-openshift/ +``` + +## What This Overlay Does + +- Targets namespace `ambient-code--runtime-int` +- Sets `PLATFORM_MODE=mpp` so the CP uses `MPPNamespaceProvisioner` (namespaces as `ambient-code--`) +- Configures OIDC client credentials auth (no static K8s SA token) +- Adds `--grpc-jwk-cert-url` so the api-server validates RH SSO tokens on gRPC +- Mounts `tenantaccess-ambient-control-plane-token` for the CP's project kube client +- Mounts `ambient-runner-api-token` for runner pods to authenticate as service callers on gRPC +- Adds `allow-ambient-tenant-ingress` NetworkPolicy (ports 8000/9000 from all `ambient-code` tenant namespaces) + +## ⚠️ One-Time Manual Bootstrap + +Two secrets must be created manually once per cluster. They are **not** managed by kustomize (to avoid committing secret values) and are **not** required per session — only per cluster. + +### Step A — TenantServiceAccount + +Grants the CP's service account `namespace-admin` in every current and future tenant namespace via the tenant-access-operator. + +```bash +# Apply the TenantServiceAccount CR to ambient-code--config (NOT via kustomize) +kubectl apply -f components/manifests/overlays/mpp-openshift/ambient-cp-tenant-sa.yaml +``` + +Wait ~30s for the operator to create `tenantaccess-ambient-control-plane-token` in `ambient-code--config`, then copy it to the runtime namespace: + +```bash +kubectl get secret tenantaccess-ambient-control-plane-token \ + -n ambient-code--config \ + -o json \ + | python3 -c " +import json, sys +s = json.load(sys.stdin) +del s['metadata']['namespace'] +del s['metadata']['resourceVersion'] +del s['metadata']['uid'] +del s['metadata']['creationTimestamp'] +s['metadata'].pop('ownerReferences', None) +s['metadata'].pop('annotations', None) +s['type'] = 'Opaque' +print(json.dumps(s)) +" | kubectl apply -n ambient-code--runtime-int -f - +``` + +**Effect:** The operator automatically injects a `namespace-admin` RoleBinding into every `ambient-code--*` namespace, including ones created after this step. The CP mounts this token as its `projectKube` client for all namespace-scoped operations. + +### Step B — Static Runner API Token + +The runner uses a static token to authenticate as a gRPC service caller, bypassing the per-user session ownership check on `WatchSessionMessages`. + +```bash +# Generate a random token — record this value; you will need it for Step C +STATIC_TOKEN=$(python3 -c "import secrets; print(secrets.token_urlsafe(32))") + +kubectl create secret generic ambient-runner-api-token \ + --from-literal=token=${STATIC_TOKEN} \ + -n ambient-code--runtime-int +``` + +**Do not commit the token value.** + +### Step C — Set AMBIENT_API_TOKEN on the api-server + +The api-server must know the static token so it can recognise the runner as a service caller: + +```bash +# Patch the api-server args to include the token file +# (or set AMBIENT_API_TOKEN directly if your deployment supports it) +# The token value must match what was set in Step B +``` + +> **Note:** Step C is currently pending implementation — see the open gap `WatchSessionMessages PERMISSION_DENIED` in `docs/internal/design/control-plane.guide.md`. + +## Files in This Overlay + +| File | Purpose | +|------|---------| +| `kustomization.yaml` | Root kustomize config; sets namespace, images, patches | +| `ambient-control-plane.yaml` | CP Deployment — OIDC env, `PROJECT_KUBE_TOKEN_FILE`, project-kube volume mount | +| `ambient-api-server.yaml` | api-server Deployment base | +| `ambient-api-server-args-patch.yaml` | api-server command args — db, grpc, OIDC JWKS URL | +| `ambient-api-server-service-ca-patch.yaml` | Service CA annotation for TLS | +| `ambient-api-server-db.yaml` | PostgreSQL Deployment + Service | +| `ambient-api-server-route.yaml` | OpenShift Route for external access | +| `ambient-control-plane-sa.yaml` | ServiceAccount for the CP | +| `ambient-control-plane-rbac.yaml` | RBAC for the CP SA | +| `ambient-tenant-ingress-netpol.yaml` | NetworkPolicy allowing runner→api-server traffic | +| `ambient-cp-tenant-sa.yaml` | TenantServiceAccount CR (applied manually — see Step A) | + +## Re-Bootstrap Required? + +Only if `ambient-code--runtime-int` is destroyed, which MPP should never do to runtime/config namespaces. Session namespaces (`ambient-code--`) are created and destroyed per session with no manual action required. diff --git a/components/manifests/overlays/mpp-openshift/ambient-api-server-args-patch.yaml b/components/manifests/overlays/mpp-openshift/ambient-api-server-args-patch.yaml new file mode 100644 index 000000000..f3ba63625 --- /dev/null +++ b/components/manifests/overlays/mpp-openshift/ambient-api-server-args-patch.yaml @@ -0,0 +1,46 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: ambient-api-server +spec: + template: + spec: + containers: + - name: api-server + command: + - /usr/local/bin/ambient-api-server + - serve + - --db-host-file=/secrets/db/db.host + - --db-port-file=/secrets/db/db.port + - --db-user-file=/secrets/db/db.user + - --db-password-file=/secrets/db/db.password + - --db-name-file=/secrets/db/db.name + - --enable-authz=false + - --enable-https=false + - --api-server-bindaddress=:8000 + - --metrics-server-bindaddress=:4433 + - --health-check-server-bindaddress=:4434 + - --db-sslmode=disable + - --db-max-open-connections=50 + - --enable-db-debug=false + - --enable-metrics-https=false + - --http-read-timeout=5s + - --http-write-timeout=30s + - --cors-allowed-origins=* + - --cors-allowed-headers=X-Ambient-Project + - --jwk-cert-file=/configs/authentication/jwks.json + - --enable-grpc=true + - --grpc-server-bindaddress=:9000 + - --grpc-enable-tls=true + - --grpc-tls-cert-file=/etc/tls/tls.crt + - --grpc-tls-key-file=/etc/tls/tls.key + - --alsologtostderr + - -v=4 + volumeMounts: + - name: tls-certs + mountPath: /etc/tls + readOnly: true + volumes: + - name: tls-certs + secret: + secretName: ambient-api-server-tls diff --git a/components/manifests/overlays/mpp-openshift/ambient-api-server-db.yaml b/components/manifests/overlays/mpp-openshift/ambient-api-server-db.yaml new file mode 100644 index 000000000..93ad1cedd --- /dev/null +++ b/components/manifests/overlays/mpp-openshift/ambient-api-server-db.yaml @@ -0,0 +1,97 @@ +apiVersion: v1 +kind: Service +metadata: + name: ambient-api-server-db + namespace: ambient-code--runtime-int + labels: + app: ambient-api-server + component: database +spec: + ports: + - name: postgresql + port: 5432 + protocol: TCP + targetPort: 5432 + selector: + app: ambient-api-server + component: database + type: ClusterIP +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: ambient-api-server-db + namespace: ambient-code--runtime-int + labels: + app: ambient-api-server + component: database +spec: + replicas: 1 + selector: + matchLabels: + app: ambient-api-server + component: database + strategy: + type: Recreate + template: + metadata: + labels: + app: ambient-api-server + component: database + spec: + securityContext: + runAsNonRoot: true + seccompProfile: + type: RuntimeDefault + containers: + - name: postgresql + image: registry.redhat.io/rhel9/postgresql-16:latest + ports: + - containerPort: 5432 + name: postgresql + env: + - name: POSTGRESQL_USER + valueFrom: + secretKeyRef: + key: db.user + name: ambient-api-server-db + - name: POSTGRESQL_PASSWORD + valueFrom: + secretKeyRef: + key: db.password + name: ambient-api-server-db + - name: POSTGRESQL_DATABASE + valueFrom: + secretKeyRef: + key: db.name + name: ambient-api-server-db + volumeMounts: + - name: ambient-api-server-db-data + mountPath: /var/lib/pgsql/data + subPath: pgdata + readinessProbe: + exec: + command: + - /bin/sh + - -c + - pg_isready -U "$POSTGRESQL_USER" + initialDelaySeconds: 10 + periodSeconds: 10 + timeoutSeconds: 5 + livenessProbe: + exec: + command: + - /bin/sh + - -c + - pg_isready -U "$POSTGRESQL_USER" + initialDelaySeconds: 30 + periodSeconds: 30 + timeoutSeconds: 5 + securityContext: + allowPrivilegeEscalation: false + capabilities: + drop: + - ALL + volumes: + - name: ambient-api-server-db-data + emptyDir: {} diff --git a/components/manifests/overlays/mpp-openshift/ambient-api-server-route.yaml b/components/manifests/overlays/mpp-openshift/ambient-api-server-route.yaml new file mode 100644 index 000000000..133ed0d55 --- /dev/null +++ b/components/manifests/overlays/mpp-openshift/ambient-api-server-route.yaml @@ -0,0 +1,43 @@ +apiVersion: route.openshift.io/v1 +kind: Route +metadata: + name: ambient-api-server + namespace: ambient-code--runtime-int + labels: + app: ambient-api-server + component: api + shard: internal + annotations: + haproxy.router.openshift.io/timeout: 10m + haproxy.router.openshift.io/timeout-tunnel: 10m +spec: + to: + kind: Service + name: ambient-api-server + port: + targetPort: api + tls: + termination: edge + insecureEdgeTerminationPolicy: Redirect +--- +apiVersion: route.openshift.io/v1 +kind: Route +metadata: + name: ambient-api-server-grpc + namespace: ambient-code--runtime-int + labels: + app: ambient-api-server + component: grpc + shard: internal + annotations: + haproxy.router.openshift.io/timeout: 10m + haproxy.router.openshift.io/timeout-tunnel: 10m +spec: + to: + kind: Service + name: ambient-api-server + port: + targetPort: grpc + tls: + termination: reencrypt + insecureEdgeTerminationPolicy: Redirect diff --git a/components/manifests/overlays/mpp-openshift/ambient-api-server-service-ca-patch.yaml b/components/manifests/overlays/mpp-openshift/ambient-api-server-service-ca-patch.yaml new file mode 100644 index 000000000..2ef884562 --- /dev/null +++ b/components/manifests/overlays/mpp-openshift/ambient-api-server-service-ca-patch.yaml @@ -0,0 +1,6 @@ +apiVersion: v1 +kind: Service +metadata: + name: ambient-api-server + annotations: + service.beta.openshift.io/serving-cert-secret-name: ambient-api-server-tls diff --git a/components/manifests/overlays/mpp-openshift/ambient-api-server.yaml b/components/manifests/overlays/mpp-openshift/ambient-api-server.yaml new file mode 100644 index 000000000..780fa56ca --- /dev/null +++ b/components/manifests/overlays/mpp-openshift/ambient-api-server.yaml @@ -0,0 +1,182 @@ +apiVersion: v1 +kind: ConfigMap +metadata: + name: ambient-api-server-auth + namespace: ambient-code--runtime-int + labels: + app: ambient-api-server + component: auth +data: + jwks.json: | + {"keys":[{"use":"sig","kty":"RSA","kid":"GWqMSQVJjDoKaU8TnH_LmZeII2wGWYez6x_Oa63hXMM","alg":"RS256","n":"5NBKTJAC7kLcQBWWT0eBuaAI-4lqO2hl3S2Oc37kwXqHowA-2XSGT5g5oW-Y3jtg5m9XUdnTdEoyIEfbcs9mkmDm-IT3fvCWgiDWvopTd9C5WxhcM0XcjqdVSshFzDK2V1ZLmic2pLZS743hfGb1FDezF9A-KNycE41_2IwisPDNJbjsxH6oabOkva4QtA_K9ivREq6gBQtZSIr_hoQLcafL6paVAuPW1wVreBENBqiYkM69iSq3pU6Svqb51WhMADCIcxUsEINTW-0hg91WOYdSJ0r1UpEc6nGxb56Jlw-5h_nFInNUorXeTezgSXcpaHz1EpQQe4vo68EWhf3I6w","e":"AQAB"},{"use":"sig","kty":"RSA","kid":"1milNqdanuBP4v4UolwNIJwbHgxj1BrgmGLdBDWpQDc","alg":"RS256","n":"lvJPPx7OqsIDUnQQtOHUw26qqvL-XjhgSxYWvONhPgIqc5f-dvkBqH9mo_5WkUZcEcvC12FuUvJlYs1mHB4Zy7FwHY00HgD2v3Qa7AuhnnX6EIhGsqL1bxEae5OeRKe5mcEpBBIaXsbbWhrxTxksZqOeYGwJfI9FK8TFFD8C9LJTAAT_CpvU9ieKvYj0rvvvELEk8-DzsjnHabd7extSRUwqtb7xMx4DcMwRi1Axt_dp7g3EyOV1aUZXeNjncE5ot1m3r0t6LtnDk9Sb94EN1YfaVtE5LzK7zD46e05nQIUguURNC8xMUzIFkkoKNv7-wEDw5AhmnbWw9960ObUcAw","e":"AQAB"},{"use":"sig","kty":"RSA","kid":"jtx9LVV86TSy7P5AsXEGe6yAWUCIdnAVEsK1S3PRE90","alg":"RS256","n":"32_0wd-rJldZn63xz7rHHrgjo-Y7A6GYN-hlBGF5EPlheR18A_jQmjHHxSzFKWx1Kgm0hV8nGNCjvXsuQ2hzDDLHYnXe1w7S9JhEQTxIV87FWod9OuGefddfCXUarI14_AvtjgrQG_0BTCpSG0IS5rojvxjvr5NeJuPu9msIbMl5xeYST63r1U6F46KGYcdAMYw21z59rT-s4d0c7FJIIu2llrlPj1m4N8FUEmf9GBCjXA_ys7ZmYLkue35WtzSSRYXZZy3czYtffsW1yeRVlWthIZ182qEzt6T00gZPlHjKNPrgPNQ9b5hA_ZC3SEWE2KU-Y_4QH4aTSsbAoRTtJbVdfb7k5Osvq2Vuu6TjDElZuZXAYu3gu5EtXp-xBWIX-Lvs_wW_5qL2h7zcv127vl4NocUz0kSl3m-t53u1JMrcxBsucQRn1CEzsph9oUABVBEP8ugviA8BbRIFfvx9cX-mSk6DYxn-deX4IOrLJqoekvoIIL0Z9wxVnp681xgLZVXG2JvOIc46ZXORGqol4m69OPbmxdrXdMNY8Hbnf4IycS99axN0rG3ZmnVLBR17b2Rl7cIS-E-1vQ8XKcH89SX8Mj9kwnmr4P6biK3T6Iyhv9CY2sZFpy6XrXGrL9eGRR_lRildgq6wCjcGAAYdTzUHgKAC3f3KT1_aTEBw9Ks","e":"AQAB"}]} + acl.yml: | + - claim: email + pattern: ^.*@(redhat\.com|ambient\.code)$ +--- +apiVersion: v1 +kind: ServiceAccount +metadata: + name: ambient-api-server + namespace: ambient-code--runtime-int + labels: + app: ambient-api-server +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: ambient-api-server + namespace: ambient-code--runtime-int + labels: + app: ambient-api-server + component: api +spec: + replicas: 1 + selector: + matchLabels: + app: ambient-api-server + component: api + strategy: + rollingUpdate: + maxSurge: 25% + maxUnavailable: 25% + type: RollingUpdate + template: + metadata: + labels: + app: ambient-api-server + component: api + spec: + serviceAccountName: ambient-api-server + securityContext: + runAsNonRoot: true + seccompProfile: + type: RuntimeDefault + initContainers: + - name: migration + image: quay.io/ambient_code/vteam_api_server:latest + imagePullPolicy: Always + command: + - /usr/local/bin/ambient-api-server + - migrate + - --db-host-file=/secrets/db/db.host + - --db-port-file=/secrets/db/db.port + - --db-user-file=/secrets/db/db.user + - --db-password-file=/secrets/db/db.password + - --db-name-file=/secrets/db/db.name + - --alsologtostderr + - -v=4 + volumeMounts: + - name: db-secrets + mountPath: /secrets/db + resources: + requests: + cpu: 50m + memory: 128Mi + limits: + cpu: 500m + memory: 512Mi + securityContext: + allowPrivilegeEscalation: false + readOnlyRootFilesystem: false + capabilities: + drop: + - ALL + containers: + - name: api-server + image: quay.io/ambient_code/vteam_api_server:latest + imagePullPolicy: Always + env: + - name: AMBIENT_ENV + value: production + - name: GRPC_SERVICE_ACCOUNT + value: "service-account-ocm-ams-service" + ports: + - name: api + containerPort: 8000 + protocol: TCP + - name: metrics + containerPort: 4433 + protocol: TCP + - name: health + containerPort: 4434 + protocol: TCP + - name: grpc + containerPort: 9000 + protocol: TCP + volumeMounts: + - name: db-secrets + mountPath: /secrets/db + - name: app-secrets + mountPath: /secrets/service + - name: auth-config + mountPath: /configs/authentication + resources: + requests: + cpu: 200m + memory: 512Mi + limits: + cpu: 1 + memory: 1Gi + livenessProbe: + httpGet: + path: /api/ambient + port: 8000 + scheme: HTTP + initialDelaySeconds: 15 + periodSeconds: 5 + readinessProbe: + httpGet: + path: /healthcheck + port: 4434 + scheme: HTTP + httpHeaders: + - name: User-Agent + value: Probe + initialDelaySeconds: 20 + periodSeconds: 10 + securityContext: + allowPrivilegeEscalation: false + readOnlyRootFilesystem: false + capabilities: + drop: + - ALL + volumes: + - name: db-secrets + secret: + secretName: ambient-api-server-db + - name: app-secrets + secret: + secretName: ambient-api-server + - name: auth-config + configMap: + name: ambient-api-server-auth +--- +apiVersion: v1 +kind: Service +metadata: + name: ambient-api-server + namespace: ambient-code--runtime-int + labels: + app: ambient-api-server + component: api +spec: + selector: + app: ambient-api-server + component: api + ports: + - name: api + port: 8000 + targetPort: 8000 + protocol: TCP + - name: grpc + port: 9000 + targetPort: 9000 + protocol: TCP + - name: metrics + port: 4433 + targetPort: 4433 + protocol: TCP + - name: health + port: 4434 + targetPort: 4434 + protocol: TCP diff --git a/components/manifests/overlays/mpp-openshift/ambient-control-plane-sa.yaml b/components/manifests/overlays/mpp-openshift/ambient-control-plane-sa.yaml new file mode 100644 index 000000000..8a8946c8a --- /dev/null +++ b/components/manifests/overlays/mpp-openshift/ambient-control-plane-sa.yaml @@ -0,0 +1,16 @@ +apiVersion: v1 +kind: ServiceAccount +metadata: + name: ambient-control-plane + namespace: ambient-code--runtime-int + labels: + app: ambient-control-plane +--- +apiVersion: v1 +kind: Secret +metadata: + name: ambient-control-plane-token + namespace: ambient-code--runtime-int + annotations: + kubernetes.io/service-account.name: ambient-control-plane +type: kubernetes.io/service-account-token diff --git a/components/manifests/overlays/mpp-openshift/ambient-control-plane-svc.yaml b/components/manifests/overlays/mpp-openshift/ambient-control-plane-svc.yaml new file mode 100644 index 000000000..f4beba4a2 --- /dev/null +++ b/components/manifests/overlays/mpp-openshift/ambient-control-plane-svc.yaml @@ -0,0 +1,15 @@ +apiVersion: v1 +kind: Service +metadata: + name: ambient-control-plane + namespace: ambient-code--runtime-int + labels: + app: ambient-control-plane +spec: + selector: + app: ambient-control-plane + ports: + - name: token + port: 8080 + targetPort: 8080 + protocol: TCP diff --git a/components/manifests/overlays/mpp-openshift/ambient-control-plane.yaml b/components/manifests/overlays/mpp-openshift/ambient-control-plane.yaml new file mode 100644 index 000000000..17ebaada4 --- /dev/null +++ b/components/manifests/overlays/mpp-openshift/ambient-control-plane.yaml @@ -0,0 +1,108 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: ambient-control-plane + namespace: ambient-code--runtime-int + labels: + app: ambient-control-plane +spec: + replicas: 1 + selector: + matchLabels: + app: ambient-control-plane + template: + metadata: + labels: + app: ambient-control-plane + spec: + serviceAccountName: ambient-control-plane + securityContext: + runAsNonRoot: true + seccompProfile: + type: RuntimeDefault + containers: + - name: ambient-control-plane + image: quay.io/ambient_code/vteam_control_plane:latest + imagePullPolicy: Always + securityContext: + allowPrivilegeEscalation: false + readOnlyRootFilesystem: true + capabilities: + drop: + - ALL + env: + - name: AMBIENT_API_TOKEN + valueFrom: + secretKeyRef: + name: ambient-control-plane-token + key: token + - name: AMBIENT_API_SERVER_URL + value: "http://ambient-api-server.ambient-code--runtime-int.svc:8000" + - name: AMBIENT_GRPC_SERVER_ADDR + value: "ambient-api-server.ambient-code--runtime-int.svc:9000" + - name: AMBIENT_GRPC_USE_TLS + value: "false" + - name: MODE + value: "kube" + - name: PLATFORM_MODE + value: "mpp" + - name: MPP_CONFIG_NAMESPACE + value: "ambient-code--config" + - name: LOG_LEVEL + value: "info" + - name: RUNNER_IMAGE + value: "quay.io/ambient_code/vteam_claude_runner:latest" + - name: OIDC_CLIENT_ID + valueFrom: + secretKeyRef: + name: ambient-api-server + key: clientId + - name: OIDC_CLIENT_SECRET + valueFrom: + secretKeyRef: + name: ambient-api-server + key: clientSecret + - name: PROJECT_KUBE_TOKEN_FILE + value: "/var/run/secrets/project-kube/token" + - name: USE_VERTEX + value: "1" + - name: ANTHROPIC_VERTEX_PROJECT_ID + value: "ambient-code-platform" + - name: CLOUD_ML_REGION + value: "us-east5" + - name: GOOGLE_APPLICATION_CREDENTIALS + value: "/app/vertex/ambient-code-key.json" + - name: VERTEX_SECRET_NAME + value: "ambient-vertex" + - name: VERTEX_SECRET_NAMESPACE + value: "ambient-code--runtime-int" + - name: CP_RUNTIME_NAMESPACE + valueFrom: + fieldRef: + fieldPath: metadata.namespace + - name: CP_TOKEN_URL + value: "http://ambient-control-plane.ambient-code--ambient-s0.svc:8080/token" + - name: MCP_IMAGE + value: "quay.io/ambient_code/vteam_mcp:latest" + volumeMounts: + - name: project-kube-token + mountPath: /var/run/secrets/project-kube + readOnly: true + - name: vertex-credentials + mountPath: /app/vertex + readOnly: true + resources: + requests: + cpu: 50m + memory: 64Mi + limits: + cpu: 200m + memory: 256Mi + volumes: + - name: project-kube-token + secret: + secretName: ambient-control-plane-token + - name: vertex-credentials + secret: + secretName: ambient-vertex + restartPolicy: Always diff --git a/components/manifests/overlays/mpp-openshift/ambient-cp-tenant-sa.yaml b/components/manifests/overlays/mpp-openshift/ambient-cp-tenant-sa.yaml new file mode 100644 index 000000000..af4d808f7 --- /dev/null +++ b/components/manifests/overlays/mpp-openshift/ambient-cp-tenant-sa.yaml @@ -0,0 +1,9 @@ +apiVersion: tenantaccess.paas.redhat.com/v1beta1 +kind: TenantServiceAccount +metadata: + name: ambient-control-plane + namespace: ambient-code--config +spec: + create-permanent-token: true + roles: + - namespace-admin diff --git a/components/manifests/overlays/mpp-openshift/ambient-cp-token-netpol.yaml b/components/manifests/overlays/mpp-openshift/ambient-cp-token-netpol.yaml new file mode 100644 index 000000000..aa11c728d --- /dev/null +++ b/components/manifests/overlays/mpp-openshift/ambient-cp-token-netpol.yaml @@ -0,0 +1,21 @@ +apiVersion: networking.k8s.io/v1 +kind: NetworkPolicy +metadata: + name: allow-runner-token-fetch + namespace: ambient-code--runtime-int + labels: + app: ambient-control-plane +spec: + podSelector: + matchLabels: + app: ambient-control-plane + ingress: + - from: + - namespaceSelector: + matchLabels: + tenant.paas.redhat.com/tenant: ambient-code + ports: + - protocol: TCP + port: 8080 + policyTypes: + - Ingress diff --git a/components/manifests/overlays/mpp-openshift/ambient-tenant-ingress-netpol.yaml b/components/manifests/overlays/mpp-openshift/ambient-tenant-ingress-netpol.yaml new file mode 100644 index 000000000..0564431dd --- /dev/null +++ b/components/manifests/overlays/mpp-openshift/ambient-tenant-ingress-netpol.yaml @@ -0,0 +1,19 @@ +apiVersion: networking.k8s.io/v1 +kind: NetworkPolicy +metadata: + name: allow-ambient-tenant-ingress + namespace: ambient-code--runtime-int +spec: + podSelector: {} + policyTypes: + - Ingress + ingress: + - from: + - namespaceSelector: + matchLabels: + tenant.paas.redhat.com/tenant: ambient-code + ports: + - port: 8000 + protocol: TCP + - port: 9000 + protocol: TCP diff --git a/components/manifests/overlays/mpp-openshift/kustomization.yaml b/components/manifests/overlays/mpp-openshift/kustomization.yaml new file mode 100644 index 000000000..802157e58 --- /dev/null +++ b/components/manifests/overlays/mpp-openshift/kustomization.yaml @@ -0,0 +1,43 @@ +apiVersion: kustomize.config.k8s.io/v1beta1 +kind: Kustomization + +metadata: + name: ambient-mpp-openshift + +resources: +- ambient-api-server-db.yaml +- ambient-api-server.yaml +- ambient-control-plane.yaml +- ambient-control-plane-svc.yaml +- ambient-cp-token-netpol.yaml +- ambient-api-server-route.yaml +- ambient-control-plane-sa.yaml +- tenant-rbac/ +- ambient-tenant-ingress-netpol.yaml + +patches: +- path: ambient-api-server-args-patch.yaml + target: + group: apps + kind: Deployment + name: ambient-api-server + version: v1 +- path: ambient-api-server-service-ca-patch.yaml + target: + kind: Service + name: ambient-api-server + version: v1 + +images: +- name: quay.io/ambient_code/vteam_api_server + newTag: latest +- name: quay.io/ambient_code/vteam_api_server:latest + newName: quay.io/ambient_code/vteam_api_server + newTag: latest +- name: quay.io/ambient_code/vteam_control_plane + newTag: latest +- name: quay.io/ambient_code/vteam_control_plane:latest + newName: quay.io/ambient_code/vteam_control_plane + newTag: latest +- name: quay.io/ambient_code/vteam_mcp + newTag: latest diff --git a/components/manifests/overlays/mpp-openshift/tenant-rbac/ambient-control-plane-rbac-runtime-int.yaml b/components/manifests/overlays/mpp-openshift/tenant-rbac/ambient-control-plane-rbac-runtime-int.yaml new file mode 100644 index 000000000..cc1907829 --- /dev/null +++ b/components/manifests/overlays/mpp-openshift/tenant-rbac/ambient-control-plane-rbac-runtime-int.yaml @@ -0,0 +1,13 @@ +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: ambient-control-plane-tenant-namespaces-runtime-int + namespace: ambient-code--config +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: Role + name: ambient-control-plane-tenant-namespaces +subjects: + - kind: ServiceAccount + name: ambient-control-plane + namespace: ambient-code--runtime-int diff --git a/components/manifests/overlays/mpp-openshift/tenant-rbac/ambient-control-plane-rbac-s0.yaml b/components/manifests/overlays/mpp-openshift/tenant-rbac/ambient-control-plane-rbac-s0.yaml new file mode 100644 index 000000000..33b5e84fb --- /dev/null +++ b/components/manifests/overlays/mpp-openshift/tenant-rbac/ambient-control-plane-rbac-s0.yaml @@ -0,0 +1,13 @@ +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: ambient-control-plane-tenant-namespaces-s0 + namespace: ambient-code--config +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: Role + name: ambient-control-plane-tenant-namespaces +subjects: + - kind: ServiceAccount + name: ambient-control-plane + namespace: ambient-code--ambient-s0 diff --git a/components/manifests/overlays/mpp-openshift/tenant-rbac/ambient-control-plane-rbac.yaml b/components/manifests/overlays/mpp-openshift/tenant-rbac/ambient-control-plane-rbac.yaml new file mode 100644 index 000000000..af30202cc --- /dev/null +++ b/components/manifests/overlays/mpp-openshift/tenant-rbac/ambient-control-plane-rbac.yaml @@ -0,0 +1,9 @@ +apiVersion: rbac.authorization.k8s.io/v1 +kind: Role +metadata: + name: ambient-control-plane-tenant-namespaces + namespace: ambient-code--config +rules: + - apiGroups: ["tenant.paas.redhat.com"] + resources: ["tenantnamespaces"] + verbs: ["get", "list", "watch", "create", "delete"] diff --git a/components/manifests/overlays/mpp-openshift/tenant-rbac/kustomization.yaml b/components/manifests/overlays/mpp-openshift/tenant-rbac/kustomization.yaml new file mode 100644 index 000000000..fe14cc3d7 --- /dev/null +++ b/components/manifests/overlays/mpp-openshift/tenant-rbac/kustomization.yaml @@ -0,0 +1,7 @@ +apiVersion: kustomize.config.k8s.io/v1beta1 +kind: Kustomization + +resources: +- ambient-control-plane-rbac.yaml +- ambient-control-plane-rbac-runtime-int.yaml +- ambient-control-plane-rbac-s0.yaml diff --git a/components/manifests/overlays/openshift-dev/ambient-api-server-args-patch.yaml b/components/manifests/overlays/openshift-dev/ambient-api-server-args-patch.yaml new file mode 100644 index 000000000..e21634fdb --- /dev/null +++ b/components/manifests/overlays/openshift-dev/ambient-api-server-args-patch.yaml @@ -0,0 +1,72 @@ +# openshift-dev: TLS via OpenShift service-ca. JWT disabled; bearer token auth +# for service-to-service (control-plane) is handled via AMBIENT_API_TOKEN env var. +apiVersion: apps/v1 +kind: Deployment +metadata: + name: ambient-api-server +spec: + template: + spec: + containers: + - name: api-server + command: + - /usr/local/bin/ambient-api-server + - serve + - --db-host-file=/secrets/db/db.host + - --db-port-file=/secrets/db/db.port + - --db-user-file=/secrets/db/db.user + - --db-password-file=/secrets/db/db.password + - --db-name-file=/secrets/db/db.name + - --enable-jwt=false + - --enable-authz=false + - --enable-https=true + - --https-cert-file=/etc/tls/tls.crt + - --https-key-file=/etc/tls/tls.key + - --enable-tls=true + - --tls-cert-file=/etc/tls/tls.crt + - --tls-key-file=/etc/tls/tls.key + - --tls-auto-detect-kubernetes=false + - --api-server-bindaddress=:8000 + - --metrics-server-bindaddress=:4433 + - --health-check-server-bindaddress=:4434 + - --enable-health-check-https=true + - --db-sslmode=disable + - --db-max-open-connections=50 + - --enable-db-debug=false + - --enable-metrics-https=false + - --http-read-timeout=5s + - --http-write-timeout=30s + - --cors-allowed-origins=* + - --cors-allowed-headers=X-Ambient-Project + - --enable-grpc=true + - --grpc-server-bindaddress=:9000 + - --grpc-enable-tls=true + - --grpc-tls-cert-file=/etc/tls/tls.crt + - --grpc-tls-key-file=/etc/tls/tls.key + - --alsologtostderr + - -v=4 + volumeMounts: + - name: tls-certs + mountPath: /etc/tls + readOnly: true + livenessProbe: + httpGet: + path: /api/ambient + port: 8000 + scheme: HTTPS + initialDelaySeconds: 15 + periodSeconds: 5 + readinessProbe: + httpGet: + path: /healthcheck + port: 4434 + scheme: HTTPS + httpHeaders: + - name: User-Agent + value: Probe + initialDelaySeconds: 20 + periodSeconds: 10 + volumes: + - name: tls-certs + secret: + secretName: ambient-api-server-tls diff --git a/components/manifests/overlays/openshift-dev/ambient-api-server-env-patch.yaml b/components/manifests/overlays/openshift-dev/ambient-api-server-env-patch.yaml new file mode 100644 index 000000000..de0572a20 --- /dev/null +++ b/components/manifests/overlays/openshift-dev/ambient-api-server-env-patch.yaml @@ -0,0 +1,17 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: ambient-api-server +spec: + template: + spec: + containers: + - name: api-server + env: + - name: AMBIENT_ENV + value: openshift-dev + - name: AMBIENT_API_TOKEN + valueFrom: + secretKeyRef: + name: ambient-control-plane-token + key: token diff --git a/components/manifests/overlays/openshift-dev/kustomization.yaml b/components/manifests/overlays/openshift-dev/kustomization.yaml new file mode 100644 index 000000000..14058cd42 --- /dev/null +++ b/components/manifests/overlays/openshift-dev/kustomization.yaml @@ -0,0 +1,24 @@ +apiVersion: kustomize.config.k8s.io/v1beta1 +kind: Kustomization + +metadata: + name: vteam-openshift-dev + +namespace: ambient-code + +resources: +- ../production + +patches: +- path: ambient-api-server-env-patch.yaml + target: + group: apps + kind: Deployment + name: ambient-api-server + version: v1 +- path: ambient-api-server-args-patch.yaml + target: + group: apps + kind: Deployment + name: ambient-api-server + version: v1 diff --git a/components/manifests/overlays/production/ambient-api-server-route.yaml b/components/manifests/overlays/production/ambient-api-server-route.yaml index 1530d558f..1b3c195a9 100644 --- a/components/manifests/overlays/production/ambient-api-server-route.yaml +++ b/components/manifests/overlays/production/ambient-api-server-route.yaml @@ -6,6 +6,8 @@ metadata: labels: app: ambient-api-server component: api + annotations: + haproxy.router.openshift.io/timeout: 10m spec: to: kind: Service @@ -23,6 +25,8 @@ metadata: labels: app: ambient-api-server component: grpc + annotations: + haproxy.router.openshift.io/timeout: 10m spec: to: kind: Service diff --git a/components/manifests/overlays/production/kustomization.yaml b/components/manifests/overlays/production/kustomization.yaml index e22d55f24..0b08e868d 100644 --- a/components/manifests/overlays/production/kustomization.yaml +++ b/components/manifests/overlays/production/kustomization.yaml @@ -95,6 +95,9 @@ images: - name: quay.io/ambient_code/vteam_state_sync:latest newName: quay.io/ambient_code/vteam_state_sync newTag: latest +- name: quay.io/ambient_code/vteam_control_plane:latest + newName: quay.io/ambient_code/vteam_control_plane + newTag: latest - name: ghcr.io/ambient-code/observability newName: ghcr.io/ambient-code/observability newTag: latest diff --git a/components/runners/ambient-runner/.mcp.json b/components/runners/ambient-runner/.mcp.json index 7a569690e..399975c9e 100644 --- a/components/runners/ambient-runner/.mcp.json +++ b/components/runners/ambient-runner/.mcp.json @@ -23,6 +23,14 @@ "READ_ONLY_MODE": "${JIRA_READ_ONLY_MODE:-true}" } }, + "openshift": { + "command": "uvx", + "args": [ + "kubernetes-mcp-server@latest", + "--kubeconfig", "/tmp/.ambient_kubeconfig", + "--disable-multi-cluster" + ] + }, "google-workspace": { "command": "uvx", "args": [ diff --git a/components/runners/ambient-runner/ag_ui_claude_sdk/adapter.py b/components/runners/ambient-runner/ag_ui_claude_sdk/adapter.py index 21cc5ac16..6eb633abd 100644 --- a/components/runners/ambient-runner/ag_ui_claude_sdk/adapter.py +++ b/components/runners/ambient-runner/ag_ui_claude_sdk/adapter.py @@ -605,6 +605,7 @@ def _emit_task_event(self, message: Any) -> "CustomEvent": TaskProgressMessage, TaskNotificationMessage, ) + if isinstance(message, TaskStartedMessage): return self._emit_task_started(message) elif isinstance(message, TaskProgressMessage): @@ -631,6 +632,7 @@ def drain_hook_events(self) -> list: sid = val.get("session_id", "") if sid: from pathlib import Path + base = Path.home() / ".claude" / "projects" if base.exists(): expected = f"agent-{agent_id}.jsonl" @@ -668,7 +670,9 @@ def _emit_task_progress(self, message: Any) -> "CustomEvent": existing = self._task_registry.get(message.task_id, {}) existing.update(progress_value) self._task_registry[message.task_id] = existing - return CustomEvent(type=EventType.CUSTOM, name="task:progress", value=progress_value) + return CustomEvent( + type=EventType.CUSTOM, name="task:progress", value=progress_value + ) def _emit_task_notification(self, message: Any) -> "CustomEvent": usage = getattr(message, "usage", None) @@ -685,7 +689,9 @@ def _emit_task_notification(self, message: Any) -> "CustomEvent": self._task_registry[message.task_id] = existing if output_file: self._task_outputs[message.task_id] = output_file - return CustomEvent(type=EventType.CUSTOM, name="task:completed", value=notification_value) + return CustomEvent( + type=EventType.CUSTOM, name="task:completed", value=notification_value + ) async def _stream_claude_sdk( self, @@ -1160,7 +1166,10 @@ def flush_pending_msg(): ): yield event - elif isinstance(message, (TaskStartedMessage, TaskProgressMessage, TaskNotificationMessage)): + elif isinstance( + message, + (TaskStartedMessage, TaskProgressMessage, TaskNotificationMessage), + ): yield self._emit_task_event(message) elif isinstance(message, SystemMessage): @@ -1361,4 +1370,3 @@ def flush_pending_msg(): # Re-raise to let run() emit RunErrorEvent if stream_error is not None: raise stream_error - diff --git a/components/runners/ambient-runner/ag_ui_claude_sdk/handlers.py b/components/runners/ambient-runner/ag_ui_claude_sdk/handlers.py index 0f61c53e6..062f851b8 100644 --- a/components/runners/ambient-runner/ag_ui_claude_sdk/handlers.py +++ b/components/runners/ambient-runner/ag_ui_claude_sdk/handlers.py @@ -232,9 +232,7 @@ async def handle_thinking_block( if thinking_text: ts = now_ms() yield ReasoningStartEvent(threadId=thread_id, runId=run_id, timestamp=ts) - yield ReasoningMessageStartEvent( - threadId=thread_id, runId=run_id, timestamp=ts - ) + yield ReasoningMessageStartEvent(threadId=thread_id, runId=run_id, timestamp=ts) yield ReasoningMessageContentEvent( threadId=thread_id, runId=run_id, delta=thinking_text ) diff --git a/components/runners/ambient-runner/ag_ui_claude_sdk/hooks.py b/components/runners/ambient-runner/ag_ui_claude_sdk/hooks.py index 15fce20f7..c9cf58232 100644 --- a/components/runners/ambient-runner/ag_ui_claude_sdk/hooks.py +++ b/components/runners/ambient-runner/ag_ui_claude_sdk/hooks.py @@ -22,12 +22,14 @@ logger = logging.getLogger(__name__) # Default hook event names to register (only what the UI consumes). -_DEFAULT_HOOKS = frozenset({ - "SubagentStart", - "SubagentStop", - "Notification", - "Stop", -}) +_DEFAULT_HOOKS = frozenset( + { + "SubagentStart", + "SubagentStop", + "Notification", + "Stop", + } +) # Keys stripped from payloads (internal paths the frontend should not see). _SANITIZE_KEYS = frozenset({"transcript_path", "cwd"}) @@ -51,7 +53,9 @@ async def _forward_hook_as_custom_event( event_name = hook_input.get("hook_event_name", "unknown") payload = {k: v for k, v in hook_input.items() if k not in _SANITIZE_KEYS} - logger.debug("[Hook] %s fired (agent_id=%s)", event_name, hook_input.get("agent_id", "n/a")) + logger.debug( + "[Hook] %s fired (agent_id=%s)", event_name, hook_input.get("agent_id", "n/a") + ) await queue.put( CustomEvent( diff --git a/components/runners/ambient-runner/ag_ui_claude_sdk/reasoning_events.py b/components/runners/ambient-runner/ag_ui_claude_sdk/reasoning_events.py index 0c384ec92..bfc2ee3e2 100644 --- a/components/runners/ambient-runner/ag_ui_claude_sdk/reasoning_events.py +++ b/components/runners/ambient-runner/ag_ui_claude_sdk/reasoning_events.py @@ -21,6 +21,7 @@ class _ReasoningBase(BaseModel): """Base with camelCase serialization to match AG-UI wire format.""" + model_config = ConfigDict(populate_by_name=True) def model_dump(self, **kwargs): diff --git a/components/runners/ambient-runner/ag_ui_gemini_cli/types.py b/components/runners/ambient-runner/ag_ui_gemini_cli/types.py old mode 100755 new mode 100644 diff --git a/components/runners/ambient-runner/ambient_runner/_grpc_client.py b/components/runners/ambient-runner/ambient_runner/_grpc_client.py new file mode 100644 index 000000000..6ceb8c43c --- /dev/null +++ b/components/runners/ambient-runner/ambient_runner/_grpc_client.py @@ -0,0 +1,266 @@ +from __future__ import annotations + +import base64 +import json +import logging +import os +import time +import urllib.error +import urllib.parse +import urllib.request +from pathlib import Path +from typing import Optional + +import grpc +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import padding + +from ambient_runner.platform.utils import set_bot_token + +logger = logging.getLogger(__name__) + +_ENV_GRPC_URL = "AMBIENT_GRPC_URL" +_ENV_TOKEN = "BOT_TOKEN" +_ENV_CP_TOKEN_URL = "AMBIENT_CP_TOKEN_URL" +_ENV_CP_TOKEN_PUBLIC_KEY = "AMBIENT_CP_TOKEN_PUBLIC_KEY" +_ENV_SESSION_ID = "SESSION_ID" +_ENV_USE_TLS = "AMBIENT_GRPC_USE_TLS" +_ENV_CA_CERT = "AMBIENT_GRPC_CA_CERT_FILE" +_DEFAULT_GRPC_URL = "ambient-api-server:9000" +_SERVICE_CA_PATH = "/var/run/secrets/kubernetes.io/serviceaccount/service-ca.crt" +_SA_TOKEN_FILE = Path("/var/run/secrets/kubernetes.io/serviceaccount/token") + + +_CP_TOKEN_FETCH_ATTEMPTS = 3 +_CP_TOKEN_FETCH_TIMEOUT = 10 + + +def _encrypt_session_id(public_key_pem: str, session_id: str) -> str: + """RSA-OAEP encrypt session_id with the CP public key, return base64-encoded ciphertext.""" + public_key = serialization.load_pem_public_key(public_key_pem.encode()) + ciphertext = public_key.encrypt( + session_id.encode(), + padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA256()), + algorithm=hashes.SHA256(), + label=None, + ), + ) + return base64.b64encode(ciphertext).decode() + + +def _validate_cp_token_url(url: str) -> None: + """Reject non-http(s) or credential-bearing URLs to prevent exfiltration.""" + parsed = urllib.parse.urlparse(url) + if ( + parsed.scheme not in {"http", "https"} + or not parsed.netloc + or parsed.username is not None + or parsed.password is not None + ): + raise RuntimeError( + f"invalid CP token URL (must be http/https with no credentials): {url!r}" + ) + + +def _fetch_token_from_cp( + cp_token_url: str, public_key_pem: str, session_id: str +) -> str: + """Fetch a fresh API token from the CP /token endpoint. + + Encrypts the session ID with the CP public key and sends it as a Bearer token. + Retries up to _CP_TOKEN_FETCH_ATTEMPTS times with exponential backoff. + """ + _validate_cp_token_url(cp_token_url) + + bearer = _encrypt_session_id(public_key_pem, session_id) + + last_err: Exception = RuntimeError("no attempts made") + for attempt in range(_CP_TOKEN_FETCH_ATTEMPTS): + if attempt > 0: + backoff = 2 ** (attempt - 1) + logger.warning( + "[GRPC CLIENT] CP token fetch attempt %d/%d failed, retrying in %ds: %s", + attempt, + _CP_TOKEN_FETCH_ATTEMPTS, + backoff, + last_err, + ) + time.sleep(backoff) + try: + req = urllib.request.Request( + cp_token_url, + headers={"Authorization": f"Bearer {bearer}"}, + ) + with urllib.request.urlopen(req, timeout=_CP_TOKEN_FETCH_TIMEOUT) as resp: + body = json.loads(resp.read()) + token = body.get("token", "") + if not token: + raise RuntimeError("CP /token response missing 'token' field") + logger.info("[GRPC CLIENT] Fetched fresh API token from CP token endpoint") + set_bot_token(token) + return token + except urllib.error.HTTPError as e: + resp_body = "" + try: + resp_body = e.read().decode(errors="replace") + except Exception: + pass + last_err = RuntimeError(f"CP /token HTTP {e.code}: {resp_body}") + except Exception as e: + last_err = e + + raise RuntimeError( + f"CP token endpoint unreachable after {_CP_TOKEN_FETCH_ATTEMPTS} attempts: {last_err}" + ) from last_err + + +def _load_ca_cert(ca_cert_file: Optional[str]) -> Optional[bytes]: + """Load CA cert from explicit path, then service-ca fallback, then None.""" + candidates = [ca_cert_file, _SERVICE_CA_PATH] + for path in candidates: + if path and os.path.exists(path): + try: + with open(path, "rb") as f: + return f.read() + except OSError: + pass + return None + + +def _build_channel( + grpc_url: str, token: str, use_tls: bool = False, ca_cert_file: Optional[str] = None +) -> grpc.Channel: + """Build a gRPC channel with optional TLS and bearer token call credentials.""" + logger.info( + "[GRPC CHANNEL] Building channel: url=%s tls=%s token_present=%s ca_cert=%s", + grpc_url, + use_tls, + bool(token), + ca_cert_file, + ) + if use_tls: + call_creds = grpc.access_token_call_credentials(token) if token else None + ca_cert = _load_ca_cert(ca_cert_file) + channel_creds = grpc.ssl_channel_credentials(root_certificates=ca_cert) + if call_creds: + logger.info("[GRPC CHANNEL] Using TLS + bearer token credentials") + return grpc.secure_channel( + grpc_url, grpc.composite_channel_credentials(channel_creds, call_creds) + ) + logger.info("[GRPC CHANNEL] Using TLS-only credentials (no token)") + return grpc.secure_channel(grpc_url, channel_creds) + logger.info("[GRPC CHANNEL] Using insecure channel (no TLS)") + return grpc.insecure_channel(grpc_url) + + +class AmbientGRPCClient: + """gRPC client for the Ambient Platform internal API. + + Intended for use inside runner Job pods where BOT_TOKEN and + AMBIENT_GRPC_URL are injected by the operator. + """ + + def __init__( + self, + grpc_url: str, + token: str, + use_tls: bool = False, + ca_cert_file: Optional[str] = None, + cp_token_url: str = "", + ) -> None: + self._grpc_url = grpc_url + self._token = token + self._use_tls = use_tls + self._ca_cert_file = ca_cert_file + self._cp_token_url = cp_token_url + self._channel: Optional[grpc.Channel] = None + self._session_messages: Optional["SessionMessagesAPI"] = None # noqa: F821 + + @classmethod + def from_env(cls) -> AmbientGRPCClient: + """Create client from environment variables.""" + grpc_url = os.environ.get(_ENV_GRPC_URL, _DEFAULT_GRPC_URL) + cp_token_url = os.environ.get(_ENV_CP_TOKEN_URL, "") + use_tls = os.environ.get(_ENV_USE_TLS, "").lower() in ("true", "1", "yes") + ca_cert_file = os.environ.get(_ENV_CA_CERT) + if cp_token_url: + public_key_pem = os.environ.get(_ENV_CP_TOKEN_PUBLIC_KEY, "") + session_id = os.environ.get(_ENV_SESSION_ID, "") + if not public_key_pem: + raise RuntimeError( + "AMBIENT_CP_TOKEN_PUBLIC_KEY env var is required when AMBIENT_CP_TOKEN_URL is set" + ) + if not session_id: + raise RuntimeError( + "SESSION_ID env var is required when AMBIENT_CP_TOKEN_URL is set" + ) + logger.info( + "[GRPC CLIENT] Fetching token from CP endpoint: url=%s", cp_token_url + ) + token = _fetch_token_from_cp(cp_token_url, public_key_pem, session_id) + else: + token = os.environ.get(_ENV_TOKEN, "") + logger.info("[GRPC CLIENT] Using BOT_TOKEN env var (local dev mode)") + logger.info( + "[GRPC CLIENT] Initializing from env: url=%s tls=%s token_len=%d", + grpc_url, + use_tls, + len(token), + ) + return cls( + grpc_url=grpc_url, + token=token, + use_tls=use_tls, + ca_cert_file=ca_cert_file, + cp_token_url=cp_token_url, + ) + + def reconnect(self) -> None: + """Close the existing channel and rebuild with a fresh token from the CP endpoint.""" + if self._cp_token_url: + public_key_pem = os.environ.get(_ENV_CP_TOKEN_PUBLIC_KEY, "") + session_id = os.environ.get(_ENV_SESSION_ID, "") + fresh_token = _fetch_token_from_cp( + self._cp_token_url, public_key_pem, session_id + ) + else: + fresh_token = os.environ.get(_ENV_TOKEN, "") + logger.info( + "[GRPC CLIENT] Reconnecting with fresh token (len=%d)", len(fresh_token) + ) + self.close() + self._token = fresh_token + + def _get_channel(self) -> grpc.Channel: + if self._channel is None: + logger.info("[GRPC CHANNEL] Creating new channel to %s", self._grpc_url) + self._channel = _build_channel( + self._grpc_url, self._token, self._use_tls, self._ca_cert_file + ) + logger.info("[GRPC CHANNEL] Channel created successfully") + return self._channel + + @property + def session_messages(self) -> "SessionMessagesAPI": # noqa: F821 + if self._session_messages is None: + logger.info("[GRPC CLIENT] Creating SessionMessagesAPI stub") + from ._session_messages_api import SessionMessagesAPI + + self._session_messages = SessionMessagesAPI( + self._get_channel(), token=self._token, grpc_client=self + ) + logger.info("[GRPC CLIENT] SessionMessagesAPI ready") + return self._session_messages + + def close(self) -> None: + if self._channel is not None: + self._channel.close() + self._channel = None + self._session_messages = None + + def __enter__(self) -> AmbientGRPCClient: + return self + + def __exit__(self, *args: object) -> None: + self.close() diff --git a/components/runners/ambient-runner/ambient_runner/_inbox_messages_api.py b/components/runners/ambient-runner/ambient_runner/_inbox_messages_api.py new file mode 100644 index 000000000..8df778172 --- /dev/null +++ b/components/runners/ambient-runner/ambient_runner/_inbox_messages_api.py @@ -0,0 +1,245 @@ +from __future__ import annotations + +import logging +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Iterator, Optional + +import grpc + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class InboxMessage: + id: str + agent_id: str + from_agent_id: Optional[str] + from_name: Optional[str] + body: str + read: Optional[bool] + created_at: Optional[datetime] + updated_at: Optional[datetime] + + @classmethod + def _from_proto(cls, pb: object) -> InboxMessage: + def _ts(ts: object) -> Optional[datetime]: + if ts is None: + return None + try: + return datetime.fromtimestamp( + ts.seconds + ts.nanos / 1e9, tz=timezone.utc + ) + except Exception: + return None + + return cls( + id=getattr(pb, "id", ""), + agent_id=getattr(pb, "agent_id", ""), + from_agent_id=getattr(pb, "from_agent_id", None) or None, + from_name=getattr(pb, "from_name", None) or None, + body=getattr(pb, "body", ""), + read=getattr(pb, "read", None), + created_at=_ts(getattr(pb, "created_at", None)), + updated_at=_ts(getattr(pb, "updated_at", None)), + ) + + +class InboxMessagesAPI: + """gRPC client wrapper for InboxService.WatchInboxMessages (server-streaming, watch-only).""" + + _WATCH_METHOD = "/ambient.v1.InboxService/WatchInboxMessages" + + def __init__(self, channel: grpc.Channel, token: str = "") -> None: + self._metadata = [("authorization", f"Bearer {token}")] if token else [] + self._watch_rpc = channel.unary_stream( + self._WATCH_METHOD, + request_serializer=_WatchInboxRequest.SerializeToString, + response_deserializer=_InboxMessageProto.FromString, + ) + + def watch( + self, + agent_id: str, + *, + timeout: Optional[float] = None, + ) -> Iterator[InboxMessage]: + """Stream live inbox messages for an agent. + + The server delivers only messages created AFTER the subscription + begins — there is no replay cursor (unlike WatchSessionMessages). + """ + logger.info( + "[GRPC INBOX WATCH←] Starting WatchInboxMessages: agent_id=%s", + agent_id, + ) + req = _WatchInboxRequest() + req.agent_id = agent_id + stream = self._watch_rpc(req, timeout=timeout, metadata=self._metadata) + msg_count = 0 + for pb in stream: + msg = InboxMessage._from_proto(pb) + msg_count += 1 + logger.info( + "[GRPC INBOX WATCH←] Message #%d received: agent_id=%s inbox_id=%s from=%s body_len=%d", + msg_count, + agent_id, + msg.id, + msg.from_name or msg.from_agent_id or "system", + len(msg.body), + ) + yield msg + logger.info( + "[GRPC INBOX WATCH←] Stream ended: agent_id=%s total_messages=%d", + agent_id, + msg_count, + ) + + +# --------------------------------------------------------------------------- +# Minimal inline proto message classes (no generated _pb2 dependency). +# Mirrors the hand-rolled encoding in _session_messages_api.py. +# --------------------------------------------------------------------------- + + +def _encode_string(field_number: int, value: str) -> bytes: + encoded = value.encode("utf-8") + tag = (field_number << 3) | 2 + return _varint(tag) + _varint(len(encoded)) + encoded + + +def _varint(value: int) -> bytes: + bits = value & 0x7F + value >>= 7 + result = b"" + while value: + result += bytes([0x80 | bits]) + bits = value & 0x7F + value >>= 7 + result += bytes([bits]) + return result + + +def _decode_varint(data: bytes, pos: int) -> tuple[int, int]: + result = 0 + shift = 0 + while True: + b = data[pos] + pos += 1 + result |= (b & 0x7F) << shift + if not (b & 0x80): + return result, pos + shift += 7 + + +def _decode_string(data: bytes, pos: int) -> tuple[str, int]: + length, pos = _decode_varint(data, pos) + return data[pos : pos + length].decode("utf-8", errors="replace"), pos + length + + +class _WatchInboxRequest: + def __init__(self) -> None: + self.agent_id: str = "" + + def SerializeToString(self) -> bytes: + out = b"" + if self.agent_id: + out += _encode_string(1, self.agent_id) + return out + + +class _TimestampLike: + __slots__ = ("seconds", "nanos") + + def __init__(self, seconds: int, nanos: int) -> None: + self.seconds = seconds + self.nanos = nanos + + +def _parse_timestamp(data: bytes) -> Optional[_TimestampLike]: + seconds = 0 + nanos = 0 + pos = 0 + while pos < len(data): + tag_varint, pos = _decode_varint(data, pos) + field_number = tag_varint >> 3 + wire_type = tag_varint & 0x7 + if wire_type == 0: + value, pos = _decode_varint(data, pos) + if field_number == 1: + seconds = value + elif field_number == 2: + nanos = value + else: + break + return _TimestampLike(seconds, nanos) + + +class _InboxMessageProto: + """Minimal hand-rolled protobuf decoder for InboxMessage. + + Proto field mapping (from ambient/v1/inbox.proto): + 1: id (string, wire 2) + 2: agent_id (string, wire 2) + 3: from_agent_id (optional string, wire 2) + 4: from_name (optional string, wire 2) + 5: body (string, wire 2) + 6: read (optional bool, wire 0) + 7: created_at (Timestamp, wire 2) + 8: updated_at (Timestamp, wire 2) + """ + + __slots__ = ( + "id", + "agent_id", + "from_agent_id", + "from_name", + "body", + "read", + "created_at", + "updated_at", + ) + + def __init__(self) -> None: + self.id: str = "" + self.agent_id: str = "" + self.from_agent_id: Optional[str] = None + self.from_name: Optional[str] = None + self.body: str = "" + self.read: Optional[bool] = None + self.created_at: Optional[_TimestampLike] = None + self.updated_at: Optional[_TimestampLike] = None + + @classmethod + def FromString(cls, data: bytes) -> _InboxMessageProto: + msg = cls() + pos = 0 + while pos < len(data): + tag_varint, pos = _decode_varint(data, pos) + field_number = tag_varint >> 3 + wire_type = tag_varint & 0x7 + if wire_type == 2: + length, pos = _decode_varint(data, pos) + value_bytes = data[pos : pos + length] + pos += length + if field_number == 1: + msg.id = value_bytes.decode("utf-8", errors="replace") + elif field_number == 2: + msg.agent_id = value_bytes.decode("utf-8", errors="replace") + elif field_number == 3: + msg.from_agent_id = value_bytes.decode("utf-8", errors="replace") + elif field_number == 4: + msg.from_name = value_bytes.decode("utf-8", errors="replace") + elif field_number == 5: + msg.body = value_bytes.decode("utf-8", errors="replace") + elif field_number == 7: + msg.created_at = _parse_timestamp(value_bytes) + elif field_number == 8: + msg.updated_at = _parse_timestamp(value_bytes) + elif wire_type == 0: + value, pos = _decode_varint(data, pos) + if field_number == 6: + msg.read = bool(value) + else: + break + return msg diff --git a/components/runners/ambient-runner/ambient_runner/_session_messages_api.py b/components/runners/ambient-runner/ambient_runner/_session_messages_api.py new file mode 100644 index 000000000..67bcea53f --- /dev/null +++ b/components/runners/ambient-runner/ambient_runner/_session_messages_api.py @@ -0,0 +1,327 @@ +from __future__ import annotations + +import logging +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Iterator, Optional + +if TYPE_CHECKING: + from ._grpc_client import AmbientGRPCClient + +import grpc + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class SessionMessage: + id: str + session_id: str + seq: int + event_type: str + payload: str + created_at: Optional[datetime] + + @classmethod + def _from_proto(cls, pb: object) -> SessionMessage: + ts = getattr(pb, "created_at", None) + created_at: Optional[datetime] = None + if ts is not None: + try: + created_at = datetime.fromtimestamp( + ts.seconds + ts.nanos / 1e9, tz=timezone.utc + ) + except Exception: + pass + return cls( + id=getattr(pb, "id", ""), + session_id=getattr(pb, "session_id", ""), + seq=getattr(pb, "seq", 0), + event_type=getattr(pb, "event_type", ""), + payload=getattr(pb, "payload", ""), + created_at=created_at, + ) + + +class SessionMessagesAPI: + """gRPC client wrapper for SessionService message RPCs.""" + + _PUSH_METHOD = "/ambient.v1.SessionService/PushSessionMessage" + _WATCH_METHOD = "/ambient.v1.SessionService/WatchSessionMessages" + + def __init__( + self, + channel: grpc.Channel, + token: str = "", + grpc_client: Optional[AmbientGRPCClient] = None, + ) -> None: + self._grpc_client = grpc_client + self._metadata = [("authorization", f"Bearer {token}")] if token else [] + self._push_rpc = channel.unary_unary( + self._PUSH_METHOD, + request_serializer=_PushRequest.SerializeToString, + response_deserializer=_SessionMessageProto.FromString, + ) + self._watch_rpc = channel.unary_stream( + self._WATCH_METHOD, + request_serializer=_WatchRequest.SerializeToString, + response_deserializer=_SessionMessageProto.FromString, + ) + + def push( + self, + session_id: str, + event_type: str, + payload: str, + *, + timeout: float = 5.0, + ) -> Optional[SessionMessage]: + """Push a single message for a session. Fire-and-forget safe — logs and + returns None on any transport error rather than raising.""" + logger.info( + "[GRPC PUSH→] session=%s event_type=%s payload_len=%d", + session_id, + event_type, + len(payload), + ) + req = _PushRequest() + req.session_id = session_id + req.event_type = event_type + req.payload = payload + + for attempt in range(2): + try: + pb = self._push_rpc(req, timeout=timeout, metadata=self._metadata) + result = SessionMessage._from_proto(pb) + logger.info( + "[GRPC PUSH→] OK session=%s event_type=%s seq=%d", + session_id, + event_type, + result.seq, + ) + return result + except grpc.RpcError as exc: + if ( + attempt == 0 + and exc.code() == grpc.StatusCode.UNAUTHENTICATED + and self._grpc_client is not None + ): + logger.warning( + "[GRPC PUSH→] UNAUTHENTICATED — reconnecting with fresh token (session=%s)", + session_id, + ) + self._grpc_client.reconnect() + new_api = self._grpc_client.session_messages + self._push_rpc = new_api._push_rpc + self._metadata = new_api._metadata + continue + logger.warning( + "[GRPC PUSH→] FAILED PushSessionMessage RPC (session=%s event=%s): %s", + session_id, + event_type, + exc, + ) + return None + except Exception as exc: + logger.warning( + "[GRPC PUSH→] FAILED PushSessionMessage unexpected error (session=%s): %s", + session_id, + exc, + ) + return None + return None + + def watch( + self, + session_id: str, + *, + after_seq: int = 0, + timeout: Optional[float] = None, + ) -> Iterator[SessionMessage]: + """Stream messages for a session starting after after_seq.""" + logger.info( + "[GRPC WATCH←] Starting WatchSessionMessages: session=%s after_seq=%d", + session_id, + after_seq, + ) + req = _WatchRequest() + req.session_id = session_id + req.after_seq = after_seq + stream = self._watch_rpc(req, timeout=timeout, metadata=self._metadata) + msg_count = 0 + for pb in stream: + msg = SessionMessage._from_proto(pb) + msg_count += 1 + logger.info( + "[GRPC WATCH←] Message #%d received: session=%s seq=%d event_type=%s payload_len=%d", + msg_count, + msg.session_id, + msg.seq, + msg.event_type, + len(msg.payload), + ) + yield msg + logger.info( + "[GRPC WATCH←] Stream ended: session=%s total_messages=%d", + session_id, + msg_count, + ) + + +# --------------------------------------------------------------------------- +# Minimal inline proto message classes (no generated _pb2 dependency). +# These use the protobuf runtime's message factory directly. +# --------------------------------------------------------------------------- + + +def _encode_string(field_number: int, value: str) -> bytes: + encoded = value.encode("utf-8") + tag = (field_number << 3) | 2 + return _varint(tag) + _varint(len(encoded)) + encoded + + +def _encode_int64(field_number: int, value: int) -> bytes: + if value == 0: + return b"" + tag = (field_number << 3) | 0 + return _varint(tag) + _varint(value) + + +def _varint(value: int) -> bytes: + bits = value & 0x7F + value >>= 7 + result = b"" + while value: + result += bytes([0x80 | bits]) + bits = value & 0x7F + value >>= 7 + result += bytes([bits]) + return result + + +def _decode_string(data: bytes, pos: int) -> tuple[str, int]: + length, pos = _decode_varint(data, pos) + return data[pos : pos + length].decode("utf-8", errors="replace"), pos + length + + +def _decode_varint(data: bytes, pos: int) -> tuple[int, int]: + result = 0 + shift = 0 + while True: + b = data[pos] + pos += 1 + result |= (b & 0x7F) << shift + if not (b & 0x80): + return result, pos + shift += 7 + + +class _PushRequest: + def __init__(self) -> None: + self.session_id: str = "" + self.event_type: str = "" + self.payload: str = "" + + def SerializeToString(self) -> bytes: + out = b"" + if self.session_id: + out += _encode_string(1, self.session_id) + if self.event_type: + out += _encode_string(2, self.event_type) + if self.payload: + out += _encode_string(3, self.payload) + return out + + +class _WatchRequest: + def __init__(self) -> None: + self.session_id: str = "" + self.after_seq: int = 0 + + def SerializeToString(self) -> bytes: + out = b"" + if self.session_id: + out += _encode_string(1, self.session_id) + if self.after_seq: + out += _encode_int64(2, self.after_seq) + return out + + +class _SessionMessageProto: + __slots__ = ("id", "session_id", "seq", "event_type", "payload", "created_at") + + def __init__(self) -> None: + self.id: str = "" + self.session_id: str = "" + self.seq: int = 0 + self.event_type: str = "" + self.payload: str = "" + self.created_at: Optional[object] = None + + @classmethod + def FromString(cls, data: bytes) -> _SessionMessageProto: + msg = cls() + pos = 0 + while pos < len(data): + tag_varint, pos = _decode_varint(data, pos) + field_number = tag_varint >> 3 + wire_type = tag_varint & 0x7 + if wire_type == 2: + length, pos = _decode_varint(data, pos) + value_bytes = data[pos : pos + length] + pos += length + if field_number == 1: + msg.id = value_bytes.decode("utf-8", errors="replace") + elif field_number == 2: + msg.session_id = value_bytes.decode("utf-8", errors="replace") + elif field_number == 4: + msg.event_type = value_bytes.decode("utf-8", errors="replace") + elif field_number == 5: + msg.payload = value_bytes.decode("utf-8", errors="replace") + elif field_number == 6: + msg.created_at = _parse_timestamp(value_bytes) + elif wire_type == 0: + value, pos = _decode_varint(data, pos) + if field_number == 3: + msg.seq = value + elif wire_type == 1: + pos += 8 + elif wire_type == 5: + pos += 4 + else: + break + return msg + + +class _TimestampLike: + __slots__ = ("seconds", "nanos") + + def __init__(self, seconds: int, nanos: int) -> None: + self.seconds = seconds + self.nanos = nanos + + +def _parse_timestamp(data: bytes) -> Optional[_TimestampLike]: + seconds = 0 + nanos = 0 + pos = 0 + while pos < len(data): + tag_varint, pos = _decode_varint(data, pos) + field_number = tag_varint >> 3 + wire_type = tag_varint & 0x7 + if wire_type == 0: + value, pos = _decode_varint(data, pos) + if field_number == 1: + seconds = value + elif field_number == 2: + nanos = value + elif wire_type == 2: + length, pos = _decode_varint(data, pos) + pos += length + elif wire_type == 1: + pos += 8 + elif wire_type == 5: + pos += 4 + else: + break + return _TimestampLike(seconds, nanos) diff --git a/components/runners/ambient-runner/ambient_runner/app.py b/components/runners/ambient-runner/ambient_runner/app.py index 6b15d06aa..4ec918a80 100755 --- a/components/runners/ambient-runner/ambient_runner/app.py +++ b/components/runners/ambient-runner/ambient_runner/app.py @@ -120,6 +120,42 @@ async def lifespan(app: FastAPI): if is_resume: logger.info("IS_RESUME=true — this is a resumed session") + # Eager gRPC listener setup (duck-typed: any bridge that exposes + # start_grpc_listener + _active_streams qualifies). + # Requires both AMBIENT_GRPC_ENABLED=true and AMBIENT_GRPC_URL to be set. + # Must complete before INITIAL_PROMPT is dispatched so the listener + # is subscribed before PushSessionMessage fires. + # + # OPERATOR COMPATIBILITY: The existing Operator never injects AMBIENT_GRPC_ENABLED + # or AMBIENT_GRPC_URL into Job pods. This entire block is a strict no-op for + # operator-created sessions. No existing Operator/Runner behavior is changed. + grpc_enabled = os.getenv("AMBIENT_GRPC_ENABLED", "").strip().lower() == "true" + grpc_url = os.getenv("AMBIENT_GRPC_URL", "").strip() + grpc_active = False + if grpc_enabled and grpc_url and hasattr(bridge, "start_grpc_listener"): + await bridge.start_grpc_listener(grpc_url) + listener = getattr(bridge, "_grpc_listener", None) + if listener is not None: + try: + await asyncio.wait_for(listener.ready.wait(), timeout=10.0) + grpc_active = True + except asyncio.TimeoutError: + logger.warning( + "gRPC listener did not become ready within 10s: session=%s", + session_id, + ) + logger.info( + "gRPC listener ready for session %s — proceeding to INITIAL_PROMPT", + session_id, + ) + # Pre-register the SSE queue for session_id so the queue exists + # in active_streams before PushSessionMessage fires the first turn. + # This closes the race between listener.ready and the first event fan-out. + active_streams = getattr(bridge, "_active_streams", None) + if active_streams is not None and session_id not in active_streams: + active_streams[session_id] = asyncio.Queue() + logger.info("Pre-registered SSE queue for session=%s", session_id) + # Auto-execute prompts when present (skipped only for resumes, # where the conversation is continued rather than re-started). if not is_resume: @@ -152,7 +188,9 @@ async def lifespan(app: FastAPI): f"Auto-executing combined prompt ({len(combined_prompt)} chars)" ) task = asyncio.create_task( - _auto_execute_initial_prompt(combined_prompt, session_id) + _auto_execute_initial_prompt( + combined_prompt, session_id, grpc_url if grpc_active else "" + ) ) task.add_done_callback(_log_auto_exec_failure) else: @@ -231,6 +269,7 @@ async def _require_session_token(request: Request, call_next): logger.info("AG-UI token authentication enabled") # Core endpoints (always registered) + from ambient_runner.endpoints.events import router as events_router from ambient_runner.endpoints.health import router as health_router from ambient_runner.endpoints.interrupt import router as interrupt_router from ambient_runner.endpoints.run import router as run_router @@ -238,6 +277,7 @@ async def _require_session_token(request: Request, call_next): app.include_router(run_router) app.include_router(interrupt_router) app.include_router(health_router) + app.include_router(events_router) from ambient_runner.endpoints.model import router as model_router @@ -344,17 +384,103 @@ def _get_workflow_startup_prompt() -> str: _AUTO_PROMPT_MAX_DELAY = 30.0 -async def _auto_execute_initial_prompt(prompt: str, session_id: str) -> None: - """Auto-execute INITIAL_PROMPT on session startup with retry backoff. +async def _auto_execute_initial_prompt( + prompt: str, session_id: str, grpc_url: str = "" +) -> None: + """Auto-execute INITIAL_PROMPT on session startup. + + When AMBIENT_GRPC_URL is set, pushes the initial prompt as a DB Message + via PushSessionMessage so the GRPCSessionListener picks it up and triggers + the run directly. The prompt is then observable to API consumers and + visible in the frontend session history. - The runner pod may be ready before the K8s Service DNS propagates, - so the first few attempts can fail with "runner not available". - Retries with exponential backoff until the backend accepts the request. + When AMBIENT_GRPC_URL is not set, falls back to the original HTTP POST + path with exponential-backoff retry (for DNS propagation races). """ delay_seconds = float(os.getenv("INITIAL_PROMPT_DELAY_SECONDS", "2")) logger.info(f"Waiting {delay_seconds}s before auto-executing INITIAL_PROMPT...") await asyncio.sleep(delay_seconds) + if grpc_url: + # gRPC mode: the initial prompt was already stored in the DB when the session + # was created via the HTTP API (acpctl create session). The GRPCSessionListener's + # WatchSessionMessages stream will deliver it to the runner automatically. + # Pushing here would use the SA token which cannot push event_type=user, + # causing a harmless but noisy PERMISSION_DENIED warning. Skip it. + logger.debug( + "gRPC mode: skipping INITIAL_PROMPT push — message already in DB via session creation: session=%s", + session_id, + ) + else: + await _push_initial_prompt_via_http(prompt, session_id) + + +async def _push_initial_prompt_via_grpc(prompt: str, session_id: str) -> None: + """Push INITIAL_PROMPT as a PushSessionMessage so it is durable in DB. + + The gRPC push is synchronous (blocking I/O) and is offloaded to a thread + pool so it does not block the asyncio event loop. + """ + import json as _json + + from ambient_runner._grpc_client import AmbientGRPCClient + + def _do_push() -> None: + client = AmbientGRPCClient.from_env() + try: + payload = { + "threadId": session_id, + "runId": str(uuid.uuid4()), + "messages": [ + { + "id": str(uuid.uuid4()), + "role": "user", + "content": prompt, + "metadata": { + "hidden": True, + "autoSent": True, + "source": "runner_initial_prompt", + }, + } + ], + } + result = client.session_messages.push( + session_id, + event_type="user", + payload=_json.dumps(payload), + ) + if result is not None: + logger.info( + "INITIAL_PROMPT pushed via gRPC: session=%s seq=%d", + session_id, + result.seq, + ) + else: + logger.warning( + "INITIAL_PROMPT gRPC push returned None (push may have failed): session=%s", + session_id, + ) + finally: + client.close() + + try: + await asyncio.get_running_loop().run_in_executor(None, _do_push) + except Exception as exc: + logger.error( + "INITIAL_PROMPT gRPC push failed: session=%s error=%s", + session_id, + exc, + exc_info=True, + ) + + +async def _push_initial_prompt_via_http(prompt: str, session_id: str) -> None: + """POST INITIAL_PROMPT to the backend AG-UI run endpoint with retry backoff. + + The runner pod may be ready before K8s Service DNS propagates, so the + first few attempts can fail with "runner not available". Retries with + exponential backoff until the backend accepts the request. + """ backend_url = os.getenv("BACKEND_API_URL", "").rstrip("/") project_name = ( os.getenv("PROJECT_NAME", "").strip() diff --git a/components/runners/ambient-runner/ambient_runner/bridge.py b/components/runners/ambient-runner/ambient_runner/bridge.py index 2a4c5006f..76f2034e0 100755 --- a/components/runners/ambient-runner/ambient_runner/bridge.py +++ b/components/runners/ambient-runner/ambient_runner/bridge.py @@ -227,6 +227,30 @@ def get_error_context(self) -> str: """ return "" + async def inject_message( + self, session_id: str, event_type: str, payload: str + ) -> None: + """Inject an inbound session message into the active run. + + Called by the run endpoint for each ``SessionMessage`` received via + ``WatchSessionMessages`` gRPC stream while a run is in progress. + + Override in bridge subclasses to handle inbound messages — e.g. to + interrupt the current Claude turn and inject a new user message. + + Default: no-op (inbound messages are silently dropped). + + Args: + session_id: The session this message belongs to. + event_type: The message event type string. + payload: The raw JSON payload string. + """ + raise NotImplementedError( + f"{type(self).__name__} does not support inject_message " + f"(session_id={session_id!r}, event_type={event_type!r}). " + "Override inject_message() in your bridge subclass to handle inbound messages." + ) + # ------------------------------------------------------------------ # Properties (override to expose state to endpoints) # ------------------------------------------------------------------ diff --git a/components/runners/ambient-runner/ambient_runner/bridges/claude/bridge.py b/components/runners/ambient-runner/ambient_runner/bridges/claude/bridge.py old mode 100755 new mode 100644 index 0aeedd7bb..893e2348c --- a/components/runners/ambient-runner/ambient_runner/bridges/claude/bridge.py +++ b/components/runners/ambient-runner/ambient_runner/bridges/claude/bridge.py @@ -9,6 +9,7 @@ - Interrupt and graceful shutdown """ +import asyncio import json import logging import os @@ -152,6 +153,9 @@ def __init__(self) -> None: self._saved_session_ids: dict[str, str] = {} # Per-thread halt tracking to avoid race conditions on shared adapter self._halted_by_thread: dict[str, bool] = {} + # gRPC transport — started lazily in _setup_platform + self._grpc_listener: Any = None + self._active_streams: dict[str, asyncio.Queue] = {} # ------------------------------------------------------------------ # PlatformBridge interface @@ -208,18 +212,11 @@ async def _initialize_run( await self._ensure_ready() - # Fresh credentials for this user on every run. - # On first run, _setup_platform() already populated credentials and - # built MCP servers with the correct env vars — skip the redundant - # clear-then-repopulate cycle to avoid briefly removing env vars - # (like USER_GOOGLE_EMAIL) that MCP servers depend on. - if self._first_run: - logger.info("First run: using credentials from _setup_platform()") - else: - clear_runtime_credentials() - await populate_runtime_credentials(self._context) - await populate_mcp_server_credentials(self._context) - self._last_creds_refresh = time.monotonic() + # Fresh credentials for this user on every run + clear_runtime_credentials() + await populate_runtime_credentials(self._context) + await populate_mcp_server_credentials(self._context) + self._last_creds_refresh = time.monotonic() # If the caller changed, destroy the worker and rebuild MCP servers + # adapter so the new ClaudeSDKClient gets fresh mcp_servers config. @@ -482,8 +479,38 @@ def task_outputs(self) -> dict: # Lifecycle methods # ------------------------------------------------------------------ + async def start_grpc_listener(self, grpc_url: str) -> None: + """Start the gRPC session listener for this bridge. + + Separated from _setup_platform so it can be called after platform + setup completes, with a bounded timeout for readiness. Only valid + when AMBIENT_GRPC_ENABLED=true and AMBIENT_GRPC_URL are both set. + """ + if self._context is None: + raise RuntimeError("Cannot start gRPC listener: context not set") + if self._grpc_listener is not None: + logger.warning("gRPC listener already started — skipping duplicate start") + return + + from ambient_runner.bridges.claude.grpc_transport import GRPCSessionListener + + session_id = self._context.session_id + self._grpc_listener = GRPCSessionListener( + bridge=self, + session_id=session_id, + grpc_url=grpc_url, + ) + self._grpc_listener.start() + logger.info( + "gRPC listener started: session=%s url=%s", + session_id, + grpc_url, + ) + async def shutdown(self) -> None: """Graceful shutdown: persist sessions, finalise tracing.""" + if self._grpc_listener is not None: + await self._grpc_listener.stop() if self._session_manager: await self._session_manager.shutdown() if self._obs: diff --git a/components/runners/ambient-runner/ambient_runner/bridges/claude/grpc_transport.py b/components/runners/ambient-runner/ambient_runner/bridges/claude/grpc_transport.py new file mode 100644 index 000000000..9bb5e8053 --- /dev/null +++ b/components/runners/ambient-runner/ambient_runner/bridges/claude/grpc_transport.py @@ -0,0 +1,463 @@ +""" +gRPC transport for ClaudeBridge (additive — only active when AMBIENT_GRPC_ENABLED=true). + +GRPCSessionListener — pod-lifetime WatchSessionMessages subscriber. + Active alongside the existing HTTP/SSE path when AMBIENT_GRPC_ENABLED=true. + Calls bridge.run() directly for each inbound user message (no HTTP round-trip). + Fans out each event to: + (a) bridge._active_streams[thread_id] queue — feeds the /events SSE tap + (b) GRPCMessageWriter — assembles and writes the durable DB record + +GRPCMessageWriter — per-turn event consumer. + Accumulates MESSAGES_SNAPSHOT content. + Pushes one PushSessionMessage(event_type="assistant") on RUN_FINISHED / RUN_ERROR. + +When AMBIENT_GRPC_ENABLED is not set, none of this code is instantiated or called. +""" + +import asyncio +import logging +import uuid +from concurrent.futures import ThreadPoolExecutor +from typing import TYPE_CHECKING, Any, Optional + +import grpc + +from ag_ui.core import BaseEvent + +if TYPE_CHECKING: + from ambient_runner._grpc_client import AmbientGRPCClient + from ambient_runner.bridge import PlatformBridge + +logger = logging.getLogger(__name__) + +_BACKOFF_INITIAL = 1.0 +_BACKOFF_MAX = 30.0 + + +def _synthesize_run_error( + thread_id: str, + error_message: str, + active_streams: dict[str, asyncio.Queue], + writer: "GRPCMessageWriter", +) -> None: + """Synthesize a terminal RUN_ERROR event when bridge.run() raises. + + Feeds the error event into the SSE tap queue (if registered) and + schedules the writer to persist an 'error' status record so neither + the SSE consumer nor the DB writer is left hanging. + """ + from ag_ui.core import RunErrorEvent + + try: + error_event = RunErrorEvent(message=error_message, code="RUNNER_ERROR") + except Exception: + error_event = None + + stream_queue = active_streams.get(thread_id) + if stream_queue is not None and error_event is not None: + try: + stream_queue.put_nowait(error_event) + except asyncio.QueueFull: + logger.warning( + "[GRPC LISTENER] SSE tap queue full while synthesising RUN_ERROR: thread=%s", + thread_id, + ) + + task = asyncio.ensure_future(writer._write_message(status="error")) + + def _log_write_error(f: asyncio.Future) -> None: + if not f.cancelled() and f.exception() is not None: + logger.warning( + "[GRPC LISTENER] _write_message(error) failed: %s", f.exception() + ) + + task.add_done_callback(_log_write_error) + + +class GRPCSessionListener: + """Pod-lifetime gRPC session listener for ClaudeBridge. + + Subscribes to WatchSessionMessages for this session. For each inbound + message with event_type=="user", parses the payload as RunnerInput and + calls bridge.run() directly. + + ready: asyncio.Event — set once the WatchSessionMessages stream is open. + Callers should await self.ready.wait() before sending the first message. + """ + + def __init__( + self, + bridge: "PlatformBridge", + session_id: str, + grpc_url: str, + ) -> None: + self._bridge = bridge + self._session_id = session_id + self._grpc_url = grpc_url + self._grpc_client: Optional["AmbientGRPCClient"] = None + self.ready = asyncio.Event() + self._task: Optional[asyncio.Task] = None + + def start(self) -> None: + from ambient_runner._grpc_client import AmbientGRPCClient + + self._grpc_client = AmbientGRPCClient.from_env() + self._task = asyncio.create_task( + self._listen_loop(), name="grpc-session-listener" + ) + logger.info( + "[GRPC LISTENER] Started: session=%s url=%s", + self._session_id, + self._grpc_url, + ) + + async def stop(self) -> None: + if self._task and not self._task.done(): + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + if self._grpc_client: + self._grpc_client.close() + logger.info("[GRPC LISTENER] Stopped: session=%s", self._session_id) + + def _watch_in_thread( + self, + msg_queue: asyncio.Queue, + loop: asyncio.AbstractEventLoop, + stop_event: asyncio.Event, + last_seq: int, + ) -> None: + """Blocking gRPC watch — runs in a ThreadPoolExecutor. + + Sets self.ready after watch() returns the stream iterator (stream open, + server will deliver messages from this point). Puts each received + SessionMessage onto msg_queue via run_coroutine_threadsafe. + """ + if self._grpc_client is None: + return + try: + stream = self._grpc_client.session_messages.watch( + self._session_id, after_seq=last_seq + ) + loop.call_soon_threadsafe(self.ready.set) + logger.info( + "[GRPC LISTENER] WatchSessionMessages stream open: session=%s after_seq=%d", + self._session_id, + last_seq, + ) + for msg in stream: + if loop.is_closed() or stop_event.is_set(): + break + logger.info( + "[GRPC LISTENER] Received: session=%s seq=%d event_type=%s", + self._session_id, + msg.seq, + msg.event_type, + ) + asyncio.run_coroutine_threadsafe(msg_queue.put(msg), loop) + except grpc.RpcError as exc: + logger.warning( + "[GRPC LISTENER] gRPC stream error: session=%s code=%s details=%s", + self._session_id, + exc.code(), + exc.details(), + ) + if ( + exc.code() == grpc.StatusCode.UNAUTHENTICATED + and self._grpc_client is not None + ): + logger.warning( + "[GRPC LISTENER] UNAUTHENTICATED — reconnecting with fresh token: session=%s", + self._session_id, + ) + self._grpc_client.reconnect() + except Exception as exc: + logger.error( + "[GRPC LISTENER] Unexpected watch error: session=%s error=%s", + self._session_id, + exc, + exc_info=True, + ) + + async def _listen_loop(self) -> None: + last_seq = 0 + backoff = _BACKOFF_INITIAL + + while True: + msg_queue: asyncio.Queue = asyncio.Queue() + stop_event = asyncio.Event() + loop = asyncio.get_running_loop() + executor = ThreadPoolExecutor(max_workers=1) + + watch_future = loop.run_in_executor( + executor, + self._watch_in_thread, + msg_queue, + loop, + stop_event, + last_seq, + ) + + try: + while True: + try: + msg = await asyncio.wait_for(msg_queue.get(), timeout=30.0) + except asyncio.TimeoutError: + if watch_future.done(): + break + continue + + last_seq = max(last_seq, msg.seq) + + if msg.event_type != "user": + logger.debug( + "[GRPC LISTENER] Skipping event_type=%s seq=%d", + msg.event_type, + msg.seq, + ) + continue + + logger.info( + "[GRPC LISTENER] User message seq=%d — triggering run: session=%s", + msg.seq, + self._session_id, + ) + await self._handle_user_message(msg) + + except asyncio.CancelledError: + stop_event.set() + executor.shutdown(wait=False) + logger.info("[GRPC LISTENER] Cancelled: session=%s", self._session_id) + raise + except Exception as exc: + stop_event.set() + executor.shutdown(wait=False) + logger.warning( + "[GRPC LISTENER] Error, reconnecting in %.1fs: session=%s error=%s", + backoff, + self._session_id, + exc, + ) + await asyncio.sleep(backoff) + backoff = min(backoff * 2, _BACKOFF_MAX) + continue + + stop_event.set() + executor.shutdown(wait=False) + backoff = _BACKOFF_INITIAL + logger.info( + "[GRPC LISTENER] Stream ended cleanly, reconnecting: session=%s last_seq=%d", + self._session_id, + last_seq, + ) + + async def _handle_user_message(self, msg: Any) -> None: + """Parse a user message payload and drive a full bridge.run() turn.""" + from ambient_runner.endpoints.run import RunnerInput + + try: + runner_input = RunnerInput.model_validate_json(msg.payload) + except Exception: + runner_input = RunnerInput( + messages=[ + {"id": str(uuid.uuid4()), "role": "user", "content": msg.payload} + ], + thread_id=self._session_id, + ) + + try: + input_data = runner_input.to_run_agent_input() + except Exception as exc: + logger.warning( + "[GRPC LISTENER] Failed to build run agent input: seq=%d error=%s", + msg.seq, + exc, + ) + return + + thread_id = input_data.thread_id or self._session_id + run_id = str(input_data.run_id) if input_data.run_id else str(uuid.uuid4()) + + writer = GRPCMessageWriter( + session_id=self._session_id, + run_id=run_id, + grpc_client=self._grpc_client, + ) + + logger.info( + "[GRPC LISTENER] bridge.run() starting: session=%s thread=%s run=%s", + self._session_id, + thread_id, + run_id, + ) + + active_streams: dict[str, asyncio.Queue] = getattr( + self._bridge, "_active_streams", {} + ) + run_queue = active_streams.get(thread_id) + + async def _run_once(): + async for event in self._bridge.run(input_data): + stream_queue = active_streams.get(thread_id) + if stream_queue is not None: + try: + stream_queue.put_nowait(event) + except asyncio.QueueFull: + logger.warning( + "[GRPC LISTENER] SSE tap queue full, dropping event: thread=%s", + thread_id, + ) + await writer.consume(event) + + try: + await _run_once() + except PermissionError as exc: + logger.warning( + "[GRPC LISTENER] Credential auth failure, refreshing token and retrying: session=%s error=%s", + self._session_id, + exc, + ) + try: + from ambient_runner.platform.utils import refresh_bot_token + + await asyncio.get_running_loop().run_in_executor( + None, refresh_bot_token + ) + except Exception as refresh_exc: + logger.warning( + "[GRPC LISTENER] Token refresh failed: session=%s error=%s", + self._session_id, + refresh_exc, + ) + try: + writer = GRPCMessageWriter( + session_id=self._session_id, + run_id=run_id, + grpc_client=self._grpc_client, + ) + await _run_once() + except Exception as retry_exc: + logger.error( + "[GRPC LISTENER] bridge.run() failed after token refresh: session=%s error=%s", + self._session_id, + retry_exc, + exc_info=True, + ) + _synthesize_run_error(thread_id, str(retry_exc), active_streams, writer) + except Exception as exc: + logger.error( + "[GRPC LISTENER] bridge.run() failed: session=%s error=%s", + self._session_id, + exc, + exc_info=True, + ) + _synthesize_run_error(thread_id, str(exc), active_streams, writer) + finally: + if run_queue is not None and active_streams.get(thread_id) is run_queue: + active_streams.pop(thread_id, None) + logger.info( + "[GRPC LISTENER] Turn complete: session=%s thread=%s", + self._session_id, + thread_id, + ) + + +class GRPCMessageWriter: + """Per-turn event consumer. Writes one PushSessionMessage on turn end. + + Accumulates messages from MESSAGES_SNAPSHOT events (storing only the + latest snapshot — each MESSAGES_SNAPSHOT is a complete replacement). + On RUN_FINISHED or RUN_ERROR, pushes the assembled payload as a single + durable DB record with event_type="assistant". + """ + + def __init__( + self, + session_id: str, + run_id: str, + grpc_client: Optional["AmbientGRPCClient"], + ) -> None: + self._session_id = session_id + self._run_id = run_id + self._grpc_client = grpc_client + self._accumulated_messages: list = [] + + async def consume(self, event: BaseEvent) -> None: + """Process one event from bridge.run(). Called by the listener fan-out loop.""" + raw_type = getattr(event, "type", None) + if raw_type is None: + return + event_type_str = raw_type.value if hasattr(raw_type, "value") else str(raw_type) + + if event_type_str == "MESSAGES_SNAPSHOT": + messages = getattr(event, "messages", None) or [] + self._accumulated_messages = [ + m.model_dump() if hasattr(m, "model_dump") else m for m in messages + ] + logger.debug( + "[GRPC WRITER] MESSAGES_SNAPSHOT accumulated: session=%s count=%d", + self._session_id, + len(self._accumulated_messages), + ) + + elif event_type_str == "RUN_FINISHED": + await self._write_message(status="completed") + + elif event_type_str == "RUN_ERROR": + await self._write_message(status="error") + + async def _write_message(self, status: str) -> None: + if self._grpc_client is None: + logger.warning( + "[GRPC WRITER] No gRPC client — cannot push assembled message: session=%s", + self._session_id, + ) + return + + assistant_text = next( + ( + m.get("content") or "" + for m in self._accumulated_messages + if m.get("role") == "assistant" + ), + "", + ) + + if not assistant_text: + logger.warning( + "[GRPC WRITER] No assistant message in snapshot: session=%s run=%s messages=%d", + self._session_id, + self._run_id, + len(self._accumulated_messages), + ) + + logger.info( + "[GRPC WRITER] PushSessionMessage: session=%s run=%s status=%s text_len=%d", + self._session_id, + self._run_id, + status, + len(assistant_text), + ) + + client = self._grpc_client + session_id = self._session_id + + def _do_push() -> None: + client.session_messages.push( + session_id, + event_type="assistant", + payload=assistant_text, + ) + + try: + await asyncio.get_running_loop().run_in_executor(None, _do_push) + except Exception as exc: + logger.warning( + "[GRPC WRITER] Push failed: session=%s status=%s error=%s", + self._session_id, + status, + exc, + ) diff --git a/components/runners/ambient-runner/ambient_runner/bridges/claude/mcp.py b/components/runners/ambient-runner/ambient_runner/bridges/claude/mcp.py index 4ac06a8b7..a3674bc1d 100644 --- a/components/runners/ambient-runner/ambient_runner/bridges/claude/mcp.py +++ b/components/runners/ambient-runner/ambient_runner/bridges/claude/mcp.py @@ -63,6 +63,15 @@ def build_mcp_servers( mcp_servers = load_mcp_config(context, cwd_path) or {} + # Ambient MCP sidecar (SSE transport, injected when annotation ambient-code.io/mcp-sidecar=true) + ambient_mcp_url = os.getenv("AMBIENT_MCP_URL", "").strip() + if ambient_mcp_url: + mcp_servers["ambient"] = { + "type": "sse", + "url": f"{ambient_mcp_url.rstrip('/')}/sse", + } + logger.info("Added ambient MCP sidecar server (SSE): %s", ambient_mcp_url) + # Session control tools refresh_creds_tool = create_refresh_credentials_tool(context, sdk_tool) session_server = create_sdk_mcp_server( diff --git a/components/runners/ambient-runner/ambient_runner/bridges/claude/prompts.py b/components/runners/ambient-runner/ambient_runner/bridges/claude/prompts.py index 0bac85f82..601d5581a 100644 --- a/components/runners/ambient-runner/ambient_runner/bridges/claude/prompts.py +++ b/components/runners/ambient-runner/ambient_runner/bridges/claude/prompts.py @@ -5,16 +5,23 @@ preset format (``type: "preset", preset: "claude_code"``). """ -from ambient_runner.platform.prompts import resolve_workspace_prompt +from ambient_runner.platform.prompts import ( + DEFAULT_AGENT_PREAMBLE, + resolve_workspace_prompt, +) def build_sdk_system_prompt(workspace_path: str, cwd_path: str) -> dict: """Build the full system prompt config dict for the Claude SDK. Wraps the platform workspace context prompt in the Claude Code preset. + The DEFAULT_AGENT_PREAMBLE (overridable via AGENT_PREAMBLE env var) is + prepended so it applies to every session regardless of workflow or prompt. """ + workspace_context = resolve_workspace_prompt(workspace_path, cwd_path) + append_content = f"{DEFAULT_AGENT_PREAMBLE}\n\n{workspace_context}" return { "type": "preset", "preset": "claude_code", - "append": resolve_workspace_prompt(workspace_path, cwd_path), + "append": append_content, } diff --git a/components/runners/ambient-runner/ambient_runner/bridges/claude/session.py b/components/runners/ambient-runner/ambient_runner/bridges/claude/session.py index 36b1236bc..0e277a967 100644 --- a/components/runners/ambient-runner/ambient_runner/bridges/claude/session.py +++ b/components/runners/ambient-runner/ambient_runner/bridges/claude/session.py @@ -170,9 +170,11 @@ async def _run(self) -> None: # Wait for reader to signal ResultMessage received, # but also bail if the reader task dies mid-turn. - reader_done = asyncio.ensure_future( - asyncio.shield(self._reader_task) - ) if self._reader_task else None + reader_done = ( + asyncio.ensure_future(asyncio.shield(self._reader_task)) + if self._reader_task + else None + ) turn_wait = asyncio.ensure_future(self._turn_done.wait()) waiters = [turn_wait] @@ -190,13 +192,12 @@ async def _run(self) -> None: if reader_done and reader_done in done: logger.error( - "[SessionWorker] Reader died mid-turn for " - "thread=%s", + "[SessionWorker] Reader died mid-turn for thread=%s", self.thread_id, ) - await output_queue.put(WorkerError( - RuntimeError("SDK message reader died") - )) + await output_queue.put( + WorkerError(RuntimeError("SDK message reader died")) + ) break except Exception as exc: @@ -247,10 +248,14 @@ async def _read_messages_forever(self, client: Any) -> None: async for msg in client.receive_messages(): msg_type = type(msg).__name__ subtype = getattr(msg, "subtype", "") - route = "run" if self._active_output_queue is not None else "between-run" + route = ( + "run" if self._active_output_queue is not None else "between-run" + ) logger.debug( "[Reader] %s (subtype=%s) → %s queue", - msg_type, subtype, route, + msg_type, + subtype, + route, ) # Capture session_id from init message (for resume) diff --git a/components/runners/ambient-runner/ambient_runner/bridges/claude/tools.py b/components/runners/ambient-runner/ambient_runner/bridges/claude/tools.py old mode 100755 new mode 100644 diff --git a/components/runners/ambient-runner/ambient_runner/bridges/gemini_cli/bridge.py b/components/runners/ambient-runner/ambient_runner/bridges/gemini_cli/bridge.py index f8a48dd65..053c23113 100644 --- a/components/runners/ambient-runner/ambient_runner/bridges/gemini_cli/bridge.py +++ b/components/runners/ambient-runner/ambient_runner/bridges/gemini_cli/bridge.py @@ -74,7 +74,9 @@ def capabilities(self) -> FrameworkCapabilities: tracing="langfuse" if has_tracing else None, ) - async def run(self, input_data: RunAgentInput, **kwargs) -> AsyncIterator[BaseEvent]: + async def run( + self, input_data: RunAgentInput, **kwargs + ) -> AsyncIterator[BaseEvent]: """Full run lifecycle: lazy setup -> session worker -> tracing.""" # 1. Lazy platform setup await self._ensure_ready() @@ -127,7 +129,9 @@ async def _line_stream_with_capture(): wrapped_stream = tracing_middleware( secret_redaction_middleware( - self._adapter.run(input_data, line_stream=_line_stream_with_capture()), + self._adapter.run( + input_data, line_stream=_line_stream_with_capture() + ), ), obs=self._obs, model=self._configured_model, diff --git a/components/runners/ambient-runner/ambient_runner/bridges/gemini_cli/session.py b/components/runners/ambient-runner/ambient_runner/bridges/gemini_cli/session.py index 6b5a45cca..c8ee67534 100644 --- a/components/runners/ambient-runner/ambient_runner/bridges/gemini_cli/session.py +++ b/components/runners/ambient-runner/ambient_runner/bridges/gemini_cli/session.py @@ -152,9 +152,6 @@ async def query( stderr=asyncio.subprocess.PIPE, cwd=self._cwd, env=env, - limit=10 - * 1024 - * 1024, # 10 MB — default 64 KB is too small for large MCP tool responses ) # Start concurrent stderr streaming diff --git a/components/runners/ambient-runner/ambient_runner/bridges/langgraph/bridge.py b/components/runners/ambient-runner/ambient_runner/bridges/langgraph/bridge.py index 1e02053f1..aab590c01 100644 --- a/components/runners/ambient-runner/ambient_runner/bridges/langgraph/bridge.py +++ b/components/runners/ambient-runner/ambient_runner/bridges/langgraph/bridge.py @@ -71,7 +71,9 @@ def capabilities(self) -> FrameworkCapabilities: def set_context(self, context: RunnerContext) -> None: self._context = context - async def run(self, input_data: RunAgentInput, **kwargs) -> AsyncIterator[BaseEvent]: + async def run( + self, input_data: RunAgentInput, **kwargs + ) -> AsyncIterator[BaseEvent]: """Run the LangGraph adapter and yield AG-UI events. Lazily creates the adapter on first run. diff --git a/components/runners/ambient-runner/ambient_runner/endpoints/events.py b/components/runners/ambient-runner/ambient_runner/endpoints/events.py index 9f79949c3..faac54909 100644 --- a/components/runners/ambient-runner/ambient_runner/endpoints/events.py +++ b/components/runners/ambient-runner/ambient_runner/endpoints/events.py @@ -1,58 +1,201 @@ -"""GET /events — persistent SSE for between-run AG-UI events.""" +"""GET /events/{thread_id} — real-time SSE tap for an in-progress bridge.run() turn. + +The backend opens this endpoint BEFORE calling PushSessionMessage (see §2.1 of +the gRPC message transport design). Opening first registers the queue in +bridge._active_streams[thread_id] so the gRPC fan-out cannot fire before a +subscriber exists — eliminating the race condition by ordering, not polling. + +The GRPCSessionListener's fan-out loop feeds events into the queue. +This endpoint reads from the queue and yields them as SSE, filtering out +MESSAGES_SNAPSHOT (internal only — used by GRPCMessageWriter for the DB write). +The stream closes when RUN_FINISHED or RUN_ERROR is received. + +Defensive fallback: if the queue is not yet registered when a client connects +(edge case: very slow lifespan startup), the endpoint polls _active_streams +with 100ms sleep intervals up to EVENTS_TAP_TIMEOUT_SEC (default 2s) before +returning 404. +""" import asyncio import logging +import os +from typing import AsyncIterator -from ag_ui.encoder import EventEncoder -from fastapi import APIRouter, Request +from fastapi import APIRouter, HTTPException, Request from fastapi.responses import StreamingResponse logger = logging.getLogger(__name__) router = APIRouter() -# Heartbeat interval to keep the SSE connection alive (seconds). -_HEARTBEAT_INTERVAL = 15 +_POLL_INTERVAL = 0.1 +_TAP_TIMEOUT_SEC = float(os.getenv("EVENTS_TAP_TIMEOUT_SEC", "2")) + +_CLOSE_TYPES = frozenset(["RUN_FINISHED", "RUN_ERROR"]) +_FILTER_TYPES = frozenset(["MESSAGES_SNAPSHOT"]) + + +def _event_type_str(event) -> str: + raw = getattr(event, "type", None) + if raw is None: + return "" + return raw.value if hasattr(raw, "value") else str(raw) -@router.get("/events") -async def stream_events(request: Request): - """Persistent SSE endpoint for between-run events. +@router.get("/events/{thread_id}") +async def get_events(thread_id: str, request: Request): + """SSE tap for an in-progress bridge.run() turn. - Streams AG-UI events that arrive outside of user-initiated runs - (background task completions, hook notifications, agent responses - to task results). + Creates a bounded asyncio.Queue and registers it in + bridge._active_streams[thread_id] before returning the SSE response. + The GRPCSessionListener fan-out loop feeds events into the queue. """ bridge = request.app.state.bridge - ctx = getattr(bridge, "_context", None) - thread_id = ctx.session_id if ctx else "" + active_streams: dict[str, asyncio.Queue] | None = getattr( + bridge, "_active_streams", None + ) - encoder = EventEncoder(accept="text/event-stream") + if active_streams is None: + raise HTTPException( + status_code=503, detail="Bridge does not support active streams" + ) - async def event_stream(): - try: - event_iter = bridge.stream_between_run_events(thread_id) + existing = active_streams.get(thread_id) + if existing is not None: + logger.info( + "[SSE TAP] Reusing existing queue for thread=%s (active_streams count=%d)", + thread_id, + len(active_streams), + ) + queue: asyncio.Queue = existing + else: + queue = asyncio.Queue() + active_streams[thread_id] = queue + logger.info( + "[SSE TAP] Queue registered: thread=%s (active_streams count=%d)", + thread_id, + len(active_streams), + ) + async def event_stream() -> AsyncIterator[str]: + try: while True: if await request.is_disconnected(): + logger.info("[SSE TAP] Client disconnected: thread=%s", thread_id) break try: - event = await asyncio.wait_for( - event_iter.__anext__(), - timeout=_HEARTBEAT_INTERVAL, + event = await asyncio.wait_for(queue.get(), timeout=30.0) + except asyncio.TimeoutError: + yield ": heartbeat\n\n" + continue + + et = _event_type_str(event) + + if et in _FILTER_TYPES: + logger.debug("[SSE TAP] Filtered %s: thread=%s", et, thread_id) + continue + + try: + from ag_ui.encoder import EventEncoder + + encoder = EventEncoder(accept="text/event-stream") + encoded = encoder.encode(event) + logger.debug( + "[SSE TAP] Yielding event: thread=%s type=%s", thread_id, et ) - yield encoder.encode(event) - except StopAsyncIteration: + yield encoded + except Exception as enc_err: + logger.warning( + "[SSE TAP] Encode error: thread=%s type=%s error=%s", + thread_id, + et, + enc_err, + ) + + if et in _CLOSE_TYPES: + logger.info("[SSE TAP] Turn ended (%s): thread=%s", et, thread_id) + break + finally: + if active_streams.get(thread_id) is queue: + active_streams.pop(thread_id, None) + logger.info("[SSE TAP] Queue removed: thread=%s", thread_id) + + return StreamingResponse( + event_stream(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "X-Accel-Buffering": "no", + }, + ) + + +@router.get("/events/{thread_id}/wait") +async def wait_for_events(thread_id: str, request: Request): + """Defensive fallback variant: polls _active_streams until the queue is + registered (for edge cases where the listener hasn't started yet). + Returns 404 after EVENTS_TAP_TIMEOUT_SEC. + + The primary path is GET /events/{thread_id} which registers the queue + immediately on connection (before PushSessionMessage is called). + """ + bridge = request.app.state.bridge + active_streams: dict[str, asyncio.Queue] | None = getattr( + bridge, "_active_streams", None + ) + + if active_streams is None: + raise HTTPException( + status_code=503, detail="Bridge does not support active streams" + ) + + elapsed = 0.0 + while elapsed < _TAP_TIMEOUT_SEC: + if thread_id in active_streams: + break + await asyncio.sleep(_POLL_INTERVAL) + elapsed += _POLL_INTERVAL + + if thread_id not in active_streams: + logger.warning( + "[SSE TAP WAIT] Timeout after %.1fs: thread=%s", elapsed, thread_id + ) + raise HTTPException( + status_code=404, detail=f"No active stream for thread {thread_id!r}" + ) + + queue = active_streams[thread_id] + logger.info("[SSE TAP WAIT] Queue found after %.1fs: thread=%s", elapsed, thread_id) + + async def event_stream() -> AsyncIterator[str]: + try: + while True: + if await request.is_disconnected(): break + try: + event = await asyncio.wait_for(queue.get(), timeout=30.0) except asyncio.TimeoutError: - # No event within heartbeat interval — send keepalive yield ": heartbeat\n\n" - except Exception as e: - logger.error(f"Error in between-run event stream: {e}") + continue + + et = _event_type_str(event) + if et in _FILTER_TYPES: + continue + + try: + from ag_ui.encoder import EventEncoder + + encoder = EventEncoder(accept="text/event-stream") + yield encoder.encode(event) + except Exception as enc_err: + logger.warning("[SSE TAP WAIT] Encode error: %s", enc_err) + + if et in _CLOSE_TYPES: break - except Exception as e: - logger.error(f"Fatal error in /events stream: {e}", exc_info=True) + finally: + if active_streams.get(thread_id) is queue: + active_streams.pop(thread_id, None) return StreamingResponse( event_stream(), diff --git a/components/runners/ambient-runner/ambient_runner/endpoints/run.py b/components/runners/ambient-runner/ambient_runner/endpoints/run.py index 729311d13..aba850bae 100644 --- a/components/runners/ambient-runner/ambient_runner/endpoints/run.py +++ b/components/runners/ambient-runner/ambient_runner/endpoints/run.py @@ -11,6 +11,8 @@ from fastapi.responses import StreamingResponse from pydantic import BaseModel +from ambient_runner.middleware import grpc_push_middleware + logger = logging.getLogger(__name__) router = APIRouter() @@ -80,13 +82,18 @@ async def run_agent(input_data: RunnerInput, request: Request): f"Run: thread_id={run_agent_input.thread_id}, run_id={run_agent_input.run_id}" ) + session_id = run_agent_input.thread_id or "" + async def event_stream(): try: - async for event in bridge.run( - run_agent_input, - current_user_id=current_user_id, - current_user_name=current_user_name, - caller_token=caller_token, + async for event in grpc_push_middleware( + bridge.run( + run_agent_input, + current_user_id=current_user_id, + current_user_name=current_user_name, + caller_token=caller_token, + ), + session_id=session_id, ): try: yield encoder.encode(event) diff --git a/components/runners/ambient-runner/ambient_runner/endpoints/tasks.py b/components/runners/ambient-runner/ambient_runner/endpoints/tasks.py index 7f5cfa575..af843712d 100644 --- a/components/runners/ambient-runner/ambient_runner/endpoints/tasks.py +++ b/components/runners/ambient-runner/ambient_runner/endpoints/tasks.py @@ -55,7 +55,11 @@ async def stop_task(task_id: str, request: Request): completed_event = CustomEvent( type=EventType.CUSTOM, name="task:completed", - value={"task_id": task_id, "status": "stopped", "summary": "Task stopped by user"}, + value={ + "task_id": task_id, + "status": "stopped", + "summary": "Task stopped by user", + }, ) sm = getattr(bridge, "_session_manager", None) @@ -126,9 +130,7 @@ async def get_task_output(task_id: str, request: Request): ) if resolved.stat().st_size > _MAX_OUTPUT_BYTES: - raise HTTPException( - status_code=413, detail="Transcript too large" - ) + raise HTTPException(status_code=413, detail="Transcript too large") try: entries = [] diff --git a/components/runners/ambient-runner/ambient_runner/middleware/__init__.py b/components/runners/ambient-runner/ambient_runner/middleware/__init__.py index f35aed0f0..2dec18d35 100644 --- a/components/runners/ambient-runner/ambient_runner/middleware/__init__.py +++ b/components/runners/ambient-runner/ambient_runner/middleware/__init__.py @@ -6,7 +6,13 @@ """ from ambient_runner.middleware.developer_events import emit_developer_message +from ambient_runner.middleware.grpc_push import grpc_push_middleware from ambient_runner.middleware.secret_redaction import secret_redaction_middleware from ambient_runner.middleware.tracing import tracing_middleware -__all__ = ["tracing_middleware", "secret_redaction_middleware", "emit_developer_message"] +__all__ = [ + "tracing_middleware", + "secret_redaction_middleware", + "grpc_push_middleware", + "emit_developer_message", +] diff --git a/components/runners/ambient-runner/ambient_runner/middleware/grpc_push.py b/components/runners/ambient-runner/ambient_runner/middleware/grpc_push.py new file mode 100644 index 000000000..256833325 --- /dev/null +++ b/components/runners/ambient-runner/ambient_runner/middleware/grpc_push.py @@ -0,0 +1,127 @@ +""" +AG-UI gRPC Push Middleware — forwards events to ambient-api-server via gRPC. + +Wraps an AG-UI event stream and pushes each event as a ``SessionMessage`` +to the ``PushSessionMessage`` RPC on the ambient-api-server. The push is +fire-and-forget: failures are logged but never propagate to the caller. + +Usage:: + + from ambient_runner.middleware import grpc_push_middleware + + async for event in grpc_push_middleware( + bridge.run(input_data), + session_id=session_id, + ): + yield encoder.encode(event) + +When ``AMBIENT_GRPC_URL`` is unset the middleware is a transparent no-op +with zero overhead. +""" + +from __future__ import annotations + +import json +import logging +import os +from typing import AsyncIterator, Optional + +from ag_ui.core import BaseEvent + +logger = logging.getLogger(__name__) + +_ENV_GRPC_URL = "AMBIENT_GRPC_URL" +_ENV_SESSION_ID = "SESSION_ID" + + +def _event_to_payload(event: BaseEvent) -> str: + """Serialise an AG-UI event to a JSON string for the gRPC payload.""" + try: + if hasattr(event, "model_dump"): + return json.dumps(event.model_dump()) + if hasattr(event, "dict"): + return json.dumps(event.dict()) + return json.dumps({"type": str(getattr(event, "type", "unknown"))}) + except Exception: + return json.dumps({"type": str(getattr(event, "type", "unknown"))}) + + +def _event_type_str(event: BaseEvent) -> str: + raw = getattr(event, "type", None) + if raw is None: + return "unknown" + return str(raw.value) if hasattr(raw, "value") else str(raw) + + +async def grpc_push_middleware( + event_stream: AsyncIterator[BaseEvent], + *, + session_id: Optional[str] = None, +) -> AsyncIterator[BaseEvent]: + """Wrap an AG-UI event stream with gRPC push to ambient-api-server. + + Args: + event_stream: The upstream event stream. + session_id: Session ID to push messages under. Falls back to the + ``SESSION_ID`` environment variable. + + Yields: + The original events unchanged. + """ + grpc_url = os.environ.get(_ENV_GRPC_URL, "").strip() + if not grpc_url: + async for event in event_stream: + yield event + return + + sid = session_id or os.environ.get(_ENV_SESSION_ID, "").strip() + if not sid: + logger.warning( + "grpc_push_middleware: AMBIENT_GRPC_URL set but SESSION_ID missing — push disabled" + ) + async for event in event_stream: + yield event + return + + grpc_client: Optional[object] = None + try: + from ambient_platform._grpc_client import AmbientGRPCClient + + grpc_client = AmbientGRPCClient.from_env() + logger.info("grpc_push_middleware: connected to %s (session=%s)", grpc_url, sid) + except Exception as exc: + logger.warning( + "grpc_push_middleware: failed to create gRPC client (%s) — push disabled", + exc, + ) + async for event in event_stream: + yield event + return + + try: + async for event in event_stream: + yield event + _push_event(grpc_client, sid, event) + finally: + try: + grpc_client.close() + except Exception: + pass + + +def _push_event(grpc_client: object, session_id: str, event: BaseEvent) -> None: + """Fire-and-forget push of a single AG-UI event via gRPC.""" + try: + event_type = _event_type_str(event) + payload = _event_to_payload(event) + grpc_client.session_messages.push( + session_id=session_id, + event_type=event_type, + payload=payload, + ) + except Exception as exc: + logger.debug( + "grpc_push_middleware: push failed (event=%s): %s", + _event_type_str(event), + exc, + ) diff --git a/components/runners/ambient-runner/ambient_runner/middleware/secret_redaction.py b/components/runners/ambient-runner/ambient_runner/middleware/secret_redaction.py index 7879f4e16..ace93ed68 100644 --- a/components/runners/ambient-runner/ambient_runner/middleware/secret_redaction.py +++ b/components/runners/ambient-runner/ambient_runner/middleware/secret_redaction.py @@ -87,7 +87,15 @@ def _redact_event(event: BaseEvent, secret_values: list[tuple[str, str]]) -> Bas Only processes event types that carry user-visible text. All other events pass through unchanged (zero cost). """ - if isinstance(event, (TextMessageContentEvent, TextMessageChunkEvent, ToolCallArgsEvent, ToolCallChunkEvent)): + if isinstance( + event, + ( + TextMessageContentEvent, + TextMessageChunkEvent, + ToolCallArgsEvent, + ToolCallChunkEvent, + ), + ): redacted = _redact_text(event.delta, secret_values) if redacted != event.delta: return event.model_copy(update={"delta": redacted}) diff --git a/components/runners/ambient-runner/ambient_runner/platform/auth.py b/components/runners/ambient-runner/ambient_runner/platform/auth.py index 5009b185d..438f5c77c 100755 --- a/components/runners/ambient-runner/ambient_runner/platform/auth.py +++ b/components/runners/ambient-runner/ambient_runner/platform/auth.py @@ -17,7 +17,7 @@ from urllib.parse import urlparse from ambient_runner.platform.context import RunnerContext -from ambient_runner.platform.utils import get_bot_token +from ambient_runner.platform.utils import get_bot_token, refresh_bot_token logger = logging.getLogger(__name__) @@ -41,6 +41,7 @@ # time), so updating os.environ mid-run would not reach it without these files. _GITHUB_TOKEN_FILE = Path("/tmp/.ambient_github_token") _GITLAB_TOKEN_FILE = Path("/tmp/.ambient_gitlab_token") +_KUBECONFIG_FILE = Path("/tmp/.ambient_kubeconfig") # --------------------------------------------------------------------------- @@ -103,23 +104,33 @@ def sanitize_user_context(user_id: str, user_name: str) -> tuple[str, str]: async def _fetch_credential(context: RunnerContext, credential_type: str) -> dict: """Fetch credentials from backend API at runtime.""" base = os.getenv("BACKEND_API_URL", "").rstrip("/") - project = os.getenv("PROJECT_NAME") or os.getenv("AGENTIC_SESSION_NAMESPACE", "") - project = project.strip() - session_id = context.session_id - if not base or not project or not session_id: + if not base: logger.warning( - f"Cannot fetch {credential_type} credentials: missing environment " - f"variables (base={base}, project={project}, session={session_id})" + f"Cannot fetch {credential_type} credentials: BACKEND_API_URL not set" ) return {} - url = f"{base}/projects/{project}/agentic-sessions/{session_id}/credentials/{credential_type}" + credential_ids = _json.loads(os.getenv("CREDENTIAL_IDS", "{}")) + credential_id = credential_ids.get(credential_type) + if not credential_id: + logger.debug(f"No credential_id for provider {credential_type}; skipping fetch") + return {} + + project_id = os.getenv("PROJECT_NAME", "") + if not project_id: + logger.warning("Cannot fetch credentials: PROJECT_NAME not set") + return {} + + url = ( + f"{base}/api/ambient/v1/projects/{project_id}/credentials/{credential_id}/token" + ) # Reject non-cluster URLs to prevent token exfiltration via user-overridden env vars parsed = urlparse(base) if parsed.hostname and not ( parsed.hostname.endswith(".svc.cluster.local") + or parsed.hostname.endswith(".svc") or parsed.hostname == "localhost" or parsed.hostname == "127.0.0.1" ): @@ -144,6 +155,7 @@ async def _fetch_credential(context: RunnerContext, credential_type: str) -> dic bot = get_bot_token() if bot: req.add_header("Authorization", f"Bearer {bot}") + logger.debug(f"Using CP OIDC token for {credential_type} credentials") loop = asyncio.get_running_loop() @@ -179,18 +191,49 @@ def _do_req(): f"and BOT_TOKEN fallback also failed" ) from fallback_err if e.code in (401, 403): - logger.warning( - f"{credential_type} credential fetch failed with HTTP {e.code}: {e}" - ) - raise PermissionError( - f"{credential_type} authentication failed with HTTP {e.code}" - ) from e + # BOT_TOKEN may have expired — refresh from CP endpoint and retry once. + return _retry_with_fresh_bot_token(e.code) logger.warning(f"{credential_type} credential fetch failed: {e}") return "" except Exception as e: logger.warning(f"{credential_type} credential fetch failed: {e}") return "" + def _retry_with_fresh_bot_token(original_code: int): + logger.info( + f"{credential_type} got {original_code} with cached BOT_TOKEN — refreshing from CP endpoint and retrying" + ) + try: + fresh_bot = refresh_bot_token() + except Exception as refresh_err: + logger.warning(f"{credential_type} CP token refresh failed: {refresh_err}") + raise PermissionError( + f"{credential_type} authentication failed with HTTP {original_code}" + ) from refresh_err + retry_req = _urllib_request.Request(url, method="GET") + if fresh_bot: + retry_req.add_header("Authorization", f"Bearer {fresh_bot}") + if context.current_user_id: + retry_req.add_header("X-Runner-Current-User", context.current_user_id) + try: + with _urllib_request.urlopen(retry_req, timeout=10) as resp: + logger.info(f"{credential_type} retry with fresh BOT_TOKEN succeeded") + return resp.read().decode("utf-8", errors="replace") + except _urllib_request.HTTPError as retry_err: + logger.warning( + f"{credential_type} retry with fresh BOT_TOKEN failed: {retry_err}" + ) + raise PermissionError( + f"{credential_type} authentication failed with HTTP {retry_err.code}" + ) from retry_err + except Exception as retry_err: + logger.warning( + f"{credential_type} retry with fresh BOT_TOKEN failed: {retry_err}" + ) + raise PermissionError( + f"{credential_type} authentication failed with HTTP {original_code}" + ) from retry_err + resp_text = await loop.run_in_executor(None, _do_req) if not resp_text: return {} @@ -313,6 +356,10 @@ async def fetch_coderabbit_credentials(context: RunnerContext) -> dict: return data +async def fetch_kubeconfig_credential(context: RunnerContext) -> dict: + return await _fetch_credential(context, "kubeconfig") + + async def fetch_token_for_url(context: RunnerContext, url: str) -> str: """Fetch appropriate token based on repository URL host.""" try: @@ -342,6 +389,7 @@ async def populate_runtime_credentials(context: RunnerContext) -> None: github_creds, coderabbit_creds, gerrit_creds, + kubeconfig_creds, ) = await asyncio.gather( fetch_google_credentials(context), fetch_jira_credentials(context), @@ -349,6 +397,7 @@ async def populate_runtime_credentials(context: RunnerContext) -> None: fetch_github_credentials(context), fetch_coderabbit_credentials(context), fetch_gerrit_credentials(context), + fetch_kubeconfig_credential(context), return_exceptions=True, ) @@ -362,28 +411,18 @@ async def populate_runtime_credentials(context: RunnerContext) -> None: logger.warning(f"Failed to refresh Google credentials: {google_creds}") if isinstance(google_creds, PermissionError): auth_failures.append(str(google_creds)) - elif google_creds.get("accessToken"): + elif google_creds.get("token"): try: - creds_dir = _GOOGLE_WORKSPACE_CREDS_FILE.parent - creds_dir.mkdir(parents=True, exist_ok=True) - - # The refresh token is written to disk because workspace-mcp - # runs as a child process and cannot call back to the platform - # backend to obtain fresh access tokens on its own. - creds_data = { - "token": google_creds.get("accessToken"), - "refresh_token": google_creds.get("refreshToken", ""), - "token_uri": "https://oauth2.googleapis.com/token", - "client_id": os.getenv("GOOGLE_OAUTH_CLIENT_ID", ""), - "client_secret": os.getenv("GOOGLE_OAUTH_CLIENT_SECRET", ""), - "scopes": google_creds.get("scopes", []), - "expiry": google_creds.get("expiresAt", ""), - } - - with open(_GOOGLE_WORKSPACE_CREDS_FILE, "w") as f: - _json.dump(creds_data, f, indent=2) - _GOOGLE_WORKSPACE_CREDS_FILE.chmod(0o600) - logger.info("Updated Google credentials file for workspace-mcp") + sa_json = google_creds["token"] + gac_path = os.getenv("GOOGLE_APPLICATION_CREDENTIALS", "") + if gac_path: + creds_path = Path(gac_path) + else: + creds_path = _GOOGLE_WORKSPACE_CREDS_FILE + creds_path.parent.mkdir(parents=True, exist_ok=True) + creds_path.write_text(sa_json) + creds_path.chmod(0o600) + logger.info(f"Updated Google service account credentials at {creds_path}") user_email = google_creds.get("email", "") if user_email and user_email != _PLACEHOLDER_EMAIL: @@ -396,9 +435,9 @@ async def populate_runtime_credentials(context: RunnerContext) -> None: logger.warning(f"Failed to refresh Jira credentials: {jira_creds}") if isinstance(jira_creds, PermissionError): auth_failures.append(str(jira_creds)) - elif jira_creds.get("apiToken"): + elif jira_creds.get("token"): os.environ["JIRA_URL"] = jira_creds.get("url", "") - os.environ["JIRA_API_TOKEN"] = jira_creds.get("apiToken", "") + os.environ["JIRA_API_TOKEN"] = jira_creds.get("token", "") os.environ["JIRA_EMAIL"] = jira_creds.get("email", "") logger.info("Updated Jira credentials in environment") @@ -456,16 +495,27 @@ async def populate_runtime_credentials(context: RunnerContext) -> None: logger.warning(f"Failed to fetch Gerrit credentials: {gerrit_creds}") if isinstance(gerrit_creds, PermissionError): auth_failures.append(str(gerrit_creds)) - # Clear config on auth failure from ambient_runner.bridges.claude.mcp import generate_gerrit_config generate_gerrit_config([]) - # On network error, preserve existing config (don't clear) else: from ambient_runner.bridges.claude.mcp import generate_gerrit_config generate_gerrit_config(gerrit_creds) + if isinstance(kubeconfig_creds, Exception): + logger.warning(f"Failed to refresh kubeconfig credentials: {kubeconfig_creds}") + if isinstance(kubeconfig_creds, PermissionError): + auth_failures.append(str(kubeconfig_creds)) + elif kubeconfig_creds.get("token"): + try: + _KUBECONFIG_FILE.write_text(kubeconfig_creds["token"]) + _KUBECONFIG_FILE.chmod(0o600) + os.environ["KUBECONFIG"] = str(_KUBECONFIG_FILE) + logger.info(f"Written kubeconfig to {_KUBECONFIG_FILE}") + except OSError as e: + logger.warning(f"Failed to write kubeconfig file: {e}") + # Configure git identity, credential helper, and gh CLI wrapper await configure_git_identity(git_user_name, git_user_email) install_git_credential_helper() @@ -484,7 +534,7 @@ def clear_runtime_credentials() -> None: """Remove sensitive credentials from environment after turn completes. Clears fixed credential keys, dynamically-injected MCP_* env vars, - and Google Workspace credential files. + and token files. Google credential files are preserved (issue #1222). """ cleared = [] for key in [ @@ -495,6 +545,7 @@ def clear_runtime_credentials() -> None: "JIRA_EMAIL", "USER_GOOGLE_EMAIL", "CODERABBIT_API_KEY", + "KUBECONFIG", ]: if os.environ.pop(key, None) is not None: cleared.append(key) @@ -511,20 +562,18 @@ def clear_runtime_credentials() -> None: cleared.append(key) # Remove token files used by the git credential helper. - for token_file in (_GITHUB_TOKEN_FILE, _GITLAB_TOKEN_FILE): + for token_file in (_GITHUB_TOKEN_FILE, _GITLAB_TOKEN_FILE, _KUBECONFIG_FILE): try: token_file.unlink(missing_ok=True) cleared.append(token_file.name) except OSError as e: logger.warning(f"Failed to remove token file {token_file}: {e}") - # NOTE: Google Workspace credential file is intentionally NOT deleted here. - # The workspace-mcp process runs as a long-lived child process of the Claude - # CLI and reads credentials from this file. Deleting it between turns causes - # workspace-mcp to lose its credentials and fall back to initiating a new - # OAuth flow (with an inaccessible localhost:8000 callback URL). - # The file is overwritten with fresh credentials at the start of each run - # by populate_runtime_credentials(), so staleness is not a concern. + # NOTE: Google credential files (_GOOGLE_WORKSPACE_CREDS_FILE and + # GOOGLE_APPLICATION_CREDENTIALS) are intentionally NOT deleted here. + # The workspace-mcp process reads credentials from these files; deleting + # them between turns causes it to fall back to an inaccessible localhost + # OAuth flow (issue #1222). if cleared: logger.info(f"Cleared credentials: {', '.join(cleared)}") diff --git a/components/runners/ambient-runner/ambient_runner/platform/context.py b/components/runners/ambient-runner/ambient_runner/platform/context.py index f9b40a2b7..79e6c8fa5 100644 --- a/components/runners/ambient-runner/ambient_runner/platform/context.py +++ b/components/runners/ambient-runner/ambient_runner/platform/context.py @@ -52,7 +52,9 @@ def get_metadata(self, key: str, default: Any = None) -> Any: """Get a metadata value.""" return self.metadata.get(key, default) - def set_current_user(self, user_id: str, user_name: str = "", token: str = "") -> None: + def set_current_user( + self, user_id: str, user_name: str = "", token: str = "" + ) -> None: """Set the current user for per-message credential scoping.""" self.current_user_id = user_id self.current_user_name = user_name diff --git a/components/runners/ambient-runner/ambient_runner/platform/prompts.py b/components/runners/ambient-runner/ambient_runner/platform/prompts.py index 2e30097b4..7de1f8433 100644 --- a/components/runners/ambient-runner/ambient_runner/platform/prompts.py +++ b/components/runners/ambient-runner/ambient_runner/platform/prompts.py @@ -18,6 +18,10 @@ # Prompt constants # --------------------------------------------------------------------------- +DEFAULT_AGENT_PREAMBLE = os.getenv( + "AGENT_PREAMBLE", "You are a helpful AI agent. Be kind." +) + WORKSPACE_STRUCTURE_HEADER = "# Workspace Structure\n\n" WORKSPACE_FIXED_PATHS_PROMPT = ( @@ -68,15 +72,6 @@ "the feature branch (`{branch}`). If push fails, do NOT fall back to main.\n\n" ) -GIT_SAFETY_INSTRUCTIONS = ( - "## Git Safety\n\n" - "**NEVER embed tokens or credentials in commands** — use environment " - "variables (`$GITHUB_TOKEN`, `$GITLAB_TOKEN`) instead of inline PATs.\n\n" - "**When a git operation fails**: stop, diagnose, report the error to the " - "user, and wait. Do NOT autonomously escalate to force pushes, API " - "workarounds, or more aggressive retry variants.\n\n" -) - RUBRIC_EVALUATION_HEADER = "## Rubric Evaluation\n\n" RUBRIC_EVALUATION_INTRO = ( @@ -224,9 +219,6 @@ def build_workspace_context_prompt( prompt += f"- **repos/{repo_name}/**\n" prompt += GIT_PUSH_STEPS.format(branch=push_branch) - if repos_cfg: - prompt += GIT_SAFETY_INSTRUCTIONS - # Human-in-the-loop instructions prompt += HUMAN_INPUT_INSTRUCTIONS diff --git a/components/runners/ambient-runner/ambient_runner/platform/utils.py b/components/runners/ambient-runner/ambient_runner/platform/utils.py index 2ebaec247..898e54a3a 100644 --- a/components/runners/ambient-runner/ambient_runner/platform/utils.py +++ b/components/runners/ambient-runner/ambient_runner/platform/utils.py @@ -23,14 +23,46 @@ # Kubelet automatically refreshes this file when the Secret is updated. _BOT_TOKEN_FILE = Path("/var/run/secrets/ambient/bot-token") +# K8s SA token mounted in every pod by the kubelet. +_SA_TOKEN_FILE = Path("/var/run/secrets/kubernetes.io/serviceaccount/token") + +# In-process cache for the token fetched from the CP token endpoint. +# Set once at startup by _grpc_client.py after a successful CP token fetch. +_cp_fetched_token: str = "" + + +def get_sa_token() -> str: + """Return the Kubernetes ServiceAccount token mounted in the pod. + + This is a long-lived K8s-managed token that authenticates to the K8s API + as system:serviceaccount::. The backend's + enforceCredentialRBAC classifies this as isBotToken=true, which grants + access to the session owner's credentials without an owner-match check. + """ + try: + if _SA_TOKEN_FILE.exists(): + return _SA_TOKEN_FILE.read_text().strip() + except OSError: + pass + return "" + + +def set_bot_token(token: str) -> None: + """Store a token fetched from the CP token endpoint for use by get_bot_token().""" + global _cp_fetched_token + _cp_fetched_token = token.strip() + def get_bot_token() -> str: - """Return the current BOT_TOKEN, preferring the file mount over env var. + """Return the current BOT_TOKEN. - The operator mounts the runner-token Secret as a file so kubelet refreshes - it automatically when the token is rotated. Falls back to the BOT_TOKEN - env var for backward-compatibility with local / non-Kubernetes runs. + Priority: + 1. Token fetched from CP token endpoint (set via set_bot_token()). + 2. File mount at _BOT_TOKEN_FILE (kubelet-refreshed Secret). + 3. BOT_TOKEN env var (local / non-Kubernetes fallback). """ + if _cp_fetched_token: + return _cp_fetched_token try: if _BOT_TOKEN_FILE.exists(): return _BOT_TOKEN_FILE.read_text().strip() @@ -39,6 +71,27 @@ def get_bot_token() -> str: return (os.getenv("BOT_TOKEN") or "").strip() +def refresh_bot_token() -> str: + """Fetch a fresh token from the CP token endpoint and update the in-process cache. + + Returns the new token, or the current cached token if the CP endpoint is not + configured (local dev mode). Raises RuntimeError if the CP fetch fails. + """ + cp_token_url = os.getenv("AMBIENT_CP_TOKEN_URL", "") + if not cp_token_url: + return get_bot_token() + + public_key_pem = os.getenv("AMBIENT_CP_TOKEN_PUBLIC_KEY", "") + session_id = os.getenv("SESSION_ID", "") + if not public_key_pem or not session_id: + logger.warning("refresh_bot_token: CP env vars incomplete, skipping refresh") + return get_bot_token() + + from ambient_runner._grpc_client import _fetch_token_from_cp + + return _fetch_token_from_cp(cp_token_url, public_key_pem, session_id) + + def is_env_truthy(value: str) -> bool: """Return True for "1", "true", or "yes" (case-insensitive).""" return value.strip().lower() in _TRUTHY_VALUES diff --git a/components/runners/ambient-runner/ambient_runner/tools/backend_api.py b/components/runners/ambient-runner/ambient_runner/tools/backend_api.py index 9121ba605..13242931c 100644 --- a/components/runners/ambient-runner/ambient_runner/tools/backend_api.py +++ b/components/runners/ambient-runner/ambient_runner/tools/backend_api.py @@ -46,7 +46,9 @@ def __init__( # when the Secret is rotated, but env vars are frozen at pod start). self._bot_token_override = bot_token # Expose self.bot_token for backward-compatibility with existing callers. - self.bot_token = (bot_token if bot_token is not None else get_bot_token()).strip() + self.bot_token = ( + bot_token if bot_token is not None else get_bot_token() + ).strip() if not self.backend_url: raise ValueError("BACKEND_API_URL environment variable is required") diff --git a/components/runners/ambient-runner/architecture.md b/components/runners/ambient-runner/architecture.md new file mode 100644 index 000000000..4c36b06d7 --- /dev/null +++ b/components/runners/ambient-runner/architecture.md @@ -0,0 +1,438 @@ +# Ambient Runner: Architecture + +## Overview + +The runner is a FastAPI server running in a Kubernetes Job pod (one pod per session). It implements the [AG-UI protocol](https://github.com/ag-ui-protocol/ag-ui) — a Server-Sent Events (SSE) streaming protocol for AI agents. The runner bridges between the platform backend and the underlying AI model (Claude Agent SDK). + +There are two delivery modes. The **HTTP path** is the original design: the backend POSTs to `/agui/run` and streams AG-UI events back over SSE. The **gRPC path** is an additive overlay that replaces the HTTP round-trip with a persistent bidirectional gRPC channel to the Ambient control plane. Both paths share the same `bridge.run()` execution primitive — only the delivery mechanism differs. + +``` +┌─────────────────────────────────────────────────────────────────────────────────────┐ +│ Kubernetes Job Pod (one per session) │ +│ │ +│ ENV: SESSION_ID, WORKSPACE_PATH, INITIAL_PROMPT │ +│ ENV: AMBIENT_GRPC_ENABLED=true, AMBIENT_GRPC_URL=... ← only in gRPC mode │ +│ │ +│ ┌─────────────────────────────────────────────────────────────────────────────┐ │ +│ │ FastAPI app (create_ambient_app) │ │ +│ │ │ │ +│ │ lifespan startup: │ │ +│ │ 1. build RunnerContext │ │ +│ │ 2. bridge.set_context(ctx) │ │ +│ │ 3. if GRPC_ENABLED → bridge.start_grpc_listener(url) ← gRPC only │ │ +│ │ └── await listener.ready (10s timeout) │ │ +│ │ 4. asyncio.create_task(_auto_execute_initial_prompt) │ │ +│ │ └── if grpc_url → _push_initial_prompt_via_grpc ← gRPC only │ │ +│ │ else → _push_initial_prompt_via_http ← HTTP path │ │ +│ │ │ │ +│ │ ┌───────────────────────────────────────────────────────────────────────┐ │ │ +│ │ │ ClaudeBridge │ │ │ +│ │ │ │ │ │ +│ │ │ _active_streams: dict[thread_id → asyncio.Queue] ← gRPC only │ │ │ +│ │ │ _grpc_listener: GRPCSessionListener | None ← gRPC only │ │ │ +│ │ │ │ │ │ +│ │ │ run(input_data) → AsyncIterator[BaseEvent] ← shared by both │ │ │ +│ │ └───────────────────────────────────────────────────────────────────────┘ │ │ +│ │ │ │ +│ │ HTTP endpoints (existing, always active): │ │ +│ │ POST /run → bridge.run() → SSE to caller │ │ +│ │ POST /interrupt → bridge.interrupt() │ │ +│ │ GET /capabilities, /mcp-status, /repos, /workflow, ... │ │ +│ │ │ │ +│ │ SSE tap endpoints (new, always mounted, only useful in gRPC mode): │ │ +│ │ GET /events/{thread_id} → SSE tap (real-time) │ │ +│ │ GET /events/{thread_id}/wait → SSE tap (polling fallback) │ │ +│ └─────────────────────────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────────────────────────┘ +``` + +--- + +## Startup and Lifecycle (`app.py`, `main.py`) + +1. **`main.py`** reads `RUNNER_TYPE` (e.g. `claude-agent-sdk`) and instantiates the bridge. +2. **`create_ambient_app(bridge)`** creates the FastAPI app with a lifespan context manager: + - Builds `RunnerContext` from `SESSION_ID` / `WORKSPACE_PATH` env vars + - Calls `bridge.set_context(context)` + - If `AMBIENT_GRPC_ENABLED=true` and `AMBIENT_GRPC_URL` are set: calls `bridge.start_grpc_listener(url)` and awaits `listener.ready` (10s timeout) before proceeding — ensures the watch stream is open before the initial prompt fires + - If `IS_RESUME` is not set and a prompt exists: fires `_auto_execute_initial_prompt()` as a background `asyncio.Task` + - On shutdown: calls `bridge.shutdown()` +3. **Auto-prompt** (`_auto_execute_initial_prompt`): + - **gRPC mode**: calls `_push_initial_prompt_via_grpc()` — pushes a `PushSessionMessage(event_type="user")` to the control plane; the listener picks it up and drives `bridge.run()` directly + - **HTTP mode**: calls `_push_initial_prompt_via_http()` — POSTs to `BACKEND_API_URL/projects/{project}/agentic-sessions/{session}/agui/run` with exponential backoff (8 retries, 2s→30s) because K8s DNS may not propagate before the pod is ready + +--- + +## The Bridge Pattern (`bridge.py`) + +`PlatformBridge` is an abstract base class. All framework implementations must provide: + +- `capabilities()` — declares features to the frontend +- `run(input_data)` — async generator yielding AG-UI `BaseEvent` objects +- `interrupt(thread_id)` — stops the current run + +Key lifecycle hooks (override as needed): + +- `set_context()` — stores `RunnerContext` at startup +- `_ensure_ready()` / `_setup_platform()` — lazy one-time init on first `run()` +- `_refresh_credentials_if_stale()` — refreshes tokens every 60s or when GitHub token is expiring +- `shutdown()` — called on pod termination +- `mark_dirty()` — called by repos/workflow endpoints when workspace changes; rebuilds adapter on next `run()` +- `inject_message()` — raises `NotImplementedError` on base class; must be overridden by any bridge that handles inbound messages + +--- + +## The Two Delivery Paths + +### HTTP Path (original) + +The backend owns the entire request lifecycle. It POSTs to the runner and receives AG-UI events back over SSE. The runner never initiates contact. + +``` + Frontend + │ HTTP + ▼ + Backend + │ POST /projects/{proj}/agentic-sessions/{sess}/agui/run + ▼ + POST /run endpoint (runner) + │ bridge.run(input_data) + ▼ + ClaudeBridge.run() + │ yields AG-UI events + ▼ + SSE stream ──────────────────────────────────────────► Backend + │ + └── writes result to DB +``` + +### gRPC Path (additive overlay) + +The control plane owns message delivery. The runner maintains a persistent outbound watch stream, and the control plane pushes messages into it. The runner calls `bridge.run()` internally, then fans events out through two parallel channels: an SSE tap for the backend to observe, and a `PushSessionMessage` to persist the assembled result. + +``` + ┌─────────────────────┐ ┌────────────────────────────────────────────────┐ + │ Ambient Control │ │ Pod │ + │ Plane (gRPC) │ │ │ + │ │ │ ┌──────────────────────────────────────────┐ │ + │ WatchSessionMsgs │◄───────┤ │ GRPCSessionListener (background thread) │ │ + │ stream │ watch │ │ ThreadPoolExecutor │ │ + │ │ │ │ blocks on gRPC stream │ │ + │ │ │ │ sets listener.ready on stream open │ │ + │ │ │ └──────────────────────────────────────────┘ │ + │ │ │ │ │ + │ PushSessionMessage │ │ event_type=="user" received │ + │ (user message) │───────►│ parse payload → RunnerInput │ + │ │ │ build RunAgentInput │ + │ │ │ │ │ + │ │ │ bridge.run(input_data) │ + │ │ │ │ │ + │ │ │ ├──► active_streams[thread_id] │ + │ │ │ │ asyncio.Queue.put_nowait() │ + │ │ │ │ │ │ + │ │ │ │ ▼ │ + │ │ │ │ GET /events/{thread_id} │ + │ │ │ │ SSE ──────────────► Backend │ + │ │ │ │ │ + │ │ │ └──► GRPCMessageWriter.consume() │ + │ │ │ accumulates MESSAGES_SNAPSHOT │ + │ │ │ on RUN_FINISHED / RUN_ERROR: │ + │ │ │ │ + │ PushSessionMessage │◄───────┤ PushSessionMessage(event_type="assistant") │ + │ (assistant result) │ │ run_in_executor (non-blocking) │ + │ │ │ payload: {run_id, status, messages} │ + └─────────────────────┘ └────────────────────────────────────────────────┘ +``` + +--- + +## SSE Queue Lifecycle and the Ordering Contract + +The key design decision in the gRPC path is the **ordering contract**: the backend must open `GET /events/{thread_id}` *before* it sends the user message via `PushSessionMessage`. Pre-registration eliminates the race — the queue exists in `active_streams` before the first event can arrive. + +``` + Backend GET /events endpoint GRPCSessionListener + │ │ │ + │ GET /events/{thread_id} │ │ + │──────────────────────────────►│ │ + │ │ queue = existing or new │ + │ │ active_streams[id] = q │ + │ │ │ + │ (control plane delivers user message) │ + │ │ bridge.run() starts + │ │ events → q.put_nowait() + │◄── SSE chunk ─────────────────│◄── q.get() ───────────────│ + │◄── SSE chunk ─────────────────│◄── q.get() ───────────────│ + │◄── RUN_FINISHED ──────────────│◄── q.get() ───────────────│ + │ │ break (stream closes) │ + │ │ if q is active_streams[id]: pop + │ │ │ finally: + │ │ │ if registered_q + │ │ │ is active_q: + │ │ │ pop +``` + +**Identity-safe cleanup:** both the SSE endpoint and `GRPCSessionListener` capture the queue reference at the start of their respective lifetimes and only remove it from `active_streams` if the map still points to the same object. This prevents a reconnecting client or a new turn from having its queue silently removed by an older cleanup. + +**Duplicate connect:** if a client connects to `/events/{thread_id}` when a queue is already registered (e.g. reconnect), the endpoint reuses the existing queue rather than replacing it. This prevents buffered events from being dropped. + +--- + +## ClaudeBridge: The Full Claude Lifecycle (`bridges/claude/bridge.py`) + +`ClaudeBridge` is the complete bridge implementation. Its `run()` method: + +1. **`_ensure_ready()`** — on first call, runs `_setup_platform()`: + - Auth setup (Anthropic API key or Vertex AI credentials) + - `populate_runtime_credentials()` / `populate_mcp_server_credentials()` — fetches GitHub tokens, Google OAuth, Jira tokens from the backend + - `resolve_workspace_paths()` — determines cwd and additional dirs + - `build_mcp_servers()` — assembles full MCP server config (external + platform tools) + - `build_sdk_system_prompt()` — builds the system prompt + - Initializes `ObservabilityManager` (Langfuse) + - Creates `SessionManager` +2. **`_ensure_adapter()`** — builds `ClaudeAgentAdapter` with all options (cwd, permission mode, allowed tools, MCP servers, system prompt). Adapter is cached and reused. A ring buffer of 50 stderr lines is maintained for error reporting. +3. **Worker selection** — gets or creates a `SessionWorker` for the thread, optionally resuming from a previously saved CLI session ID (for pod restarts). +4. **Event streaming** — acquires a per-thread `asyncio.Lock` (prevents concurrent requests to the same thread from mixing), calls `worker.query(user_msg)`, wraps the stream through `tracing_middleware`, and yields events. +5. **Halt detection** — after the stream ends, checks `adapter.halted`. If the adapter halted (because Claude called a frontend HITL tool like `AskUserQuestion`), calls `worker.interrupt()` to prevent the SDK from auto-approving the tool call. +6. **Session persistence** — after each turn, saves the CLI session ID to disk (`claude_session_ids.json`) so `--resume` works after pod restart. + +**gRPC listener** (`start_grpc_listener`): a dedicated startup hook (separate from `_setup_platform`) that instantiates and starts `GRPCSessionListener`. Only called when both `AMBIENT_GRPC_ENABLED=true` and `AMBIENT_GRPC_URL` are set. The listener is started before the initial prompt fires so the watch stream is open before the first message arrives. Duplicate calls are idempotent. + +--- + +## SessionWorker and Queue Architecture (`bridges/claude/session.py`) + +This is the mechanism that lets the long-lived Claude CLI process work inside FastAPI's async event loop: + +``` + Request Handler (async context A) Background Task (async context B) + │ │ + worker.query(prompt) worker._run() loop + │ │ + puts (prompt, session_id, ◄── input_queue.get() + output_queue) on input_queue │ + │ client.query(prompt) + output_queue.get() in loop async for msg in client.receive_response() + │ output_queue.put(msg) + ▼ ... + yields messages output_queue.put(None) ← sentinel +``` + +**Why this exists:** the Claude Agent SDK uses `anyio` task groups internally. Using a persistent `ClaudeSDKClient` inside a FastAPI SSE handler (a different async context) hits anyio's task group context mismatch. The worker pattern sidesteps this by running the SDK client entirely inside one stable background `asyncio.Task`. + +Queue protocol: +- Input queue items: `(prompt, session_id, output_queue)` or `_SHUTDOWN` sentinel +- Output queue items: SDK `Message` objects, `WorkerError(exception)` wrapper, or `None` sentinel (end of turn) +- `WorkerError` is a typed wrapper to avoid ambiguous `isinstance(item, Exception)` checks + +Worker lifecycle: +- `start()` — spawns `asyncio.create_task(self._run())` +- `_run()` loop — connects SDK client, then: get from input queue → query client → stream responses to output queue → put `None` sentinel +- On any error during a query: puts `WorkerError` then `None`, then breaks (worker dies; `SessionManager` recreates it) +- `stop()` — puts `_SHUTDOWN`, waits up to 15s, then cancels task + +**Graceful disconnect:** closes stdin of the Claude CLI subprocess so the CLI saves its session state to `.claude/` before terminating. Enables `--resume` on pod restart. + +`SessionManager`: one worker per `thread_id`. Maintains a per-thread `asyncio.Lock` to serialize concurrent requests. Session IDs are persisted to `claude_session_ids.json` and restored on startup. + +--- + +## AG-UI Protocol Translation (`ag_ui_claude_sdk/adapter.py`) + +`ClaudeAgentAdapter._stream_claude_sdk()` consumes Claude SDK messages and emits AG-UI events: + +| Claude SDK message | AG-UI event(s) emitted | +|---|---| +| `StreamEvent(type=message_start)` | (starts tracking `current_message_id`) | +| `StreamEvent(type=content_block_start, block_type=thinking)` | `ReasoningStartEvent`, `ReasoningMessageStartEvent` | +| `StreamEvent(type=content_block_delta, delta_type=thinking_delta)` | `ReasoningMessageContentEvent` | +| `StreamEvent(type=content_block_start, block_type=tool_use)` | `ToolCallStartEvent` | +| `StreamEvent(type=content_block_delta, delta_type=input_json_delta)` | `ToolCallArgsEvent` | +| `StreamEvent(type=content_block_stop)` for tool | `ToolCallEndEvent` (or halt if frontend tool) | +| `StreamEvent(type=content_block_delta, delta_type=text_delta)` | `TextMessageStartEvent` (first chunk), `TextMessageContentEvent` | +| `StreamEvent(type=message_stop)` | `TextMessageEndEvent` | +| `AssistantMessage` (non-streamed fallback) | accumulated into `run_messages` | +| `ToolResultBlock` | `ToolCallEndEvent` + `ToolCallResultEvent` | +| `SystemMessage` | `TextMessageStart/Content/End` | +| `ResultMessage` | captured as `_last_result_data` for `RunFinishedEvent` | +| End of stream | `MessagesSnapshotEvent` (full conversation snapshot) | + +The entire run is wrapped: `RunStartedEvent` → ... → `RunFinishedEvent` (or `RunErrorEvent`). + +--- + +## gRPC Transport Detail (`bridges/claude/grpc_transport.py`) + +### `GRPCSessionListener` + +Pod-lifetime background component. One instance per session, started in the lifespan before the initial prompt. + +``` + start() + │ + ├── AmbientGRPCClient.from_env() + └── asyncio.create_task(_listen_loop()) + │ + └── _listen_loop() [async, event loop] + │ + ├── ThreadPoolExecutor(max_workers=1) + │ └── _watch_in_thread() [blocking, thread] + │ ├── client.session_messages.watch(session_id, after_seq=N) + │ ├── loop.call_soon_threadsafe(ready.set) + │ └── for msg in stream: + │ asyncio.run_coroutine_threadsafe(msg_queue.put(msg), loop) + │ + └── while True: + msg = await msg_queue.get() + if msg.event_type == "user": + await _handle_user_message(msg) + # reconnects with backoff on stream end or error +``` + +**Reconnect logic:** when the gRPC stream ends (server-side close or network error), `_listen_loop` reconnects with exponential backoff (1s → 30s). `after_seq=last_seq` ensures no messages are replayed. + +### `_handle_user_message` + +Drives one complete bridge turn per inbound user message: + +``` + _handle_user_message(msg) + │ + ├── parse msg.payload as RunnerInput (fallback: raw string as content) + ├── runner_input.to_run_agent_input() → RunAgentInput + ├── capture run_queue = active_streams.get(thread_id) + ├── GRPCMessageWriter(session_id, run_id, grpc_client) + │ + ├── async for event in bridge.run(input_data): + │ ├── active_streams.get(thread_id).put_nowait(event) → SSE tap + │ └── writer.consume(event) → DB writer + │ + ├── on exception: + │ _synthesize_run_error(thread_id, error, active_streams, writer) + │ ├── put RunErrorEvent into SSE queue + │ └── asyncio.ensure_future(writer._write_message(status="error")) + │ + └── finally: + if run_queue is not None and active_streams.get(thread_id) is run_queue: + active_streams.pop(thread_id) ← identity-safe cleanup +``` + +### `GRPCMessageWriter` + +Per-turn consumer. Accumulates `MESSAGES_SNAPSHOT` content (each snapshot is a complete replacement of the conversation). On `RUN_FINISHED` or `RUN_ERROR`, pushes one `PushSessionMessage(event_type="assistant")` to the control plane via `run_in_executor` (non-blocking). + +``` + consume(event) + │ + ├── MESSAGES_SNAPSHOT → self._accumulated_messages = [...] + ├── RUN_FINISHED → _write_message(status="completed") + └── RUN_ERROR → _write_message(status="error") + + _write_message(status) + │ + └── run_in_executor(None, _do_push) + └── client.session_messages.push( + session_id, + event_type="assistant", + payload={"run_id", "status", "messages"} + ) +``` + +--- + +## Interrupts (`endpoints/interrupt.py`, `bridges/claude/bridge.py`, `session.py`) + +HTTP trigger: `POST /interrupt` with optional `{ "thread_id": "..." }` body. + +Flow: +1. `interrupt_run()` endpoint → `bridge.interrupt(thread_id)` +2. `ClaudeBridge.interrupt()` → looks up `SessionWorker` → `worker.interrupt()` +3. `SessionWorker.interrupt()` → `self._client.interrupt()` on `ClaudeSDKClient` + +The SDK client's interrupt propagates to the Claude CLI subprocess (signal or stdin close), which stops generation mid-stream. The output queue drains and `None` is eventually put on it, causing `worker.query()` to return. + +**Frontend tool halt:** not triggered by HTTP — the adapter sets `self._halted = True` when Claude calls a frontend tool (e.g. `AskUserQuestion`). After the stream ends, `ClaudeBridge.run()` calls `worker.interrupt()` automatically to prevent the SDK from auto-approving the pending tool call. + +**Observability:** `bridge.interrupt()` calls `self._obs.record_interrupt()` if tracing is enabled. + +--- + +## Queue Draining + +No explicit drain operation. The queue drains through normal flow: + +1. **Normal completion:** `_run()` puts all response messages then `None`. `worker.query()` yields until `None`, then returns. +2. **Interrupt:** SDK stops generation. `async for` ends. `None` is put in the `finally` block. `worker.query()` returns. +3. **Worker error:** `WorkerError` then `None`. `worker.query()` raises, propagates through `bridge.run()` → `event_stream()` → `RunErrorEvent`. +4. **Worker death:** `SessionManager.get_or_create()` detects `worker.is_alive == False` on the next request, destroys the dead worker, creates a fresh one using `--resume`. + +Per-thread lock: `asyncio.Lock` per thread prevents a second request from being processed while the first is still draining. Lock is held for the entire duration of `worker.query()`. + +--- + +## How New Messages Are Added + +**Normal turn (HTTP path):** +1. Frontend sends `POST /agui/run` via backend proxy with `RunnerInput` JSON +2. `run_agent()` endpoint creates `RunAgentInput`, calls `bridge.run(input_data)` +3. `ClaudeBridge.run()` calls `process_messages(input_data)` to extract the last user message +4. `worker.query(user_msg)` puts `(user_msg, session_id, output_queue)` on the input queue +5. Background worker picks it up, sends to Claude CLI, streams responses back + +**Normal turn (gRPC path):** +1. Control plane pushes `PushSessionMessage(event_type="user")` to the watch stream +2. `GRPCSessionListener._handle_user_message()` parses payload, calls `bridge.run(input_data)` directly +3. Events are fanned out to the SSE tap queue and `GRPCMessageWriter` + +**Auto-prompt:** +- HTTP mode: `_push_initial_prompt_via_http()` POSTs to the backend run endpoint with `metadata.hidden=True`, `metadata.autoSent=True` +- gRPC mode: `_push_initial_prompt_via_grpc()` pushes a `PushSessionMessage(event_type="user")` directly; listener handles it identically to any other user message + +**Tool results (frontend HITL tools):** +- Claude halts; user responds; frontend sends next message containing tool result +- On next `run()`, adapter detects `previous_halted_tool_call_id` and emits `ToolCallResultEvent` before starting the new turn + +**Tool results (backend MCP tools):** +- Handled internally by Claude CLI — SDK calls MCP server in-process, gets result, continues without HTTP round-trip + +--- + +## MCP Tools (`bridges/claude/mcp.py`, `tools.py`, `corrections.py`) + +Three categories of platform-injected MCP servers: + +| Server | Tool | Purpose | +|---|---|---| +| `session` | `refresh_credentials` | Lets Claude refresh GitHub/Google/Jira tokens mid-run | +| `rubric` | `evaluate_rubric` | Scores Claude's output against a rubric; logs to Langfuse | +| `corrections` | `log_correction` | Logs human corrections to Langfuse for the feedback loop | + +Plus external MCP servers loaded from `.mcp.json` in the workspace. All passed to `ClaudeAgentOptions.mcp_servers`. Wildcard permissions (`mcp__session__*`, etc.) added to `allowed_tools`. + +--- + +## Tracing Middleware (`middleware/tracing.py`) + +A transparent async generator wrapper around the event stream. If `obs` (Langfuse `ObservabilityManager`) is present: +- `obs.track_agui_event(event)` called for each event (tracks turns, tool calls, usage) +- Once a trace ID is available (after first assistant message), emits `CustomEvent("ambient:langfuse_trace", {"traceId": ...})` — frontend uses this to link feedback to the trace +- On exception: `obs.cleanup_on_error(exc)` marks the Langfuse trace as errored +- On normal completion: `obs.finalize_event_tracking()` + +--- + +## Feedback (`endpoints/feedback.py`) + +`POST /feedback` accepts META events with `metaType: thumbs_up | thumbs_down`. Resolves the Langfuse trace ID (from payload or from `bridge.obs.last_trace_id`), creates a BOOLEAN score in Langfuse. Returns a RAW event for the backend to persist. + +--- + +## `mark_dirty()` and Adapter Rebuilds + +When repos or workflows are added at runtime (`POST /repos` or `POST /workflow`), the endpoint calls `bridge.mark_dirty()`. This: + +1. Sets `self._ready = False` (triggers `_setup_platform()` on next run) +2. Sets `self._adapter = None` (triggers `_ensure_adapter()` on next run) +3. Captures all current session IDs → `self._saved_session_ids` +4. Async-shuts down the current `SessionManager` (fire-and-forget) +5. On next `run()`: full re-init with new workspace/MCP config, existing conversations resumed via `--resume ` diff --git a/components/runners/ambient-runner/pyproject.toml b/components/runners/ambient-runner/pyproject.toml index ebccef560..352a1ded0 100644 --- a/components/runners/ambient-runner/pyproject.toml +++ b/components/runners/ambient-runner/pyproject.toml @@ -16,6 +16,9 @@ dependencies = [ "aiohttp>=3.13.4", "requests>=2.33.0", "pyjwt>=2.11.0", + "cryptography>=42.0.0", + "grpcio>=1.60.0", + "protobuf>=4.25.0", ] [project.optional-dependencies] diff --git a/components/runners/ambient-runner/tests/test_app_initial_prompt.py b/components/runners/ambient-runner/tests/test_app_initial_prompt.py new file mode 100644 index 000000000..4b93b2ab6 --- /dev/null +++ b/components/runners/ambient-runner/tests/test_app_initial_prompt.py @@ -0,0 +1,528 @@ +"""Unit tests for app.py initial prompt dispatch functions. + +Coverage targets: +- _push_initial_prompt_via_grpc: happy path, push raises (client still closed), + None result, from_env error, offloaded to executor (non-blocking) +- _push_initial_prompt_via_http: happy path, missing env vars bail, bot token, + no token, retry-on-failure (8 attempts), non-transient error early return +- _auto_execute_initial_prompt: routes to gRPC when grpc_url set, + routes to HTTP when grpc_url empty, routes to HTTP when grpc_url defaulted +- create_ambient_app lifespan: gRPC OFF path (no AMBIENT_GRPC_ENABLED env), + gRPC ON path (AMBIENT_GRPC_ENABLED=true + AMBIENT_GRPC_URL) +""" + +import asyncio +import json +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from ambient_runner.app import ( + _auto_execute_initial_prompt, + _push_initial_prompt_via_grpc, + _push_initial_prompt_via_http, +) + + +# --------------------------------------------------------------------------- +# _push_initial_prompt_via_grpc +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestPushInitialPromptViaGRPC: + async def test_pushes_user_event_with_prompt_content(self): + mock_result = MagicMock() + mock_result.seq = 42 + + mock_client = MagicMock() + mock_client.session_messages.push.return_value = mock_result + mock_client.close = MagicMock() + + mock_cls = MagicMock() + mock_cls.from_env.return_value = mock_client + + with patch("ambient_runner._grpc_client.AmbientGRPCClient", mock_cls): + await _push_initial_prompt_via_grpc("hello world", "sess-1") + + mock_client.session_messages.push.assert_called_once() + call = mock_client.session_messages.push.call_args + assert call[0][0] == "sess-1" + assert call[1]["event_type"] == "user" + payload = json.loads(call[1]["payload"]) + assert payload["threadId"] == "sess-1" + assert "runId" in payload + assert len(payload["messages"]) == 1 + assert payload["messages"][0]["role"] == "user" + assert payload["messages"][0]["content"] == "hello world" + + async def test_closes_client_after_push(self): + mock_result = MagicMock() + mock_result.seq = 1 + mock_client = MagicMock() + mock_client.session_messages.push.return_value = mock_result + mock_client.close = MagicMock() + + mock_cls = MagicMock() + mock_cls.from_env.return_value = mock_client + + with patch("ambient_runner._grpc_client.AmbientGRPCClient", mock_cls): + await _push_initial_prompt_via_grpc("prompt", "sess-close") + + mock_client.close.assert_called_once() + + async def test_closes_client_even_when_push_raises(self): + """client.close() must be called in finally even if push() raises.""" + mock_client = MagicMock() + mock_client.session_messages.push.side_effect = RuntimeError("rpc failed") + mock_client.close = MagicMock() + + mock_cls = MagicMock() + mock_cls.from_env.return_value = mock_client + + with patch("ambient_runner._grpc_client.AmbientGRPCClient", mock_cls): + await _push_initial_prompt_via_grpc("prompt", "sess-push-raises") + + mock_client.close.assert_called_once() + + async def test_does_not_raise_on_grpc_error(self): + mock_cls = MagicMock() + mock_cls.from_env.side_effect = RuntimeError("connection refused") + + with patch("ambient_runner._grpc_client.AmbientGRPCClient", mock_cls): + await _push_initial_prompt_via_grpc("prompt", "sess-err") + + async def test_handles_none_push_result(self): + mock_client = MagicMock() + mock_client.session_messages.push.return_value = None + mock_client.close = MagicMock() + + mock_cls = MagicMock() + mock_cls.from_env.return_value = mock_client + + with patch("ambient_runner._grpc_client.AmbientGRPCClient", mock_cls): + await _push_initial_prompt_via_grpc("prompt", "sess-none") + + mock_client.close.assert_called_once() + + async def test_push_offloaded_to_executor(self): + """The blocking push must be offloaded via run_in_executor, not called inline.""" + mock_client = MagicMock() + mock_client.session_messages.push.return_value = MagicMock(seq=1) + mock_client.close = MagicMock() + + mock_cls = MagicMock() + mock_cls.from_env.return_value = mock_client + + executor_calls = [] + real_loop = asyncio.get_event_loop() + + original_run_in_executor = real_loop.run_in_executor + + async def capturing_executor(executor, fn, *args): + executor_calls.append(fn) + return await original_run_in_executor(executor, fn, *args) + + with ( + patch("ambient_runner._grpc_client.AmbientGRPCClient", mock_cls), + patch.object(real_loop, "run_in_executor", side_effect=capturing_executor), + ): + await _push_initial_prompt_via_grpc("prompt", "sess-executor") + + assert len(executor_calls) == 1 + + +# --------------------------------------------------------------------------- +# _push_initial_prompt_via_http +# --------------------------------------------------------------------------- + + +def _make_aiohttp_session(status: int = 200, text: str = "ok"): + """Build a mock aiohttp.ClientSession that works with async-with on both + the session itself and session.post(...).""" + mock_resp = AsyncMock() + mock_resp.status = status + mock_resp.text = AsyncMock(return_value=text) + + post_ctx = MagicMock() + post_ctx.__aenter__ = AsyncMock(return_value=mock_resp) + post_ctx.__aexit__ = AsyncMock(return_value=False) + + mock_session = MagicMock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + mock_session.post = MagicMock(return_value=post_ctx) + + return mock_session + + +@pytest.mark.asyncio +class TestPushInitialPromptViaHTTP: + async def test_posts_to_backend_url(self): + mock_session = _make_aiohttp_session() + + with ( + patch("aiohttp.ClientSession", return_value=mock_session), + patch.dict( + os.environ, + { + "INITIAL_PROMPT_DELAY_SECONDS": "0", + "BACKEND_API_URL": "http://backend:8080", + "PROJECT_NAME": "ambient-code", + }, + ), + ): + await _push_initial_prompt_via_http("hi", "sess-http") + + mock_session.post.assert_called_once() + call_url = mock_session.post.call_args[0][0] + assert "backend:8080" in call_url + assert "ambient-code" in call_url + assert "sess-http" in call_url + + async def test_bails_early_when_backend_url_missing(self): + """If BACKEND_API_URL is not set, function logs error and returns without posting.""" + mock_session = _make_aiohttp_session() + + env = { + "INITIAL_PROMPT_DELAY_SECONDS": "0", + "PROJECT_NAME": "ambient-code", + } + with ( + patch("aiohttp.ClientSession", return_value=mock_session), + patch.dict(os.environ, env, clear=True), + ): + await _push_initial_prompt_via_http("hi", "sess-no-backend") + + mock_session.post.assert_not_called() + + async def test_bails_early_when_project_name_missing(self): + """If PROJECT_NAME is not set, function logs error and returns without posting.""" + mock_session = _make_aiohttp_session() + + env = { + "INITIAL_PROMPT_DELAY_SECONDS": "0", + "BACKEND_API_URL": "http://backend:8080", + } + with ( + patch("aiohttp.ClientSession", return_value=mock_session), + patch.dict(os.environ, env, clear=True), + ): + await _push_initial_prompt_via_http("hi", "sess-no-project") + + mock_session.post.assert_not_called() + + async def test_includes_bot_token_in_auth_header_when_present(self): + mock_session = _make_aiohttp_session() + + with ( + patch("aiohttp.ClientSession", return_value=mock_session), + patch.dict( + os.environ, + { + "BOT_TOKEN": "tok-abc", + "INITIAL_PROMPT_DELAY_SECONDS": "0", + "BACKEND_API_URL": "http://backend:8080", + "PROJECT_NAME": "ambient-code", + }, + ), + ): + await _push_initial_prompt_via_http("hi", "sess-token") + + headers = mock_session.post.call_args[1]["headers"] + assert headers.get("Authorization") == "Bearer tok-abc" + + async def test_no_auth_header_when_bot_token_absent(self): + mock_session = _make_aiohttp_session() + + env_without_token = {k: v for k, v in os.environ.items() if k != "BOT_TOKEN"} + env_without_token["INITIAL_PROMPT_DELAY_SECONDS"] = "0" + env_without_token["BACKEND_API_URL"] = "http://backend:8080" + env_without_token["PROJECT_NAME"] = "ambient-code" + with ( + patch("aiohttp.ClientSession", return_value=mock_session), + patch.dict(os.environ, env_without_token, clear=True), + ): + await _push_initial_prompt_via_http("hi", "sess-no-token") + + headers = mock_session.post.call_args[1]["headers"] + assert "Authorization" not in headers + + async def test_returns_after_max_retries_on_failure(self): + mock_session = MagicMock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + mock_session.post = MagicMock(side_effect=Exception("connection refused")) + + with ( + patch("aiohttp.ClientSession", return_value=mock_session), + patch("asyncio.sleep", new_callable=AsyncMock), + patch.dict( + os.environ, + { + "INITIAL_PROMPT_DELAY_SECONDS": "0", + "BACKEND_API_URL": "http://backend:8080", + "PROJECT_NAME": "ambient-code", + }, + ), + ): + await _push_initial_prompt_via_http("hi", "sess-retry") + + assert mock_session.post.call_count == 8 + + async def test_non_transient_error_exits_early_without_full_retries(self): + """A 400 response without 'not available' body should not exhaust all retries.""" + mock_session = _make_aiohttp_session(status=400, text="bad request") + + with ( + patch("aiohttp.ClientSession", return_value=mock_session), + patch("asyncio.sleep", new_callable=AsyncMock), + patch.dict( + os.environ, + { + "INITIAL_PROMPT_DELAY_SECONDS": "0", + "BACKEND_API_URL": "http://backend:8080", + "PROJECT_NAME": "ambient-code", + }, + ), + ): + await _push_initial_prompt_via_http("hi", "sess-400") + + assert mock_session.post.call_count == 1 + + async def test_not_available_body_triggers_retry(self): + """'not available' in response body should retry up to max retries.""" + mock_session = _make_aiohttp_session(status=503, text="runner not available") + + with ( + patch("aiohttp.ClientSession", return_value=mock_session), + patch("asyncio.sleep", new_callable=AsyncMock), + patch.dict( + os.environ, + { + "INITIAL_PROMPT_DELAY_SECONDS": "0", + "BACKEND_API_URL": "http://backend:8080", + "PROJECT_NAME": "ambient-code", + }, + ), + ): + await _push_initial_prompt_via_http("hi", "sess-not-available") + + assert mock_session.post.call_count == 8 + + async def test_uses_agentic_session_namespace_fallback_for_project(self): + """When PROJECT_NAME is missing but AGENTIC_SESSION_NAMESPACE is set, uses that.""" + mock_session = _make_aiohttp_session() + + env = { + "INITIAL_PROMPT_DELAY_SECONDS": "0", + "BACKEND_API_URL": "http://backend:8080", + "AGENTIC_SESSION_NAMESPACE": "ns-fallback", + } + with ( + patch("aiohttp.ClientSession", return_value=mock_session), + patch.dict(os.environ, env, clear=True), + ): + await _push_initial_prompt_via_http("hi", "sess-ns") + + mock_session.post.assert_called_once() + call_url = mock_session.post.call_args[0][0] + assert "ns-fallback" in call_url + + +# --------------------------------------------------------------------------- +# _auto_execute_initial_prompt — routing: gRPC ON vs OFF +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestAutoExecuteInitialPrompt: + async def test_skips_push_when_grpc_url_set(self): + with ( + patch( + "ambient_runner.app._push_initial_prompt_via_grpc", + new_callable=AsyncMock, + ) as mock_grpc, + patch( + "ambient_runner.app._push_initial_prompt_via_http", + new_callable=AsyncMock, + ) as mock_http, + patch.dict(os.environ, {"INITIAL_PROMPT_DELAY_SECONDS": "0"}), + ): + await _auto_execute_initial_prompt( + "hello", "sess-1", grpc_url="localhost:9000" + ) + + mock_grpc.assert_not_awaited() + mock_http.assert_not_awaited() + + async def test_routes_to_http_when_no_grpc_url(self): + with ( + patch( + "ambient_runner.app._push_initial_prompt_via_grpc", + new_callable=AsyncMock, + ) as mock_grpc, + patch( + "ambient_runner.app._push_initial_prompt_via_http", + new_callable=AsyncMock, + ) as mock_http, + patch.dict(os.environ, {"INITIAL_PROMPT_DELAY_SECONDS": "0"}), + ): + await _auto_execute_initial_prompt("hello", "sess-1", grpc_url="") + + mock_http.assert_awaited_once_with("hello", "sess-1") + mock_grpc.assert_not_awaited() + + async def test_routes_to_http_when_grpc_url_default(self): + with ( + patch( + "ambient_runner.app._push_initial_prompt_via_grpc", + new_callable=AsyncMock, + ) as mock_grpc, + patch( + "ambient_runner.app._push_initial_prompt_via_http", + new_callable=AsyncMock, + ) as mock_http, + patch.dict(os.environ, {"INITIAL_PROMPT_DELAY_SECONDS": "0"}), + ): + await _auto_execute_initial_prompt("hello", "sess-1") + + mock_http.assert_awaited_once() + mock_grpc.assert_not_awaited() + + +# --------------------------------------------------------------------------- +# create_ambient_app lifespan — gRPC OFF path (no env vars) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCreateAmbientAppLifespanGRPCOff: + """Verify gRPC listener is NOT started when AMBIENT_GRPC_ENABLED is absent.""" + + async def test_grpc_listener_not_started_without_env(self): + from ambient_runner.app import create_ambient_app + from ambient_runner.bridges.claude.bridge import ClaudeBridge + + bridge = ClaudeBridge() + bridge._active_streams = {} + + env_overrides = {} + for key in ("AMBIENT_GRPC_ENABLED", "AMBIENT_GRPC_URL", "INITIAL_PROMPT"): + env_overrides[key] = "" + + app = create_ambient_app(bridge) + + with ( + patch.dict(os.environ, env_overrides), + patch.object( + bridge, "start_grpc_listener", new_callable=AsyncMock + ) as mock_start, + patch.object(bridge, "shutdown", new_callable=AsyncMock), + ): + async with app.router.lifespan_context(app): + pass + + mock_start.assert_not_called() + + async def test_grpc_listener_not_started_when_only_url_set(self): + """URL alone (without AMBIENT_GRPC_ENABLED=true) must not start listener.""" + from ambient_runner.app import create_ambient_app + from ambient_runner.bridges.claude.bridge import ClaudeBridge + + bridge = ClaudeBridge() + bridge._active_streams = {} + + app = create_ambient_app(bridge) + + with ( + patch.dict( + os.environ, + {"AMBIENT_GRPC_URL": "localhost:9000", "INITIAL_PROMPT": ""}, + clear=False, + ), + patch.object( + bridge, "start_grpc_listener", new_callable=AsyncMock + ) as mock_start, + patch.object(bridge, "shutdown", new_callable=AsyncMock), + ): + async with app.router.lifespan_context(app): + pass + + mock_start.assert_not_called() + + +# --------------------------------------------------------------------------- +# create_ambient_app lifespan — gRPC ON path +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCreateAmbientAppLifespanGRPCOn: + """Verify gRPC listener IS started when AMBIENT_GRPC_ENABLED=true and URL set.""" + + async def test_grpc_listener_started_when_both_env_vars_set(self): + from ambient_runner.app import create_ambient_app + from ambient_runner.bridges.claude.bridge import ClaudeBridge + + bridge = ClaudeBridge() + bridge._active_streams = {} + + mock_listener = MagicMock() + mock_listener.ready = asyncio.Event() + mock_listener.ready.set() + + app = create_ambient_app(bridge) + + async def _mock_start_grpc_listener(grpc_url): + bridge._grpc_listener = mock_listener + + with ( + patch.dict( + os.environ, + { + "AMBIENT_GRPC_ENABLED": "true", + "AMBIENT_GRPC_URL": "localhost:9000", + "INITIAL_PROMPT": "", + "SESSION_ID": "sess-grpc-on", + }, + ), + patch.object( + bridge, "start_grpc_listener", side_effect=_mock_start_grpc_listener + ) as mock_start, + patch.object(bridge, "shutdown", new_callable=AsyncMock), + ): + async with app.router.lifespan_context(app): + pass + + mock_start.assert_called_once_with("localhost:9000") + + async def test_grpc_listener_not_started_when_enabled_but_url_empty(self): + """AMBIENT_GRPC_ENABLED=true but AMBIENT_GRPC_URL="" must not start listener.""" + from ambient_runner.app import create_ambient_app + from ambient_runner.bridges.claude.bridge import ClaudeBridge + + bridge = ClaudeBridge() + bridge._active_streams = {} + + app = create_ambient_app(bridge) + + with ( + patch.dict( + os.environ, + { + "AMBIENT_GRPC_ENABLED": "true", + "AMBIENT_GRPC_URL": "", + "INITIAL_PROMPT": "", + }, + ), + patch.object( + bridge, "start_grpc_listener", new_callable=AsyncMock + ) as mock_start, + patch.object(bridge, "shutdown", new_callable=AsyncMock), + ): + async with app.router.lifespan_context(app): + pass + + mock_start.assert_not_called() diff --git a/components/runners/ambient-runner/tests/test_auto_push.py b/components/runners/ambient-runner/tests/test_auto_push.py index 744fab786..7ce0225e3 100644 --- a/components/runners/ambient-runner/tests/test_auto_push.py +++ b/components/runners/ambient-runner/tests/test_auto_push.py @@ -317,38 +317,6 @@ def test_prompt_includes_multiple_autopush_repos(self): # repo3 should not be in git instructions since autoPush=false # (but it will be in the general repos list) - def test_prompt_includes_git_safety_with_repos(self): - """Git safety guardrails are included when repos are present.""" - repos_cfg = [ - { - "name": "my-repo", - "url": "https://github.com/owner/my-repo.git", - "branch": "main", - "autoPush": False, - } - ] - prompt = build_workspace_context_prompt( - repos_cfg=repos_cfg, - workflow_name=None, - artifacts_path="artifacts", - ambient_config={}, - workspace_path="/workspace", - ) - assert "Git Safety" in prompt - assert "NEVER embed tokens" in prompt - assert "Do NOT autonomously escalate" in prompt - - def test_prompt_excludes_git_safety_without_repos(self): - """Git safety instructions are excluded when no repos are present.""" - prompt = build_workspace_context_prompt( - repos_cfg=[], - workflow_name=None, - artifacts_path="artifacts", - ambient_config={}, - workspace_path="/workspace", - ) - assert "Git Safety" not in prompt - def test_prompt_without_repos(self): """Test prompt generation when no repos are configured.""" prompt = build_workspace_context_prompt( diff --git a/components/runners/ambient-runner/tests/test_bridge_claude.py b/components/runners/ambient-runner/tests/test_bridge_claude.py index 1f1a98bed..5f9377ff6 100644 --- a/components/runners/ambient-runner/tests/test_bridge_claude.py +++ b/components/runners/ambient-runner/tests/test_bridge_claude.py @@ -1,5 +1,17 @@ -"""Unit tests for PlatformBridge ABC and ClaudeBridge.""" - +"""Unit tests for PlatformBridge ABC and ClaudeBridge. + +Coverage targets: +- ClaudeBridge initial gRPC state (None listener, empty active_streams) +- shutdown stops listener / safe when None +- start_grpc_listener creates and starts GRPCSessionListener with correct args, + guards against duplicate starts, raises when no context +- inject_message raises NotImplementedError on PlatformBridge base +- PlatformBridge ABC contract +- FrameworkCapabilities dataclass defaults +- ClaudeBridge capabilities, lifecycle, run guards, shutdown, observability setup +""" + +import asyncio from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -15,6 +27,205 @@ from ambient_runner.platform.context import RunnerContext +# ------------------------------------------------------------------ +# ClaudeBridge gRPC transport tests +# ------------------------------------------------------------------ + + +class TestClaudeBridgeGRPCState: + """Verify gRPC state is initialized correctly on ClaudeBridge.""" + + def test_grpc_listener_none_by_default(self): + bridge = ClaudeBridge() + assert bridge._grpc_listener is None + + def test_active_streams_empty_dict_by_default(self): + bridge = ClaudeBridge() + assert bridge._active_streams == {} + assert isinstance(bridge._active_streams, dict) + + +@pytest.mark.asyncio +class TestClaudeBridgeShutdownGRPC: + """Test shutdown stops the gRPC listener when present.""" + + async def test_shutdown_stops_grpc_listener(self): + bridge = ClaudeBridge() + mock_listener = AsyncMock() + bridge._grpc_listener = mock_listener + await bridge.shutdown() + mock_listener.stop.assert_awaited_once() + + async def test_shutdown_without_grpc_listener_does_not_raise(self): + bridge = ClaudeBridge() + assert bridge._grpc_listener is None + await bridge.shutdown() + + +@pytest.mark.asyncio +class TestClaudeBridgeStartGRPCListener: + """Test the dedicated start_grpc_listener hook (separate from _setup_platform).""" + + async def test_start_creates_listener_with_correct_args(self): + bridge = ClaudeBridge() + ctx = RunnerContext(session_id="sess-grpc", workspace_path="/workspace") + bridge.set_context(ctx) + + mock_listener_instance = MagicMock() + mock_listener_instance.start = MagicMock() + mock_listener_cls = MagicMock(return_value=mock_listener_instance) + + with patch( + "ambient_runner.bridges.claude.grpc_transport.GRPCSessionListener", + mock_listener_cls, + ): + await bridge.start_grpc_listener("localhost:9000") + + mock_listener_instance.start.assert_called_once() + assert bridge._grpc_listener is mock_listener_instance + + async def test_start_raises_without_context(self): + bridge = ClaudeBridge() + with pytest.raises(RuntimeError, match="context not set"): + await bridge.start_grpc_listener("localhost:9000") + + async def test_duplicate_start_is_idempotent(self): + bridge = ClaudeBridge() + ctx = RunnerContext(session_id="sess-dup", workspace_path="/workspace") + bridge.set_context(ctx) + + first_listener = MagicMock() + first_listener.start = MagicMock() + bridge._grpc_listener = first_listener + + mock_listener_cls = MagicMock() + with patch( + "ambient_runner.bridges.claude.grpc_transport.GRPCSessionListener", + mock_listener_cls, + ): + await bridge.start_grpc_listener("localhost:9000") + + mock_listener_cls.assert_not_called() + assert bridge._grpc_listener is first_listener + + async def test_listener_started_and_ready_event_available(self): + bridge = ClaudeBridge() + ctx = RunnerContext(session_id="sess-ready", workspace_path="/workspace") + bridge.set_context(ctx) + + ready_event = asyncio.Event() + ready_event.set() + + mock_listener = MagicMock() + mock_listener.ready = ready_event + mock_listener.start = MagicMock() + mock_listener_cls = MagicMock(return_value=mock_listener) + + with patch( + "ambient_runner.bridges.claude.grpc_transport.GRPCSessionListener", + mock_listener_cls, + ): + await bridge.start_grpc_listener("localhost:9000") + + assert bridge._grpc_listener.ready.is_set() + + +@pytest.mark.asyncio +class TestClaudeBridgeStartGRPCListenerRealPath: + """start_grpc_listener only patches GRPCSessionListener — no _setup_platform mock.""" + + async def test_listener_class_receives_bridge_and_session_id(self): + """Verify GRPCSessionListener is constructed with the correct bridge and session_id.""" + bridge = ClaudeBridge() + ctx = RunnerContext(session_id="sess-realpath", workspace_path="/workspace") + bridge.set_context(ctx) + + captured_kwargs = {} + + def capturing_init(self_inner, *, bridge, session_id, grpc_url): + captured_kwargs["bridge"] = bridge + captured_kwargs["session_id"] = session_id + captured_kwargs["grpc_url"] = grpc_url + self_inner._bridge = bridge + self_inner._session_id = session_id + self_inner._grpc_url = grpc_url + self_inner._grpc_client = None + self_inner.ready = asyncio.Event() + self_inner._task = None + + mock_listener_cls = MagicMock() + mock_instance = MagicMock() + mock_instance.start = MagicMock() + mock_listener_cls.return_value = mock_instance + + with patch( + "ambient_runner.bridges.claude.grpc_transport.GRPCSessionListener", + mock_listener_cls, + ): + await bridge.start_grpc_listener("grpc.example.com:9000") + + call_kwargs = mock_listener_cls.call_args[1] + assert call_kwargs["bridge"] is bridge + assert call_kwargs["session_id"] == "sess-realpath" + assert call_kwargs["grpc_url"] == "grpc.example.com:9000" + mock_instance.start.assert_called_once() + + async def test_listener_not_started_without_context(self): + """start_grpc_listener raises RuntimeError when no context is set.""" + bridge = ClaudeBridge() + mock_listener_cls = MagicMock() + + with patch( + "ambient_runner.bridges.claude.grpc_transport.GRPCSessionListener", + mock_listener_cls, + ): + with pytest.raises(RuntimeError, match="context not set"): + await bridge.start_grpc_listener("grpc.example.com:9000") + + mock_listener_cls.assert_not_called() + + +# ------------------------------------------------------------------ +# inject_message — base class raises NotImplementedError +# ------------------------------------------------------------------ + + +@pytest.mark.asyncio +class TestPlatformBridgeInjectMessage: + """inject_message must raise NotImplementedError on the base class and any + subclass that doesn't override it.""" + + async def test_base_class_raises_not_implemented(self): + class MinimalBridge(PlatformBridge): + def capabilities(self): + return FrameworkCapabilities(framework="test") + + async def run(self, input_data): + yield + + async def interrupt(self, thread_id=None): + pass + + bridge = MinimalBridge() + with pytest.raises(NotImplementedError): + await bridge.inject_message("sess-1", "user", "{}") + + async def test_error_includes_bridge_class_name(self): + class MyBridge(PlatformBridge): + def capabilities(self): + return FrameworkCapabilities(framework="test") + + async def run(self, input_data): + yield + + async def interrupt(self, thread_id=None): + pass + + bridge = MyBridge() + with pytest.raises(NotImplementedError, match="MyBridge"): + await bridge.inject_message("s1", "user", "{}") + + # ------------------------------------------------------------------ # PlatformBridge ABC tests # ------------------------------------------------------------------ @@ -243,7 +454,6 @@ async def test_forwards_workflow_env_vars_to_initialize(self): "ambient_runner.observability.ObservabilityManager", return_value=mock_obs_instance, ) as mock_obs_cls: - await setup_bridge_observability(ctx, "claude-sonnet-4-5") mock_obs_cls.assert_called_once() @@ -275,7 +485,6 @@ async def test_forwards_empty_defaults_when_workflow_vars_unset(self): "ambient_runner.observability.ObservabilityManager", return_value=mock_obs_instance, ): - await setup_bridge_observability(ctx, "claude-sonnet-4-5") call_kwargs = mock_obs_instance.initialize.call_args[1] diff --git a/components/runners/ambient-runner/tests/test_e2e_api.py b/components/runners/ambient-runner/tests/test_e2e_api.py index a1b02e084..a2a7b6140 100644 --- a/components/runners/ambient-runner/tests/test_e2e_api.py +++ b/components/runners/ambient-runner/tests/test_e2e_api.py @@ -401,7 +401,6 @@ def test_interrupt_returns_structured_error(self, client): data = resp.json() assert "detail" in data - def test_run_endpoint_schema_validation(self, client): """Various payload validation checks.""" # Missing messages entirely diff --git a/components/runners/ambient-runner/tests/test_events_endpoint.py b/components/runners/ambient-runner/tests/test_events_endpoint.py new file mode 100644 index 000000000..27f96d974 --- /dev/null +++ b/components/runners/ambient-runner/tests/test_events_endpoint.py @@ -0,0 +1,323 @@ +"""Unit tests for GET /events/{thread_id} and GET /events/{thread_id}/wait. + +Coverage targets: +- Queue registration before streaming begins +- Identity-safe cleanup (only removes if queue is the same object) +- Duplicate registration warning (second connect logs warning, replaces queue) +- 503 when bridge has no _active_streams attribute +- MESSAGES_SNAPSHOT filtered from output +- Stream closes on RUN_FINISHED / RUN_ERROR +- Text events emitted +- /wait: 404 on timeout, 503 when no attr, streams when queue registered, + MESSAGES_SNAPSHOT filtered in wait path +- Real async producer: background task puts events into the actual registered + queue while the endpoint is streaming, verifying end-to-end delivery +""" + +import asyncio +from unittest.mock import MagicMock + +import httpx +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from ag_ui.core import EventType + +from ambient_runner.endpoints.events import router + +from tests.conftest import ( + make_run_finished, + make_text_content, + make_text_start, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_bridge(active_streams=None): + bridge = MagicMock() + bridge._active_streams = active_streams if active_streams is not None else {} + return bridge + + +def _make_app(bridge): + app = FastAPI() + app.state.bridge = bridge + app.include_router(router) + return app + + +def _make_client(bridge): + return TestClient(_make_app(bridge), raise_server_exceptions=False) + + +# --------------------------------------------------------------------------- +# GET /events/{thread_id} — 503 guard (sync, instant) +# --------------------------------------------------------------------------- + + +class TestEventsEndpointGuards: + def test_returns_503_when_bridge_has_no_active_streams(self): + bridge = MagicMock(spec=[]) + client = _make_client(bridge) + resp = client.get("/events/t-1") + assert resp.status_code == 503 + + def test_wait_returns_503_when_bridge_has_no_active_streams(self): + bridge = MagicMock(spec=[]) + client = _make_client(bridge) + resp = client.get("/events/t-1/wait") + assert resp.status_code == 503 + + def test_wait_returns_404_when_no_active_stream(self, monkeypatch): + monkeypatch.setenv("EVENTS_TAP_TIMEOUT_SEC", "0.05") + bridge = _make_bridge(active_streams={}) + client = _make_client(bridge) + resp = client.get("/events/missing-thread/wait") + assert resp.status_code == 404 + + +# --------------------------------------------------------------------------- +# GET /events/{thread_id} — async producer tests (real queue delivery) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestEventsEndpointAsyncDelivery: + """Use httpx.AsyncClient with ASGI transport to test real async queue delivery. + + The background producer task polls active_streams until the endpoint has + registered its own queue, then feeds events into it — exactly mimicking + how GRPCSessionListener would deliver events in production. + """ + + async def _stream_events( + self, app, path: str, active_streams: dict, events_to_put: list + ) -> str: + """Open SSE stream and concurrently feed events into the endpoint's registered queue.""" + collected = [] + + async def producer(): + deadline = asyncio.get_event_loop().time() + 3.0 + thread_id = path.split("/events/")[-1].split("/")[0] + while asyncio.get_event_loop().time() < deadline: + q = active_streams.get(thread_id) + if q is not None: + for ev in events_to_put: + await q.put(ev) + return + await asyncio.sleep(0.005) + + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + base_url="http://test", + ) as client: + producer_task = asyncio.create_task(producer()) + async with client.stream("GET", path) as resp: + assert resp.status_code == 200 + async for chunk in resp.aiter_bytes(): + collected.append(chunk.decode()) + await producer_task + + return "".join(collected) + + async def test_run_finished_closes_stream(self): + active_streams = {} + bridge = _make_bridge(active_streams=active_streams) + app = _make_app(bridge) + + body = await self._stream_events( + app, + "/events/t-async-fin", + active_streams, + [make_text_start(), make_run_finished()], + ) + assert "RUN_FINISHED" in body + + async def test_run_error_closes_stream(self): + active_streams = {} + bridge = _make_bridge(active_streams=active_streams) + app = _make_app(bridge) + + from ag_ui.core import RunErrorEvent + + run_error = RunErrorEvent(message="test error", code="TEST") + + body = await self._stream_events( + app, "/events/t-async-err", active_streams, [run_error] + ) + assert "RUN_ERROR" in body + + async def test_messages_snapshot_filtered(self): + active_streams = {} + bridge = _make_bridge(active_streams=active_streams) + app = _make_app(bridge) + + snapshot = MagicMock() + snapshot.type = EventType.MESSAGES_SNAPSHOT + + body = await self._stream_events( + app, "/events/t-async-snap", active_streams, [snapshot, make_run_finished()] + ) + assert "MESSAGES_SNAPSHOT" not in body + assert "RUN_FINISHED" in body + + async def test_text_events_delivered(self): + active_streams = {} + bridge = _make_bridge(active_streams=active_streams) + app = _make_app(bridge) + + body = await self._stream_events( + app, + "/events/t-async-text", + active_streams, + [make_text_start(), make_text_content(), make_run_finished()], + ) + assert "TEXT_MESSAGE_START" in body + assert "TEXT_MESSAGE_CONTENT" in body + + async def test_queue_removed_from_active_streams_after_stream_closes(self): + active_streams = {} + bridge = _make_bridge(active_streams=active_streams) + app = _make_app(bridge) + + await self._stream_events( + app, "/events/t-async-cleanup", active_streams, [make_run_finished()] + ) + assert "t-async-cleanup" not in active_streams + + async def test_identity_safe_cleanup_preserves_newer_queue(self): + """After stream closes, the endpoint must not remove a queue it didn't create.""" + active_streams = {} + bridge = _make_bridge(active_streams=active_streams) + app = _make_app(bridge) + + newer_queue = asyncio.Queue(maxsize=100) + + async def producer(): + deadline = asyncio.get_event_loop().time() + 3.0 + while asyncio.get_event_loop().time() < deadline: + q = active_streams.get("t-id-safe") + if q is not None: + active_streams["t-id-safe"] = newer_queue + await q.put(make_run_finished()) + return + await asyncio.sleep(0.005) + + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + base_url="http://test", + ) as client: + producer_task = asyncio.create_task(producer()) + async with client.stream("GET", "/events/t-id-safe") as resp: + async for _ in resp.aiter_bytes(): + pass + await producer_task + + assert active_streams.get("t-id-safe") is newer_queue + + +# --------------------------------------------------------------------------- +# GET /events/{thread_id}/wait — async variants +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestEventsWaitEndpointAsync: + async def test_streams_when_queue_pre_registered(self): + active_streams = {} + bridge = _make_bridge(active_streams=active_streams) + + q: asyncio.Queue = asyncio.Queue(maxsize=100) + await q.put(make_text_start()) + await q.put(make_run_finished()) + active_streams["t-wait-async"] = q + + app = _make_app(bridge) + + collected = [] + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + base_url="http://test", + ) as client: + async with client.stream("GET", "/events/t-wait-async/wait") as resp: + assert resp.status_code == 200 + async for chunk in resp.aiter_bytes(): + collected.append(chunk.decode()) + + body = "".join(collected) + assert "RUN_FINISHED" in body + + async def test_wait_messages_snapshot_filtered(self): + active_streams = {} + bridge = _make_bridge(active_streams=active_streams) + + snapshot = MagicMock() + snapshot.type = EventType.MESSAGES_SNAPSHOT + + q: asyncio.Queue = asyncio.Queue(maxsize=100) + await q.put(snapshot) + await q.put(make_run_finished()) + active_streams["t-wait-filter"] = q + + app = _make_app(bridge) + + collected = [] + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + base_url="http://test", + ) as client: + async with client.stream("GET", "/events/t-wait-filter/wait") as resp: + async for chunk in resp.aiter_bytes(): + collected.append(chunk.decode()) + + body = "".join(collected) + assert "MESSAGES_SNAPSHOT" not in body + assert "RUN_FINISHED" in body + + async def test_wait_queue_removed_after_stream(self): + active_streams = {} + bridge = _make_bridge(active_streams=active_streams) + + q: asyncio.Queue = asyncio.Queue(maxsize=100) + await q.put(make_run_finished()) + active_streams["t-wait-cleanup"] = q + + app = _make_app(bridge) + + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + base_url="http://test", + ) as client: + async with client.stream("GET", "/events/t-wait-cleanup/wait") as resp: + async for _ in resp.aiter_bytes(): + pass + + assert "t-wait-cleanup" not in active_streams + + async def test_wait_identity_safe_cleanup(self): + active_streams = {} + bridge = _make_bridge(active_streams=active_streams) + + old_queue: asyncio.Queue = asyncio.Queue(maxsize=100) + await old_queue.put(make_run_finished()) + active_streams["t-wait-id"] = old_queue + + newer_queue = asyncio.Queue(maxsize=100) + + app = _make_app(bridge) + + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + base_url="http://test", + ) as client: + async with client.stream("GET", "/events/t-wait-id/wait") as resp: + active_streams["t-wait-id"] = newer_queue + async for _ in resp.aiter_bytes(): + pass + + assert active_streams.get("t-wait-id") is newer_queue diff --git a/components/runners/ambient-runner/tests/test_gemini_auth.py b/components/runners/ambient-runner/tests/test_gemini_auth.py index 24ab94e4e..14c07edc5 100644 --- a/components/runners/ambient-runner/tests/test_gemini_auth.py +++ b/components/runners/ambient-runner/tests/test_gemini_auth.py @@ -1,6 +1,5 @@ """Tests for Gemini CLI authentication setup.""" - import warnings import pytest diff --git a/components/runners/ambient-runner/tests/test_gemini_cli_adapter.py b/components/runners/ambient-runner/tests/test_gemini_cli_adapter.py old mode 100755 new mode 100644 index 7e736547a..fd88f0171 --- a/components/runners/ambient-runner/tests/test_gemini_cli_adapter.py +++ b/components/runners/ambient-runner/tests/test_gemini_cli_adapter.py @@ -9,7 +9,6 @@ InitEvent, MessageEvent, ResultEvent, - ThinkingEvent, ToolResultEvent, ToolUseEvent, parse_event, @@ -159,33 +158,6 @@ def test_result_event_error(self): assert evt.status == "error" assert evt.error["type"] == "FatalAuthenticationError" - def test_thinking_event(self): - line = json.dumps( - { - "type": "thinking", - "timestamp": "2025-01-01T00:00:01Z", - "content": "Let me reason about this...", - "delta": True, - } - ) - evt = parse_event(line) - assert isinstance(evt, ThinkingEvent) - assert evt.content == "Let me reason about this..." - assert evt.delta is True - - def test_thinking_event_non_delta(self): - line = json.dumps( - { - "type": "thinking", - "timestamp": "2025-01-01T00:00:01Z", - "content": "Full thought.", - } - ) - evt = parse_event(line) - assert isinstance(evt, ThinkingEvent) - assert evt.content == "Full thought." - assert evt.delta is False - def test_invalid_json_returns_none(self): evt = parse_event("not valid json") assert evt is None @@ -329,233 +301,3 @@ async def line_stream(): assert "TOOL_CALL_START" in types assert "TOOL_CALL_ARGS" in types assert "TOOL_CALL_END" in types - - @pytest.mark.asyncio - async def test_thinking_then_text_response(self): - """thinking + assistant message → REASONING events + TEXT events.""" - from ag_ui_gemini_cli.adapter import GeminiCLIAdapter - from ag_ui.core import RunAgentInput - - lines = [ - json.dumps( - { - "type": "init", - "timestamp": "T", - "session_id": "s1", - "model": "gemini-2.5-pro", - } - ), - json.dumps( - { - "type": "thinking", - "timestamp": "T", - "content": "Let me think about this...", - "delta": True, - } - ), - json.dumps( - { - "type": "thinking", - "timestamp": "T", - "content": " I should consider X.", - "delta": True, - } - ), - json.dumps( - { - "type": "message", - "timestamp": "T", - "role": "assistant", - "content": "Here is my answer.", - "delta": True, - } - ), - json.dumps( - { - "type": "result", - "timestamp": "T", - "status": "success", - "stats": {"total_tokens": 20}, - } - ), - ] - - async def line_stream(): - for line in lines: - yield line - - input_data = RunAgentInput( - thread_id="t1", - run_id="r1", - state={}, - messages=[], - tools=[], - context=[], - forwardedProps={}, - ) - adapter = GeminiCLIAdapter() - events = [] - async for event in adapter.run(input_data, line_stream=line_stream()): - events.append(event) - - types = [e.type if isinstance(e.type, str) else e.type for e in events] - assert "RUN_STARTED" in types - assert "REASONING_START" in types - assert "REASONING_MESSAGE_START" in types - assert "REASONING_MESSAGE_CONTENT" in types - assert "REASONING_MESSAGE_END" in types - assert "REASONING_END" in types - assert "TEXT_MESSAGE_START" in types - assert "TEXT_MESSAGE_CONTENT" in types - assert "RUN_FINISHED" in types - - # Reasoning events should come before text events - reasoning_start_idx = types.index("REASONING_START") - reasoning_end_idx = types.index("REASONING_END") - text_start_idx = types.index("TEXT_MESSAGE_START") - assert reasoning_start_idx < reasoning_end_idx < text_start_idx - - # Should have two REASONING_MESSAGE_CONTENT events (two delta chunks) - reasoning_content_events = [ - e for e in events if getattr(e, "type", None) == "REASONING_MESSAGE_CONTENT" - ] - assert len(reasoning_content_events) == 2 - assert reasoning_content_events[0].delta == "Let me think about this..." - assert reasoning_content_events[1].delta == " I should consider X." - - @pytest.mark.asyncio - async def test_non_delta_thinking(self): - """Non-delta thinking event opens and closes reasoning block immediately.""" - from ag_ui_gemini_cli.adapter import GeminiCLIAdapter - from ag_ui.core import RunAgentInput - - lines = [ - json.dumps( - { - "type": "init", - "timestamp": "T", - "session_id": "s1", - "model": "gemini-2.5-pro", - } - ), - json.dumps( - { - "type": "thinking", - "timestamp": "T", - "content": "Full reasoning block.", - } - ), - json.dumps( - { - "type": "message", - "timestamp": "T", - "role": "assistant", - "content": "Answer.", - "delta": True, - } - ), - json.dumps({"type": "result", "timestamp": "T", "status": "success"}), - ] - - async def line_stream(): - for line in lines: - yield line - - input_data = RunAgentInput( - thread_id="t1", - run_id="r1", - state={}, - messages=[], - tools=[], - context=[], - forwardedProps={}, - ) - adapter = GeminiCLIAdapter() - events = [] - async for event in adapter.run(input_data, line_stream=line_stream()): - events.append(event) - - types = [e.type if isinstance(e.type, str) else e.type for e in events] - # Reasoning should be fully closed before text starts - assert "REASONING_START" in types - assert "REASONING_MESSAGE_START" in types - assert "REASONING_MESSAGE_CONTENT" in types - assert "REASONING_MESSAGE_END" in types - assert "REASONING_END" in types - assert "TEXT_MESSAGE_START" in types - - # Non-delta: reasoning block closed immediately (not by the message handler) - reasoning_end_idx = types.index("REASONING_END") - text_start_idx = types.index("TEXT_MESSAGE_START") - assert reasoning_end_idx < text_start_idx - - @pytest.mark.asyncio - async def test_thinking_before_tool_call(self): - """Reasoning block is closed before tool call events are emitted.""" - from ag_ui_gemini_cli.adapter import GeminiCLIAdapter - from ag_ui.core import RunAgentInput - - lines = [ - json.dumps( - { - "type": "init", - "timestamp": "T", - "session_id": "s1", - "model": "gemini-2.5-pro", - } - ), - json.dumps( - { - "type": "thinking", - "timestamp": "T", - "content": "I need to read a file.", - "delta": True, - } - ), - json.dumps( - { - "type": "tool_use", - "timestamp": "T", - "tool_name": "read_file", - "tool_id": "t1", - "parameters": {"path": "a.py"}, - } - ), - json.dumps( - { - "type": "tool_result", - "timestamp": "T", - "tool_id": "t1", - "status": "success", - "output": "data", - } - ), - json.dumps({"type": "result", "timestamp": "T", "status": "success"}), - ] - - async def line_stream(): - for line in lines: - yield line - - input_data = RunAgentInput( - thread_id="t1", - run_id="r1", - state={}, - messages=[], - tools=[], - context=[], - forwardedProps={}, - ) - adapter = GeminiCLIAdapter() - events = [] - async for event in adapter.run(input_data, line_stream=line_stream()): - events.append(event) - - types = [e.type if isinstance(e.type, str) else e.type for e in events] - assert "REASONING_END" in types - assert "TOOL_CALL_START" in types - - # Reasoning must be closed before tool call starts - reasoning_end_idx = types.index("REASONING_END") - tool_start_idx = types.index("TOOL_CALL_START") - assert reasoning_end_idx < tool_start_idx diff --git a/components/runners/ambient-runner/tests/test_gemini_session.py b/components/runners/ambient-runner/tests/test_gemini_session.py index b0f2b457f..94fbaf260 100644 --- a/components/runners/ambient-runner/tests/test_gemini_session.py +++ b/components/runners/ambient-runner/tests/test_gemini_session.py @@ -78,46 +78,6 @@ async def _wait(): # ------------------------------------------------------------------ -class TestWorkerSubprocessConfig: - """Verify subprocess configuration passed to create_subprocess_exec.""" - - @pytest.mark.asyncio - async def test_stream_buffer_limit_raised(self): - """The subprocess must use a 10 MB buffer limit to handle large MCP tool responses.""" - worker = GeminiSessionWorker(model="gemini-2.5-flash", api_key="key1") - proc = _make_mock_process( - stdout_lines=[b'{"type":"result"}\n'], - stderr_lines=[], - ) - - with patch("asyncio.create_subprocess_exec", return_value=proc) as mock_exec: - async for _ in worker.query("test"): - pass - - call_kwargs = mock_exec.call_args[1] - assert call_kwargs["limit"] == 10 * 1024 * 1024 - - @pytest.mark.asyncio - async def test_large_stdout_line_does_not_crash(self): - """A real subprocess outputting >64 KB on one line must not raise ValueError. - - Without the 10 MB limit the default 64 KB StreamReader would blow up with: - ValueError: Separator is found, but chunk is longer than limit - """ - payload = "x" * 100_000 # 100 KB — well above the old 64 KB default - proc = await asyncio.create_subprocess_exec( - "python3", - "-c", - f'print("{payload}")', - stdout=asyncio.subprocess.PIPE, - limit=10 * 1024 * 1024, - ) - line = await proc.stdout.readline() - await proc.wait() - assert len(line) > 64 * 1024 - assert proc.returncode == 0 - - class TestWorkerCommandConstruction: """Verify the CLI command built by query().""" diff --git a/components/runners/ambient-runner/tests/test_google_drive_e2e.py b/components/runners/ambient-runner/tests/test_google_drive_e2e.py index a908d735d..dda8fdbb1 100644 --- a/components/runners/ambient-runner/tests/test_google_drive_e2e.py +++ b/components/runners/ambient-runner/tests/test_google_drive_e2e.py @@ -155,7 +155,9 @@ def path_factory(path_str): shutil.rmtree(secret_mount_dir.parent, ignore_errors=True) -@pytest.mark.skip(reason="Tool invocation test not yet implemented - requires Claude SDK integration") +@pytest.mark.skip( + reason="Tool invocation test not yet implemented - requires Claude SDK integration" +) @pytest.mark.skipif( not os.getenv("GOOGLE_DRIVE_E2E_TEST"), reason="Requires GOOGLE_DRIVE_E2E_TEST=true and real credentials", diff --git a/components/runners/ambient-runner/tests/test_grpc_client.py b/components/runners/ambient-runner/tests/test_grpc_client.py new file mode 100644 index 000000000..207cf5d9a --- /dev/null +++ b/components/runners/ambient-runner/tests/test_grpc_client.py @@ -0,0 +1,308 @@ +from __future__ import annotations + +import base64 +import json +import os +from unittest.mock import MagicMock, patch + +import pytest +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import padding, rsa + +from ambient_runner._grpc_client import ( + _encrypt_session_id, + _fetch_token_from_cp, + _validate_cp_token_url, +) + + +def generate_keypair(): + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + private_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ).decode() + public_pem = ( + private_key.public_key() + .public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + .decode() + ) + return private_key, private_pem, public_pem + + +class TestValidateCPTokenURL: + def test_valid_http(self): + _validate_cp_token_url("http://ambient-control-plane.svc:8080/token") + + def test_valid_https(self): + _validate_cp_token_url("https://ambient-control-plane.svc:8080/token") + + def test_rejects_ftp(self): + with pytest.raises(RuntimeError, match="invalid CP token URL"): + _validate_cp_token_url("ftp://example.com/token") + + def test_rejects_file(self): + with pytest.raises(RuntimeError, match="invalid CP token URL"): + _validate_cp_token_url("file:///etc/passwd") + + def test_rejects_credentials_in_url(self): + with pytest.raises(RuntimeError, match="invalid CP token URL"): + _validate_cp_token_url("http://user:pass@example.com/token") + + def test_rejects_empty(self): + with pytest.raises(RuntimeError, match="invalid CP token URL"): + _validate_cp_token_url("") + + def test_rejects_no_host(self): + with pytest.raises(RuntimeError, match="invalid CP token URL"): + _validate_cp_token_url("http:///token") + + +class TestEncryptSessionID: + def test_produces_base64_ciphertext(self): + _, _, public_pem = generate_keypair() + result = _encrypt_session_id(public_pem, "my-session-id") + decoded = base64.b64decode(result) + assert len(decoded) > 0 + + def test_decryptable_with_private_key(self): + private_key, _, public_pem = generate_keypair() + session_id = "3BurtLWQNFMLp61XAGFKILYiHoN" + + ciphertext_b64 = _encrypt_session_id(public_pem, session_id) + ciphertext = base64.b64decode(ciphertext_b64) + + plaintext = private_key.decrypt( + ciphertext, + padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA256()), + algorithm=hashes.SHA256(), + label=None, + ), + ) + assert plaintext.decode() == session_id + + def test_different_ciphertexts_for_same_input(self): + _, _, public_pem = generate_keypair() + result1 = _encrypt_session_id(public_pem, "session-abc") + result2 = _encrypt_session_id(public_pem, "session-abc") + assert result1 != result2 + + def test_invalid_public_key_raises(self): + with pytest.raises(Exception): + _encrypt_session_id("not a pem key", "session-id") + + +class TestFetchTokenFromCP: + def _mock_successful_response(self, token: str = "api-token-xyz"): + + mock_resp = MagicMock() + mock_resp.read.return_value = json.dumps({"token": token}).encode() + mock_resp.__enter__ = MagicMock(return_value=mock_resp) + mock_resp.__exit__ = MagicMock(return_value=False) + return mock_resp + + def test_success(self): + _, _, public_pem = generate_keypair() + mock_resp = self._mock_successful_response("test-api-token") + + with patch("urllib.request.urlopen", return_value=mock_resp): + token = _fetch_token_from_cp( + "http://cp.svc:8080/token", public_pem, "session-12345678" + ) + + assert token == "test-api-token" + + def test_sends_encrypted_bearer(self): + _, _, public_pem = generate_keypair() + mock_resp = self._mock_successful_response() + captured_req = {} + + def fake_urlopen(req, timeout=None): + captured_req["req"] = req + return mock_resp + + with patch("urllib.request.urlopen", side_effect=fake_urlopen): + _fetch_token_from_cp("http://cp.svc:8080/token", public_pem, "session-abc") + + auth = captured_req["req"].get_header("Authorization") + assert auth.startswith("Bearer ") + b64_part = auth[len("Bearer ") :] + decoded = base64.b64decode(b64_part) + assert len(decoded) > 0 + + def test_retries_on_failure_then_succeeds(self): + _, _, public_pem = generate_keypair() + mock_resp = self._mock_successful_response() + import urllib.error + + call_count = [0] + + def fake_urlopen(req, timeout=None): + call_count[0] += 1 + if call_count[0] < 3: + raise urllib.error.URLError("connection refused") + return mock_resp + + with patch("urllib.request.urlopen", side_effect=fake_urlopen): + with patch("time.sleep"): + token = _fetch_token_from_cp( + "http://cp.svc:8080/token", public_pem, "session-12345678" + ) + + assert token == "api-token-xyz" + assert call_count[0] == 3 + + def test_raises_after_all_attempts_fail(self): + _, _, public_pem = generate_keypair() + import urllib.error + + with patch( + "urllib.request.urlopen", side_effect=urllib.error.URLError("refused") + ): + with patch("time.sleep"): + with pytest.raises(RuntimeError, match="CP token endpoint unreachable"): + _fetch_token_from_cp( + "http://cp.svc:8080/token", public_pem, "session-12345678" + ) + + def test_includes_http_error_body_in_exception(self): + _, _, public_pem = generate_keypair() + import urllib.error + + err_body = b"unauthorized: invalid token" + http_err = urllib.error.HTTPError( + url="http://cp.svc:8080/token", + code=401, + msg="Unauthorized", + hdrs=None, + fp=MagicMock(read=MagicMock(return_value=err_body)), + ) + + with patch("urllib.request.urlopen", side_effect=http_err): + with patch("time.sleep"): + with pytest.raises(RuntimeError, match="CP /token HTTP 401"): + _fetch_token_from_cp( + "http://cp.svc:8080/token", public_pem, "session-12345678" + ) + + def test_raises_on_missing_token_field(self): + _, _, public_pem = generate_keypair() + mock_resp = MagicMock() + mock_resp.read.return_value = json.dumps({"other": "field"}).encode() + mock_resp.__enter__ = MagicMock(return_value=mock_resp) + mock_resp.__exit__ = MagicMock(return_value=False) + + with patch("urllib.request.urlopen", return_value=mock_resp): + with patch("time.sleep"): + with pytest.raises(RuntimeError, match="missing 'token' field"): + _fetch_token_from_cp( + "http://cp.svc:8080/token", public_pem, "session-12345678" + ) + + +class TestSetBotTokenIntegration: + def test_get_bot_token_returns_cp_fetched_token_after_successful_fetch(self): + import ambient_runner.platform.utils as utils + + utils._cp_fetched_token = "" + + _, _, public_pem = generate_keypair() + mock_resp = MagicMock() + mock_resp.read.return_value = json.dumps( + {"token": "oidc-token-for-api-calls"} + ).encode() + mock_resp.__enter__ = MagicMock(return_value=mock_resp) + mock_resp.__exit__ = MagicMock(return_value=False) + + assert utils.get_bot_token() == "", ( + "get_bot_token() must be empty before any CP fetch" + ) + + with patch("urllib.request.urlopen", return_value=mock_resp): + _fetch_token_from_cp( + "http://cp.svc:8080/token", public_pem, "session-12345678" + ) + + assert utils.get_bot_token() == "oidc-token-for-api-calls", ( + "get_bot_token() must return the CP-fetched token so backend API credential " + "calls are authenticated — regression for HTTP 401 on credential refresh" + ) + utils._cp_fetched_token = "" + + def test_fetch_from_cp_calls_set_bot_token(self): + from cryptography.hazmat.primitives.asymmetric import rsa as _rsa + + private_key = _rsa.generate_private_key(public_exponent=65537, key_size=2048) + public_pem = ( + private_key.public_key() + .public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + .decode() + ) + + mock_resp = MagicMock() + mock_resp.read.return_value = json.dumps( + {"token": "oidc-api-token-abc"} + ).encode() + mock_resp.__enter__ = MagicMock(return_value=mock_resp) + mock_resp.__exit__ = MagicMock(return_value=False) + + import ambient_runner.platform.utils as utils + + utils._cp_fetched_token = "" + + with patch("urllib.request.urlopen", return_value=mock_resp): + _fetch_token_from_cp( + "http://cp.svc:8080/token", public_pem, "session-12345678" + ) + + assert utils.get_bot_token() == "oidc-api-token-abc" + utils._cp_fetched_token = "" + + +class TestFromEnvIntegration: + def test_uses_encrypted_session_id_when_cp_token_url_set(self): + _, _, public_pem = generate_keypair() + mock_resp = MagicMock() + mock_resp.read.return_value = json.dumps({"token": "env-token"}).encode() + mock_resp.__enter__ = MagicMock(return_value=mock_resp) + mock_resp.__exit__ = MagicMock(return_value=False) + + env = { + "AMBIENT_GRPC_URL": "localhost:9000", + "AMBIENT_CP_TOKEN_URL": "http://cp.svc:8080/token", + "AMBIENT_CP_TOKEN_PUBLIC_KEY": public_pem, + "SESSION_ID": "session-test-1234", + "AMBIENT_GRPC_USE_TLS": "false", + } + + with patch.dict(os.environ, env, clear=False): + with patch("urllib.request.urlopen", return_value=mock_resp): + from ambient_runner._grpc_client import AmbientGRPCClient + + client = AmbientGRPCClient.from_env() + + assert client._token == "env-token" + + def test_falls_back_to_bot_token_when_no_cp_url(self): + env = { + "AMBIENT_GRPC_URL": "localhost:9000", + "BOT_TOKEN": "static-bot-token", + "AMBIENT_GRPC_USE_TLS": "false", + } + env_without_cp = {k: v for k, v in env.items()} + + with patch.dict(os.environ, env_without_cp, clear=False): + with patch.dict(os.environ, {"AMBIENT_CP_TOKEN_URL": ""}, clear=False): + from ambient_runner._grpc_client import AmbientGRPCClient + + client = AmbientGRPCClient.from_env() + + assert client._token == "static-bot-token" diff --git a/components/runners/ambient-runner/tests/test_grpc_transport.py b/components/runners/ambient-runner/tests/test_grpc_transport.py new file mode 100644 index 000000000..dd5a07baf --- /dev/null +++ b/components/runners/ambient-runner/tests/test_grpc_transport.py @@ -0,0 +1,662 @@ +"""Tests for GRPCSessionListener and GRPCMessageWriter in grpc_transport.py. + +Coverage targets: +- GRPCSessionListener: ready event lifecycle, message type filtering, + fan-out to SSE queues, stop/cancel, bridge.run() called with correct RunnerInput, + exception in bridge.run() synthesizes RUN_ERROR, invalid JSON fallback +- GRPCMessageWriter: MESSAGES_SNAPSHOT accumulation, RUN_FINISHED/RUN_ERROR push, + non-terminal events ignored, push offloaded to executor (non-blocking), + push failure logged without re-raising +- _synthesize_run_error: feeds RUN_ERROR to SSE queue, schedules writer persist +""" + +import asyncio +import json +import uuid +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from ag_ui.core import EventType + +from tests.conftest import ( + async_event_stream, + make_run_finished, + make_text_content, + make_text_start, +) + +from ambient_runner.bridges.claude.grpc_transport import ( + GRPCMessageWriter, + GRPCSessionListener, + _synthesize_run_error, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_session_message(event_type: str, payload: str, seq: int = 1): + msg = MagicMock() + msg.event_type = event_type + msg.payload = payload + msg.seq = seq + msg.session_id = "sess-1" + return msg + + +def _make_runner_payload( + thread_id: str = "t-1", + run_id: str = "r-1", + content: str = "hello", +) -> str: + return json.dumps( + { + "threadId": thread_id, + "runId": run_id, + "messages": [{"id": str(uuid.uuid4()), "role": "user", "content": content}], + } + ) + + +def _make_grpc_client(messages=None): + """Return a mock AmbientGRPCClient whose watch() yields the given messages.""" + client = MagicMock() + client.session_messages.watch.return_value = iter(messages or []) + client.session_messages.push.return_value = MagicMock(seq=1) + return client + + +def _make_bridge(active_streams=None): + bridge = MagicMock() + bridge._active_streams = active_streams if active_streams is not None else {} + return bridge + + +# --------------------------------------------------------------------------- +# GRPCSessionListener — ready event +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestGRPCSessionListenerReady: + async def test_ready_set_after_watch_opens(self): + client = _make_grpc_client(messages=[]) + bridge = _make_bridge() + listener = GRPCSessionListener( + bridge=bridge, session_id="s-1", grpc_url="localhost:9000" + ) + listener._grpc_client = client + + task = asyncio.create_task(listener._listen_loop()) + try: + await asyncio.wait_for(listener.ready.wait(), timeout=2.0) + assert listener.ready.is_set() + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + async def test_ready_not_set_before_watch(self): + bridge = _make_bridge() + listener = GRPCSessionListener( + bridge=bridge, session_id="s-1", grpc_url="localhost:9000" + ) + assert not listener.ready.is_set() + + async def test_ready_set_on_successful_watch(self): + client = _make_grpc_client(messages=[]) + bridge = _make_bridge() + listener = GRPCSessionListener( + bridge=bridge, session_id="s-1", grpc_url="localhost:9000" + ) + listener._grpc_client = client + + task = asyncio.create_task(listener._listen_loop()) + try: + await asyncio.wait_for(listener.ready.wait(), timeout=2.0) + assert listener.ready.is_set() + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +# --------------------------------------------------------------------------- +# GRPCSessionListener — message filtering +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestGRPCSessionListenerFiltering: + async def test_non_user_messages_do_not_trigger_run(self): + msgs = [ + _make_session_message("assistant", '{"foo": "bar"}', seq=1), + _make_session_message("system", "{}", seq=2), + ] + client = _make_grpc_client(messages=msgs) + bridge = _make_bridge() + bridge.run = AsyncMock(return_value=async_event_stream([])) + + listener = GRPCSessionListener( + bridge=bridge, session_id="s-1", grpc_url="localhost:9000" + ) + listener._grpc_client = client + + task = asyncio.create_task(listener._listen_loop()) + try: + await asyncio.wait_for(listener.ready.wait(), timeout=2.0) + await asyncio.sleep(0.1) + bridge.run.assert_not_called() + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + async def test_user_message_triggers_bridge_run(self): + payload = _make_runner_payload( + thread_id="t-1", run_id="r-1", content="do the thing" + ) + msgs = [_make_session_message("user", payload, seq=1)] + client = _make_grpc_client(messages=msgs) + bridge = _make_bridge() + + run_inputs = [] + + async def fake_run(input_data): + run_inputs.append(input_data) + yield make_text_start() + yield make_run_finished() + + bridge.run = fake_run + bridge._active_streams = {} + + listener = GRPCSessionListener( + bridge=bridge, session_id="s-1", grpc_url="localhost:9000" + ) + listener._grpc_client = client + + task = asyncio.create_task(listener._listen_loop()) + try: + await asyncio.wait_for(listener.ready.wait(), timeout=2.0) + await asyncio.sleep(0.3) + assert len(run_inputs) == 1 + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + async def test_user_message_run_called_with_correct_thread_id(self): + """bridge.run() must receive input_data with thread_id from the message payload.""" + payload = _make_runner_payload( + thread_id="t-specific", run_id="r-42", content="hello" + ) + msgs = [_make_session_message("user", payload, seq=5)] + client = _make_grpc_client(messages=msgs) + bridge = _make_bridge() + + run_inputs = [] + + async def fake_run(input_data): + run_inputs.append(input_data) + yield make_run_finished() + + bridge.run = fake_run + bridge._active_streams = {} + + listener = GRPCSessionListener( + bridge=bridge, session_id="s-1", grpc_url="localhost:9000" + ) + listener._grpc_client = client + + task = asyncio.create_task(listener._listen_loop()) + try: + await asyncio.wait_for(listener.ready.wait(), timeout=2.0) + await asyncio.sleep(0.3) + assert len(run_inputs) == 1 + assert run_inputs[0].thread_id == "t-specific" + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + async def test_invalid_json_payload_uses_raw_as_content_fallback(self): + """Invalid JSON in payload falls back to creating a message with raw payload as content.""" + msgs = [_make_session_message("user", "not-json", seq=1)] + client = _make_grpc_client(messages=msgs) + bridge = _make_bridge() + + run_inputs = [] + + async def fake_run(input_data): + run_inputs.append(input_data) + yield make_run_finished() + + bridge.run = fake_run + bridge._active_streams = {} + + listener = GRPCSessionListener( + bridge=bridge, session_id="s-1", grpc_url="localhost:9000" + ) + listener._grpc_client = client + + task = asyncio.create_task(listener._listen_loop()) + try: + await asyncio.wait_for(listener.ready.wait(), timeout=2.0) + await asyncio.sleep(0.3) + assert len(run_inputs) == 1 + msgs_in_input = run_inputs[0].messages + assert len(msgs_in_input) == 1 + msg = msgs_in_input[0] + role = msg["role"] if isinstance(msg, dict) else getattr(msg, "role", None) + content = ( + msg["content"] + if isinstance(msg, dict) + else getattr(msg, "content", None) + ) + assert role == "user" + assert content == "not-json" + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + async def test_bridge_run_exception_synthesizes_run_error_to_sse_queue(self): + """If bridge.run() raises, a RUN_ERROR event must be fed to the SSE tap queue.""" + payload = _make_runner_payload(thread_id="t-err", run_id="r-err") + msgs = [_make_session_message("user", payload, seq=1)] + client = _make_grpc_client(messages=msgs) + + tap_queue: asyncio.Queue = asyncio.Queue(maxsize=100) + active_streams = {"t-err": tap_queue} + bridge = _make_bridge(active_streams=active_streams) + + async def exploding_run(input_data): + raise RuntimeError("boom") + yield # make it a generator + + bridge.run = exploding_run + + listener = GRPCSessionListener( + bridge=bridge, session_id="s-1", grpc_url="localhost:9000" + ) + listener._grpc_client = client + + task = asyncio.create_task(listener._listen_loop()) + try: + await asyncio.wait_for(listener.ready.wait(), timeout=2.0) + await asyncio.sleep(0.5) + + run_error_events = [] + while not tap_queue.empty(): + ev = tap_queue.get_nowait() + raw = getattr(ev, "type", None) + ev_str = raw.value if hasattr(raw, "value") else str(raw) + if "RUN_ERROR" in ev_str: + run_error_events.append(ev) + assert len(run_error_events) >= 1 + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +# --------------------------------------------------------------------------- +# GRPCSessionListener — fan-out +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestGRPCSessionListenerFanOut: + async def test_events_fed_to_active_streams_queue(self): + payload = _make_runner_payload(thread_id="t-fanout", run_id="r-1") + msgs = [_make_session_message("user", payload, seq=1)] + client = _make_grpc_client(messages=msgs) + + received_events = [] + tap_queue: asyncio.Queue = asyncio.Queue(maxsize=100) + + bridge = _make_bridge(active_streams={"t-fanout": tap_queue}) + events = [make_text_start(), make_text_content(), make_run_finished()] + + async def fake_run(input_data): + for e in events: + yield e + + bridge.run = fake_run + + listener = GRPCSessionListener( + bridge=bridge, session_id="s-1", grpc_url="localhost:9000" + ) + listener._grpc_client = client + + task = asyncio.create_task(listener._listen_loop()) + try: + await asyncio.wait_for(listener.ready.wait(), timeout=2.0) + await asyncio.sleep(0.3) + while not tap_queue.empty(): + received_events.append(tap_queue.get_nowait()) + assert len(received_events) == len(events) + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + async def test_no_active_stream_fan_out_skipped_silently(self): + payload = _make_runner_payload(thread_id="t-1", run_id="r-1") + msgs = [_make_session_message("user", payload, seq=1)] + client = _make_grpc_client(messages=msgs) + bridge = _make_bridge(active_streams={}) + + events = [make_text_start(), make_run_finished()] + + async def fake_run(input_data): + for e in events: + yield e + + bridge.run = fake_run + + listener = GRPCSessionListener( + bridge=bridge, session_id="s-1", grpc_url="localhost:9000" + ) + listener._grpc_client = client + + task = asyncio.create_task(listener._listen_loop()) + try: + await asyncio.wait_for(listener.ready.wait(), timeout=2.0) + await asyncio.sleep(0.3) + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + async def test_full_queue_drops_event_without_raising(self): + payload = _make_runner_payload(thread_id="t-full", run_id="r-1") + msgs = [_make_session_message("user", payload, seq=1)] + client = _make_grpc_client(messages=msgs) + + full_queue: asyncio.Queue = asyncio.Queue(maxsize=1) + full_queue.put_nowait(make_text_start()) + + bridge = _make_bridge(active_streams={"t-full": full_queue}) + events = [make_text_start(), make_run_finished()] + + async def fake_run(input_data): + for e in events: + yield e + + bridge.run = fake_run + + listener = GRPCSessionListener( + bridge=bridge, session_id="s-1", grpc_url="localhost:9000" + ) + listener._grpc_client = client + + task = asyncio.create_task(listener._listen_loop()) + try: + await asyncio.wait_for(listener.ready.wait(), timeout=2.0) + await asyncio.sleep(0.3) + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + async def test_active_streams_entry_removed_after_turn(self): + payload = _make_runner_payload(thread_id="t-cleanup", run_id="r-1") + msgs = [_make_session_message("user", payload, seq=1)] + client = _make_grpc_client(messages=msgs) + + tap_queue: asyncio.Queue = asyncio.Queue(maxsize=100) + active_streams = {"t-cleanup": tap_queue} + bridge = _make_bridge(active_streams=active_streams) + + async def fake_run(input_data): + yield make_run_finished() + + bridge.run = fake_run + + listener = GRPCSessionListener( + bridge=bridge, session_id="s-1", grpc_url="localhost:9000" + ) + listener._grpc_client = client + + task = asyncio.create_task(listener._listen_loop()) + try: + await asyncio.wait_for(listener.ready.wait(), timeout=2.0) + await asyncio.sleep(0.3) + assert "t-cleanup" not in active_streams + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +# --------------------------------------------------------------------------- +# GRPCSessionListener — stop +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestGRPCSessionListenerStop: + async def test_stop_cancels_task(self): + client = _make_grpc_client(messages=[]) + bridge = _make_bridge() + listener = GRPCSessionListener( + bridge=bridge, session_id="s-1", grpc_url="localhost:9000" + ) + listener._grpc_client = client + listener._task = asyncio.create_task(listener._listen_loop()) + + await asyncio.wait_for(listener.ready.wait(), timeout=2.0) + await listener.stop() + assert listener._task.done() + + +# --------------------------------------------------------------------------- +# GRPCMessageWriter — consume +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestGRPCMessageWriterConsume: + def _make_messages_snapshot(self, messages): + event = MagicMock() + event.type = EventType.MESSAGES_SNAPSHOT + event.messages = messages + return event + + def _make_run_finished_event(self): + event = MagicMock() + event.type = EventType.RUN_FINISHED + return event + + def _make_run_error_event(self): + event = MagicMock() + event.type = EventType.RUN_ERROR + return event + + def _make_text_event(self): + event = MagicMock() + event.type = EventType.TEXT_MESSAGE_CONTENT + return event + + def _writer(self): + client = MagicMock() + client.session_messages.push.return_value = MagicMock(seq=1) + return GRPCMessageWriter( + session_id="s-1", run_id="r-1", grpc_client=client + ), client + + async def test_messages_snapshot_accumulated(self): + writer, _ = self._writer() + msg = MagicMock() + msg.model_dump.return_value = {"role": "assistant", "content": "hi"} + snap = self._make_messages_snapshot([msg]) + await writer.consume(snap) + assert len(writer._accumulated_messages) == 1 + + async def test_run_finished_pushes_completed(self): + writer, client = self._writer() + msg = MagicMock() + msg.model_dump.return_value = {"role": "assistant", "content": "done"} + snap = self._make_messages_snapshot([msg]) + await writer.consume(snap) + await writer.consume(self._make_run_finished_event()) + + client.session_messages.push.assert_called_once() + call = client.session_messages.push.call_args + assert call[0][0] == "s-1" + assert call[1]["event_type"] == "assistant" + assert call[1]["payload"] == "done" + + async def test_run_error_pushes_error_status(self): + writer, client = self._writer() + await writer.consume(self._make_run_error_event()) + + client.session_messages.push.assert_called_once() + assert client.session_messages.push.call_args[1]["event_type"] == "assistant" + + async def test_non_terminal_events_do_not_push(self): + writer, client = self._writer() + await writer.consume(self._make_text_event()) + client.session_messages.push.assert_not_called() + + async def test_unknown_event_type_ignored(self): + writer, client = self._writer() + event = MagicMock() + event.type = None + await writer.consume(event) + client.session_messages.push.assert_not_called() + + async def test_latest_snapshot_replaces_previous(self): + writer, client = self._writer() + msg1 = MagicMock() + msg1.model_dump.return_value = {"role": "assistant", "content": "first"} + msg2 = MagicMock() + msg2.model_dump.return_value = {"role": "assistant", "content": "second"} + + await writer.consume(self._make_messages_snapshot([msg1])) + await writer.consume(self._make_messages_snapshot([msg2])) + await writer.consume(self._make_run_finished_event()) + + assert client.session_messages.push.call_args[1]["payload"] == "second" + + async def test_no_grpc_client_write_skipped(self): + writer = GRPCMessageWriter(session_id="s-1", run_id="r-1", grpc_client=None) + event = MagicMock() + event.type = EventType.RUN_FINISHED + await writer.consume(event) + + async def test_push_includes_correct_session_id(self): + writer, client = self._writer() + await writer.consume(self._make_run_finished_event()) + assert client.session_messages.push.call_args[0][0] == "s-1" + assert client.session_messages.push.call_args[1]["event_type"] == "assistant" + + async def test_push_offloaded_to_executor_not_inline(self): + """The synchronous gRPC push must be run via run_in_executor, not inline.""" + writer, client = self._writer() + + executor_calls = [] + real_loop = asyncio.get_event_loop() + original = real_loop.run_in_executor + + async def capturing(executor, fn, *args): + executor_calls.append(fn) + return await original(executor, fn, *args) + + with patch.object(real_loop, "run_in_executor", side_effect=capturing): + await writer.consume(self._make_run_finished_event()) + + assert len(executor_calls) == 1 + + async def test_push_failure_does_not_raise(self): + """If the gRPC push in executor fails, _write_message must not re-raise.""" + writer, client = self._writer() + client.session_messages.push.side_effect = RuntimeError("rpc unavailable") + + await writer.consume(self._make_run_finished_event()) + + +# --------------------------------------------------------------------------- +# _synthesize_run_error — standalone helper +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestSynthesizeRunError: + async def test_feeds_run_error_event_to_sse_queue(self): + """_synthesize_run_error must put a RUN_ERROR event into the SSE tap queue.""" + tap_queue: asyncio.Queue = asyncio.Queue(maxsize=100) + active_streams = {"t-synth": tap_queue} + + client = MagicMock() + client.session_messages.push.return_value = MagicMock(seq=1) + writer = GRPCMessageWriter(session_id="s-1", run_id="r-1", grpc_client=client) + + _synthesize_run_error("t-synth", "test error", active_streams, writer) + + await asyncio.sleep(0.1) + + assert not tap_queue.empty() + ev = tap_queue.get_nowait() + raw = getattr(ev, "type", None) + ev_str = raw.value if hasattr(raw, "value") else str(raw) + assert "RUN_ERROR" in ev_str + + async def test_no_sse_queue_does_not_raise(self): + """When no SSE queue is registered, _synthesize_run_error must not raise.""" + active_streams: dict = {} + + client = MagicMock() + client.session_messages.push.return_value = MagicMock(seq=1) + writer = GRPCMessageWriter(session_id="s-1", run_id="r-1", grpc_client=client) + + _synthesize_run_error("t-missing", "test error", active_streams, writer) + await asyncio.sleep(0.1) + + async def test_schedules_writer_error_persist(self): + """_synthesize_run_error must schedule writer._write_message(status='error').""" + tap_queue: asyncio.Queue = asyncio.Queue(maxsize=100) + active_streams = {"t-wr": tap_queue} + + client = MagicMock() + client.session_messages.push.return_value = MagicMock(seq=1) + writer = GRPCMessageWriter(session_id="s-1", run_id="r-1", grpc_client=client) + + write_calls = [] + original_write = writer._write_message + + async def tracking_write(status): + write_calls.append(status) + return await original_write(status) + + writer._write_message = tracking_write + + _synthesize_run_error("t-wr", "boom", active_streams, writer) + await asyncio.sleep(0.2) + + assert "error" in write_calls diff --git a/components/runners/ambient-runner/tests/test_grpc_writer.py b/components/runners/ambient-runner/tests/test_grpc_writer.py new file mode 100644 index 000000000..ab4234e19 --- /dev/null +++ b/components/runners/ambient-runner/tests/test_grpc_writer.py @@ -0,0 +1,213 @@ +""" +Tests for GRPCMessageWriter. + +Covers the event-accumulation and push logic, including edge cases that +caused production failures: + + - assistant message with content=None (tool-call-only turns where Claude + emits no text; MESSAGES_SNAPSHOT contains {"role":"assistant","content":null}) + - no assistant message in snapshot at all + - normal happy-path with text content + - RUN_ERROR triggers push with status="error" +""" + +import pytest +from unittest.mock import MagicMock + +from ag_ui.core import EventType, RunFinishedEvent, RunErrorEvent + +from ambient_runner.bridges.claude.grpc_transport import GRPCMessageWriter + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_writer(grpc_client=None): + if grpc_client is None: + grpc_client = MagicMock() + return GRPCMessageWriter( + session_id="sess-1", + run_id="run-1", + grpc_client=grpc_client, + ) + + +def make_snapshot_event(messages: list) -> MagicMock: + evt = MagicMock() + evt.type = EventType.MESSAGES_SNAPSHOT + evt.messages = [_dict_to_mock(m) for m in messages] + return evt + + +def _dict_to_mock(d: dict) -> MagicMock: + m = MagicMock() + m.model_dump.return_value = d + return m + + +def make_run_finished() -> RunFinishedEvent: + return RunFinishedEvent( + type=EventType.RUN_FINISHED, + thread_id="t-1", + run_id="run-1", + ) + + +def make_run_error() -> RunErrorEvent: + return RunErrorEvent( + type=EventType.RUN_ERROR, + message="something went wrong", + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_write_message_none_content_raises_without_fix(): + """ + Regression: assistant message with content=None causes TypeError: object of + type 'NoneType' has no len(). + + Snapshot has role=assistant but content=null (tool-call-only turn). + Before the fix this crashes; after the fix it should push an empty string. + """ + client = MagicMock() + writer = make_writer(client) + + snapshot = make_snapshot_event( + [ + {"role": "user", "content": "i'm sending you a message"}, + {"role": "assistant", "content": None}, + ] + ) + await writer.consume(snapshot) + await writer.consume(make_run_finished()) + + client.session_messages.push.assert_called_once_with( + "sess-1", + event_type="assistant", + payload="", + ) + + +@pytest.mark.asyncio +async def test_write_message_no_assistant_in_snapshot(): + """No assistant message at all — push should still succeed with empty payload.""" + client = MagicMock() + writer = make_writer(client) + + snapshot = make_snapshot_event( + [ + {"role": "user", "content": "hello"}, + ] + ) + await writer.consume(snapshot) + await writer.consume(make_run_finished()) + + client.session_messages.push.assert_called_once_with( + "sess-1", + event_type="assistant", + payload="", + ) + + +@pytest.mark.asyncio +async def test_write_message_happy_path(): + """Normal turn: assistant has text content — push uses that text.""" + client = MagicMock() + writer = make_writer(client) + + snapshot = make_snapshot_event( + [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "Hello! I'm here."}, + ] + ) + await writer.consume(snapshot) + await writer.consume(make_run_finished()) + + client.session_messages.push.assert_called_once_with( + "sess-1", + event_type="assistant", + payload="Hello! I'm here.", + ) + + +@pytest.mark.asyncio +async def test_run_error_pushes_with_error_status(): + """RUN_ERROR triggers push with status='error'.""" + client = MagicMock() + writer = make_writer(client) + + snapshot = make_snapshot_event( + [ + {"role": "assistant", "content": "partial"}, + ] + ) + await writer.consume(snapshot) + await writer.consume(make_run_error()) + + client.session_messages.push.assert_called_once_with( + "sess-1", + event_type="assistant", + payload="partial", + ) + + +@pytest.mark.asyncio +async def test_latest_snapshot_wins(): + """Multiple MESSAGES_SNAPSHOT events — only the last one counts.""" + client = MagicMock() + writer = make_writer(client) + + await writer.consume( + make_snapshot_event( + [ + {"role": "assistant", "content": "stale"}, + ] + ) + ) + await writer.consume( + make_snapshot_event( + [ + {"role": "assistant", "content": "fresh"}, + ] + ) + ) + await writer.consume(make_run_finished()) + + client.session_messages.push.assert_called_once_with( + "sess-1", + event_type="assistant", + payload="fresh", + ) + + +@pytest.mark.asyncio +async def test_no_push_without_run_finished(): + """Events before RUN_FINISHED/RUN_ERROR don't trigger a push.""" + client = MagicMock() + writer = make_writer(client) + + await writer.consume( + make_snapshot_event( + [ + {"role": "assistant", "content": "something"}, + ] + ) + ) + + client.session_messages.push.assert_not_called() + + +@pytest.mark.asyncio +async def test_no_grpc_client_does_not_raise(): + """Writer with no gRPC client logs a warning and returns cleanly.""" + writer = make_writer(grpc_client=None) + await writer.consume(make_snapshot_event([{"role": "assistant", "content": "x"}])) + await writer.consume(make_run_finished()) diff --git a/components/runners/ambient-runner/tests/test_shared_session_credentials.py b/components/runners/ambient-runner/tests/test_shared_session_credentials.py index e56078236..4c8d573f6 100755 --- a/components/runners/ambient-runner/tests/test_shared_session_credentials.py +++ b/components/runners/ambient-runner/tests/test_shared_session_credentials.py @@ -325,6 +325,7 @@ async def test_sends_current_user_header_when_set(self): _CredentialHandler.response_body = {"token": "gh-token-for-userB"} _CredentialHandler.captured_headers = {} + cred_id = "cred-github-001" try: with patch.dict( os.environ, @@ -332,6 +333,7 @@ async def test_sends_current_user_header_when_set(self): "BACKEND_API_URL": f"http://127.0.0.1:{port}/api", "PROJECT_NAME": "test-project", "BOT_TOKEN": "fake-bot-token", + "CREDENTIAL_IDS": json.dumps({"github": cred_id}), }, ): ctx = _make_context( @@ -367,6 +369,7 @@ async def test_omits_current_user_header_when_not_set(self): _CredentialHandler.response_body = {"token": "owner-token"} _CredentialHandler.captured_headers = {} + cred_id = "cred-github-002" try: with patch.dict( os.environ, @@ -374,6 +377,7 @@ async def test_omits_current_user_header_when_not_set(self): "BACKEND_API_URL": f"http://127.0.0.1:{port}/api", "PROJECT_NAME": "test-project", "BOT_TOKEN": "fake-bot-token", + "CREDENTIAL_IDS": json.dumps({"github": cred_id}), }, ): ctx = _make_context() # no current_user_id @@ -394,6 +398,7 @@ async def test_returns_empty_when_backend_unavailable(self): { "BACKEND_API_URL": "http://127.0.0.1:1/api", "PROJECT_NAME": "test-project", + "CREDENTIAL_IDS": json.dumps({"github": "cred-unreachable"}), }, ): ctx = _make_context(current_user_id="user-123") @@ -414,23 +419,25 @@ async def test_credentials_populated_then_cleared(self): server = HTTPServer(("127.0.0.1", 0), _CredentialHandler) port = server.server_address[1] - # We need to handle multiple requests (github, google, jira, gitlab) + # We need to handle multiple requests (github, google, jira, gitlab, coderabbit, gerrit, kubeconfig) call_count = [0] responses = { - "/github": {"token": "gh-tok"}, - "/google": {}, - "/jira": { - "apiToken": "jira-tok", + "cred-gh": {"token": "gh-tok"}, + "cred-google": {}, + "cred-jira": { + "token": "jira-tok", "url": "https://jira.example.com", "email": "j@example.com", }, - "/gitlab": {"token": "gl-tok"}, + "cred-gl": {"token": "gl-tok"}, + "cred-coderabbit": {}, + "cred-gerrit": [], + "cred-kubeconfig": {}, } class MultiHandler(BaseHTTPRequestHandler): def do_GET(self): call_count[0] += 1 - # Extract credential type from URL path for key, resp in responses.items(): if key in self.path: self.send_response(200) @@ -447,10 +454,22 @@ def log_message(self, format, *args): server = HTTPServer(("127.0.0.1", 0), MultiHandler) port = server.server_address[1] thread = Thread( - target=lambda: [server.handle_request() for _ in range(4)], daemon=True + target=lambda: [server.handle_request() for _ in range(7)], daemon=True ) thread.start() + credential_ids = json.dumps( + { + "github": "cred-gh", + "google": "cred-google", + "jira": "cred-jira", + "gitlab": "cred-gl", + "coderabbit": "cred-coderabbit", + "gerrit": "cred-gerrit", + "kubeconfig": "cred-kubeconfig", + } + ) + try: with patch.dict( os.environ, @@ -458,6 +477,7 @@ def log_message(self, format, *args): "BACKEND_API_URL": f"http://127.0.0.1:{port}/api", "PROJECT_NAME": "test-project", "BOT_TOKEN": "fake-bot", + "CREDENTIAL_IDS": credential_ids, }, ): ctx = _make_context(current_user_id="userB") @@ -509,9 +529,9 @@ async def test_raises_permission_error_on_401_without_caller_token( monkeypatch.setenv("BACKEND_API_URL", "http://backend.svc.cluster.local/api") monkeypatch.setenv("PROJECT_NAME", "test-project") monkeypatch.setenv("BOT_TOKEN", "bot-token") + monkeypatch.setenv("CREDENTIAL_IDS", json.dumps({"github": "cred-gh-001"})) ctx = _make_context(session_id="sess-1") - # No caller token — uses BOT_TOKEN directly err = HTTPError( "http://backend.svc.cluster.local/api/...", @@ -534,6 +554,7 @@ async def test_raises_permission_error_on_403_without_caller_token( monkeypatch.setenv("BACKEND_API_URL", "http://backend.svc.cluster.local/api") monkeypatch.setenv("PROJECT_NAME", "test-project") monkeypatch.setenv("BOT_TOKEN", "bot-token") + monkeypatch.setenv("CREDENTIAL_IDS", json.dumps({"google": "cred-google-001"})) ctx = _make_context(session_id="sess-1") @@ -558,6 +579,7 @@ async def test_raises_permission_error_when_caller_and_bot_both_fail( monkeypatch.setenv("BACKEND_API_URL", "http://backend.svc.cluster.local/api") monkeypatch.setenv("PROJECT_NAME", "test-project") monkeypatch.setenv("BOT_TOKEN", "bot-token") + monkeypatch.setenv("CREDENTIAL_IDS", json.dumps({"github": "cred-gh-002"})) ctx = _make_context(session_id="sess-1", current_user_id="user@example.com") ctx.caller_token = "Bearer expired-caller-token" @@ -577,6 +599,7 @@ async def test_does_not_raise_on_non_auth_http_errors(self, monkeypatch): """_fetch_credential returns {} for non-auth HTTP errors (404, 500, etc.).""" monkeypatch.setenv("BACKEND_API_URL", "http://backend.svc.cluster.local/api") monkeypatch.setenv("PROJECT_NAME", "test-project") + monkeypatch.setenv("CREDENTIAL_IDS", json.dumps({"github": "cred-gh-003"})) ctx = _make_context(session_id="sess-1") @@ -594,6 +617,7 @@ async def test_caller_token_fallback_succeeds_when_bot_token_works( monkeypatch.setenv("BACKEND_API_URL", "http://backend.svc.cluster.local/api") monkeypatch.setenv("PROJECT_NAME", "test-project") monkeypatch.setenv("BOT_TOKEN", "valid-bot-token") + monkeypatch.setenv("CREDENTIAL_IDS", json.dumps({"github": "cred-gh-004"})) ctx = _make_context(session_id="sess-1", current_user_id="user@example.com") ctx.caller_token = "Bearer expired-caller-token" @@ -772,6 +796,85 @@ async def test_includes_mcp_diagnostics_on_auth_warning(self): assert "google-workspace" in text +# --------------------------------------------------------------------------- +# _fetch_credential — CP OIDC token used when no caller token (regression) +# --------------------------------------------------------------------------- + + +class TestFetchCredentialBotToken: + @pytest.mark.asyncio + async def test_uses_bot_token_when_no_caller_token(self): + """_fetch_credential sends the CP OIDC token when caller_token is absent.""" + server = HTTPServer(("127.0.0.1", 0), _CredentialHandler) + port = server.server_address[1] + thread = Thread(target=server.handle_request, daemon=True) + thread.start() + + _CredentialHandler.response_body = {"token": "gh-tok-via-oidc"} + _CredentialHandler.captured_headers = {} + + cp_oidc_token = "cp-oidc-jwt-token" + + try: + with ( + patch.dict( + os.environ, + { + "BACKEND_API_URL": f"http://127.0.0.1:{port}/api", + "PROJECT_NAME": "test-project", + "CREDENTIAL_IDS": json.dumps({"github": "cred-gh-bot-test"}), + }, + ), + patch( + "ambient_runner.platform.auth.get_bot_token", + return_value=cp_oidc_token, + ), + ): + ctx = _make_context() + result = await _fetch_credential(ctx, "github") + + assert result.get("token") == "gh-tok-via-oidc" + assert _CredentialHandler.captured_headers.get("Authorization") == ( + f"Bearer {cp_oidc_token}" + ) + finally: + server.server_close() + thread.join(timeout=2) + + @pytest.mark.asyncio + async def test_bot_token_used_when_no_caller_token(self): + """CP OIDC token (get_bot_token) is used when caller_token is absent.""" + called_with = {} + + def fake_urlopen(req, timeout=None): + called_with["auth"] = req.get_header("Authorization") + mock_resp = MagicMock() + mock_resp.read.return_value = json.dumps({"token": "ok"}).encode() + mock_resp.__enter__ = lambda s: s + mock_resp.__exit__ = MagicMock(return_value=False) + return mock_resp + + with ( + patch.dict( + os.environ, + { + "BACKEND_API_URL": "http://backend.svc.cluster.local/api", + "PROJECT_NAME": "test-project", + "CREDENTIAL_IDS": json.dumps({"github": "cred-gh-pref"}), + }, + ), + patch("urllib.request.urlopen", side_effect=fake_urlopen), + patch( + "ambient_runner.platform.auth.get_bot_token", + return_value="cp-oidc-token", + ), + ): + ctx = _make_context() + await _fetch_credential(ctx, "github") + + assert called_with.get("auth") == "Bearer cp-oidc-token" + + # --------------------------------------------------------------------------- # gh CLI wrapper — ensures gh picks up refreshed tokens (issue #1135) # --------------------------------------------------------------------------- @@ -786,12 +889,15 @@ class TestGhWrapper: always uses the freshest token. """ - def _cleanup(self): - """Remove wrapper artifacts created during tests.""" + @staticmethod + def _get_auth_mod(): import ambient_runner.platform.auth as _auth_mod + return _auth_mod + def _cleanup(self): + """Remove wrapper artifacts created during tests.""" + _auth_mod = self._get_auth_mod() _auth_mod._gh_wrapper_installed = False - # Read the current module-level paths (not the stale import-time values) wrapper_path = _auth_mod._GH_WRAPPER_PATH wrapper_dir_path = _auth_mod._GH_WRAPPER_DIR if wrapper_path: @@ -805,10 +911,9 @@ def _cleanup(self): def test_install_creates_executable_wrapper(self): """install_gh_wrapper creates an executable script at _GH_WRAPPER_PATH.""" - import ambient_runner.platform.auth as _auth_mod - self._cleanup() try: + _auth_mod = self._get_auth_mod() install_gh_wrapper() wrapper = Path(_auth_mod._GH_WRAPPER_PATH) assert wrapper.exists(), "Wrapper script should be created" @@ -821,12 +926,10 @@ def test_install_creates_executable_wrapper(self): def test_install_prepends_to_path(self): """install_gh_wrapper prepends the wrapper dir to PATH.""" - import ambient_runner.platform.auth as _auth_mod - self._cleanup() + _auth_mod = self._get_auth_mod() original_path = os.environ.get("PATH", "") try: - # Remove wrapper dir from PATH if present (use current module value) current_dir = _auth_mod._GH_WRAPPER_DIR parts = [p for p in original_path.split(":") if p != current_dir] os.environ["PATH"] = ":".join(parts) @@ -844,9 +947,8 @@ def test_install_prepends_to_path(self): def test_install_is_idempotent(self): """Calling install_gh_wrapper twice does not duplicate PATH entries.""" - import ambient_runner.platform.auth as _auth_mod - self._cleanup() + _auth_mod = self._get_auth_mod() original_path = os.environ.get("PATH", "") try: current_dir = _auth_mod._GH_WRAPPER_DIR @@ -869,6 +971,7 @@ async def test_populate_installs_gh_wrapper(self): """populate_runtime_credentials installs the gh wrapper.""" self._cleanup() try: + _auth_mod = self._get_auth_mod() with patch("ambient_runner.platform.auth._fetch_credential") as mock_fetch: async def _creds(ctx, ctype): @@ -884,8 +987,6 @@ async def _creds(ctx, ctype): ctx = _make_context() await populate_runtime_credentials(ctx) - import ambient_runner.platform.auth as _auth_mod - wrapper = Path(_auth_mod._GH_WRAPPER_PATH) assert wrapper.exists(), ( "populate_runtime_credentials should install gh wrapper" diff --git a/docs/internal/proposals/alpha-to-main-migration.md b/docs/internal/proposals/alpha-to-main-migration.md index 453d28428..48a894c4b 100644 --- a/docs/internal/proposals/alpha-to-main-migration.md +++ b/docs/internal/proposals/alpha-to-main-migration.md @@ -35,165 +35,58 @@ removes this file. ## PR Checklist -### PR 1 — Migration Plan + Docs, Skills, and Claude Config -> Zero code risk. Safe to merge immediately. Combines the migration plan with all -> non-code documentation, skills, and config changes. +### PR 1 — Migration Plan + Docs, Skills, and Claude Config ✅ Merged +> Merged as PR #1354. - [x] Analyze alpha→main delta and component dependencies - [x] Write migration plan (`docs/internal/proposals/alpha-to-main-migration.md`) - [x] Fix alpha→main branch references in `.claude/skills/devflow/SKILL.md` -- [ ] `.claude/skills/ambient/SKILL.md` -- [ ] `.claude/skills/ambient-pr-test/SKILL.md` -- [ ] `.claude/skills/grpc-dev/SKILL.md` -- [ ] `.claude/settings.json` updates -- [ ] `CLAUDE.md` project-level updates -- [ ] `docs/internal/design/` — specs and guides: - - [ ] `README.md` - - [ ] `ambient-model.guide.md` - - [ ] `ambient-model.spec.md` - - [ ] `control-plane.guide.md` - - [ ] `control-plane.spec.md` - - [ ] `frontend-backend-migration-plan.md` - - [ ] `frontend-to-api-status.md` - - [ ] `mcp-server.guide.md` - - [ ] `mcp-server.spec.md` - - [ ] `runner.spec.md` -- [ ] `docs/internal/developer/local-development/openshift.md` -- [ ] Update this checklist -- [ ] Merge to main - -### PR 2 — ambient-api-server: OpenAPI Specs, Generated Client, New Kinds -> Foundation PR. All other components depend on its API surface. - -- [ ] New OpenAPI specs: - - [ ] `openapi/openapi.credentials.yaml` - - [ ] `openapi/openapi.inbox.yaml` - - [ ] `openapi/openapi.sessions.yaml` additions - - [ ] `openapi/openapi.agents.yaml` changes - - [ ] `openapi/openapi.projects.yaml` changes - - [ ] `openapi/openapi.yaml` root spec updates -- [ ] Generated Go client (`pkg/api/openapi/`) — regenerate, do not hand-edit -- [ ] Proto definitions: - - [ ] `proto/ambient/v1/inbox.proto` - - [ ] `proto/ambient/v1/sessions.proto` changes -- [ ] Generated proto Go code (`pkg/api/grpc/ambient/v1/`) -- [ ] New plugins: - - [ ] `plugins/credentials/` (model, handler, service, dao, presenter, migration, tests) - - [ ] `plugins/inbox/` (model, handler, service, dao, presenter, migration, tests) -- [ ] Service layer additions: - - [ ] `plugins/sessions/service.go` — `ActiveByAgentID`, `Start`, `Stop` - - [ ] `plugins/sessions/presenter.go` — new fields -- [ ] `cmd/ambient-api-server/main.go` — new plugin imports -- [ ] `Makefile` updates -- [ ] Verify: `make test` passes -- [ ] Verify: `golangci-lint run` passes -- [ ] Update this checklist -- [ ] Merge to main - -### PR 3 — ambient-sdk: Go + TypeScript Client Updates -> Depends on: PR 2 (api-server API surface) - -- [ ] Go SDK updates matching new API surface -- [ ] TypeScript SDK updates: - - [ ] `ts-sdk/src/session_message_api.ts` - - [ ] `ts-sdk/src/user.ts` - - [ ] `ts-sdk/src/user_api.ts` - - [ ] New/updated type definitions -- [ ] Removal of deprecated types (`ProjectAgent`, `ProjectDocument`, `Ignite`) - - [ ] Verify no main-branch code references removed types before merging -- [ ] New integration tests (`ts-sdk/tests/integration.test.ts`) -- [ ] Verify: SDK builds and tests pass -- [ ] Update this checklist -- [ ] Merge to main - -### PR 4 — ambient-control-plane: New Component -> Depends on: PR 2 (api-server API surface). Purely additive (0 deletions). - -- [ ] Core control plane: - - [ ] `cmd/` — entry point - - [ ] `internal/config/` — configuration - - [ ] `internal/watcher/watcher.go` — resource watcher - - [ ] `internal/handlers/` — reconciliation handlers -- [ ] Token server: - - [ ] `internal/tokenserver/server.go` - - [ ] `internal/tokenserver/handler.go` - - [ ] `internal/tokenserver/handler_test.go` -- [ ] Credential injection into runner pods -- [ ] Namespace provisioning -- [ ] Proxy environment forwarding (`HTTP_PROXY`, `HTTPS_PROXY`, `NO_PROXY`) -- [ ] RSA keypair auth for runner token endpoint -- [ ] Exponential backoff retry in informer -- [ ] Verify: `go vet ./...` and `golangci-lint run` pass -- [ ] Update this checklist -- [ ] Merge to main - -### PR 5 — ambient-cli: acpctl Enhancements -> Depends on: PR 2 (api-server), PR 3 (SDK) - -- [ ] `acpctl login --use-auth-code` — OAuth2 + PKCE flow (RHOAIENG-55812) -- [ ] Agent commands: - - [ ] `acpctl agent start` with `--all/-A` flag - - [ ] `acpctl agent stop` with `--all/-A` flag -- [ ] `acpctl session send -f` — follow mode -- [ ] Credential CLI verbs -- [ ] `pkg/config/config.go` — new config fields -- [ ] `pkg/config/token.go` + `token_test.go` — token management -- [ ] `pkg/connection/connection.go` — connection updates -- [ ] Security fixes and idempotent start -- [ ] Verify: `go vet ./...` and `golangci-lint run` pass -- [ ] Update this checklist -- [ ] Merge to main - -### PR 6 — runners: Auth, Credentials, gRPC, and SSE -> Depends on: PR 2 (api-server), PR 4 (control-plane token endpoint) - -- [ ] Credential system: - - [ ] `platform/auth.py` — `_fetch_credential` with caller/bot token fallback - - [ ] `platform/auth.py` — `populate_runtime_credentials`, `clear_runtime_credentials` - - [ ] `platform/auth.py` — `gh` CLI wrapper (`install_gh_wrapper`) - - [ ] `platform/auth.py` — `sanitize_user_context` -- [ ] `platform/utils.py` — `get_active_integrations` and helpers -- [ ] `platform/context.py` — `RunnerContext` updates -- [ ] `platform/prompts.py` — prompt additions -- [ ] `middleware/secret_redaction.py` — redaction changes -- [ ] `observability.py` — observability updates -- [ ] `tools/backend_api.py` — API tool updates -- [ ] gRPC transport and delta buffer -- [ ] SSE flush per chunk, unbounded tap queue -- [ ] CP OIDC token for backend credential fetches -- [ ] Tests: - - [ ] `tests/test_shared_session_credentials.py` - - [ ] `tests/test_bridge_claude.py` - - [ ] `tests/test_app_initial_prompt.py` - - [ ] `tests/test_events_endpoint.py` - - [ ] `tests/test_grpc_client.py` - - [ ] `tests/test_grpc_transport.py` - - [ ] `tests/test_grpc_writer.py` -- [ ] `pyproject.toml` dependency additions -- [ ] Verify: `python -m pytest tests/` passes -- [ ] Update this checklist -- [ ] Merge to main - -### PR 7 — manifests: Kustomize Overlays and RBAC -> Depends on: All component PRs (references their images/deployments) - -- [ ] `mpp-openshift` overlay: - - [ ] NetworkPolicy for runner→CP token server - - [ ] gRPC Route for ambient-api-server - - [ ] CP token Service + `CP_RUNTIME_NAMESPACE` + `CP_TOKEN_URL` - - [ ] MCP sidecar image wiring - - [ ] RBAC (`ambient-control-plane-rbac.yaml`) - - [ ] RoleBinding namespace fixes via Kustomize replacement - - [ ] Explicit namespaces per resource - - [ ] Remove hardcoded preprod hostname from route -- [ ] `production` overlay: - - [ ] `ambient-api-server-env-patch.yaml` - - [ ] `ambient-api-server-route.yaml` - - [ ] `kustomization.yaml` updates (components, patches, images) -- [ ] `openshift-dev` overlay: - - [ ] `kustomization.yaml` - - [ ] `ambient-api-server-env-patch.yaml` -- [ ] Verify: `kustomize build` succeeds for each overlay +- [x] Merge to main + +### PR 2 — ambient-api-server: OpenAPI Specs, Generated Client, New Kinds ✅ Merged +> Merged as PR #1368. + +- [x] All items completed +- [x] Merge to main + +### PR 3 — ambient-sdk: Go + TypeScript Client Updates ✅ Merged +> Merged as PR #1373. Adapted alpha code to main's generated SDK signatures +> (project-scoped basePath, URL naming). Deferred gRPC MessageWatcher/InboxWatcher +> (proto types not in published module). + +- [x] All items completed +- [x] Merge to main + +### PR 4 — ambient-control-plane: New Component ✅ Merged +> Merged as PR #1375. Purely additive, 21 new files. Fixed Credential API +> signature (removed projectID param) and gofmt. + +- [x] All items completed +- [x] Merge to main + +### PR 5 — ambient-cli: acpctl Enhancements ✅ Merged +> Merged as PR #1377. Adapted credential API calls (removed projectID), +> fixed Url→URL naming, replaced WatchSessionMessages with WatchMessages, +> removed agent.Version references, removed dead code for golangci-lint. + +- [x] All items completed +- [x] Merge to main + +### PR 6 — runners + manifests: Auth, Credentials, gRPC, SSE, and Kustomize Overlays +> Depends on: PR 2 (api-server), PR 4 (control-plane token endpoint). +> Combined with PR 7 (manifests) since all component PRs have landed. + +- [x] Credential system (`platform/auth.py`) +- [x] gRPC transport and delta buffer +- [x] SSE flush per chunk, unbounded tap queue +- [x] CP OIDC token for backend credential fetches +- [x] All runner tests pass (707 passed) +- [x] Ruff lint and format clean +- [x] `pyproject.toml` — added `cryptography`, `grpcio`, `protobuf` +- [x] `mpp-openshift` overlay (NetworkPolicy, gRPC Route, CP token, RBAC) +- [x] `production` overlay updates +- [x] `openshift-dev` overlay +- [x] `kustomize build` succeeds for all overlays - [ ] Update this checklist - [ ] Merge to main diff --git a/e2e/cypress/e2e/screenshots.cy.ts b/e2e/cypress/e2e/screenshots.cy.ts index bdf2d3671..25cfdb8c9 100644 --- a/e2e/cypress/e2e/screenshots.cy.ts +++ b/e2e/cypress/e2e/screenshots.cy.ts @@ -123,7 +123,7 @@ describe('Documentation Screenshots', () => { function setTheme(theme: 'light' | 'dark'): void { const label = theme === 'dark' ? 'Switch to dark theme' : 'Switch to light theme' - cy.get('button[aria-label="Toggle theme"]').first().click({ force: true }) + cy.get('button[aria-label="Toggle theme"]', { timeout: 10000 }).first().should('be.visible').click({ force: true }) // 10 s timeout: slow CI environments can take > 5 s for Radix to mount the dropdown content cy.get(`[aria-label="${label}"]`, { timeout: 10000 }).first().click({ force: true }) if (theme === 'dark') {