diff --git a/examples/openai_client.py b/examples/openai_client.py new file mode 100644 index 0000000..cb1ba32 --- /dev/null +++ b/examples/openai_client.py @@ -0,0 +1,45 @@ +import json +import os +from openai import OpenAI +import dotenv +from javelin_sdk import ( + JavelinClient, + JavelinConfig, +) + +javelin_api_key = os.getenv('JAVELIN_API_KEY') +llm_api_key = os.getenv("OPENAI_API_KEY") + +# Create OpenAI client +openai_client = OpenAI(api_key=llm_api_key) + +# Initialize Javelin Client +config = JavelinConfig( + # base_url="https://api-dev.javelin.live", + base_url="http://localhost:8000", + javelin_api_key=javelin_api_key, +) +client = JavelinClient(config) +client.register_openai(openai_client, route_name="openai") + +# Call OpenAI endpoints +chat_completions = openai_client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "What is machine learning?"}], +) +print(chat_completions.model_dump_json(indent=2)) + +completions = openai_client.completions.create( + model="gpt-3.5-turbo-instruct", + prompt="What is machine learning?", + max_tokens=7, + temperature=0 +) +print(completions.model_dump_json(indent=2)) + +embeddings = openai_client.embeddings.create( + model="text-embedding-ada-002", + input="The food was delicious and the waiter...", + encoding_format="float" +) +print(embeddings.model_dump_json(indent=2)) diff --git a/javelin_sdk/client.py b/javelin_sdk/client.py index bb4bfc3..4d6e733 100644 --- a/javelin_sdk/client.py +++ b/javelin_sdk/client.py @@ -28,6 +28,7 @@ class JavelinClient: def __init__(self, config: JavelinConfig) -> None: self.config = config self.base_url = urljoin(config.base_url, config.api_version or "/v1") + self._headers = { "x-api-key": config.javelin_api_key, } @@ -42,6 +43,7 @@ def __init__(self, config: JavelinConfig) -> None: self.bedrock_session = None self.default_bedrock_route = None self.use_default_bedrock_route = False + self.openai_base_url = None self.gateway_service = GatewayService(self) self.provider_service = ProviderService(self) @@ -90,6 +92,116 @@ async def aclose(self): def close(self): if self._client: self._client.close() + + def register_openai(self, + openai_client: Any, + route_name: str = None) -> Any: + """ + Register the passed-in OpenAI client so that calls to: + - client.chat.completions.create(...) + - client.completions.create(...) + - client.embeddings.create(...) + are intercepted by Javelin. + + Additionally sets: + - openai_client.base_url to self.base_url + - openai_client._custom_headers to include self._headers + """ + + # Store the OpenAI base URL + if self.openai_base_url is None: + self.openai_base_url = openai_client.base_url + + # Point the OpenAI client to Javelin's base URL + openai_client.base_url=self.base_url + + if not hasattr(openai_client, "_custom_headers"): + openai_client._custom_headers = {} + openai_client._custom_headers.update(self._headers) + + base_url_str = str(self.openai_base_url) + # Remove trailing slash if present + if base_url_str.endswith("/"): + base_url_str = base_url_str[:-1] + + # Update Javelin headers into the client's _custom_headers + openai_client._custom_headers["x-javelin-provider"] = base_url_str + openai_client._custom_headers["x-javelin-route"] = route_name + + # Print out the headers you’ve set (for debug) + print("DEBUG - Patched OpenAI client headers:", openai_client._custom_headers) + + # Store references to the original methods + original_chat_completions_create = openai_client.chat.completions.create + original_completions_create = openai_client.completions.create + original_embeddings_create = openai_client.embeddings.create + + # Define patched versions, injecting Javelin logs/traces + + def patched_chat_completions_create(*args, **kwargs): + # BEFORE calling original + ''' + TODO: self.trace_service.log_trace( + message="OpenAI chat.completions.create called", + extra={"args": args, "kwargs": kwargs}, + ) + ''' + + # Call the real method + response = original_chat_completions_create(*args, **kwargs) + + # AFTER calling original + ''' + TODO: self.trace_service.log_trace( + message="OpenAI chat.completions.create response", + extra={"response": response}, + ) + ''' + return response + + def patched_completions_create(*args, **kwargs): + ''' + TODO: self.trace_service.log_trace( + message="OpenAI completions.create called", + extra={"args": args, "kwargs": kwargs}, + ) + ''' + + response = original_completions_create(*args, **kwargs) + + ''' + TODO: self.trace_service.log_trace( + message="OpenAI completions.create response", + extra={"response": response}, + ) + ''' + return response + + def patched_embeddings_create(*args, **kwargs): + ''' + TODO: self.trace_service.log_trace( + message="OpenAI embeddings.create called", + extra={"args": args, "kwargs": kwargs}, + ) + ''' + + response = original_embeddings_create(*args, **kwargs) + + ''' + TODO: self.trace_service.log_trace( + message="OpenAI embeddings.create response", + extra={"response": response}, + ) + ''' + return response + + # patch the client’s methods + openai_client.chat.completions.create = patched_chat_completions_create + openai_client.completions.create = patched_completions_create + openai_client.embeddings.create = patched_embeddings_create + + # Return the patched client + return openai_client def register_bedrock(self, bedrock_runtime_client: Any,