diff --git a/agent/kb_agent.py b/agent/kb_agent.py index 9d27c2a..0ebd55c 100644 --- a/agent/kb_agent.py +++ b/agent/kb_agent.py @@ -34,29 +34,21 @@ def kb_agent_tool(query: str, top_k: int = 5) -> str: if not kb_gateway_url: raise ValueError("KB_GATEWAY_URL environment variable is not set") + decoded = ops_context.decode_jwt_claims(access_token) + tenant_id = decoded.get("tenantId") + streamable_http_mcp_client = MCPClient( lambda: streamablehttp_client( kb_gateway_url, headers={ "Authorization": f"{access_token}", + "X-Tenant-ID": tenant_id }, ) ) - decoded = ops_context.decode_jwt_claims(access_token) - tenant_id = decoded.get("tenantId") - with streamable_http_mcp_client: - tools = [] - - for t in streamable_http_mcp_client.list_tools_sync(): - if t.tool_name != "x_amz_bedrock_agentcore_search": - tool = wrapped_tool.WrappedTool(t) - tool.bind_param("tenant_id", tenant_id) - - tools.append(tool) - else: - tools.append(t) + tools = list(streamable_http_mcp_client.list_tools_sync()) kb_agent = Agent( name="kb_agent", diff --git a/agent/log_agent.py b/agent/log_agent.py index 1003263..1d0c80e 100644 --- a/agent/log_agent.py +++ b/agent/log_agent.py @@ -32,30 +32,22 @@ def log_agent_tool(query: str) -> str: if not log_gateway_url: raise ValueError("LOG_GATEWAY_URL environment variable is not set") + decoded = ops_context.decode_jwt_claims(access_token) + tenant_id = decoded.get("tenantId") + streamable_http_mcp_client = MCPClient( lambda: streamablehttp_client( log_gateway_url, headers={ "Authorization": f"{access_token}", + "X-Tenant-ID": tenant_id }, ) ) - decoded = ops_context.decode_jwt_claims(access_token) - tenant_id = decoded.get("tenantId") - with streamable_http_mcp_client: - tools = [] + tools = list(streamable_http_mcp_client.list_tools_sync()) - for t in streamable_http_mcp_client.list_tools_sync(): - if t.tool_name != "x_amz_bedrock_agentcore_search": - tool = wrapped_tool.WrappedTool(t) - tool.bind_param("tenant_id", tenant_id) - - tools.append(tool) - else: - tools.append(t) - system_prompt = """You are a log analysis agent that searches tenant application logs using Amazon Athena-compatible SQL queries. TENANT_LOGS SCHEMA: diff --git a/cdk/lambda/gateway-interceptor/handler.py b/cdk/lambda/gateway-interceptor/handler.py new file mode 100644 index 0000000..ad4aab9 --- /dev/null +++ b/cdk/lambda/gateway-interceptor/handler.py @@ -0,0 +1,34 @@ +import json +import uuid + +def lambda_handler(event, context): + # Extract the gateway request + mcp_data = event.get('mcp', {}) + gateway_request = mcp_data.get('gatewayRequest', {}) + headers = gateway_request.get('headers', {}) + body = gateway_request.get('body', {}) + extended_body = body + + auth_header = headers.get('authorization', '') or headers.get('Authorization', '') + + # Extract Tenant Id from custom header for propagation + tenant_id = headers.get('X-Tenant-ID', '') + + if "params" in extended_body and "arguments" in extended_body["params"]: + # Add custom header to arguments for downstream processing + extended_body["params"]["arguments"]["tenant_id"] = tenant_id + + # Return transformed request without passing the original authorization header + response = { + "interceptorOutputVersion": "1.0", + "mcp": { + "transformedGatewayRequest": { + "headers": { + "Accept": "application/json", + "Content-Type": "application/json" + }, + "body": extended_body + } + } + } + return response \ No newline at end of file diff --git a/cdk/lib/agentcore-stack.ts b/cdk/lib/agentcore-stack.ts index 5f44982..78be016 100644 --- a/cdk/lib/agentcore-stack.ts +++ b/cdk/lib/agentcore-stack.ts @@ -25,6 +25,7 @@ export class AgentCoreStack extends cdk.NestedStack { public readonly agentCoreRoleArn: string; public readonly logMcpLambdaArn: string; public readonly kbMcpLambdaArn: string; + public readonly gatewayInterceptorArn: string; constructor(scope: Construct, id: string, props: AgentCoreStackProps) { super(scope, id, props); @@ -82,12 +83,14 @@ export class AgentCoreStack extends cdk.NestedStack { // const abacRole = this.createAbacRole(basicRole, s3BucketName, props.athenaResultsBucketName, props.athenaDatabase, props.athenaTable, props.athenaWorkgroup); // const logMcpLambda = this.createLogMcpHandlerLambda(s3BucketName, props.athenaResultsBucketName, props.athenaDatabase, props.athenaTable, props.athenaWorkgroup, basicRole, abacRole); const kbMcpLambda = this.createKbMcpHandlerLambda(kbId); + const gatewayInterceptorLambda = this.createGatewayInterceptorLambda(); // Create IAM role for AgentCore Gateway const agentCoreRole = this.createAgentCoreRole(); logMcpLambda.grantInvoke(agentCoreRole); kbMcpLambda.grantInvoke(agentCoreRole); - + gatewayInterceptorLambda.grantInvoke(agentCoreRole); + // Store outputs as public properties this.userPoolId = userPool.userPoolId; this.userClientId = userClient.userPoolClientId; @@ -96,6 +99,7 @@ export class AgentCoreStack extends cdk.NestedStack { this.agentCoreRoleArn = agentCoreRole.roleArn; this.logMcpLambdaArn = logMcpLambda.functionArn; this.kbMcpLambdaArn = kbMcpLambda.functionArn; + this.gatewayInterceptorArn = gatewayInterceptorLambda.functionArn; // Create CloudWatch log groups for gateway logs const logGatewayLogGroup = this.createGatewayLogGroup("LogGateway"); @@ -150,6 +154,11 @@ export class AgentCoreStack extends cdk.NestedStack { value: athenaResultsBucket.bucketName, description: "The name of the Athena results bucket", }); + + new cdk.CfnOutput(this, "GatewayInterceptorArn", { + value: gatewayInterceptorLambda.functionArn, + description: "The ARN of the Gateway Interceptor Lambda", + }); } private createLogMcpHandlerBasicRole(): iam.Role { @@ -445,6 +454,38 @@ export class AgentCoreStack extends cdk.NestedStack { }); } + private createGatewayInterceptorRole(): iam.Role { + return new iam.Role(this, "GatewayInterceptorRole", { + roleName: "AgentCore-Gateway-Interceptor-Role", + assumedBy: new iam.ServicePrincipal("lambda.amazonaws.com"), + managedPolicies: [ + iam.ManagedPolicy.fromAwsManagedPolicyName( + "service-role/AWSLambdaBasicExecutionRole" + ), + ], + description: "IAM role for Gateway Interceptor Lambda", + }); + } + + private createGatewayInterceptorLambda(): lambda.Function { + const interceptorRole = this.createGatewayInterceptorRole(); + return new lambda.Function(this, "GatewayInterceptor", { + functionName: "AgentCore-Gateway-Interceptor", + description: "Request Gateway interceptor to extract tenant_id from JWT", + runtime: lambda.Runtime.PYTHON_3_12, + handler: "handler.lambda_handler", + code: lambda.Code.fromAsset( + path.join(__dirname, "../lambda/gateway-interceptor") + ), + role:interceptorRole, + timeout: cdk.Duration.seconds(30), + memorySize: 256, + environment: { + LOG_LEVEL: "INFO" + } + }); + } + private createResourceServer( userPool: cdk.aws_cognito.UserPool, identifier: string diff --git a/scripts/agentcore-provisioning/deploy-agentcore.py b/scripts/agentcore-provisioning/deploy-agentcore.py index faa0be7..c87e17b 100755 --- a/scripts/agentcore-provisioning/deploy-agentcore.py +++ b/scripts/agentcore-provisioning/deploy-agentcore.py @@ -40,6 +40,7 @@ def get_stack_outputs(): "AgentCoreRoleArn": os.environ.get("AGENT_CORE_ROLE_ARN"), "LogMcpLambdaArn": os.environ.get("LOG_MCP_LAMBDA_ARN"), "KbMcpLambdaArn": os.environ.get("KB_MCP_LAMBDA_ARN"), + "GatewayInterceptorArn": os.environ.get("GATEWAY_INTERCEPTOR_ARN"), } # Check if all required environment variables are set @@ -296,6 +297,7 @@ def create_log_mcp_server( m2m_client_id, region, log_lambda_arn, + interceptor_lambda_arn, recreate=False, ): logger.info("1.1: Creating Log MCP Server") @@ -323,6 +325,18 @@ def create_log_mcp_server( "allowedClients": [user_client_id, m2m_client_id], } }, + interceptorConfigurations=[ + { + 'interceptor': { + 'lambda': { + 'arn': interceptor_lambda_arn + } + }, + 'interceptionPoints': ['REQUEST'], + 'inputConfiguration': { + 'passRequestHeaders': True + } + }], exceptionLevel="DEBUG", ) gateway_id = response["gatewayId"] @@ -357,12 +371,7 @@ def create_log_mcp_server( "type": "string", "description": "Amazon Athena-compatible search query", }, - "tenant_id": { - "type": "string", - "description": "Tenant identifier for multi-tenant log isolation", - }, }, - "required": ["query", "tenant_id"], }, } ] @@ -386,6 +395,7 @@ def create_kb_mcp_server( m2m_client_id, region, kb_lambda_arn, + interceptor_lambda_arn, recreate=False, ): logger.info("1.2: Creating KB MCP Server") @@ -413,6 +423,18 @@ def create_kb_mcp_server( "allowedClients": [user_client_id, m2m_client_id], } }, + interceptorConfigurations=[ + { + 'interceptor': { + 'lambda': { + 'arn': interceptor_lambda_arn + } + }, + 'interceptionPoints': ['REQUEST'], + 'inputConfiguration': { + 'passRequestHeaders': True + } + }], exceptionLevel="DEBUG", ) gateway_id = response["gatewayId"] @@ -447,16 +469,11 @@ def create_kb_mcp_server( "type": "string", "description": "Free text search query", }, - "tenant_id": { - "type": "string", - "description": "Tenant identifier for multi-tenant knowledge base isolation", - }, "top_k": { "type": "integer", "description": "Maximum number of results to return", }, }, - "required": ["query", "tenant_id"], }, } ] @@ -679,6 +696,7 @@ def main(): role_arn = stack_outputs["AgentCoreRoleArn"] log_lambda_arn = stack_outputs["LogMcpLambdaArn"] kb_lambda_arn = stack_outputs["KbMcpLambdaArn"] + interceptor_lambda_arn = stack_outputs["GatewayInterceptorArn"] discovery_url = f"https://cognito-idp.{region}.amazonaws.com/{user_pool_id}/.well-known/openid-configuration" # Check if resources exist and recreate if needed @@ -695,6 +713,7 @@ def main(): m2m_client_id, region, log_lambda_arn, + interceptor_lambda_arn, log_exists and args.recreate, ) kb_gateway_id = create_kb_mcp_server( @@ -704,6 +723,7 @@ def main(): m2m_client_id, region, kb_lambda_arn, + interceptor_lambda_arn, kb_exists and args.recreate, )