diff --git a/README.md b/README.md index aff19b6..ccb8d51 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ If you find this GitHub repository useful, please consider giving it a free star - [x] Support Application Inference Profiles (**new**) - [x] Support Reasoning (**new**) - [x] Support Interleaved thinking (**new**) +- [x] Support for Guardrails (**new**) Please check [Usage Guide](./docs/Usage.md) for more details about how to use the new APIs. @@ -99,6 +100,17 @@ After creation, you'll see your secret in the Secrets Manager console. Make not That is it! 🎉 Once deployed, click the CloudFormation stack and go to **Outputs** tab, you can find the API Base URL from `APIBaseUrl`, the value should look like `http://xxxx.xxx.elb.amazonaws.com/api/v1`. +### Guardrails Integration +If you would like to create an AWS Bedrock Guardrail that is applied to FM interactions then please follow the following steps. +1. Create the AWS Bedrock Guardrail via the console or your favorite IaC language +2. Inject the following environment variables into your chosen deployment + +```bash +ENABLE_GUARDRAIL=true +GUARDRAIL_IDENTIFIER= +GUARDRAIL_VERSION= +``` + ### Troubleshooting If you encounter any issues, please check the [Troubleshooting Guide](./docs/Troubleshooting.md) for more details. diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index 374fcd1..dedd376 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -45,7 +45,10 @@ DEBUG, DEFAULT_MODEL, ENABLE_CROSS_REGION_INFERENCE, - ENABLE_APPLICATION_INFERENCE_PROFILES, + ENABLE_APPLICATION_INFERENCE_PROFILES, + ENABLE_GUARDRAIL, + GUARDRAIL_IDENTIFIER, + GUARDRAIL_VERSION, ) logger = logging.getLogger(__name__) @@ -527,12 +530,27 @@ def _parse_request(self, chat_request: ChatRequest) -> dict: stop = [stop] inference_config["stopSequences"] = stop - args = { - "modelId": chat_request.model, - "messages": messages, - "system": system_prompts, - "inferenceConfig": inference_config, - } + if ENABLE_GUARDRAIL: + guardrail_config = { + "guardrailIdentifier": GUARDRAIL_IDENTIFIER, + "guardrailVersion": GUARDRAIL_VERSION, + } + + args = { + "modelId": chat_request.model, + "messages": messages, + "system": system_prompts, + "inferenceConfig": inference_config, + "guardrailConfig": guardrail_config, + } + else: + args = { + "modelId": chat_request.model, + "messages": messages, + "system": system_prompts, + "inferenceConfig": inference_config, + } + if chat_request.reasoning_effort: # From OpenAI api, the max_token is not supported in reasoning mode # Use max_completion_tokens if provided. @@ -878,12 +896,22 @@ def _invoke_model(self, args: dict, model_id: str): logger.info("Invoke Bedrock Model: " + model_id) logger.info("Bedrock request body: " + body) try: - return bedrock_runtime.invoke_model( - body=body, - modelId=model_id, - accept=self.accept, - contentType=self.content_type, - ) + if ENABLE_GUARDRAIL: + return bedrock_runtime.invoke_model( + body=body, + modelId=model_id, + accept=self.accept, + contentType=self.content_type, + guardrailIdentifier=GUARDRAIL_IDENTIFIER, + guardrailVersion=GUARDRAIL_VERSION, + ) + else: + return bedrock_runtime.invoke_model( + body=body, + modelId=model_id, + accept=self.accept, + contentType=self.content_type, + ) except bedrock_runtime.exceptions.ValidationException as e: logger.error("Validation Error: " + str(e)) raise HTTPException(status_code=400, detail=str(e)) diff --git a/src/api/setting.py b/src/api/setting.py index 4e0a7bb..a929398 100644 --- a/src/api/setting.py +++ b/src/api/setting.py @@ -17,3 +17,6 @@ DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3") ENABLE_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_CROSS_REGION_INFERENCE", "true").lower() != "false" ENABLE_APPLICATION_INFERENCE_PROFILES = os.environ.get("ENABLE_APPLICATION_INFERENCE_PROFILES", "true").lower() != "false" +ENABLE_GUARDRAIL = os.environ.get("ENABLE_GUARDRAIL", "false").lower() != "false" +GUARDRAIL_IDENTIFIER = os.environ.get("GUARDRAIL_IDENTIFIER") +GUARDRAIL_VERSION = os.environ.get("GUARDRAIL_VERSION")