diff --git a/cdk/lambda/gateway-interceptor/handler.py b/cdk/lambda/gateway-interceptor/handler.py index bd2ea4c..a20749d 100644 --- a/cdk/lambda/gateway-interceptor/handler.py +++ b/cdk/lambda/gateway-interceptor/handler.py @@ -12,15 +12,25 @@ We decode it here (without verification) only to extract the tenantId claim. """ +import os import json import logging import base64 import uuid from typing import Any, Dict, Optional +import boto3 + logger = logging.getLogger() logger.setLevel(logging.INFO) +REGION = os.getenv("AWS_REGION", "us-west-2") + +ABAC_ROLE_ARN = os.getenv("ABAC_ROLE_ARN") + +# Tools that require tenant-scoped ABAC credentials +ABAC_TOOLS = {"query_logs", "LogSearchTarget___search_logs"} + def _decode_jwt_payload(token: str) -> Dict[str, Any]: """ @@ -58,12 +68,42 @@ def _extract_tenant_id(headers: Dict[str, str]) -> Optional[str]: return None +def _assume_tenant_role(tenant_id: str) -> Optional[Dict[str, str]]: + """ + Assume the ABAC role with tenant-scoped session tags. + + Returns temporary credentials dict with AccessKeyId, SecretAccessKey, + and SessionToken, or None if ABAC is not configured. + """ + try: + sts = boto3.client("sts", region_name=REGION) + response = sts.assume_role( + RoleArn=ABAC_ROLE_ARN, + RoleSessionName=f"tenant-{tenant_id}-session", + Tags=[{"Key": "tenant_id", "Value": tenant_id}], + ) + creds = response["Credentials"] + return { + "access_key_id": creds["AccessKeyId"], + "secret_access_key": creds["SecretAccessKey"], + "session_token": creds["SessionToken"], + } + except Exception as e: + logger.error(f"Failed to assume ABAC role for tenant {tenant_id}: {e}") + raise + return None + + def _inject_tenant_id_into_tool_call(body: Dict[str, Any], tenant_id: str) -> Dict[str, Any]: """ Inject tenant_id into MCP tools/call arguments. For tools/call requests, the arguments are in body.params.arguments. For other MCP methods (tools/list, initialize, etc.), pass through unchanged. + + For tools that require ABAC (listed in ABAC_TOOLS), also assumes a + tenant-scoped IAM role and injects temporary credentials so the + downstream tool handler doesn't need to manage tenant isolation itself. """ method = body.get("method", "") if method != "tools/call": @@ -71,10 +111,23 @@ def _inject_tenant_id_into_tool_call(body: Dict[str, Any], tenant_id: str) -> Di params = body.get("params", {}) arguments = params.get("arguments", {}) + tool_name = params.get("name", "") # Inject tenant_id (overwrites any agent-supplied value) arguments["tenant_id"] = tenant_id + # LAB 2: Uncomment block below to inject ABAC credentials for specific tools + # logger.info(json.dumps({ + # "tool_name": tool_name, + # "ABAC_TOOLS": list(ABAC_TOOLS), + # "tool_in_abac": tool_name in ABAC_TOOLS, + # "ABAC_ROLE_ARN": ABAC_ROLE_ARN, + # })) + # if tool_name in ABAC_TOOLS: + # creds = _assume_tenant_role(tenant_id) + # if creds: + # arguments["tenant_credentials"] = creds + # Return modified body modified = body.copy() modified["params"] = {**params, "arguments": arguments} diff --git a/cdk/lambda/log-mcp-handler/handler.py b/cdk/lambda/log-mcp-handler/handler.py index 5256ee3..68186f7 100644 --- a/cdk/lambda/log-mcp-handler/handler.py +++ b/cdk/lambda/log-mcp-handler/handler.py @@ -19,9 +19,6 @@ ATHENA_WORKGROUP = os.getenv("ATHENA_WORKGROUP", "primary") ATHENA_OUTPUT = os.getenv("ATHENA_OUTPUT", "s3://your-athena-query-output/") -# LAB 2: Uncomment for ABAC -#ABAC_ROLE_ARN = os.getenv("ABAC_ROLE_ARN") - def _wait(qid: str, athena_client, timeout_s: int = 180): start = time.time() while time.time() - start < timeout_s: @@ -103,22 +100,16 @@ def handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]: sql = append_tenant_filter(user_sql, tenant_id) logger.info(json.dumps({"tenant_id": event.get('tenant_id'), "sql": sql})) - # LAB 2: Uncomment block below and comment out the line after it - # sts = boto3.client("sts", region_name=REGION) - # response = sts.assume_role( - # RoleArn=ABAC_ROLE_ARN, - # RoleSessionName=f"tenant-{event.get('tenant_id')}-session", - # Tags=[{'Key': 'tenant_id', 'Value': event.get('tenant_id')}] - # ) - # creds = response['Credentials'] - # athena_client = boto3.client( - # "athena", region_name=REGION, - # aws_access_key_id=creds['AccessKeyId'], - # aws_secret_access_key=creds['SecretAccessKey'], - # aws_session_token=creds['SessionToken'] - # ) - # LAB 2: Comment this line - athena_client = boto3.client("athena", region_name=REGION) + tenant_creds = event.get("tenant_credentials") + if tenant_creds: + athena_client = boto3.client( + "athena", region_name=REGION, + aws_access_key_id=tenant_creds["access_key_id"], + aws_secret_access_key=tenant_creds["secret_access_key"], + aws_session_token=tenant_creds["session_token"], + ) + else: + athena_client = boto3.client("athena", region_name=REGION) db = event.get("database") or ATHENA_DB rows = _exec(sql, athena_client, database=db) diff --git a/cdk/lib/agentcore-stack.ts b/cdk/lib/agentcore-stack.ts index 2d97c9f..5377bec 100644 --- a/cdk/lib/agentcore-stack.ts +++ b/cdk/lib/agentcore-stack.ts @@ -78,14 +78,22 @@ export class AgentCoreStack extends cdk.NestedStack { props.athenaWorkgroup, logMcpHandlerRole ); + const basicRole = this.createLogMcpHandlerBasicRole(); + const interceptorLambda = this.createInterceptorLambda(); + const kbMcpLambda = this.createKbMcpHandlerLambda(kbId); + // UNCOMMENT: - // const basicRole = this.createLogMcpHandlerBasicRole(); - // const abacRole = this.createAbacRole(basicRole, s3BucketName, props.athenaResultsBucketName, props.athenaDatabase, props.athenaTable, props.athenaWorkgroup); + // const abacRole = this.createAbacRole(interceptorLambda.role!, s3BucketName, props.athenaResultsBucketName, props.athenaDatabase, props.athenaTable, props.athenaWorkgroup); // const logMcpLambda = this.createLogMcpHandlerLambda(s3BucketName, props.athenaResultsBucketName, props.athenaDatabase, props.athenaTable, props.athenaWorkgroup, basicRole, abacRole); - const kbMcpLambda = this.createKbMcpHandlerLambda(kbId); + // // Set ABAC_ROLE_ARN env var and grant STS permissions on the interceptor + // interceptorLambda.addEnvironment("ABAC_ROLE_ARN", abacRole.roleArn); + // interceptorLambda.addToRolePolicy( + // new iam.PolicyStatement({ + // actions: ["sts:AssumeRole", "sts:TagSession"], + // resources: [abacRole.roleArn], + // }) + // ); - // Create Gateway Interceptor Lambda - const interceptorLambda = this.createInterceptorLambda(); // Create IAM role for AgentCore Gateway const agentCoreRole = this.createAgentCoreRole(); @@ -171,23 +179,11 @@ export class AgentCoreStack extends cdk.NestedStack { "service-role/AWSLambdaBasicExecutionRole" ), ], - inlinePolicies: { - AssumeAbacRole: new iam.PolicyDocument({ - statements: [ - new iam.PolicyStatement({ - actions: ["sts:AssumeRole", "sts:TagSession"], - resources: [ - `arn:aws:iam::${this.account}:role/*LogMcpHandlerAbacRole*`, - ], - }), - ], - }), - }, }); } private createAbacRole( - basicRole: iam.Role, + interceptorRole: iam.IRole, s3BucketName: string, athenaResultsBucketName: string, athenaDatabase: string, @@ -195,7 +191,7 @@ export class AgentCoreStack extends cdk.NestedStack { athenaWorkgroup: string ): iam.Role { const role = new iam.Role(this, "LogMcpHandlerAbacRole", { - assumedBy: new iam.ArnPrincipal(basicRole.roleArn), + assumedBy: new iam.ArnPrincipal(interceptorRole.roleArn), inlinePolicies: { TenantSpecificAccess: new iam.PolicyDocument({ statements: [ @@ -267,7 +263,7 @@ export class AgentCoreStack extends cdk.NestedStack { Statement: [ { Effect: "Allow", - Principal: { AWS: basicRole.roleArn }, + Principal: { AWS: interceptorRole.roleArn }, Action: ["sts:AssumeRole", "sts:TagSession"], Condition: { StringLike: { diff --git a/scripts/agentcore-provisioning/deploy-agentcore.py b/scripts/agentcore-provisioning/deploy-agentcore.py index d01d359..7b74b51 100755 --- a/scripts/agentcore-provisioning/deploy-agentcore.py +++ b/scripts/agentcore-provisioning/deploy-agentcore.py @@ -24,7 +24,7 @@ logger.handlers = [handler] logger.propagate = False -region = os.environ.get("AWS_REGION", "AWS_DEFAULT_REGION") +region = os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION", "us-west-2")) def get_stack_outputs(): """