Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ If you find this GitHub repository useful, please consider giving it a free star
- [x] Support Application Inference Profiles (**new**)
- [x] Support Reasoning (**new**)
- [x] Support Interleaved thinking (**new**)
- [x] Support for Guardrails (**new**)

Please check [Usage Guide](./docs/Usage.md) for more details about how to use the new APIs.

Expand Down Expand Up @@ -99,6 +100,17 @@ After creation, you'll see your secret in the Secrets Manager console. Make not

That is it! 🎉 Once deployed, click the CloudFormation stack and go to **Outputs** tab, you can find the API Base URL from `APIBaseUrl`, the value should look like `http://xxxx.xxx.elb.amazonaws.com/api/v1`.

### Guardrails Integration
If you would like to create an AWS Bedrock Guardrail that is applied to FM interactions then please follow the following steps.
1. Create the AWS Bedrock Guardrail via the console or your favorite IaC language
2. Inject the following environment variables into your chosen deployment

```bash
ENABLE_GUARDRAIL=true
GUARDRAIL_IDENTIFIER=<arn of your guardrail - e.g. arn:aws:bedrock:us-east-1:123456789012:guardrail/ab4z4asdre90>
GUARDRAIL_VERSION=<version of your guardrail - e.g. 1>
```

### Troubleshooting

If you encounter any issues, please check the [Troubleshooting Guide](./docs/Troubleshooting.md) for more details.
Expand Down
54 changes: 41 additions & 13 deletions src/api/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@
DEBUG,
DEFAULT_MODEL,
ENABLE_CROSS_REGION_INFERENCE,
ENABLE_APPLICATION_INFERENCE_PROFILES,
ENABLE_APPLICATION_INFERENCE_PROFILES,
ENABLE_GUARDRAIL,
GUARDRAIL_IDENTIFIER,
GUARDRAIL_VERSION,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -527,12 +530,27 @@ def _parse_request(self, chat_request: ChatRequest) -> dict:
stop = [stop]
inference_config["stopSequences"] = stop

args = {
"modelId": chat_request.model,
"messages": messages,
"system": system_prompts,
"inferenceConfig": inference_config,
}
if ENABLE_GUARDRAIL:
guardrail_config = {
"guardrailIdentifier": GUARDRAIL_IDENTIFIER,
"guardrailVersion": GUARDRAIL_VERSION,
}

args = {
"modelId": chat_request.model,
"messages": messages,
"system": system_prompts,
"inferenceConfig": inference_config,
"guardrailConfig": guardrail_config,
}
else:
args = {
"modelId": chat_request.model,
"messages": messages,
"system": system_prompts,
"inferenceConfig": inference_config,
}

if chat_request.reasoning_effort:
# From OpenAI api, the max_token is not supported in reasoning mode
# Use max_completion_tokens if provided.
Expand Down Expand Up @@ -878,12 +896,22 @@ def _invoke_model(self, args: dict, model_id: str):
logger.info("Invoke Bedrock Model: " + model_id)
logger.info("Bedrock request body: " + body)
try:
return bedrock_runtime.invoke_model(
body=body,
modelId=model_id,
accept=self.accept,
contentType=self.content_type,
)
if ENABLE_GUARDRAIL:
return bedrock_runtime.invoke_model(
body=body,
modelId=model_id,
accept=self.accept,
contentType=self.content_type,
guardrailIdentifier=GUARDRAIL_IDENTIFIER,
guardrailVersion=GUARDRAIL_VERSION,
)
else:
return bedrock_runtime.invoke_model(
body=body,
modelId=model_id,
accept=self.accept,
contentType=self.content_type,
)
except bedrock_runtime.exceptions.ValidationException as e:
logger.error("Validation Error: " + str(e))
raise HTTPException(status_code=400, detail=str(e))
Expand Down
3 changes: 3 additions & 0 deletions src/api/setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@
DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3")
ENABLE_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_CROSS_REGION_INFERENCE", "true").lower() != "false"
ENABLE_APPLICATION_INFERENCE_PROFILES = os.environ.get("ENABLE_APPLICATION_INFERENCE_PROFILES", "true").lower() != "false"
ENABLE_GUARDRAIL = os.environ.get("ENABLE_GUARDRAIL", "false").lower() != "false"
GUARDRAIL_IDENTIFIER = os.environ.get("GUARDRAIL_IDENTIFIER")
GUARDRAIL_VERSION = os.environ.get("GUARDRAIL_VERSION")