diff --git a/docs/user-guides/community/openai.md b/docs/user-guides/community/openai.md new file mode 100644 index 000000000..2150e4be8 --- /dev/null +++ b/docs/user-guides/community/openai.md @@ -0,0 +1,16 @@ +## OpenAI API Compatibility for NeMo Guardrails + +NeMo Guardrails provides server-side compatibility with OpenAI API endpoints, enabling applications that use OpenAI clients to seamlessly integrate with NeMo Guardrails for adding guardrails to LLM interactions. Point your OpenAI client to `http://localhost:8000` (or your server URL) and use the standard `/v1/chat/completions` endpoint. + +## Feature Support Matrix + +The following table outlines which OpenAI API features are currently supported when using NeMo Guardrails: + +| Feature | Status | Notes | +| :------ | :----: | :---- | +| **Basic Chat Completion** | ✔ Supported | Full support for standard chat completions with guardrails applied | +| **Streaming Responses** | ✔ Supported | Server-Sent Events (SSE) streaming with `stream=true` | +| **Multimodal Input** | ✖ Unsupported | Support for text and image inputs (vision models) with guardrails but not yet OpenAI compatible | +| **Function Calling** | ✖ Unsupported | Not yet implemented; guardrails need structured output support | +| **Tools** | ✖ Unsupported | Related to function calling; requires action flow integration | +| **Response Format (JSON Mode)** | ✖ Unsupported | Structured output with guardrails requires additional validation logic | diff --git a/nemoguardrails/colang/v2_x/runtime/runtime.py b/nemoguardrails/colang/v2_x/runtime/runtime.py index 6980714bc..9cbbcb776 100644 --- a/nemoguardrails/colang/v2_x/runtime/runtime.py +++ b/nemoguardrails/colang/v2_x/runtime/runtime.py @@ -31,6 +31,7 @@ ColangSyntaxError, ) from nemoguardrails.colang.v2_x.runtime.flows import Event, FlowStatus +from nemoguardrails.colang.v2_x.runtime.serialization import json_to_state from nemoguardrails.colang.v2_x.runtime.statemachine import ( FlowConfig, InternalEvent, @@ -394,10 +395,13 @@ async def process_events( state = State(flow_states={}, flow_configs=self.flow_configs, rails_config=self.config) initialize_state(state) elif isinstance(state, dict): - # TODO: Implement dict to State conversion - raise NotImplementedError() - # if isinstance(state, dict): - # state = State.from_dict(state) + # Convert dict to State object + if state.get("version") == "2.x" and "state" in state: + # Handle the serialized state format from API calls + state = json_to_state(state["state"]) + else: + # TODO: Implement other dict to State conversion formats if needed + raise NotImplementedError("Unsupported state dict format") assert isinstance(state, State) assert state.main_flow_state is not None diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index c4d33f83d..528a4c6d3 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -490,6 +490,11 @@ def _init_llms(self): if not self.llm: self.llm = llm_model self.runtime.register_action_param("llm", self.llm) + self._configure_main_llm_streaming( + self.llm, + model_name=llm_config.model, + provider_name=llm_config.engine, + ) else: model_name = f"{llm_config.type}_llm" if not hasattr(self, model_name): diff --git a/nemoguardrails/server/api.py b/nemoguardrails/server/api.py index 658cffd01..3fb726588 100644 --- a/nemoguardrails/server/api.py +++ b/nemoguardrails/server/api.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import asyncio import contextvars import importlib.util @@ -20,23 +21,27 @@ import os.path import re import time +import uuid import warnings from contextlib import asynccontextmanager -from typing import Any, Callable, List, Optional +from typing import Any, AsyncIterator, Callable, List, Optional from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware +from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_message import ChatCompletionMessage from pydantic import BaseModel, Field, root_validator, validator from starlette.responses import StreamingResponse from starlette.staticfiles import StaticFiles from nemoguardrails import LLMRails, RailsConfig, utils -from nemoguardrails.rails.llm.options import ( - GenerationLog, - GenerationOptions, - GenerationResponse, -) +from nemoguardrails.rails.llm.options import GenerationOptions, GenerationResponse from nemoguardrails.server.datastore.datastore import DataStore +from nemoguardrails.server.schemas.openai import ( + GuardrailsModel, + ModelsResponse, + ResponseBody, +) from nemoguardrails.streaming import StreamingHandler logging.basicConfig(level=logging.INFO) @@ -228,10 +233,53 @@ class RequestBody(BaseModel): default=None, description="A state object that should be used to continue the interaction.", ) + # Standard OpenAI completion parameters + model: Optional[str] = Field( + default="main", + description="The model to use for chat completion. Maps to the main model in the config.", + ) + max_tokens: Optional[int] = Field( + default=None, + description="The maximum number of tokens to generate.", + ) + temperature: Optional[float] = Field( + default=None, + description="Sampling temperature to use.", + ) + top_p: Optional[float] = Field( + default=None, + description="Top-p sampling parameter.", + ) + stop: Optional[str] = Field( + default=None, + description="Stop sequences.", + ) + presence_penalty: Optional[float] = Field( + default=None, + description="Presence penalty parameter.", + ) + frequency_penalty: Optional[float] = Field( + default=None, + description="Frequency penalty parameter.", + ) + function_call: Optional[dict] = Field( + default=None, + description="Function call parameter.", + ) + logit_bias: Optional[dict] = Field( + default=None, + description="Logit bias parameter.", + ) + log_probs: Optional[bool] = Field( + default=None, + description="Log probabilities parameter.", + ) @root_validator(pre=True) def ensure_config_id(cls, data: Any) -> Any: if isinstance(data, dict): + if data.get("model") is not None and data.get("config_id") is None: + data["config_id"] = data["model"] if data.get("config_id") is not None and data.get("config_ids") is not None: raise ValueError("Only one of config_id or config_ids should be specified") if data.get("config_id") is None and data.get("config_ids") is not None: @@ -248,21 +296,72 @@ def ensure_config_ids(cls, v, values): return v -class ResponseBody(BaseModel): - messages: Optional[List[dict]] = Field(default=None, description="The new messages in the conversation") - llm_output: Optional[dict] = Field( - default=None, - description="Contains any additional output coming from the LLM.", - ) - output_data: Optional[dict] = Field( - default=None, - description="The output data, i.e. a dict with the values corresponding to the `output_vars`.", - ) - log: Optional[GenerationLog] = Field(default=None, description="Additional logging information.") - state: Optional[dict] = Field( - default=None, - description="A state object that should be used to continue the interaction in the future.", - ) +@app.get( + "/v1/models", + response_model=ModelsResponse, + summary="List available models", + description="Lists the currently available models, mapping guardrails configurations to OpenAI-compatible model format.", +) +async def get_models(): + """Returns the list of available models (guardrails configurations) in OpenAI-compatible format.""" + + # Use the same logic as get_rails_configs to find available configurations + if app.single_config_mode: + config_ids = [app.single_config_id] if app.single_config_id else [] + + else: + config_ids = [ + f + for f in os.listdir(app.rails_config_path) + if os.path.isdir(os.path.join(app.rails_config_path, f)) + and f[0] != "." + and f[0] != "_" + # Filter out all the configs for which there is no `config.yml` file. + and ( + os.path.exists(os.path.join(app.rails_config_path, f, "config.yml")) + or os.path.exists(os.path.join(app.rails_config_path, f, "config.yaml")) + ) + ] + + models = [] + for config_id in config_ids: + try: + # Load the RailsConfig to extract model information + if app.single_config_mode: + config_path = app.rails_config_path + else: + config_path = os.path.join(app.rails_config_path, config_id) + + rails_config = RailsConfig.from_path(config_path) + # Extract all models from this config + config_models = rails_config.models + + if len(config_models) == 0: + guardrails_model = GuardrailsModel( + id=config_id, + object="model", + created=int(time.time()), + owned_by="nemo-guardrails", + guardrails_config_id=config_id, + ) + models.append(guardrails_model) + else: + for model in config_models: + # Only include models with a model name + if model.model: + guardrails_model = GuardrailsModel( + id=model.model, + object="model", + created=int(time.time()), + owned_by="nemo-guardrails", + guardrails_config_id=config_id, + ) + models.append(guardrails_model) + except Exception as ex: + log.warning(f"Could not load model info for config {config_id}: {ex}") + continue + + return ModelsResponse(data=models) @app.get( @@ -305,6 +404,14 @@ def _generate_cache_key(config_ids: List[str]) -> str: return "-".join((config_ids)) # remove sorted +def _get_main_model_name(rails_config: RailsConfig) -> Optional[str]: + """Extracts the main model name from a RailsConfig.""" + main_models = [m for m in rails_config.models if m.type == "main"] + if main_models and main_models[0].model: + return main_models[0].model + return None + + def _get_rails(config_ids: List[str]) -> LLMRails: """Returns the rails instance for the given config id.""" @@ -355,6 +462,73 @@ def _get_rails(config_ids: List[str]) -> LLMRails: return llm_rails +async def _format_streaming_response( + streaming_handler: StreamingHandler, model_name: Optional[str] +) -> AsyncIterator[str]: + while True: + try: + chunk = await streaming_handler.__anext__() + except StopAsyncIteration: + # When the stream ends, yield the [DONE] message + yield "data: [DONE]\n\n" + break + + # Determine the payload format based on chunk type + if isinstance(chunk, dict): + # If chunk is a dict, wrap it in OpenAI chunk format with delta + payload = { + "id": None, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model_name, + "choices": [ + { + "delta": chunk, + "index": 0, + "finish_reason": None, + } + ], + } + elif isinstance(chunk, str): + try: + # Try parsing as JSON - if it parses, it might be a pre-formed payload + payload = json.loads(chunk) + except Exception: + # treat as plain text content token + payload = { + "id": None, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model_name, + "choices": [ + { + "delta": {"content": chunk}, + "index": 0, + "finish_reason": None, + } + ], + } + else: + # For any other type, treat as plain content + payload = { + "id": None, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model_name, + "choices": [ + { + "delta": {"content": str(chunk)}, + "index": 0, + "finish_reason": None, + } + ], + } + + # Send the payload as JSON + data = json.dumps(payload, ensure_ascii=False) + yield f"data: {data}\n\n" + + @app.post( "/v1/chat/completions", response_model=ResponseBody, @@ -375,6 +549,7 @@ async def chat_completion(body: RequestBody, request: Request): # Use Request config_ids if set, otherwise use the FastAPI default config. # If neither is available we can't generate any completions as we have no config_id config_ids = body.config_ids + if not config_ids: if app.default_config_id: config_ids = [app.default_config_id] @@ -383,19 +558,33 @@ async def chat_completion(body: RequestBody, request: Request): try: llm_rails = _get_rails(config_ids) + except ValueError as ex: log.exception(ex) return ResponseBody( - messages=[ - { - "role": "assistant", - "content": f"Could not load the {config_ids} guardrails configuration. " - f"An internal error has occurred.", - } - ] + id=f"chatcmpl-{uuid.uuid4()}", + object="chat.completion", + created=int(time.time()), + model=config_ids[0] if config_ids else "unknown", + choices=[ + Choice( + index=0, + message=ChatCompletionMessage( + content=f"Could not load the {config_ids} guardrails configuration. " + f"An internal error has occurred.", + role="assistant", + ), + finish_reason="stop", + logprobs=None, + ) + ], ) try: + main_model_name = _get_main_model_name(llm_rails.config) + if main_model_name is None: + main_model_name = config_ids[0] if config_ids else "unknown" + messages = body.messages or [] if body.context: messages.insert(0, {"role": "context", "content": body.context}) @@ -406,16 +595,24 @@ async def chat_completion(body: RequestBody, request: Request): if body.thread_id: if datastore is None: raise RuntimeError("No DataStore has been configured.") - # We make sure the `thread_id` meets the minimum complexity requirement. if len(body.thread_id) < 16: return ResponseBody( - messages=[ - { - "role": "assistant", - "content": "The `thread_id` must have a minimum length of 16 characters.", - } - ] + id=f"chatcmpl-{uuid.uuid4()}", + object="chat.completion", + created=int(time.time()), + model=main_model_name, + choices=[ + Choice( + index=0, + message=ChatCompletionMessage( + content="The `thread_id` must have a minimum length of 16 characters.", + role="assistant", + ), + finish_reason="stop", + logprobs=None, + ) + ], ) # Fetch the existing thread messages. For easier management, we prepend @@ -426,6 +623,25 @@ async def chat_completion(body: RequestBody, request: Request): # And prepend them. messages = thread_messages + messages + generation_options = body.options + + # Initialize llm_params if not already set + if generation_options.llm_params is None: + generation_options.llm_params = {} + + # Set OpenAI-compatible parameters in llm_params + if body.max_tokens: + generation_options.llm_params["max_tokens"] = body.max_tokens + if body.temperature is not None: + generation_options.llm_params["temperature"] = body.temperature + if body.top_p is not None: + generation_options.llm_params["top_p"] = body.top_p + if body.stop: + generation_options.llm_params["stop"] = body.stop + if body.presence_penalty is not None: + generation_options.llm_params["presence_penalty"] = body.presence_penalty + if body.frequency_penalty is not None: + generation_options.llm_params["frequency_penalty"] = body.frequency_penalty if body.stream and llm_rails.config.streaming_supported and llm_rails.main_llm_supports_streaming: # Create the streaming handler instance streaming_handler = StreamingHandler() @@ -435,16 +651,17 @@ async def chat_completion(body: RequestBody, request: Request): llm_rails.generate_async( messages=messages, streaming_handler=streaming_handler, - options=body.options, + options=generation_options, state=body.state, ) ) - # TODO: Add support for thread_ids in streaming mode - - return StreamingResponse(streaming_handler) + return StreamingResponse( + _format_streaming_response(streaming_handler, model_name=main_model_name), + media_type="text/event-stream", + ) else: - res = await llm_rails.generate_async(messages=messages, options=body.options, state=body.state) + res = await llm_rails.generate_async(messages=messages, options=generation_options, state=body.state) if isinstance(res, GenerationResponse): bot_message_content = res.response[0] @@ -462,20 +679,53 @@ async def chat_completion(body: RequestBody, request: Request): if body.thread_id and datastore is not None and datastore_key is not None: await datastore.set(datastore_key, json.dumps(messages + [bot_message])) - result = ResponseBody(messages=[bot_message]) - - # If we have additional GenerationResponse fields, we return as well + # Build the response with OpenAI-compatible format + response_kwargs = { + "id": f"chatcmpl-{uuid.uuid4()}", + "object": "chat.completion", + "created": int(time.time()), + "model": main_model_name, + "choices": [ + Choice( + index=0, + message=ChatCompletionMessage( + role="assistant", + content=bot_message["content"], + ), + finish_reason="stop", + logprobs=None, + ) + ], + } + + # If we have additional GenerationResponse fields, include them for backward compatibility if isinstance(res, GenerationResponse): - result.llm_output = res.llm_output - result.output_data = res.output_data - result.log = res.log - result.state = res.state + response_kwargs["llm_output"] = res.llm_output + response_kwargs["output_data"] = res.output_data + response_kwargs["log"] = res.log + response_kwargs["state"] = res.state - return result + return ResponseBody(**response_kwargs) except Exception as ex: log.exception(ex) - return ResponseBody(messages=[{"role": "assistant", "content": "Internal server error."}]) + return ResponseBody( + id=f"chatcmpl-{uuid.uuid4()}", + object="chat.completion", + created=int(time.time()), + model=config_ids[0] if config_ids else "unknown", + choices=[ + Choice( + index=0, + message=ChatCompletionMessage( + content="Internal server error", + role="assistant", + ), + finish_reason="stop", + logprobs=None, + ) + ], + ) # By default, there are no challenges diff --git a/nemoguardrails/server/schemas/openai.py b/nemoguardrails/server/schemas/openai.py new file mode 100644 index 000000000..fff6d020b --- /dev/null +++ b/nemoguardrails/server/schemas/openai.py @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OpenAI API schema definitions for the NeMo Guardrails server.""" + +from typing import List, Optional + +from openai.types.chat.chat_completion import ChatCompletion +from openai.types.model import Model +from pydantic import BaseModel, Field + + +class ResponseBody(ChatCompletion): + """OpenAI API response body with NeMo-Guardrails extensions.""" + + guardrails_config_id: Optional[str] = Field( + default=None, + description="The guardrails configuration ID associated with this response.", + ) + state: Optional[dict] = Field(default=None, description="State object for continuing the conversation.") + llm_output: Optional[dict] = Field(default=None, description="Additional LLM output data.") + output_data: Optional[dict] = Field(default=None, description="Additional output data.") + log: Optional[dict] = Field(default=None, description="Generation log data.") + + +class GuardrailsModel(Model): + """OpenAI API model with NeMo-Guardrails extensions.""" + + guardrails_config_id: Optional[str] = Field( + default=None, + description="[NeMo Guardrails extension] The guardrails configuration ID associated with this model.", + ) + + +class ModelsResponse(BaseModel): + """OpenAI API models list response with NeMo-Guardrails extensions.""" + + object: str = Field(default="list", description="The object type, which is always 'list'.") + data: List[GuardrailsModel] = Field(description="The list of models.") diff --git a/nemoguardrails/streaming.py b/nemoguardrails/streaming.py index 7cf8ac7c3..5a862c18d 100644 --- a/nemoguardrails/streaming.py +++ b/nemoguardrails/streaming.py @@ -173,18 +173,37 @@ async def __anext__(self): async def _process( self, - chunk: Union[str, object], + chunk: Union[str, dict, object], generation_info: Optional[Dict[str, Any]] = None, ): - """Process a chunk of text. + """Process a chunk of text or dict. If we're in buffering mode, record the text. Otherwise, update the full completion, check for stop tokens, and enqueue the chunk. + Dict chunks bypass completion tracking and go directly to the queue. """ if self.include_generation_metadata and generation_info: self.current_generation_info = generation_info + # Dict chunks bypass buffering and completion tracking + if isinstance(chunk, dict): + if self.pipe_to: + asyncio.create_task(self.pipe_to.push_chunk(chunk)) + else: + if self.include_generation_metadata: + await self.queue.put( + { + "text": chunk, + "generation_info": ( + self.current_generation_info.copy() if self.current_generation_info else {} + ), + } + ) + else: + await self.queue.put(chunk) + return + if self.enable_buffer: if chunk is not END_OF_STREAM: self.buffer += chunk if chunk is not None else "" @@ -254,10 +273,28 @@ async def _process( async def push_chunk( self, - chunk: Union[str, GenerationChunk, AIMessageChunk, ChatGenerationChunk, None], + chunk: Union[ + str, + dict, + GenerationChunk, + AIMessageChunk, + ChatGenerationChunk, + None, + object, + ], generation_info: Optional[Dict[str, Any]] = None, ): - """Push a new chunk to the stream.""" + """Push a new chunk to the stream. + + Args: + chunk: The chunk to push. Can be: + - str: Plain text content + - dict: Dictionary with fields like role, content, etc. + - GenerationChunk/AIMessageChunk/ChatGenerationChunk: LangChain chunk types + - None: Signals end of stream (converted to END_OF_STREAM) + - object: END_OF_STREAM sentinel + generation_info: Optional metadata about the generation + """ # if generation_info is not explicitly passed, # try to get it from the chunk itself if it's a GenerationChunk or ChatGenerationChunk @@ -281,6 +318,9 @@ async def push_chunk( elif isinstance(chunk, str): # empty string is a valid chunk and should be processed normally pass + elif isinstance(chunk, dict): + # plain dict chunks are allowed (e.g., for OpenAI-compatible streaming) + pass else: raise Exception(f"Unsupported chunk type: {chunk.__class__.__name__}") @@ -291,6 +331,11 @@ async def push_chunk( if self.include_generation_metadata and generation_info: self.current_generation_info = generation_info + # Dict chunks bypass prefix/suffix processing and go directly to _process + if isinstance(chunk, dict): + await self._process(chunk, generation_info) + return + # Process prefix: accumulate until the expected prefix is received, then remove it. if self.prefix: if chunk is not None and chunk is not END_OF_STREAM: diff --git a/poetry.lock b/poetry.lock index 9e24d2a40..8cfb2a9c6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "accessible-pygments" @@ -22,7 +22,7 @@ tests = ["hypothesis", "pytest"] name = "aiofiles" version = "24.1.0" description = "File support for asyncio." -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "aiofiles-24.1.0-py3-none-any.whl", hash = "sha256:b4ec55f4195e3eb5d7abd1bf7e061763e864dd4954231fb8539a0ef8bb8260e5"}, @@ -918,7 +918,7 @@ files = [ name = "distro" version = "1.9.0" description = "Distro - an OS platform information API" -optional = true +optional = false python-versions = ">=3.6" files = [ {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, @@ -1722,7 +1722,7 @@ i18n = ["Babel (>=2.7)"] name = "jiter" version = "0.10.0" description = "Fast iterable JSON parser." -optional = true +optional = false python-versions = ">=3.9" files = [ {file = "jiter-0.10.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:cd2fb72b02478f06a900a5782de2ef47e0396b3e1f7d5aba30daeb1fce66f303"}, @@ -2921,7 +2921,7 @@ sympy = "*" name = "openai" version = "1.102.0" description = "The official Python library for the openai API" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "openai-1.102.0-py3-none-any.whl", hash = "sha256:d751a7e95e222b5325306362ad02a7aa96e1fab3ed05b5888ce1c7ca63451345"}, @@ -4023,13 +4023,13 @@ dev = ["build", "flake8", "mypy", "pytest", "twine"] [[package]] name = "pyright" -version = "1.1.405" +version = "1.1.407" description = "Command line wrapper for pyright" optional = false python-versions = ">=3.7" files = [ - {file = "pyright-1.1.405-py3-none-any.whl", hash = "sha256:a2cb13700b5508ce8e5d4546034cb7ea4aedb60215c6c33f56cec7f53996035a"}, - {file = "pyright-1.1.405.tar.gz", hash = "sha256:5c2a30e1037af27eb463a1cc0b9f6d65fec48478ccf092c1ac28385a15c55763"}, + {file = "pyright-1.1.407-py3-none-any.whl", hash = "sha256:6dd419f54fcc13f03b52285796d65e639786373f433e243f8b94cf93a7444d21"}, + {file = "pyright-1.1.407.tar.gz", hash = "sha256:099674dba5c10489832d4a4b2d302636152a9a42d317986c38474c76fe562262"}, ] [package.dependencies] @@ -6194,16 +6194,16 @@ files = [ cffi = ["cffi (>=1.17)"] [extras] -all = ["aiofiles", "google-cloud-language", "langchain-nvidia-ai-endpoints", "langchain-openai", "numpy", "numpy", "numpy", "numpy", "opentelemetry-api", "presidio-analyzer", "presidio-anonymizer", "streamlit", "tqdm", "yara-python"] +all = ["google-cloud-language", "langchain-nvidia-ai-endpoints", "langchain-openai", "numpy", "numpy", "numpy", "numpy", "opentelemetry-api", "presidio-analyzer", "presidio-anonymizer", "streamlit", "tqdm", "yara-python"] eval = ["numpy", "numpy", "numpy", "numpy", "streamlit", "tornado", "tqdm"] gcp = ["google-cloud-language"] jailbreak = ["yara-python"] nvidia = ["langchain-nvidia-ai-endpoints"] openai = ["langchain-openai"] sdd = ["presidio-analyzer", "presidio-anonymizer"] -tracing = ["aiofiles", "opentelemetry-api"] +tracing = ["opentelemetry-api"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.14" -content-hash = "d5e8dc8fdbad5781141f4c65671d115060aa4c99abca0bd72ec025781352b775" +content-hash = "a048d4ecee654c25ea1be4a65cfccf4bb51289b2aa4db72afd5d096f3d2add1a" diff --git a/pyproject.toml b/pyproject.toml index f3452a964..78d78bbd5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,10 +71,11 @@ starlette = ">=0.49.1" typer = ">=0.8" uvicorn = ">=0.23" watchdog = ">=3.0.0," +aiofiles = ">=24.1.0" +openai = ">=1.0.0, <2.0.0" # tracing opentelemetry-api = { version = ">=1.27.0,<2.0.0", optional = true } -aiofiles = { version = ">=24.1.0", optional = true } # openai langchain-openai = { version = ">=0.1.0", optional = true } @@ -110,7 +111,7 @@ sdd = ["presidio-analyzer", "presidio-anonymizer"] eval = ["tqdm", "numpy", "streamlit", "tornado"] openai = ["langchain-openai"] gcp = ["google-cloud-language"] -tracing = ["opentelemetry-api", "aiofiles"] +tracing = ["opentelemetry-api"] nvidia = ["langchain-nvidia-ai-endpoints"] jailbreak = ["yara-python"] # Poetry does not support recursive dependencies, so we need to add all the dependencies here. @@ -125,7 +126,6 @@ all = [ "langchain-openai", "google-cloud-language", "opentelemetry-api", - "aiofiles", "langchain-nvidia-ai-endpoints", "yara-python", ] diff --git a/tests/test_api.py b/tests/test_api.py index b6619fe7a..8c50b5c01 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -13,13 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio +import json import os import pytest from fastapi.testclient import TestClient from nemoguardrails.server import api -from nemoguardrails.server.api import RequestBody +from nemoguardrails.server.api import RequestBody, _format_streaming_response +from nemoguardrails.streaming import END_OF_STREAM, StreamingHandler + +LIVE_TEST_MODE = os.environ.get("LIVE_TEST_MODE") or os.environ.get("TEST_LIVE_MODE") client = TestClient(api.app) @@ -41,7 +46,31 @@ def test_get(): assert len(result) > 0 -@pytest.mark.skip(reason="Should only be run locally as it needs OpenAI key.") +def test_get_models(): + """Test the OpenAI-compatible /v1/models endpoint.""" + response = client.get("/v1/models") + assert response.status_code == 200 + + result = response.json() + + # Check OpenAI models list format + assert result["object"] == "list" + assert "data" in result + assert len(result["data"]) > 0 + + # Check each model has the required OpenAI format + for model in result["data"]: + assert "id" in model + assert "guardrails_config_id" in model + assert model["object"] == "model" + assert "created" in model + assert model["owned_by"] == "nemo-guardrails" + + +@pytest.mark.skipif( + not LIVE_TEST_MODE, + reason="This test requires LIVE_TEST_MODE or TEST_LIVE_MODE environment variable to be set for live testing", +) def test_chat_completion(): response = client.post( "/v1/chat/completions", @@ -57,11 +86,20 @@ def test_chat_completion(): ) assert response.status_code == 200 res = response.json() - assert len(res["messages"]) == 1 - assert res["messages"][0]["content"] + # Check OpenAI-compatible response structure + assert res["object"] == "chat.completion" + assert "id" in res + assert "created" in res + assert "model" in res + assert len(res["choices"]) == 1 + assert res["choices"][0]["message"]["content"] + assert res["choices"][0]["message"]["role"] == "assistant" -@pytest.mark.skip(reason="Should only be run locally as it needs OpenAI key.") +@pytest.mark.skipif( + not LIVE_TEST_MODE, + reason="This test requires LIVE_TEST_MODE or TEST_LIVE_MODE environment variable to be set for live testing", +) def test_chat_completion_with_default_configs(): api.set_default_config_id("general") @@ -78,8 +116,14 @@ def test_chat_completion_with_default_configs(): ) assert response.status_code == 200 res = response.json() - assert len(res["messages"]) == 1 - assert res["messages"][0]["content"] + # Check OpenAI-compatible response structure + assert res["object"] == "chat.completion" + assert "id" in res + assert "created" in res + assert "model" in res + assert len(res["choices"]) == 1 + assert res["choices"][0]["message"]["content"] + assert res["choices"][0]["message"]["role"] == "assistant" def test_request_body_validation(): @@ -113,6 +157,31 @@ def test_request_body_validation(): assert request_body.config_ids is None +def test_openai_model_field_mapping(): + """Test OpenAI-compatible model field mapping to config_id.""" + + # Test model field maps to config_id + data = { + "model": "test_model", + "messages": [{"role": "user", "content": "Hello"}], + } + request_body = RequestBody.model_validate(data) + assert request_body.model == "test_model" + assert request_body.config_id == "test_model" + assert request_body.config_ids == ["test_model"] + + # Test model and config_id both provided (config_id takes precedence) + data = { + "model": "test_model", + "config_id": "test_config", + "messages": [{"role": "user", "content": "Hello"}], + } + request_body = RequestBody.model_validate(data) + assert request_body.model == "test_model" + assert request_body.config_id == "test_config" + assert request_body.config_ids == ["test_config"] + + def test_request_body_state(): """Test RequestBody state handling.""" data = { @@ -142,3 +211,301 @@ def test_request_body_messages(): } request_body = RequestBody.model_validate(data) assert len(request_body.messages) == 1 + + +@pytest.mark.asyncio +async def test_openai_sse_format_basic_chunks(): + """Test basic string chunks are properly formatted as SSE events.""" + handler = StreamingHandler() + + # Collect yielded SSE messages + collected = [] + + async def collector(): + async for b in _format_streaming_response(handler, model_name=None): + collected.append(b) + + task = asyncio.create_task(collector()) + + # Push a couple of chunks and then signal completion + await handler.push_chunk("Hello ") + await handler.push_chunk("world") + await handler.push_chunk(END_OF_STREAM) + + # Wait for the collector task to finish + await task + + # We expect three messages: two data: {json}\n\n events and final data: [DONE]\n\n + assert len(collected) == 3 + # First two are JSON SSE events + evt1 = collected[0] + evt2 = collected[1] + done = collected[2] + + assert evt1.startswith("data: ") + j1 = json.loads(evt1[len("data: ") :].strip()) + assert j1["object"] == "chat.completion.chunk" + assert j1["choices"][0]["delta"]["content"] == "Hello " + + assert evt2.startswith("data: ") + j2 = json.loads(evt2[len("data: ") :].strip()) + assert j2["choices"][0]["delta"]["content"] == "world" + + assert done == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_with_model_name(): + """Test that model name is properly included in the response.""" + handler = StreamingHandler() + collected = [] + + async def collector(): + async for b in _format_streaming_response(handler, model_name="gpt-4"): + collected.append(b) + + task = asyncio.create_task(collector()) + + await handler.push_chunk("Test") + await handler.push_chunk(END_OF_STREAM) + + await task + + assert len(collected) == 2 + evt = collected[0] + j = json.loads(evt[len("data: ") :].strip()) + assert j["model"] == "gpt-4" + assert j["choices"][0]["delta"]["content"] == "Test" + assert collected[1] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_with_dict_chunk(): + """Test that dict chunks with role and content are properly formatted.""" + handler = StreamingHandler() + collected = [] + + async def collector(): + async for b in _format_streaming_response(handler, model_name=None): + collected.append(b) + + task = asyncio.create_task(collector()) + + # Push a dict chunk that includes role and content + await handler.push_chunk({"role": "assistant", "content": "Hi!"}) + await handler.push_chunk(None) + + await task + + # We expect two messages: one data chunk and final data: [DONE] + assert len(collected) == 2 + evt = collected[0] + j = json.loads(evt[len("data: ") :].strip()) + assert j["object"] == "chat.completion.chunk" + assert j["choices"][0]["delta"]["role"] == "assistant" + assert j["choices"][0]["delta"]["content"] == "Hi!" + assert collected[1] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_empty_string(): + """Test that empty strings are handled correctly.""" + handler = StreamingHandler() + collected = [] + + async def collector(): + async for b in _format_streaming_response(handler, model_name=None): + collected.append(b) + + task = asyncio.create_task(collector()) + + await handler.push_chunk("") + await handler.push_chunk(END_OF_STREAM) + + await task + + assert len(collected) == 2 + evt = collected[0] + j = json.loads(evt[len("data: ") :].strip()) + assert j["choices"][0]["delta"]["content"] == "" + assert collected[1] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_none_triggers_done(): + """Test that None (converted to END_OF_STREAM) triggers [DONE].""" + handler = StreamingHandler() + collected = [] + + async def collector(): + async for b in _format_streaming_response(handler, model_name=None): + collected.append(b) + + task = asyncio.create_task(collector()) + + await handler.push_chunk("Content") + await handler.push_chunk(None) # None converts to END_OF_STREAM + + await task + + assert len(collected) == 2 + evt = collected[0] + j = json.loads(evt[len("data: ") :].strip()) + assert j["choices"][0]["delta"]["content"] == "Content" + assert collected[1] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_multiple_dict_chunks(): + """Test multiple dict chunks with different fields.""" + handler = StreamingHandler() + collected = [] + + async def collector(): + async for b in _format_streaming_response(handler, model_name="test-model"): + collected.append(b) + + task = asyncio.create_task(collector()) + + # Push multiple dict chunks + await handler.push_chunk({"role": "assistant"}) + await handler.push_chunk({"content": "Hello"}) + await handler.push_chunk({"content": " world"}) + await handler.push_chunk(END_OF_STREAM) + + await task + + assert len(collected) == 4 + + # Check first chunk (role only) + j1 = json.loads(collected[0][len("data: ") :].strip()) + assert j1["choices"][0]["delta"]["role"] == "assistant" + assert "content" not in j1["choices"][0]["delta"] + + # Check second chunk (content only) + j2 = json.loads(collected[1][len("data: ") :].strip()) + assert j2["choices"][0]["delta"]["content"] == "Hello" + + # Check third chunk (content only) + j3 = json.loads(collected[2][len("data: ") :].strip()) + assert j3["choices"][0]["delta"]["content"] == " world" + + # Check [DONE] message + assert collected[3] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_special_characters(): + """Test that special characters are properly escaped in JSON.""" + handler = StreamingHandler() + collected = [] + + async def collector(): + async for b in _format_streaming_response(handler, model_name=None): + collected.append(b) + + task = asyncio.create_task(collector()) + + # Push chunks with special characters + await handler.push_chunk("Line 1\nLine 2") + await handler.push_chunk('Quote: "test"') + await handler.push_chunk(END_OF_STREAM) + + await task + + assert len(collected) == 3 + + # Verify first chunk with newline + j1 = json.loads(collected[0][len("data: ") :].strip()) + assert j1["choices"][0]["delta"]["content"] == "Line 1\nLine 2" + + # Verify second chunk with quotes + j2 = json.loads(collected[1][len("data: ") :].strip()) + assert j2["choices"][0]["delta"]["content"] == 'Quote: "test"' + + assert collected[2] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_events(): + """Test that all events follow proper SSE format.""" + handler = StreamingHandler() + collected = [] + + async def collector(): + async for b in _format_streaming_response(handler, model_name=None): + collected.append(b) + + task = asyncio.create_task(collector()) + + await handler.push_chunk("Test") + await handler.push_chunk(END_OF_STREAM) + + await task + + # All events except [DONE] should be valid JSON with proper SSE format + for event in collected[:-1]: + assert event.startswith("data: ") + assert event.endswith("\n\n") + # Verify it's valid JSON + json_str = event[len("data: ") :].strip() + j = json.loads(json_str) + assert "object" in j + assert "choices" in j + assert isinstance(j["choices"], list) + assert len(j["choices"]) > 0 + + # Last event should be [DONE] + assert collected[-1] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_chunk_metadata(): + """Test that chunk metadata is properly formatted.""" + handler = StreamingHandler() + collected = [] + + async def collector(): + async for b in _format_streaming_response(handler, model_name="test-model"): + collected.append(b) + + task = asyncio.create_task(collector()) + + await handler.push_chunk("Test") + await handler.push_chunk(END_OF_STREAM) + + await task + + evt = collected[0] + j = json.loads(evt[len("data: ") :].strip()) + + # Verify all required fields are present + assert j["id"] is None # id can be None for chunks + assert j["object"] == "chat.completion.chunk" + assert isinstance(j["created"], int) + assert j["model"] == "test-model" + assert isinstance(j["choices"], list) + assert len(j["choices"]) == 1 + + choice = j["choices"][0] + assert "delta" in choice + assert choice["index"] == 0 + assert choice["finish_reason"] is None + + +@pytest.mark.skip(reason="Should only be run locally as it needs OpenAI key.") +def test_chat_completion_with_streaming(): + response = client.post( + "/v1/chat/completions", + json={ + "config_id": "general", + "messages": [{"role": "user", "content": "Hello"}], + "stream": True, + }, + ) + assert response.status_code == 200 + assert response.headers["Content-Type"] == "text/event-stream" + for chunk in response.iter_lines(): + assert chunk.startswith("data: ") + assert chunk.endswith("\n\n") + assert "data: [DONE]\n\n" in response.text diff --git a/tests/test_openai_integration.py b/tests/test_openai_integration.py new file mode 100644 index 000000000..9d2523668 --- /dev/null +++ b/tests/test_openai_integration.py @@ -0,0 +1,164 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import time + +import pytest +from fastapi.testclient import TestClient +from openai import OpenAI +from openai.types.chat.chat_completion import ChatCompletion, Choice +from openai.types.chat.chat_completion_message import ChatCompletionMessage +from openai.types.model import Model + +from nemoguardrails.server import api + + +@pytest.fixture(scope="function", autouse=True) +def set_rails_config_path(): + """Set the rails config path to the test configs directory.""" + original_path = api.app.rails_config_path + api.app.rails_config_path = os.path.normpath(os.path.join(os.path.dirname(__file__), "test_configs/simple_server")) + yield + + # Restore the original path and clear cache after the test + api.app.rails_config_path = original_path + api.llm_rails_instances.clear() + api.llm_rails_events_history_cache.clear() + + +@pytest.fixture(scope="function") +def test_client(): + """Create a FastAPI TestClient for the guardrails server.""" + return TestClient(api.app) + + +@pytest.fixture(scope="function") +def openai_client(test_client): + client = OpenAI( + api_key="dummy-key", + base_url="http://dummy-url/v1", + http_client=test_client, + ) + return client + + +def test_openai_client_list_models(openai_client): + models = openai_client.models.list() + + # Verify the response structure matches the GuardrailsModel schema + assert models.data[0] == Model( + id="config_1", + object="model", + created=int(time.time()), + owned_by="nemo-guardrails", + guardrails_config_id="config_1", + ) + + +def test_openai_client_chat_completion(openai_client): + response = openai_client.chat.completions.create( + model="config_1", + messages=[{"role": "user", "content": "hi"}], + stream=False, + ) + + assert isinstance(response, ChatCompletion) + assert response.id is not None + + assert response.choices[0] == Choice( + finish_reason="stop", + index=0, + logprobs=None, + message=ChatCompletionMessage( + content="Hello!", + refusal=None, + role="assistant", + annotations=None, + audio=None, + function_call=None, + tool_calls=None, + ), + ) + assert hasattr(response, "created") + + +def test_openai_client_chat_completion_parameterized(openai_client): + response = openai_client.chat.completions.create( + model="config_1", + messages=[{"role": "user", "content": "hi"}], + temperature=0.7, + max_tokens=100, + stream=False, + ) + + # Verify response exists + assert isinstance(response, ChatCompletion) + assert response.id is not None + assert response.choices[0] == Choice( + finish_reason="stop", + index=0, + logprobs=None, + message=ChatCompletionMessage( + content="Hello!", + refusal=None, + role="assistant", + annotations=None, + ), + ) + assert hasattr(response, "created") + + +def test_openai_client_chat_completion_input_rails(openai_client): + response = openai_client.chat.completions.create( + model="input_rails", + messages=[{"role": "user", "content": "Hello, how are you?"}], + stream=False, + ) + + # Verify response exists + assert isinstance(response, ChatCompletion) + assert response.id is not None + assert isinstance(response.choices[0], Choice) + assert hasattr(response, "created") + + +@pytest.mark.skip(reason="Should only be run locally as it needs OpenAI key.") +def test_openai_client_chat_completion_streaming(openai_client): + stream = openai_client.chat.completions.create( + model="input_rails", + messages=[{"role": "user", "content": "Tell me a short joke."}], + stream=True, + ) + + chunks = list(stream) + assert len(chunks) > 0 + + # Verify at least one chunk has content + has_content = any(hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content for chunk in chunks) + assert has_content, "At least one chunk should contain content" + + +def test_openai_client_error_handling_invalid_model(openai_client): + response = openai_client.chat.completions.create( + model="nonexistent_config", + messages=[{"role": "user", "content": "hi"}], + stream=False, + ) + + # The error should be in the content + assert ( + "Could not load" in response.choices[0].message.content + or "error" in response.choices[0].message.content.lower() + ) diff --git a/tests/test_server_calls_with_state.py b/tests/test_server_calls_with_state.py index 051096432..736f2592c 100644 --- a/tests/test_server_calls_with_state.py +++ b/tests/test_server_calls_with_state.py @@ -37,12 +37,15 @@ def _test_call(config_id): ) assert response.status_code == 200 res = response.json() - assert len(res["messages"]) == 1 - assert res["messages"][0]["content"] == "Hello!" + print(res) + assert len(res["choices"][0]["message"]) == 2 + assert res["choices"][0]["message"]["content"] == "Hello!" assert res.get("state") # When making a second call with the returned state, the conversations should continue # and we should get the "Hello again!" message. + # For Colang 2.x, we only send the new user message, not the conversation history + # since the state maintains the conversation context. response = client.post( "/v1/chat/completions", json={ @@ -57,7 +60,7 @@ def _test_call(config_id): }, ) res = response.json() - assert res["messages"][0]["content"] == "Hello again!" + assert res["choices"][0]["message"]["content"] == "Hello again!" def test_1(): diff --git a/tests/test_streaming.py b/tests/test_streaming.py index c7f59a7d1..fa7ffaa49 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -815,3 +815,62 @@ def test_main_llm_supports_streaming_flag_disabled_when_no_streaming(): assert rails.main_llm_supports_streaming is False, ( "main_llm_supports_streaming should be False when streaming is disabled" ) + + +def test_main_llm_supports_streaming_with_multiple_model_types( + custom_streaming_providers, +): + """Test that streaming is properly configured when config has multiple model types.""" + config = RailsConfig.from_content( + config={ + "models": [ + { + "type": "main", + "engine": "custom_streaming", + "model": "test-model", + }, + { + "type": "content_safety", + "engine": "custom_streaming", + "model": "safety-model", + }, + ], + "streaming": True, + } + ) + + rails = LLMRails(config) + + assert rails.main_llm_supports_streaming is True, ( + "main_llm_supports_streaming should be True when streaming is enabled " + "and config has multiple model types including a streaming-capable main LLM" + ) + # Verify the main LLM's streaming attribute was set + assert hasattr(rails.llm, "streaming") and rails.llm.streaming is True, ( + "Main LLM's streaming attribute should be set to True" + ) + + +def test_main_llm_supports_streaming_with_specialized_models_only( + custom_streaming_providers, +): + """Test streaming config when only specialized models are defined (no main).""" + config = RailsConfig.from_content( + config={ + "models": [ + { + "type": "content_safety", + "engine": "custom_streaming", + "model": "safety-model", + }, + ], + "streaming": True, + } + ) + + rails = LLMRails(config) + + # Verify that main_llm_supports_streaming is False when no main LLM is configured + assert rails.main_llm_supports_streaming is False, ( + "main_llm_supports_streaming should be False when no main LLM is configured" + ) diff --git a/tests/test_threads.py b/tests/test_threads.py index 88946007b..baace32b7 100644 --- a/tests/test_threads.py +++ b/tests/test_threads.py @@ -51,8 +51,9 @@ def test_1(): ) assert response.status_code == 200 res = response.json() - assert len(res["messages"]) == 1 - assert res["messages"][0]["content"] == "Hello!" + assert "choices" in res + assert "message" in res["choices"][0] + assert res["choices"][0]["message"]["content"] == "Hello!" # When making a second call with the same thread_id, the conversations should continue # and we should get the "Hello again!" message. @@ -70,7 +71,7 @@ def test_1(): }, ) res = response.json() - assert res["messages"][0]["content"] == "Hello again!" + assert res["choices"][0]["message"]["content"] == "Hello again!" @pytest.mark.parametrize( @@ -138,4 +139,4 @@ def test_with_redis(): }, ) res = response.json() - assert res["messages"][0]["content"] == "Hello again!" + assert res["choices"][0]["message"]["content"] == "Hello again!"