diff --git a/javelin_sdk/services/guardrails_service.py b/javelin_sdk/services/guardrails_service.py index 228a9a0..eb37370 100644 --- a/javelin_sdk/services/guardrails_service.py +++ b/javelin_sdk/services/guardrails_service.py @@ -27,7 +27,7 @@ def _handle_guardrails_response(self, response: httpx.Response) -> None: def apply_trustsafety( self, text: str, config: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: - data: Dict[str, Any] = {"text": text} + data: Dict[str, Any] = {"input": {"text": text}} if config: data["config"] = config response = self.client._send_request_sync( @@ -43,7 +43,7 @@ def apply_trustsafety( def apply_promptinjectiondetection( self, text: str, config: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: - data: Dict[str, Any] = {"text": text} + data: Dict[str, Any] = {"input": {"text": text}} if config: data["config"] = config response = self.client._send_request_sync( @@ -56,8 +56,12 @@ def apply_promptinjectiondetection( self._handle_guardrails_response(response) return response.json() - def apply_guardrails(self, text: str, guardrails: list) -> Dict[str, Any]: - data = {"text": text, "guardrails": guardrails} + def apply_guardrails( + self, text: str, guardrails: list, config: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + data: Dict[str, Any] = {"input": {"text": text}, "guardrails": guardrails} + if config: + data["config"] = config response = self.client._send_request_sync( Request( method=HttpMethod.POST,