Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions cdk/lambda/gateway-interceptor/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,25 @@
We decode it here (without verification) only to extract the tenantId claim.
"""

import os
import json
import logging
import base64
import uuid
from typing import Any, Dict, Optional

import boto3

logger = logging.getLogger()
logger.setLevel(logging.INFO)

REGION = os.getenv("AWS_REGION", "us-west-2")

ABAC_ROLE_ARN = os.getenv("ABAC_ROLE_ARN")

# Tools that require tenant-scoped ABAC credentials
ABAC_TOOLS = {"query_logs", "LogSearchTarget___search_logs"}


def _decode_jwt_payload(token: str) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -58,23 +68,66 @@ def _extract_tenant_id(headers: Dict[str, str]) -> Optional[str]:
return None


def _assume_tenant_role(tenant_id: str) -> Optional[Dict[str, str]]:
"""
Assume the ABAC role with tenant-scoped session tags.

Returns temporary credentials dict with AccessKeyId, SecretAccessKey,
and SessionToken, or None if ABAC is not configured.
"""
try:
sts = boto3.client("sts", region_name=REGION)
response = sts.assume_role(
RoleArn=ABAC_ROLE_ARN,
RoleSessionName=f"tenant-{tenant_id}-session",
Tags=[{"Key": "tenant_id", "Value": tenant_id}],
)
creds = response["Credentials"]
return {
"access_key_id": creds["AccessKeyId"],
"secret_access_key": creds["SecretAccessKey"],
"session_token": creds["SessionToken"],
}
except Exception as e:
logger.error(f"Failed to assume ABAC role for tenant {tenant_id}: {e}")
raise
return None


def _inject_tenant_id_into_tool_call(body: Dict[str, Any], tenant_id: str) -> Dict[str, Any]:
"""
Inject tenant_id into MCP tools/call arguments.

For tools/call requests, the arguments are in body.params.arguments.
For other MCP methods (tools/list, initialize, etc.), pass through unchanged.

For tools that require ABAC (listed in ABAC_TOOLS), also assumes a
tenant-scoped IAM role and injects temporary credentials so the
downstream tool handler doesn't need to manage tenant isolation itself.
"""
method = body.get("method", "")
if method != "tools/call":
return body

params = body.get("params", {})
arguments = params.get("arguments", {})
tool_name = params.get("name", "")

# Inject tenant_id (overwrites any agent-supplied value)
arguments["tenant_id"] = tenant_id

# LAB 2: Uncomment block below to inject ABAC credentials for specific tools
# logger.info(json.dumps({
# "tool_name": tool_name,
# "ABAC_TOOLS": list(ABAC_TOOLS),
# "tool_in_abac": tool_name in ABAC_TOOLS,
# "ABAC_ROLE_ARN": ABAC_ROLE_ARN,
# }))
# if tool_name in ABAC_TOOLS:
# creds = _assume_tenant_role(tenant_id)
# if creds:
# arguments["tenant_credentials"] = creds

# Return modified body
modified = body.copy()
modified["params"] = {**params, "arguments": arguments}
Expand Down
29 changes: 10 additions & 19 deletions cdk/lambda/log-mcp-handler/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
ATHENA_WORKGROUP = os.getenv("ATHENA_WORKGROUP", "primary")
ATHENA_OUTPUT = os.getenv("ATHENA_OUTPUT", "s3://your-athena-query-output/")

# LAB 2: Uncomment for ABAC
#ABAC_ROLE_ARN = os.getenv("ABAC_ROLE_ARN")

def _wait(qid: str, athena_client, timeout_s: int = 180):
start = time.time()
while time.time() - start < timeout_s:
Expand Down Expand Up @@ -103,22 +100,16 @@ def handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
sql = append_tenant_filter(user_sql, tenant_id)
logger.info(json.dumps({"tenant_id": event.get('tenant_id'), "sql": sql}))

# LAB 2: Uncomment block below and comment out the line after it
# sts = boto3.client("sts", region_name=REGION)
# response = sts.assume_role(
# RoleArn=ABAC_ROLE_ARN,
# RoleSessionName=f"tenant-{event.get('tenant_id')}-session",
# Tags=[{'Key': 'tenant_id', 'Value': event.get('tenant_id')}]
# )
# creds = response['Credentials']
# athena_client = boto3.client(
# "athena", region_name=REGION,
# aws_access_key_id=creds['AccessKeyId'],
# aws_secret_access_key=creds['SecretAccessKey'],
# aws_session_token=creds['SessionToken']
# )
# LAB 2: Comment this line
athena_client = boto3.client("athena", region_name=REGION)
tenant_creds = event.get("tenant_credentials")
if tenant_creds:
athena_client = boto3.client(
"athena", region_name=REGION,
aws_access_key_id=tenant_creds["access_key_id"],
aws_secret_access_key=tenant_creds["secret_access_key"],
aws_session_token=tenant_creds["session_token"],
)
else:
athena_client = boto3.client("athena", region_name=REGION)

db = event.get("database") or ATHENA_DB
rows = _exec(sql, athena_client, database=db)
Expand Down
36 changes: 16 additions & 20 deletions cdk/lib/agentcore-stack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,22 @@ export class AgentCoreStack extends cdk.NestedStack {
props.athenaWorkgroup,
logMcpHandlerRole
);
const basicRole = this.createLogMcpHandlerBasicRole();
const interceptorLambda = this.createInterceptorLambda();
const kbMcpLambda = this.createKbMcpHandlerLambda(kbId);

// UNCOMMENT:
// const basicRole = this.createLogMcpHandlerBasicRole();
// const abacRole = this.createAbacRole(basicRole, s3BucketName, props.athenaResultsBucketName, props.athenaDatabase, props.athenaTable, props.athenaWorkgroup);
// const abacRole = this.createAbacRole(interceptorLambda.role!, s3BucketName, props.athenaResultsBucketName, props.athenaDatabase, props.athenaTable, props.athenaWorkgroup);
// const logMcpLambda = this.createLogMcpHandlerLambda(s3BucketName, props.athenaResultsBucketName, props.athenaDatabase, props.athenaTable, props.athenaWorkgroup, basicRole, abacRole);
const kbMcpLambda = this.createKbMcpHandlerLambda(kbId);
// // Set ABAC_ROLE_ARN env var and grant STS permissions on the interceptor
// interceptorLambda.addEnvironment("ABAC_ROLE_ARN", abacRole.roleArn);
// interceptorLambda.addToRolePolicy(
// new iam.PolicyStatement({
// actions: ["sts:AssumeRole", "sts:TagSession"],
// resources: [abacRole.roleArn],
// })
// );

// Create Gateway Interceptor Lambda
const interceptorLambda = this.createInterceptorLambda();

// Create IAM role for AgentCore Gateway
const agentCoreRole = this.createAgentCoreRole();
Expand Down Expand Up @@ -171,31 +179,19 @@ export class AgentCoreStack extends cdk.NestedStack {
"service-role/AWSLambdaBasicExecutionRole"
),
],
inlinePolicies: {
AssumeAbacRole: new iam.PolicyDocument({
statements: [
new iam.PolicyStatement({
actions: ["sts:AssumeRole", "sts:TagSession"],
resources: [
`arn:aws:iam::${this.account}:role/*LogMcpHandlerAbacRole*`,
],
}),
],
}),
},
});
}

private createAbacRole(
basicRole: iam.Role,
interceptorRole: iam.IRole,
s3BucketName: string,
athenaResultsBucketName: string,
athenaDatabase: string,
athenaTable: string,
athenaWorkgroup: string
): iam.Role {
const role = new iam.Role(this, "LogMcpHandlerAbacRole", {
assumedBy: new iam.ArnPrincipal(basicRole.roleArn),
assumedBy: new iam.ArnPrincipal(interceptorRole.roleArn),
inlinePolicies: {
TenantSpecificAccess: new iam.PolicyDocument({
statements: [
Expand Down Expand Up @@ -267,7 +263,7 @@ export class AgentCoreStack extends cdk.NestedStack {
Statement: [
{
Effect: "Allow",
Principal: { AWS: basicRole.roleArn },
Principal: { AWS: interceptorRole.roleArn },
Action: ["sts:AssumeRole", "sts:TagSession"],
Condition: {
StringLike: {
Expand Down
2 changes: 1 addition & 1 deletion scripts/agentcore-provisioning/deploy-agentcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
logger.handlers = [handler]
logger.propagate = False

region = os.environ.get("AWS_REGION", "AWS_DEFAULT_REGION")
region = os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION", "us-west-2"))

def get_stack_outputs():
"""
Expand Down