Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions agent/kb_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,29 +34,21 @@ def kb_agent_tool(query: str, top_k: int = 5) -> str:
if not kb_gateway_url:
raise ValueError("KB_GATEWAY_URL environment variable is not set")

decoded = ops_context.decode_jwt_claims(access_token)
tenant_id = decoded.get("tenantId")

streamable_http_mcp_client = MCPClient(
lambda: streamablehttp_client(
kb_gateway_url,
headers={
"Authorization": f"{access_token}",
"X-Tenant-ID": tenant_id
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The access_token already has the tenant id, so lets not add the tenant_id in the headers.

},
)
)

decoded = ops_context.decode_jwt_claims(access_token)
tenant_id = decoded.get("tenantId")

with streamable_http_mcp_client:
tools = []

for t in streamable_http_mcp_client.list_tools_sync():
if t.tool_name != "x_amz_bedrock_agentcore_search":
tool = wrapped_tool.WrappedTool(t)
tool.bind_param("tenant_id", tenant_id)

tools.append(tool)
else:
tools.append(t)
tools = list(streamable_http_mcp_client.list_tools_sync())

kb_agent = Agent(
name="kb_agent",
Expand Down
18 changes: 5 additions & 13 deletions agent/log_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,30 +32,22 @@ def log_agent_tool(query: str) -> str:
if not log_gateway_url:
raise ValueError("LOG_GATEWAY_URL environment variable is not set")

decoded = ops_context.decode_jwt_claims(access_token)
tenant_id = decoded.get("tenantId")

streamable_http_mcp_client = MCPClient(
lambda: streamablehttp_client(
log_gateway_url,
headers={
"Authorization": f"{access_token}",
"X-Tenant-ID": tenant_id
},
)
)

decoded = ops_context.decode_jwt_claims(access_token)
tenant_id = decoded.get("tenantId")

with streamable_http_mcp_client:
tools = []
tools = list(streamable_http_mcp_client.list_tools_sync())

for t in streamable_http_mcp_client.list_tools_sync():
if t.tool_name != "x_amz_bedrock_agentcore_search":
tool = wrapped_tool.WrappedTool(t)
tool.bind_param("tenant_id", tenant_id)

tools.append(tool)
else:
tools.append(t)

system_prompt = """You are a log analysis agent that searches tenant application logs using Amazon Athena-compatible SQL queries.

TENANT_LOGS SCHEMA:
Expand Down
34 changes: 34 additions & 0 deletions cdk/lambda/gateway-interceptor/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import json
import uuid

def lambda_handler(event, context):
# Extract the gateway request
mcp_data = event.get('mcp', {})
gateway_request = mcp_data.get('gatewayRequest', {})
headers = gateway_request.get('headers', {})
body = gateway_request.get('body', {})
extended_body = body

auth_header = headers.get('authorization', '') or headers.get('Authorization', '')

# Extract Tenant Id from custom header for propagation
tenant_id = headers.get('X-Tenant-ID', '')
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get this tenant id from the access token which is in the Authorization header


if "params" in extended_body and "arguments" in extended_body["params"]:
# Add custom header to arguments for downstream processing
extended_body["params"]["arguments"]["tenant_id"] = tenant_id

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have kb-mcp-handler and log-mcp-handler tools(lambda) which uses tenant_id. So this tenant_id will be available as a part of lambda handler event?

# Return transformed request without passing the original authorization header
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add/move this assume role code from log-mcp-handler, to here. Then send the generated credentials as a arguments so that log-mcp-handler tool can directly use the generated credentials.

response = {
"interceptorOutputVersion": "1.0",
"mcp": {
"transformedGatewayRequest": {
"headers": {
"Accept": "application/json",
"Content-Type": "application/json"
},
"body": extended_body
}
}
}
return response
43 changes: 42 additions & 1 deletion cdk/lib/agentcore-stack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ export class AgentCoreStack extends cdk.NestedStack {
public readonly agentCoreRoleArn: string;
public readonly logMcpLambdaArn: string;
public readonly kbMcpLambdaArn: string;
public readonly gatewayInterceptorArn: string;

constructor(scope: Construct, id: string, props: AgentCoreStackProps) {
super(scope, id, props);
Expand Down Expand Up @@ -82,12 +83,14 @@ export class AgentCoreStack extends cdk.NestedStack {
// const abacRole = this.createAbacRole(basicRole, s3BucketName, props.athenaResultsBucketName, props.athenaDatabase, props.athenaTable, props.athenaWorkgroup);
// const logMcpLambda = this.createLogMcpHandlerLambda(s3BucketName, props.athenaResultsBucketName, props.athenaDatabase, props.athenaTable, props.athenaWorkgroup, basicRole, abacRole);
const kbMcpLambda = this.createKbMcpHandlerLambda(kbId);
const gatewayInterceptorLambda = this.createGatewayInterceptorLambda();

// Create IAM role for AgentCore Gateway
const agentCoreRole = this.createAgentCoreRole();
logMcpLambda.grantInvoke(agentCoreRole);
kbMcpLambda.grantInvoke(agentCoreRole);

gatewayInterceptorLambda.grantInvoke(agentCoreRole);

// Store outputs as public properties
this.userPoolId = userPool.userPoolId;
this.userClientId = userClient.userPoolClientId;
Expand All @@ -96,6 +99,7 @@ export class AgentCoreStack extends cdk.NestedStack {
this.agentCoreRoleArn = agentCoreRole.roleArn;
this.logMcpLambdaArn = logMcpLambda.functionArn;
this.kbMcpLambdaArn = kbMcpLambda.functionArn;
this.gatewayInterceptorArn = gatewayInterceptorLambda.functionArn;

// Create CloudWatch log groups for gateway logs
const logGatewayLogGroup = this.createGatewayLogGroup("LogGateway");
Expand Down Expand Up @@ -150,6 +154,11 @@ export class AgentCoreStack extends cdk.NestedStack {
value: athenaResultsBucket.bucketName,
description: "The name of the Athena results bucket",
});

new cdk.CfnOutput(this, "GatewayInterceptorArn", {
value: gatewayInterceptorLambda.functionArn,
description: "The ARN of the Gateway Interceptor Lambda",
});
}

private createLogMcpHandlerBasicRole(): iam.Role {
Expand Down Expand Up @@ -445,6 +454,38 @@ export class AgentCoreStack extends cdk.NestedStack {
});
}

private createGatewayInterceptorRole(): iam.Role {
return new iam.Role(this, "GatewayInterceptorRole", {
roleName: "AgentCore-Gateway-Interceptor-Role",
assumedBy: new iam.ServicePrincipal("lambda.amazonaws.com"),
managedPolicies: [
iam.ManagedPolicy.fromAwsManagedPolicyName(
"service-role/AWSLambdaBasicExecutionRole"
),
],
description: "IAM role for Gateway Interceptor Lambda",
});
}

private createGatewayInterceptorLambda(): lambda.Function {
const interceptorRole = this.createGatewayInterceptorRole();
return new lambda.Function(this, "GatewayInterceptor", {
functionName: "AgentCore-Gateway-Interceptor",
description: "Request Gateway interceptor to extract tenant_id from JWT",
runtime: lambda.Runtime.PYTHON_3_12,
handler: "handler.lambda_handler",
code: lambda.Code.fromAsset(
path.join(__dirname, "../lambda/gateway-interceptor")
),
role:interceptorRole,
timeout: cdk.Duration.seconds(30),
memorySize: 256,
environment: {
LOG_LEVEL: "INFO"
}
});
}

private createResourceServer(
userPool: cdk.aws_cognito.UserPool,
identifier: string
Expand Down
40 changes: 30 additions & 10 deletions scripts/agentcore-provisioning/deploy-agentcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def get_stack_outputs():
"AgentCoreRoleArn": os.environ.get("AGENT_CORE_ROLE_ARN"),
"LogMcpLambdaArn": os.environ.get("LOG_MCP_LAMBDA_ARN"),
"KbMcpLambdaArn": os.environ.get("KB_MCP_LAMBDA_ARN"),
"GatewayInterceptorArn": os.environ.get("GATEWAY_INTERCEPTOR_ARN"),
}

# Check if all required environment variables are set
Expand Down Expand Up @@ -296,6 +297,7 @@ def create_log_mcp_server(
m2m_client_id,
region,
log_lambda_arn,
interceptor_lambda_arn,
recreate=False,
):
logger.info("1.1: Creating Log MCP Server")
Expand Down Expand Up @@ -323,6 +325,18 @@ def create_log_mcp_server(
"allowedClients": [user_client_id, m2m_client_id],
}
},
interceptorConfigurations=[
{
'interceptor': {
'lambda': {
'arn': interceptor_lambda_arn
}
},
'interceptionPoints': ['REQUEST'],
'inputConfiguration': {
'passRequestHeaders': True
}
}],
exceptionLevel="DEBUG",
)
gateway_id = response["gatewayId"]
Expand Down Expand Up @@ -357,12 +371,7 @@ def create_log_mcp_server(
"type": "string",
"description": "Amazon Athena-compatible search query",
},
"tenant_id": {
"type": "string",
"description": "Tenant identifier for multi-tenant log isolation",
},
},
"required": ["query", "tenant_id"],
},
}
]
Expand All @@ -386,6 +395,7 @@ def create_kb_mcp_server(
m2m_client_id,
region,
kb_lambda_arn,
interceptor_lambda_arn,
recreate=False,
):
logger.info("1.2: Creating KB MCP Server")
Expand Down Expand Up @@ -413,6 +423,18 @@ def create_kb_mcp_server(
"allowedClients": [user_client_id, m2m_client_id],
}
},
interceptorConfigurations=[
{
'interceptor': {
'lambda': {
'arn': interceptor_lambda_arn
}
},
'interceptionPoints': ['REQUEST'],
'inputConfiguration': {
'passRequestHeaders': True
}
}],
exceptionLevel="DEBUG",
)
gateway_id = response["gatewayId"]
Expand Down Expand Up @@ -447,16 +469,11 @@ def create_kb_mcp_server(
"type": "string",
"description": "Free text search query",
},
"tenant_id": {
"type": "string",
"description": "Tenant identifier for multi-tenant knowledge base isolation",
},
"top_k": {
"type": "integer",
"description": "Maximum number of results to return",
},
},
"required": ["query", "tenant_id"],
},
}
]
Expand Down Expand Up @@ -679,6 +696,7 @@ def main():
role_arn = stack_outputs["AgentCoreRoleArn"]
log_lambda_arn = stack_outputs["LogMcpLambdaArn"]
kb_lambda_arn = stack_outputs["KbMcpLambdaArn"]
interceptor_lambda_arn = stack_outputs["GatewayInterceptorArn"]
discovery_url = f"https://cognito-idp.{region}.amazonaws.com/{user_pool_id}/.well-known/openid-configuration"

# Check if resources exist and recreate if needed
Expand All @@ -695,6 +713,7 @@ def main():
m2m_client_id,
region,
log_lambda_arn,
interceptor_lambda_arn,
log_exists and args.recreate,
)
kb_gateway_id = create_kb_mcp_server(
Expand All @@ -704,6 +723,7 @@ def main():
m2m_client_id,
region,
kb_lambda_arn,
interceptor_lambda_arn,
kb_exists and args.recreate,
)

Expand Down