From e3628b57c506fa172d698a051370d7b3baeec73a Mon Sep 17 00:00:00 2001 From: kamalgovindgj Date: Tue, 7 Jan 2025 23:09:40 -0500 Subject: [PATCH 01/11] Add AISPM service and configuration updates --- javelin_cli/_internal/commands.py | 120 ++++++++++++++++++++ javelin_cli/cli.py | 97 +++++++++++++++++ javelin_sdk/client.py | 59 +++++++--- javelin_sdk/models.py | 130 ++++++++++++++++++++++ javelin_sdk/services/aispm_service.py | 151 ++++++++++++++++++++++++++ 5 files changed, 543 insertions(+), 14 deletions(-) create mode 100644 javelin_sdk/services/aispm_service.py diff --git a/javelin_cli/_internal/commands.py b/javelin_cli/_internal/commands.py index 049082f..01cd768 100644 --- a/javelin_cli/_internal/commands.py +++ b/javelin_cli/_internal/commands.py @@ -1,5 +1,7 @@ import json import os +import asyncio + from pathlib import Path from pydantic import ValidationError @@ -28,6 +30,7 @@ Secrets, Template, Templates, + Customer, AWSConfig, AzureConfig, UsageResponse, AlertResponse,Request,HttpMethod, ) @@ -87,6 +90,123 @@ def get_javelin_client(): return JavelinClient(config) +#aispm commands +def create_customer(args): + try: + client = get_javelin_client() + customer = Customer( + name=args.name, + description=args.description, + metrics_interval=args.metrics_interval, + security_interval=args.security_interval + ) + response = client._send_request_sync(Request( + method=HttpMethod.POST, + route="v1/admin/aispm/customer", + data=customer.dict() + )) + print(f"Customer '{args.name}' created successfully.") + except Exception as e: + print(f"Error creating customer: {e}") + +def get_customer(args): + try: + client = get_javelin_client() + route = "v1/admin/aispm/customer" + print(f"Making request to: {client.config.base_url}{route}") + print(f"Headers: {client._headers}") # Access _headers directly + + response = client._send_request_sync(Request( + method=HttpMethod.GET, + route=route + )) + print(f"Status: {response.status_code}") + print(f"Response: {response.content.decode('utf-8')}") + except Exception as e: + print(f"Error getting customer: {e}") + +def configure_aws(args): + try: + client = get_javelin_client() + config = json.loads(args.config) + configs = [AWSConfig(**config)] + result = client.aispm.configure_aws(configs) + print(f"AWS configuration created successfully.") + except Exception as e: + print(f"Error configuring AWS: {e}") + +def get_aws_config(args): + try: + client = get_javelin_client() + request = Request( + method=HttpMethod.GET, + route="v1/admin/aispm/config/aws" + ) + response = client._send_request_sync(request) + print(json.dumps(response.json(), indent=2)) + except Exception as e: + print(f"Error getting AWS config: {e}") + +def delete_aws_config(args): + try: + client = get_javelin_client() + request = Request( + method=HttpMethod.DELETE, + route=f"v1/admin/aispm/config/aws/{args.name}" + ) + response = client._send_request_sync(request) + print(f"AWS configuration '{args.name}' deleted successfully.") + except Exception as e: + print(f"Error deleting AWS config: {e}") + +def get_azure_config(args): + try: + client = get_javelin_client() + request = Request( + method=HttpMethod.GET, + route="v1/admin/aispm/config/azure" + ) + response = client._send_request_sync(request) + print(json.dumps(response.json(), indent=2)) + except Exception as e: + print(f"Error getting Azure config: {e}") + +def configure_azure(args): + try: + client = get_javelin_client() + config = json.loads(args.config) + configs = [AzureConfig(**config)] + result = client.aispm.configure_azure(configs) + print(f"Azure configuration created successfully.") + except Exception as e: + print(f"Error configuring Azure: {e}") + +def get_usage(args): + try: + client = get_javelin_client() + usage = client.aispm.get_usage( + provider=args.provider, + cloud_account=args.account, + model=args.model, + region=args.region + ) + print(json.dumps(usage.dict(), indent=2)) + except Exception as e: + print(f"Error getting usage: {e}") + +def get_alerts(args): + try: + client = get_javelin_client() + alerts = client.aispm.get_alerts( + provider=args.provider, + cloud_account=args.account, + model=args.model, + region=args.region + ) + print(json.dumps(alerts.dict(), indent=2)) + except Exception as e: + print(f"Error getting alerts: {e}") + def create_gateway(args): try: client = get_javelin_client() diff --git a/javelin_cli/cli.py b/javelin_cli/cli.py index 4481f76..dff023e 100644 --- a/javelin_cli/cli.py +++ b/javelin_cli/cli.py @@ -38,9 +38,30 @@ update_route, update_secret, update_template, + create_customer, + get_customer, + configure_aws, + configure_azure, + get_usage, + get_alerts, + get_aws_config, + get_azure_config, + delete_aws_config ) + +#def check_permissions(): +# """Check if user has permissions""" +# home_dir = Path.home() +# cache_file = home_dir / ".javelin" / "cache.json" + +# if not cache_file.exists(): +# print("❌ Not authenticated. Please run 'javelin auth' first.") +# sys.exit(1) + +# return True # Skip role check + def check_permissions(): """Check if user has superadmin permissions""" home_dir = Path.home() @@ -89,6 +110,80 @@ def main(): auth_parser = subparsers.add_parser("auth", help="Authenticate with Javelin.") auth_parser.add_argument("--force", action="store_true", help="Force re-authentication, overriding existing credentials") auth_parser.set_defaults(func=authenticate) + #aispm CRUD + # AISPM commands + aispm_parser = subparsers.add_parser("aispm", help="Manage AISPM functionality") + aispm_subparsers = aispm_parser.add_subparsers() + + # Customer commands + customer_parser = aispm_subparsers.add_parser("customer", help="Manage customers") + customer_subparsers = customer_parser.add_subparsers() + + customer_create = customer_subparsers.add_parser("create", help="Create customer") + customer_create.add_argument("--name", required=True, help="Customer name") + customer_create.add_argument("--description", help="Customer description") + customer_create.add_argument("--metrics-interval", default="5m", help="Metrics interval") + customer_create.add_argument("--security-interval", default="1m", help="Security interval") + customer_create.set_defaults(func=create_customer) + + customer_get = customer_subparsers.add_parser("get", help="Get customer details") + customer_get.set_defaults(func=get_customer) + + + # Cloud config commands + config_parser = aispm_subparsers.add_parser("config", help="Manage cloud configurations") + config_subparsers = config_parser.add_subparsers() + + aws_parser = config_subparsers.add_parser("aws", help="Configure AWS") + + azure_parser = config_subparsers.add_parser("azure", help="Configure Azure") + + #azure_parser.add_argument("--config", type=str, required=True, help="Azure config JSON") + #azure_parser.set_defaults(func=configure_azure) + + + aws_subparsers = aws_parser.add_subparsers() + + # GET AWS Config + aws_get_parser = aws_subparsers.add_parser("get", help="Get AWS configuration") + aws_get_parser.set_defaults(func=get_aws_config) + + # Existing AWS Config (for creating) + aws_config_parser = aws_subparsers.add_parser("create", help="Configure AWS") + aws_config_parser.add_argument("--config", type=str, required=True, help="AWS config JSON") + aws_config_parser.set_defaults(func=configure_aws) + + aws_delete_parser = aws_subparsers.add_parser("delete", help="Delete AWS configuration") + aws_delete_parser.add_argument("--name", type=str, required=True, help="Name of AWS configuration to delete") + aws_delete_parser.set_defaults(func=delete_aws_config) + + azure_subparsers = azure_parser.add_subparsers(dest='azure_command') + +# Get Azure Config + azure_get_parser = azure_subparsers.add_parser("get", help="Get Azure configuration") + azure_get_parser.set_defaults(func=get_azure_config) + + # Create Azure Config + azure_create_parser = azure_subparsers.add_parser("create", help="Configure Azure") + azure_create_parser.add_argument("--config", type=str, required=True, help="Azure config JSON") + azure_create_parser.set_defaults(func=configure_azure) + + + # Usage metrics + usage_parser = aispm_subparsers.add_parser("usage", help="Get usage metrics") + usage_parser.add_argument("--provider", help="Cloud provider") + usage_parser.add_argument("--account", help="Cloud account name") + usage_parser.add_argument("--model", help="Model ID") + usage_parser.add_argument("--region", help="Region") + usage_parser.set_defaults(func=get_usage) + + # Alerts + alerts_parser = aispm_subparsers.add_parser("alerts", help="Get alerts") + alerts_parser.add_argument("--provider", help="Cloud provider") + alerts_parser.add_argument("--account", help="Cloud account name") + alerts_parser.add_argument("--model", help="Model ID") + alerts_parser.add_argument("--region", help="Region") + alerts_parser.set_defaults(func=get_alerts) # Gateway CRUD gateway_parser = subparsers.add_parser( "gateway", @@ -373,6 +468,8 @@ def main(): ) template_delete.set_defaults(func=delete_template) + + args = parser.parse_args() if hasattr(args, "func"): diff --git a/javelin_sdk/client.py b/javelin_sdk/client.py index c5a6a1f..be019be 100644 --- a/javelin_sdk/client.py +++ b/javelin_sdk/client.py @@ -15,6 +15,8 @@ from javelin_sdk.services.secret_service import SecretService from javelin_sdk.services.template_service import TemplateService from javelin_sdk.services.trace_service import TraceService +from javelin_sdk.services.aispm_service import AISPMService + API_BASEURL = "https://api-dev.javelin.live" API_BASE_PATH = "/v1" @@ -45,6 +47,9 @@ def __init__(self, config: JavelinConfig) -> None: self.chat = Chat(self) self.completions = Completions(self) + self.aispm = AISPMService(self) + + @property def client(self): if self._client is None: @@ -84,24 +89,41 @@ def close(self): self._client.close() def _prepare_request(self, request: Request) -> tuple: - url = self._construct_url( - gateway_name=request.gateway, - provider_name=request.provider, - route_name=request.route, - secret_name=request.secret, - template_name=request.template, - trace=request.trace, - query=request.is_query, - archive=request.archive, - query_params=request.query_params, - is_transformation_rules=request.is_transformation_rules, - is_reload=request.is_reload, - ) + if request.route.startswith("v1/admin/aispm"): + url = f"{self.config.base_url.rstrip('/')}/{request.route}" + if request.query_params: + query_string = "&".join(f"{k}={v}" for k, v in request.query_params.items()) + url += f"?{query_string}" + + + else: + url = self._construct_url( + gateway_name=request.gateway, + provider_name=request.provider, + route_name=request.route, + secret_name=request.secret, + template_name=request.template, + trace=request.trace, + query=request.is_query, + archive=request.archive, + query_params=request.query_params, + is_transformation_rules=request.is_transformation_rules, + is_reload=request.is_reload, + ) + headers = {**self._headers, **(request.headers or {})} return url, headers def _send_request_sync(self, request: Request) -> httpx.Response: - return self._core_send_request(self.client, request) + url, headers = self._prepare_request(request) + print(f"Making request to: {url}") + print(f"With headers: {headers}") + response = self._core_send_request(self.client, request) + print(f"Response status: {response.status_code}") + print(f"Response body: {response.text}") + return response + + async def _send_request_async(self, request: Request) -> httpx.Response: return await self._core_send_request(self.aclient, request) @@ -199,6 +221,15 @@ def _construct_url( return url + def _construct_aispm_url(self, request: Request) -> str: + url = request.route + + if request.query_params: + query_string = "&".join(f"{k}={v}" for k, v in request.query_params.items()) + url += f"?{query_string}" + + return url + # Gateway methods create_gateway = lambda self, gateway: self.gateway_service.create_gateway(gateway) acreate_gateway = lambda self, gateway: self.gateway_service.acreate_gateway( diff --git a/javelin_sdk/models.py b/javelin_sdk/models.py index 80c8d7c..6a71f45 100644 --- a/javelin_sdk/models.py +++ b/javelin_sdk/models.py @@ -1,6 +1,12 @@ from enum import Enum, auto from typing import Any, Dict, List, Optional +from datetime import datetime + +from typing import Dict, List, Optional + + + from pydantic import BaseModel, Field, field_validator from javelin_sdk.exceptions import UnauthorizedError @@ -554,3 +560,127 @@ class EndpointType(str, Enum): INVOKE_STREAM = "invoke_stream" CONVERSE_STREAM = "converse_stream" ALL = "all" + + +#aispm models + +class TimeRange(BaseModel): + start_time: str # Change from datetime to str + end_time: str # Change from datetime to str + +class BaseResponse(BaseModel): + message: Optional[str] = None + +# Customer Models +class Customer(BaseModel): + name: str + description: Optional[str] + metrics_interval: str = "5m" + security_interval: str = "1m" + initial_scan: str = "24h" + +class CustomerResponse(Customer): + status: str + created_at: datetime + modified_at: datetime + +# Cloud Config Models +class BaseCloudConfig(BaseModel): + cloud_account_name: str + team: str + +class AWSConfig(BaseCloudConfig): + role_arn: str + region: Optional[str] = None # Make region optional + +class AzureConfig(BaseCloudConfig): + subscription_id: str + tenant_id: str + client_id: str + client_secret: str + location: str + +class GCPConfig(BaseCloudConfig): + project_id: str + service_account_key: str + +class CloudConfigResponse(BaseModel): + name: Optional[str] = Field(None, alias='cloud_account_name') + provider: str + status: str + created_at: datetime + modified_at: datetime + +# Usage Models +class ModelMetrics(BaseModel): + latency_avg_ms: float + cost_per_request: float + tokens_per_request: float + attempt_count: int + failure_count: int + success_count: int + success_rate_pct: float + cost_total: float + request_count: int + token_count: int + +class CloudAccountUsage(BaseModel): + region_count: int + regions: List[str] + model_count: int + models: List[str] + model_metrics: ModelMetrics + +class UsageResponse(BaseModel): + cloud_provider: Dict[str, Any] # Change to allow any structure + time_range: Optional[TimeRange] = None + +# Alert Models +class AlertSeverity(str, Enum): + CRITICAL = "CRITICAL" + HIGH = "HIGH" + MEDIUM = "MEDIUM" + LOW = "LOW" + +class AlertState(str, Enum): + ALARM = "ALARM" + OK = "OK" + INSUFFICIENT_DATA = "INSUFFICIENT_DATA" + +class AlertScope(str, Enum): + GLOBAL = "GLOBAL" + MODEL = "MODEL" + REGION = "REGION" + +class AlertMetrics(BaseModel): + total_alerts: int + active_alerts: int + resolved_alerts: int + critical_alerts: int + high_alerts: int + medium_alerts: int + low_alerts: int + +class Alert(BaseModel): + title: str + state: AlertState + state_reason: str + severity: AlertSeverity + scope: AlertScope + region: Optional[str] + model_id: Optional[str] + detected_at: datetime + +class CloudProviderAlerts(BaseModel): + cloud_account_count: int + cloud_accounts: List[str] + region_count: int + regions: List[str] + model_count: int + models: List[str] + alert_metrics: AlertMetrics + alerts: List[Alert] + +class AlertResponse(BaseModel): + cloud_provider: Dict[str, CloudProviderAlerts] + time_range: TimeRange diff --git a/javelin_sdk/services/aispm_service.py b/javelin_sdk/services/aispm_service.py new file mode 100644 index 0000000..ae09937 --- /dev/null +++ b/javelin_sdk/services/aispm_service.py @@ -0,0 +1,151 @@ +from typing import Dict, List, Optional, Union +from httpx import Response +import json + +from javelin_sdk.models import ( + Customer, CustomerResponse, + AWSConfig, AzureConfig, GCPConfig, CloudConfigResponse, + UsageResponse, AlertResponse, TimeRange, + HttpMethod, Request +) + +class AISPMService: + def __init__(self, client): + self.client = client + + def _handle_response(self, response: Response) -> None: + if response.status_code >= 400: + error = response.json().get("error", "Unknown error") + raise Exception(f"API error: {error}") + + # Customer Methods + def create_customer(self, customer: Customer) -> CustomerResponse: + request = Request( + method=HttpMethod.POST, + route="v1/admin/aispm/customer", + data=customer.dict() + ) + print(f"Sending request: {request.method} {request.route}") + response = self.client._send_request_sync(request) + print(f"Raw response: {response.text}") + self._handle_response(response) + return CustomerResponse(**response.json()) + + def get_customer(self) -> CustomerResponse: + request = Request( + method=HttpMethod.GET, + route="v1/admin/aispm/customer" + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + return CustomerResponse(**response.json()) + + def update_customer(self, customer: Customer) -> CustomerResponse: + request = Request( + method=HttpMethod.PUT, + route="v1/admin/aispm/customer", + data=customer.dict() + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + return CustomerResponse(**response.json()) + + # Cloud Config Methods + def configure_aws(self, configs: List[AWSConfig]) -> List[CloudConfigResponse]: + request = Request( + method=HttpMethod.POST, + route="v1/admin/aispm/config/aws", + data=[config.dict() for config in configs] + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + return [CloudConfigResponse(**config) for config in response.json()] + + def configure_azure(self, configs: List[AzureConfig]) -> List[CloudConfigResponse]: + request = Request( + method=HttpMethod.POST, + route="v1/admin/aispm/config/azure", + data=[config.dict() for config in configs] + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + return [CloudConfigResponse(**config) for config in response.json()] + + def configure_gcp(self, configs: List[GCPConfig]) -> List[CloudConfigResponse]: + request = Request( + method=HttpMethod.POST, + route="v1/admin/aispm/config/gcp", + data=[config.dict() for config in configs] + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + return [CloudConfigResponse(**config) for config in response.json()] + + # Usage Methods + def get_usage(self, + provider: Optional[str] = None, + cloud_account: Optional[str] = None, + model: Optional[str] = None, + region: Optional[str] = None) -> UsageResponse: + + route = "v1/admin/aispm/usage" + if provider: + route += f"/{provider}" + if cloud_account: + route += f"/{cloud_account}" + + params = {} + if model: + params["model"] = model + if region: + params["region"] = region + + request = Request( + method=HttpMethod.GET, + route=route, + query_params=params + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + return UsageResponse(**response.json()) + + # Alert Methods + def get_alerts(self, + provider: Optional[str] = None, + cloud_account: Optional[str] = None, + model: Optional[str] = None, + region: Optional[str] = None) -> AlertResponse: + + route = "v1/admin/aispm/alerts" + if provider: + route += f"/{provider}" + if cloud_account: + route += f"/{cloud_account}" + + params = {} + if model: + params["model"] = model + if region: + params["region"] = region + + request = Request( + method=HttpMethod.GET, + route=route, + query_params=params + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + return AlertResponse(**response.json()) + + # Helpers + def _validate_provider(self, provider: str) -> None: + valid_providers = ["aws", "azure", "gcp", "openai"] + if provider.lower() not in valid_providers: + raise ValueError(f"Invalid provider. Must be one of: {valid_providers}") + + def _construct_error(self, response: Response) -> Dict: + try: + error = response.json() + return error.get("error", str(response.content)) + except json.JSONDecodeError: + return str(response.content) \ No newline at end of file From 14ca66c49eebd6fc0682a7f8412fdb5cf2a703a8 Mon Sep 17 00:00:00 2001 From: kamalgovindgj Date: Mon, 13 Jan 2025 03:19:34 -0500 Subject: [PATCH 02/11] feat:supporting aispm in CLI --- javelin_cli/_internal/commands.py | 60 ++++++++++++++++++++++++++----- javelin_cli/cli.py | 18 +++++----- javelin_sdk/client.py | 8 ++--- 3 files changed, 64 insertions(+), 22 deletions(-) diff --git a/javelin_cli/_internal/commands.py b/javelin_cli/_internal/commands.py index 01cd768..49b9130 100644 --- a/javelin_cli/_internal/commands.py +++ b/javelin_cli/_internal/commands.py @@ -34,6 +34,48 @@ ) +def get_javelin_client_aispm(): + # Path to cache.json file + home_dir = Path.home() + json_file_path = home_dir / ".javelin" / "cache.json" + + # Load cache.json + if not json_file_path.exists(): + raise FileNotFoundError(f"Configuration file not found: {json_file_path}") + + with open(json_file_path, "r") as json_file: + cache_data = json.load(json_file) + + # Retrieve the list of gateways + gateways = cache_data.get("memberships", {}).get("data", [{}])[0].get("organization", {}).get("public_metadata", {}).get("Gateways", []) + if not gateways: + raise ValueError("No gateways found in the configuration.") + + # Automatically select the first gateway (index 0) + selected_gateway = gateways[0] + base_url = selected_gateway["base_url"] + javelin_api_key = selected_gateway["api_key_value"] + + # Ensure the API key is set before initializing + if not javelin_api_key or javelin_api_key == "": + raise UnauthorizedError( + response=None, + message=( + "Please provide a valid Javelin API Key. " + "When you sign into Javelin, you can find your API Key in the " + "Account->Developer settings" + ), + ) + + # Initialize and return the JavelinClient + config = JavelinConfig( + base_url=base_url, + javelin_api_key=javelin_api_key, + ) + + return JavelinClient(config) + + def get_javelin_client(): # Path to cache.json file home_dir = Path.home() @@ -93,7 +135,7 @@ def get_javelin_client(): #aispm commands def create_customer(args): try: - client = get_javelin_client() + client = get_javelin_client_aispm() customer = Customer( name=args.name, description=args.description, @@ -111,7 +153,7 @@ def create_customer(args): def get_customer(args): try: - client = get_javelin_client() + client = get_javelin_client_aispm() route = "v1/admin/aispm/customer" print(f"Making request to: {client.config.base_url}{route}") print(f"Headers: {client._headers}") # Access _headers directly @@ -127,7 +169,7 @@ def get_customer(args): def configure_aws(args): try: - client = get_javelin_client() + client = get_javelin_client_aispm() config = json.loads(args.config) configs = [AWSConfig(**config)] result = client.aispm.configure_aws(configs) @@ -137,7 +179,7 @@ def configure_aws(args): def get_aws_config(args): try: - client = get_javelin_client() + client = get_javelin_client_aispm() request = Request( method=HttpMethod.GET, route="v1/admin/aispm/config/aws" @@ -149,7 +191,7 @@ def get_aws_config(args): def delete_aws_config(args): try: - client = get_javelin_client() + client = get_javelin_client_aispm() request = Request( method=HttpMethod.DELETE, route=f"v1/admin/aispm/config/aws/{args.name}" @@ -161,7 +203,7 @@ def delete_aws_config(args): def get_azure_config(args): try: - client = get_javelin_client() + client = get_javelin_client_aispm() request = Request( method=HttpMethod.GET, route="v1/admin/aispm/config/azure" @@ -173,7 +215,7 @@ def get_azure_config(args): def configure_azure(args): try: - client = get_javelin_client() + client = get_javelin_client_aispm() config = json.loads(args.config) configs = [AzureConfig(**config)] result = client.aispm.configure_azure(configs) @@ -183,7 +225,7 @@ def configure_azure(args): def get_usage(args): try: - client = get_javelin_client() + client = get_javelin_client_aispm() usage = client.aispm.get_usage( provider=args.provider, cloud_account=args.account, @@ -196,7 +238,7 @@ def get_usage(args): def get_alerts(args): try: - client = get_javelin_client() + client = get_javelin_client_aispm() alerts = client.aispm.get_alerts( provider=args.provider, cloud_account=args.account, diff --git a/javelin_cli/cli.py b/javelin_cli/cli.py index dff023e..94f8778 100644 --- a/javelin_cli/cli.py +++ b/javelin_cli/cli.py @@ -51,18 +51,18 @@ -#def check_permissions(): -# """Check if user has permissions""" -# home_dir = Path.home() -# cache_file = home_dir / ".javelin" / "cache.json" +def check_permissions(): + """Check if user has permissions""" + home_dir = Path.home() + cache_file = home_dir / ".javelin" / "cache.json" -# if not cache_file.exists(): -# print("❌ Not authenticated. Please run 'javelin auth' first.") -# sys.exit(1) + if not cache_file.exists(): + print("❌ Not authenticated. Please run 'javelin auth' first.") + sys.exit(1) -# return True # Skip role check + return True # Skip role check -def check_permissions(): +#def check_permissions(): """Check if user has superadmin permissions""" home_dir = Path.home() cache_file = home_dir / ".javelin" / "cache.json" diff --git a/javelin_sdk/client.py b/javelin_sdk/client.py index be019be..fb02962 100644 --- a/javelin_sdk/client.py +++ b/javelin_sdk/client.py @@ -116,11 +116,11 @@ def _prepare_request(self, request: Request) -> tuple: def _send_request_sync(self, request: Request) -> httpx.Response: url, headers = self._prepare_request(request) - print(f"Making request to: {url}") - print(f"With headers: {headers}") + #print(f"Making request to: {url}") + #print(f"With headers: {headers}") response = self._core_send_request(self.client, request) - print(f"Response status: {response.status_code}") - print(f"Response body: {response.text}") + #print(f"Response status: {response.status_code}") + #print(f"Response body: {response.text}") return response From fa2aaa4851dc61af2b5bbacd5d27ed3f25958059 Mon Sep 17 00:00:00 2001 From: kamalgovindgj Date: Tue, 14 Jan 2025 20:48:21 -0500 Subject: [PATCH 03/11] feat: asipm support for cli --- javelin_cli/_internal/commands.py | 69 ++++++++++++++------------- javelin_sdk/services/aispm_service.py | 16 ++++++- 2 files changed, 51 insertions(+), 34 deletions(-) diff --git a/javelin_cli/_internal/commands.py b/javelin_cli/_internal/commands.py index 49b9130..e6b8f65 100644 --- a/javelin_cli/_internal/commands.py +++ b/javelin_cli/_internal/commands.py @@ -132,41 +132,43 @@ def get_javelin_client(): return JavelinClient(config) -#aispm commands def create_customer(args): - try: - client = get_javelin_client_aispm() - customer = Customer( - name=args.name, - description=args.description, - metrics_interval=args.metrics_interval, - security_interval=args.security_interval - ) - response = client._send_request_sync(Request( - method=HttpMethod.POST, - route="v1/admin/aispm/customer", - data=customer.dict() - )) - print(f"Customer '{args.name}' created successfully.") - except Exception as e: - print(f"Error creating customer: {e}") + client = get_javelin_client_aispm() + customer = Customer( + name=args.name, + description=args.description, + metrics_interval=args.metrics_interval, + security_interval=args.security_interval + ) + return client.aispm.create_customer(customer) + + + def get_customer(args): + """ + Gets customer details using the AISPM service. + """ try: client = get_javelin_client_aispm() - route = "v1/admin/aispm/customer" - print(f"Making request to: {client.config.base_url}{route}") - print(f"Headers: {client._headers}") # Access _headers directly + response = client.aispm.get_customer() - response = client._send_request_sync(Request( - method=HttpMethod.GET, - route=route - )) - print(f"Status: {response.status_code}") - print(f"Response: {response.content.decode('utf-8')}") + # Pretty print the response for CLI output + formatted_response = { + "name": response.name, + "description": response.description, + "metrics_interval": response.metrics_interval, + "security_interval": response.security_interval, + "status": response.status, + "created_at": response.created_at.isoformat(), + "modified_at": response.modified_at.isoformat() + } + + print(json.dumps(formatted_response, indent=2)) except Exception as e: print(f"Error getting customer: {e}") + def configure_aws(args): try: client = get_javelin_client_aispm() @@ -178,16 +180,17 @@ def configure_aws(args): print(f"Error configuring AWS: {e}") def get_aws_config(args): + """ + Gets AWS configurations using the AISPM service. + """ try: client = get_javelin_client_aispm() - request = Request( - method=HttpMethod.GET, - route="v1/admin/aispm/config/aws" - ) - response = client._send_request_sync(request) - print(json.dumps(response.json(), indent=2)) + response = client.aispm.get_aws_configs() + # Simply print the JSON response + print(json.dumps(response, indent=2)) + except Exception as e: - print(f"Error getting AWS config: {e}") + print(f"Error getting AWS configurations: {e}") def delete_aws_config(args): try: diff --git a/javelin_sdk/services/aispm_service.py b/javelin_sdk/services/aispm_service.py index ae09937..bfed398 100644 --- a/javelin_sdk/services/aispm_service.py +++ b/javelin_sdk/services/aispm_service.py @@ -30,6 +30,8 @@ def create_customer(self, customer: Customer) -> CustomerResponse: print(f"Raw response: {response.text}") self._handle_response(response) return CustomerResponse(**response.json()) + + def get_customer(self) -> CustomerResponse: request = Request( @@ -71,6 +73,18 @@ def configure_azure(self, configs: List[AzureConfig]) -> List[CloudConfigRespons self._handle_response(response) return [CloudConfigResponse(**config) for config in response.json()] + def get_aws_configs(self) -> Dict: + """ + Retrieves AWS configurations. + """ + request = Request( + method=HttpMethod.GET, + route="v1/admin/aispm/config/aws" + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + return response.json() + def configure_gcp(self, configs: List[GCPConfig]) -> List[CloudConfigResponse]: request = Request( method=HttpMethod.POST, @@ -80,7 +94,7 @@ def configure_gcp(self, configs: List[GCPConfig]) -> List[CloudConfigResponse]: response = self.client._send_request_sync(request) self._handle_response(response) return [CloudConfigResponse(**config) for config in response.json()] - + # Usage Methods def get_usage(self, provider: Optional[str] = None, From 51a33e9f6d06410231446fa7c54123470669e21c Mon Sep 17 00:00:00 2001 From: kamalgovindgj Date: Mon, 20 Jan 2025 11:21:30 -0500 Subject: [PATCH 04/11] fix: improved AISPM service URL handling and commands --- javelin_cli/_internal/commands.py | 34 +++++++++-------------- javelin_sdk/client.py | 5 ---- javelin_sdk/services/aispm_service.py | 40 ++++++++++++++++++++++++++- 3 files changed, 52 insertions(+), 27 deletions(-) diff --git a/javelin_cli/_internal/commands.py b/javelin_cli/_internal/commands.py index e6b8f65..7ed6243 100644 --- a/javelin_cli/_internal/commands.py +++ b/javelin_cli/_internal/commands.py @@ -57,16 +57,7 @@ def get_javelin_client_aispm(): javelin_api_key = selected_gateway["api_key_value"] # Ensure the API key is set before initializing - if not javelin_api_key or javelin_api_key == "": - raise UnauthorizedError( - response=None, - message=( - "Please provide a valid Javelin API Key. " - "When you sign into Javelin, you can find your API Key in the " - "Account->Developer settings" - ), - ) - + # Initialize and return the JavelinClient config = JavelinConfig( base_url=base_url, @@ -192,27 +183,28 @@ def get_aws_config(args): except Exception as e: print(f"Error getting AWS configurations: {e}") +# Add these functions to commands.py + def delete_aws_config(args): + """ + Deletes an AWS configuration. + """ try: client = get_javelin_client_aispm() - request = Request( - method=HttpMethod.DELETE, - route=f"v1/admin/aispm/config/aws/{args.name}" - ) - response = client._send_request_sync(request) + client.aispm.delete_aws_config(args.name) print(f"AWS configuration '{args.name}' deleted successfully.") except Exception as e: print(f"Error deleting AWS config: {e}") def get_azure_config(args): + """ + Gets Azure configurations using the AISPM service. + """ try: client = get_javelin_client_aispm() - request = Request( - method=HttpMethod.GET, - route="v1/admin/aispm/config/azure" - ) - response = client._send_request_sync(request) - print(json.dumps(response.json(), indent=2)) + response = client.aispm.get_azure_config() + # Format and print the response nicely + print(json.dumps(response, indent=2)) except Exception as e: print(f"Error getting Azure config: {e}") diff --git a/javelin_sdk/client.py b/javelin_sdk/client.py index fb02962..8ee0db3 100644 --- a/javelin_sdk/client.py +++ b/javelin_sdk/client.py @@ -115,12 +115,7 @@ def _prepare_request(self, request: Request) -> tuple: return url, headers def _send_request_sync(self, request: Request) -> httpx.Response: - url, headers = self._prepare_request(request) - #print(f"Making request to: {url}") - #print(f"With headers: {headers}") response = self._core_send_request(self.client, request) - #print(f"Response status: {response.status_code}") - #print(f"Response body: {response.text}") return response diff --git a/javelin_sdk/services/aispm_service.py b/javelin_sdk/services/aispm_service.py index bfed398..d513583 100644 --- a/javelin_sdk/services/aispm_service.py +++ b/javelin_sdk/services/aispm_service.py @@ -73,6 +73,9 @@ def configure_azure(self, configs: List[AzureConfig]) -> List[CloudConfigRespons self._handle_response(response) return [CloudConfigResponse(**config) for config in response.json()] + + + def get_aws_configs(self) -> Dict: """ Retrieves AWS configurations. @@ -162,4 +165,39 @@ def _construct_error(self, response: Response) -> Dict: error = response.json() return error.get("error", str(response.content)) except json.JSONDecodeError: - return str(response.content) \ No newline at end of file + return str(response.content) + + def delete_aws_config(self, name: str) -> None: + """ + Deletes an AWS configuration by name. + + Args: + name (str): The name of the AWS configuration to delete + + Raises: + Exception: If the API request fails + """ + request = Request( + method=HttpMethod.DELETE, + route=f"v1/admin/aispm/config/aws/{name}" + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + + def get_azure_config(self) -> Dict: + """ + Retrieves Azure configurations. + + Returns: + Dict: The Azure configuration data + + Raises: + Exception: If the API request fails + """ + request = Request( + method=HttpMethod.GET, + route="v1/admin/aispm/config/azure" + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + return response.json() \ No newline at end of file From 91ab62ae1dbece231b3bef16b5d90a54dbdaeae4 Mon Sep 17 00:00:00 2001 From: Darshana Date: Wed, 17 Dec 2025 20:43:21 +0530 Subject: [PATCH 05/11] feat: fix AISPM auth and update dev url (#230) --- javelin_cli/_internal/commands.py | 30 +++++++++++++++---- javelin_cli/cli.py | 2 +- javelin_sdk/client.py | 14 +++++++-- javelin_sdk/services/aispm_service.py | 43 +++++++++++++++++++-------- pyproject.toml | 6 ++-- 5 files changed, 71 insertions(+), 24 deletions(-) diff --git a/javelin_cli/_internal/commands.py b/javelin_cli/_internal/commands.py index 7ed6243..7d510eb 100644 --- a/javelin_cli/_internal/commands.py +++ b/javelin_cli/_internal/commands.py @@ -54,17 +54,37 @@ def get_javelin_client_aispm(): # Automatically select the first gateway (index 0) selected_gateway = gateways[0] base_url = selected_gateway["base_url"] - javelin_api_key = selected_gateway["api_key_value"] - - # Ensure the API key is set before initializing + + # Get account_id from role_arn or account_id field + account_id = selected_gateway.get("account_id") + role_arn = selected_gateway.get("role_arn") + + # Extract account_id from role ARN if provided (format: arn:aws:iam::ACCOUNT_ID:role/ROLE_NAME) + if role_arn and not account_id: + try: + parts = role_arn.split(":") + if len(parts) >= 5 and parts[2] == "iam": + account_id = parts[4] + except (IndexError, AttributeError): + pass + + javelin_api_key = selected_gateway.get("api_key_value", "placeholder") # Initialize and return the JavelinClient config = JavelinConfig( base_url=base_url, javelin_api_key=javelin_api_key, ) - - return JavelinClient(config) + + client = JavelinClient(config) + + # Store account_id in client for AISPM service to use + if account_id: + client._aispm_account_id = account_id + client._aispm_user = "test-user" + client._aispm_userrole = "org:superadmin" + + return client def get_javelin_client(): diff --git a/javelin_cli/cli.py b/javelin_cli/cli.py index 94f8778..047e519 100644 --- a/javelin_cli/cli.py +++ b/javelin_cli/cli.py @@ -491,7 +491,7 @@ def authenticate(args): print("Use --force to re-authenticate and override existing cache.") return - default_url = "https://dev.javelin.live/" + default_url = "https://dev.highflame.dev/" print(" O") print(" /|\\") diff --git a/javelin_sdk/client.py b/javelin_sdk/client.py index bf6a9c4..75a8240 100644 --- a/javelin_sdk/client.py +++ b/javelin_sdk/client.py @@ -53,9 +53,10 @@ def __init__(self, config: JavelinConfig) -> None: @property def client(self): if self._client is None: + # Don't set headers at client level - they'll be added per-request + # This allows us to exclude x-api-key for AISPM requests self._client = httpx.Client( base_url=self.base_url, - headers=self._headers, timeout= self.config.timeout if self.config.timeout else API_TIMEOUT, ) return self._client @@ -63,8 +64,9 @@ def client(self): @property def aclient(self): if self._aclient is None: + # Don't set headers at client level - they'll be added per-request self._aclient = httpx.AsyncClient( - base_url=self.base_url, headers=self._headers, timeout=API_TIMEOUT + base_url=self.base_url, timeout=API_TIMEOUT ) return self._aclient @@ -112,6 +114,12 @@ def _prepare_request(self, request: Request) -> tuple: ) headers = {**self._headers, **(request.headers or {})} + + # For AISPM requests: if account_id header is present, remove x-api-key + # AISPM uses account_id-based authentication instead of API key + if request.route.startswith("v1/admin/aispm") and "x-javelin-accountid" in headers: + headers.pop("x-api-key", None) + return url, headers def _send_request_sync(self, request: Request) -> httpx.Response: @@ -128,6 +136,8 @@ def _core_send_request( ) -> Union[httpx.Response, Coroutine[Any, Any, httpx.Response]]: url, headers = self._prepare_request(request) + # For httpx.Client, headers passed to request methods override client-level headers + # So we need to ensure we're passing the correct headers if request.method == HttpMethod.GET: return client.get(url, headers=headers) elif request.method == HttpMethod.POST: diff --git a/javelin_sdk/services/aispm_service.py b/javelin_sdk/services/aispm_service.py index d513583..b874a52 100644 --- a/javelin_sdk/services/aispm_service.py +++ b/javelin_sdk/services/aispm_service.py @@ -13,6 +13,17 @@ class AISPMService: def __init__(self, client): self.client = client + def _get_aispm_headers(self) -> Dict[str, str]: + """Get headers for AISPM requests, including account_id if available.""" + headers = {} + # Check if account_id is stored in client (set by get_javelin_client_aispm) + account_id = getattr(self.client, '_aispm_account_id', None) + if account_id: + headers["x-javelin-accountid"] = account_id + headers["x-javelin-user"] = getattr(self.client, '_aispm_user', "test-user") + headers["x-javelin-userrole"] = getattr(self.client, '_aispm_userrole', "org:superadmin") + return headers + def _handle_response(self, response: Response) -> None: if response.status_code >= 400: error = response.json().get("error", "Unknown error") @@ -23,11 +34,10 @@ def create_customer(self, customer: Customer) -> CustomerResponse: request = Request( method=HttpMethod.POST, route="v1/admin/aispm/customer", - data=customer.dict() + data=customer.dict(), + headers=self._get_aispm_headers() ) - print(f"Sending request: {request.method} {request.route}") response = self.client._send_request_sync(request) - print(f"Raw response: {response.text}") self._handle_response(response) return CustomerResponse(**response.json()) @@ -36,7 +46,8 @@ def create_customer(self, customer: Customer) -> CustomerResponse: def get_customer(self) -> CustomerResponse: request = Request( method=HttpMethod.GET, - route="v1/admin/aispm/customer" + route="v1/admin/aispm/customer", + headers=self._get_aispm_headers() ) response = self.client._send_request_sync(request) self._handle_response(response) @@ -57,7 +68,8 @@ def configure_aws(self, configs: List[AWSConfig]) -> List[CloudConfigResponse]: request = Request( method=HttpMethod.POST, route="v1/admin/aispm/config/aws", - data=[config.dict() for config in configs] + data=[config.dict() for config in configs], + headers=self._get_aispm_headers() ) response = self.client._send_request_sync(request) self._handle_response(response) @@ -67,7 +79,8 @@ def configure_azure(self, configs: List[AzureConfig]) -> List[CloudConfigRespons request = Request( method=HttpMethod.POST, route="v1/admin/aispm/config/azure", - data=[config.dict() for config in configs] + data=[config.dict() for config in configs], + headers=self._get_aispm_headers() ) response = self.client._send_request_sync(request) self._handle_response(response) @@ -82,7 +95,8 @@ def get_aws_configs(self) -> Dict: """ request = Request( method=HttpMethod.GET, - route="v1/admin/aispm/config/aws" + route="v1/admin/aispm/config/aws", + headers=self._get_aispm_headers() ) response = self.client._send_request_sync(request) self._handle_response(response) @@ -92,7 +106,8 @@ def configure_gcp(self, configs: List[GCPConfig]) -> List[CloudConfigResponse]: request = Request( method=HttpMethod.POST, route="v1/admin/aispm/config/gcp", - data=[config.dict() for config in configs] + data=[config.dict() for config in configs], + headers=self._get_aispm_headers() ) response = self.client._send_request_sync(request) self._handle_response(response) @@ -120,7 +135,8 @@ def get_usage(self, request = Request( method=HttpMethod.GET, route=route, - query_params=params + query_params=params, + headers=self._get_aispm_headers() ) response = self.client._send_request_sync(request) self._handle_response(response) @@ -148,7 +164,8 @@ def get_alerts(self, request = Request( method=HttpMethod.GET, route=route, - query_params=params + query_params=params, + headers=self._get_aispm_headers() ) response = self.client._send_request_sync(request) self._handle_response(response) @@ -179,7 +196,8 @@ def delete_aws_config(self, name: str) -> None: """ request = Request( method=HttpMethod.DELETE, - route=f"v1/admin/aispm/config/aws/{name}" + route=f"v1/admin/aispm/config/aws/{name}", + headers=self._get_aispm_headers() ) response = self.client._send_request_sync(request) self._handle_response(response) @@ -196,7 +214,8 @@ def get_azure_config(self) -> Dict: """ request = Request( method=HttpMethod.GET, - route="v1/admin/aispm/config/azure" + route="v1/admin/aispm/config/azure", + headers=self._get_aispm_headers() ) response = self.client._send_request_sync(request) self._handle_response(response) diff --git a/pyproject.toml b/pyproject.toml index 3fa7ff8..3ec080d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,13 +1,11 @@ -[project.urls] -"Homepage" = "https://getjavelin.io" - [tool.poetry] name = "javelin-sdk" -version = "RELEASE_VERSION" +version = "0.2.7-dev" description = "Python client for Javelin" authors = ["Sharath Rajasekar "] readme = "README.md" license = "Apache-2.0" +homepage = "https://dev.highflame.dev/sign-in" packages = [ { include = "javelin_cli" }, { include = "javelin_sdk" }, From 6f17fa27b80df213d83293154330e1eddf1340bb Mon Sep 17 00:00:00 2001 From: Darshana Date: Wed, 17 Dec 2025 21:00:18 +0530 Subject: [PATCH 06/11] fix: extract account_id from org metadata for AISPM auth --- javelin_cli/_internal/commands.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/javelin_cli/_internal/commands.py b/javelin_cli/_internal/commands.py index 7d510eb..1a3c415 100644 --- a/javelin_cli/_internal/commands.py +++ b/javelin_cli/_internal/commands.py @@ -55,11 +55,21 @@ def get_javelin_client_aispm(): selected_gateway = gateways[0] base_url = selected_gateway["base_url"] - # Get account_id from role_arn or account_id field + # Get organization metadata (where account_id might be stored) + organization = cache_data.get("memberships", {}).get("data", [{}])[0].get("organization", {}) + org_metadata = organization.get("public_metadata", {}) + + # Get account_id from multiple possible locations (in order of preference): + # 1. Gateway's account_id field + # 2. Organization's public_metadata account_id + # 3. Extract from role_arn if provided account_id = selected_gateway.get("account_id") + if not account_id: + account_id = org_metadata.get("account_id") + role_arn = selected_gateway.get("role_arn") - # Extract account_id from role ARN if provided (format: arn:aws:iam::ACCOUNT_ID:role/ROLE_NAME) + # Extract account_id from role ARN if still not found (format: arn:aws:iam::ACCOUNT_ID:role/ROLE_NAME) if role_arn and not account_id: try: parts = role_arn.split(":") From 284c0fa0b92c0997cbb0cd172a4b7ed8e872e8b7 Mon Sep 17 00:00:00 2001 From: Darshana Date: Wed, 17 Dec 2025 21:00:49 +0530 Subject: [PATCH 07/11] fix: improve error handling for AISPM API responses --- javelin_sdk/services/aispm_service.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/javelin_sdk/services/aispm_service.py b/javelin_sdk/services/aispm_service.py index b874a52..05ed670 100644 --- a/javelin_sdk/services/aispm_service.py +++ b/javelin_sdk/services/aispm_service.py @@ -26,7 +26,12 @@ def _get_aispm_headers(self) -> Dict[str, str]: def _handle_response(self, response: Response) -> None: if response.status_code >= 400: - error = response.json().get("error", "Unknown error") + try: + error_data = response.json() + # Handle different error response formats + error = error_data.get("error") or error_data.get("message") or str(error_data) + except: + error = f"HTTP {response.status_code}: {response.text}" raise Exception(f"API error: {error}") # Customer Methods @@ -51,7 +56,12 @@ def get_customer(self) -> CustomerResponse: ) response = self.client._send_request_sync(request) self._handle_response(response) - return CustomerResponse(**response.json()) + response_data = response.json() + # Check if response indicates failure (even with 200 status) + if isinstance(response_data, dict) and response_data.get("success") is False: + error_msg = response_data.get("message") or response_data.get("error") or "Request failed" + raise Exception(f"API error: {error_msg}") + return CustomerResponse(**response_data) def update_customer(self, customer: Customer) -> CustomerResponse: request = Request( From 079e636d147cdb3d249ee9a7f1470c20d474af4b Mon Sep 17 00:00:00 2001 From: Darshana Date: Wed, 17 Dec 2025 21:28:50 +0530 Subject: [PATCH 08/11] fix: reformat javelin_cli/_internal/commands.py using black cmd --- javelin_cli/_internal/commands.py | 74 ++++++++++++++++++++++--------- 1 file changed, 53 insertions(+), 21 deletions(-) diff --git a/javelin_cli/_internal/commands.py b/javelin_cli/_internal/commands.py index 1a3c415..ac47d55 100644 --- a/javelin_cli/_internal/commands.py +++ b/javelin_cli/_internal/commands.py @@ -30,7 +30,13 @@ Secrets, Template, Templates, - Customer, AWSConfig, AzureConfig, UsageResponse, AlertResponse,Request,HttpMethod, + Customer, + AWSConfig, + AzureConfig, + UsageResponse, + AlertResponse, + Request, + HttpMethod, ) @@ -47,18 +53,26 @@ def get_javelin_client_aispm(): cache_data = json.load(json_file) # Retrieve the list of gateways - gateways = cache_data.get("memberships", {}).get("data", [{}])[0].get("organization", {}).get("public_metadata", {}).get("Gateways", []) + gateways = ( + cache_data.get("memberships", {}) + .get("data", [{}])[0] + .get("organization", {}) + .get("public_metadata", {}) + .get("Gateways", []) + ) if not gateways: raise ValueError("No gateways found in the configuration.") # Automatically select the first gateway (index 0) selected_gateway = gateways[0] base_url = selected_gateway["base_url"] - + # Get organization metadata (where account_id might be stored) - organization = cache_data.get("memberships", {}).get("data", [{}])[0].get("organization", {}) + organization = ( + cache_data.get("memberships", {}).get("data", [{}])[0].get("organization", {}) + ) org_metadata = organization.get("public_metadata", {}) - + # Get account_id from multiple possible locations (in order of preference): # 1. Gateway's account_id field # 2. Organization's public_metadata account_id @@ -66,9 +80,9 @@ def get_javelin_client_aispm(): account_id = selected_gateway.get("account_id") if not account_id: account_id = org_metadata.get("account_id") - + role_arn = selected_gateway.get("role_arn") - + # Extract account_id from role ARN if still not found (format: arn:aws:iam::ACCOUNT_ID:role/ROLE_NAME) if role_arn and not account_id: try: @@ -77,17 +91,17 @@ def get_javelin_client_aispm(): account_id = parts[4] except (IndexError, AttributeError): pass - + javelin_api_key = selected_gateway.get("api_key_value", "placeholder") - + # Initialize and return the JavelinClient config = JavelinConfig( base_url=base_url, javelin_api_key=javelin_api_key, ) - + client = JavelinClient(config) - + # Store account_id in client for AISPM service to use if account_id: client._aispm_account_id = account_id @@ -110,7 +124,13 @@ def get_javelin_client(): cache_data = json.load(json_file) # Retrieve the list of gateways - gateways = cache_data.get("memberships", {}).get("data", [{}])[0].get("organization", {}).get("public_metadata", {}).get("Gateways", []) + gateways = ( + cache_data.get("memberships", {}) + .get("data", [{}])[0] + .get("organization", {}) + .get("public_metadata", {}) + .get("Gateways", []) + ) if not gateways: raise ValueError("No gateways found in the configuration.") @@ -159,13 +179,11 @@ def create_customer(args): name=args.name, description=args.description, metrics_interval=args.metrics_interval, - security_interval=args.security_interval + security_interval=args.security_interval, ) return client.aispm.create_customer(customer) - - def get_customer(args): """ Gets customer details using the AISPM service. @@ -173,7 +191,7 @@ def get_customer(args): try: client = get_javelin_client_aispm() response = client.aispm.get_customer() - + # Pretty print the response for CLI output formatted_response = { "name": response.name, @@ -182,9 +200,9 @@ def get_customer(args): "security_interval": response.security_interval, "status": response.status, "created_at": response.created_at.isoformat(), - "modified_at": response.modified_at.isoformat() + "modified_at": response.modified_at.isoformat(), } - + print(json.dumps(formatted_response, indent=2)) except Exception as e: print(f"Error getting customer: {e}") @@ -200,6 +218,7 @@ def configure_aws(args): except Exception as e: print(f"Error configuring AWS: {e}") + def get_aws_config(args): """ Gets AWS configurations using the AISPM service. @@ -213,8 +232,10 @@ def get_aws_config(args): except Exception as e: print(f"Error getting AWS configurations: {e}") + # Add these functions to commands.py + def delete_aws_config(args): """ Deletes an AWS configuration. @@ -226,6 +247,7 @@ def delete_aws_config(args): except Exception as e: print(f"Error deleting AWS config: {e}") + def get_azure_config(args): """ Gets Azure configurations using the AISPM service. @@ -238,6 +260,7 @@ def get_azure_config(args): except Exception as e: print(f"Error getting Azure config: {e}") + def configure_azure(args): try: client = get_javelin_client_aispm() @@ -248,6 +271,7 @@ def configure_azure(args): except Exception as e: print(f"Error configuring Azure: {e}") + def get_usage(args): try: client = get_javelin_client_aispm() @@ -255,12 +279,13 @@ def get_usage(args): provider=args.provider, cloud_account=args.account, model=args.model, - region=args.region + region=args.region, ) print(json.dumps(usage.dict(), indent=2)) except Exception as e: print(f"Error getting usage: {e}") + def get_alerts(args): try: client = get_javelin_client_aispm() @@ -268,12 +293,13 @@ def get_alerts(args): provider=args.provider, cloud_account=args.account, model=args.model, - region=args.region + region=args.region, ) print(json.dumps(alerts.dict(), indent=2)) except Exception as e: print(f"Error getting alerts: {e}") + def create_gateway(args): try: client = get_javelin_client() @@ -325,7 +351,13 @@ def list_gateways(args): cache_data = json.load(json_file) # Retrieve the list of gateways - gateways = cache_data.get("memberships", {}).get("data", [{}])[0].get("organization", {}).get("public_metadata", {}).get("Gateways", []) + gateways = ( + cache_data.get("memberships", {}) + .get("data", [{}])[0] + .get("organization", {}) + .get("public_metadata", {}) + .get("Gateways", []) + ) if not gateways: print("No gateways found in the configuration.") return From 7ad667bdc47cbf5cdabad2659fd40b983c2a783c Mon Sep 17 00:00:00 2001 From: Darshana Date: Wed, 17 Dec 2025 21:37:36 +0530 Subject: [PATCH 09/11] refactor: move AISPM url construction logic to _construct_url --- javelin_sdk/client.py | 51 ++++++++++++++++++------------------------- 1 file changed, 21 insertions(+), 30 deletions(-) diff --git a/javelin_sdk/client.py b/javelin_sdk/client.py index 75a8240..32b5eb6 100644 --- a/javelin_sdk/client.py +++ b/javelin_sdk/client.py @@ -91,27 +91,19 @@ def close(self): self._client.close() def _prepare_request(self, request: Request) -> tuple: - if request.route.startswith("v1/admin/aispm"): - url = f"{self.config.base_url.rstrip('/')}/{request.route}" - if request.query_params: - query_string = "&".join(f"{k}={v}" for k, v in request.query_params.items()) - url += f"?{query_string}" - - - else: - url = self._construct_url( - gateway_name=request.gateway, - provider_name=request.provider, - route_name=request.route, - secret_name=request.secret, - template_name=request.template, - trace=request.trace, - query=request.is_query, - archive=request.archive, - query_params=request.query_params, - is_transformation_rules=request.is_transformation_rules, - is_reload=request.is_reload, - ) + url = self._construct_url( + gateway_name=request.gateway, + provider_name=request.provider, + route_name=request.route, + secret_name=request.secret, + template_name=request.template, + trace=request.trace, + query=request.is_query, + archive=request.archive, + query_params=request.query_params, + is_transformation_rules=request.is_transformation_rules, + is_reload=request.is_reload, + ) headers = {**self._headers, **(request.headers or {})} @@ -163,6 +155,14 @@ def _construct_url( is_transformation_rules: bool = False, is_reload: bool = False, ) -> str: + # Handle AISPM routes: they use the route directly with base_url + if route_name and route_name.startswith("v1/admin/aispm"): + url = f"{self.config.base_url.rstrip('/')}/{route_name}" + if query_params: + query_string = "&".join(f"{k}={v}" for k, v in query_params.items()) + url += f"?{query_string}" + return url + url_parts = [self.base_url] @@ -226,15 +226,6 @@ def _construct_url( return url - def _construct_aispm_url(self, request: Request) -> str: - url = request.route - - if request.query_params: - query_string = "&".join(f"{k}={v}" for k, v in request.query_params.items()) - url += f"?{query_string}" - - return url - # Gateway methods create_gateway = lambda self, gateway: self.gateway_service.create_gateway(gateway) acreate_gateway = lambda self, gateway: self.gateway_service.acreate_gateway( From 5590146750f2a1321efacdef6c7596fbb21f841a Mon Sep 17 00:00:00 2001 From: Darshana Date: Thu, 18 Dec 2025 20:46:35 +0530 Subject: [PATCH 10/11] fix: resolve merge conflicts --- javelin_cli/_internal/commands.py | 42 +- javelin_cli/cli.py | 162 +-- javelin_sdk/client.py | 1650 +++++++++++++++++++++++++---- javelin_sdk/models.py | 310 ++++-- 4 files changed, 1746 insertions(+), 418 deletions(-) diff --git a/javelin_cli/_internal/commands.py b/javelin_cli/_internal/commands.py index ac47d55..99f9717 100644 --- a/javelin_cli/_internal/commands.py +++ b/javelin_cli/_internal/commands.py @@ -1,26 +1,19 @@ import json -import os -import asyncio - from pathlib import Path -from pydantic import ValidationError - from javelin_sdk.client import JavelinClient from javelin_sdk.exceptions import ( BadRequest, - GatewayNotFoundError, NetworkError, - ProviderNotFoundError, - RouteNotFoundError, - SecretNotFoundError, - TemplateNotFoundError, UnauthorizedError, ) from javelin_sdk.models import ( + AWSConfig, + AlertResponse, Gateway, GatewayConfig, JavelinConfig, + Customer, Model, Provider, ProviderConfig, @@ -29,15 +22,11 @@ Secret, Secrets, Template, - Templates, - Customer, - AWSConfig, + TemplateConfig, AzureConfig, UsageResponse, - AlertResponse, - Request, - HttpMethod, ) +from pydantic import ValidationError def get_javelin_client_aispm(): @@ -399,7 +388,7 @@ def update_gateway(args): name=args.name, type=args.type, enabled=args.enabled, config=config ) - client.update_gateway(args.name, gateway_data) + client.update_gateway(gateway) print(f"Gateway '{args.name}' updated successfully.") except UnauthorizedError as e: @@ -447,7 +436,8 @@ def create_provider(args): config=config, ) - # Assuming client.create_provider accepts a Pydantic model and handles it internally + # Assuming client.create_provider accepts a Pydantic model and handles it + # internally client.create_provider(provider) print(f"Provider '{args.name}' created successfully.") @@ -513,7 +503,7 @@ def update_provider(args): config=config, ) - result = client.update_provider(provider) + client.update_provider(provider) print(f"Provider '{args.name}' updated successfully.") except json.JSONDecodeError as e: @@ -564,7 +554,8 @@ def create_route(args): config=config, ) - # Assuming client.create_route accepts a Pydantic model and handles it internally + # Assuming client.create_route accepts a Pydantic model and handles it + # internally client.create_route(route) print(f"Route '{args.name}' created successfully.") @@ -631,7 +622,7 @@ def update_route(args): config=config, ) - result = client.update_route(route) + client.update_route(route) print(f"Route '{args.name}' updated successfully.") except json.JSONDecodeError as e: @@ -659,9 +650,6 @@ def delete_route(args): print(f"Unexpected error: {e}") -from collections import namedtuple - - def create_secret(args): try: client = get_javelin_client() @@ -769,7 +757,7 @@ def update_secret(args): enabled=args.enabled if args.enabled is not None else None, ) - result = client.update_secret(secret) + client.update_secret(secret) print(f"Secret '{args.api_key}' updated successfully.") except UnauthorizedError as e: @@ -819,7 +807,7 @@ def create_template(args): config=config, ) - result = client.create_template(template) + client.create_template(template) print(f"Template '{args.name}' created successfully.") except json.JSONDecodeError as e: @@ -886,7 +874,7 @@ def update_template(args): config=config, ) - result = client.update_template(template) + client.update_template(template) print(f"Template '{args.name}' updated successfully.") except json.JSONDecodeError as e: diff --git a/javelin_cli/cli.py b/javelin_cli/cli.py index 047e519..b74b243 100644 --- a/javelin_cli/cli.py +++ b/javelin_cli/cli.py @@ -1,15 +1,14 @@ import argparse +import http.server import importlib.metadata -import os -import webbrowser -from pathlib import Path import json -import http.server +import random import socketserver +import sys import threading import urllib.parse -import random -import sys +import webbrowser +from pathlib import Path import requests @@ -66,24 +65,24 @@ def check_permissions(): """Check if user has superadmin permissions""" home_dir = Path.home() cache_file = home_dir / ".javelin" / "cache.json" - + if not cache_file.exists(): print("❌ Not authenticated. Please run 'javelin auth' first.") sys.exit(1) - + try: with open(cache_file) as f: cache = json.load(f) # Check memberships - memberships = cache.get('memberships', {}).get('data', []) + memberships = cache.get("memberships", {}).get("data", []) for membership in memberships: - if membership.get('role') == 'org:superadmin': + if membership.get("role") == "org:superadmin": return True - + print("❌ Permission denied: Javelin CLI requires superadmin privileges.") print("Please contact your administrator for access.") sys.exit(1) - + except Exception as e: print(f"❌ Error reading credentials: {e}") sys.exit(1) @@ -98,7 +97,10 @@ def main(): parser = argparse.ArgumentParser( description="The CLI for Javelin.", formatter_class=argparse.RawTextHelpFormatter, - epilog="See https://docs.getjavelin.io/docs/javelin-python/cli for more detailed documentation.", + epilog=( + "See https://docs.getjavelin.io/docs/javelin-python/cli for more " + "detailed documentation." + ), ) parser.add_argument( "--version", action="version", version=f"Javelin CLI v{package_version}" @@ -108,9 +110,13 @@ def main(): # Auth command auth_parser = subparsers.add_parser("auth", help="Authenticate with Javelin.") - auth_parser.add_argument("--force", action="store_true", help="Force re-authentication, overriding existing credentials") - auth_parser.set_defaults(func=authenticate) - #aispm CRUD + auth_parser.add_argument( + "--force", + action="store_true", + help="Force re-authentication, overriding existing credentials", + ) + auth_parser.set_defaults(func=authenticate) + # AISPM commands aispm_parser = subparsers.add_parser("aispm", help="Manage AISPM functionality") aispm_subparsers = aispm_parser.add_subparsers() @@ -122,54 +128,57 @@ def main(): customer_create = customer_subparsers.add_parser("create", help="Create customer") customer_create.add_argument("--name", required=True, help="Customer name") customer_create.add_argument("--description", help="Customer description") - customer_create.add_argument("--metrics-interval", default="5m", help="Metrics interval") - customer_create.add_argument("--security-interval", default="1m", help="Security interval") + customer_create.add_argument( + "--metrics-interval", default="5m", help="Metrics interval" + ) + customer_create.add_argument( + "--security-interval", default="1m", help="Security interval" + ) customer_create.set_defaults(func=create_customer) customer_get = customer_subparsers.add_parser("get", help="Get customer details") customer_get.set_defaults(func=get_customer) - # Cloud config commands - config_parser = aispm_subparsers.add_parser("config", help="Manage cloud configurations") + config_parser = aispm_subparsers.add_parser( + "config", help="Manage cloud configurations" + ) config_subparsers = config_parser.add_subparsers() aws_parser = config_subparsers.add_parser("aws", help="Configure AWS") - azure_parser = config_subparsers.add_parser("azure", help="Configure Azure") - - #azure_parser.add_argument("--config", type=str, required=True, help="Azure config JSON") - #azure_parser.set_defaults(func=configure_azure) - aws_subparsers = aws_parser.add_subparsers() - - # GET AWS Config aws_get_parser = aws_subparsers.add_parser("get", help="Get AWS configuration") aws_get_parser.set_defaults(func=get_aws_config) - # Existing AWS Config (for creating) - aws_config_parser = aws_subparsers.add_parser("create", help="Configure AWS") - aws_config_parser.add_argument("--config", type=str, required=True, help="AWS config JSON") - aws_config_parser.set_defaults(func=configure_aws) + aws_create_parser = aws_subparsers.add_parser("create", help="Configure AWS") + aws_create_parser.add_argument( + "--config", type=str, required=True, help="AWS config JSON" + ) + aws_create_parser.set_defaults(func=configure_aws) - aws_delete_parser = aws_subparsers.add_parser("delete", help="Delete AWS configuration") - aws_delete_parser.add_argument("--name", type=str, required=True, help="Name of AWS configuration to delete") + aws_delete_parser = aws_subparsers.add_parser( + "delete", help="Delete AWS configuration" + ) + aws_delete_parser.add_argument( + "--name", type=str, required=True, help="Name of AWS configuration to delete" + ) aws_delete_parser.set_defaults(func=delete_aws_config) - azure_subparsers = azure_parser.add_subparsers(dest='azure_command') - -# Get Azure Config - azure_get_parser = azure_subparsers.add_parser("get", help="Get Azure configuration") + azure_subparsers = azure_parser.add_subparsers(dest="azure_command") + azure_get_parser = azure_subparsers.add_parser( + "get", help="Get Azure configuration" + ) azure_get_parser.set_defaults(func=get_azure_config) - # Create Azure Config azure_create_parser = azure_subparsers.add_parser("create", help="Configure Azure") - azure_create_parser.add_argument("--config", type=str, required=True, help="Azure config JSON") + azure_create_parser.add_argument( + "--config", type=str, required=True, help="Azure config JSON" + ) azure_create_parser.set_defaults(func=configure_azure) - - # Usage metrics + # Usage metrics usage_parser = aispm_subparsers.add_parser("usage", help="Get usage metrics") usage_parser.add_argument("--provider", help="Cloud provider") usage_parser.add_argument("--account", help="Cloud account name") @@ -180,14 +189,17 @@ def main(): # Alerts alerts_parser = aispm_subparsers.add_parser("alerts", help="Get alerts") alerts_parser.add_argument("--provider", help="Cloud provider") - alerts_parser.add_argument("--account", help="Cloud account name") + alerts_parser.add_argument("--account", help="Cloud account name") alerts_parser.add_argument("--model", help="Model ID") alerts_parser.add_argument("--region", help="Region") alerts_parser.set_defaults(func=get_alerts) # Gateway CRUD gateway_parser = subparsers.add_parser( "gateway", - help="Manage gateways: create, list, update, and delete gateways for routing requests.", + help=( + "Manage gateways: create, list, update, and delete gateways for " + "routing requests." + ), ) gateway_subparsers = gateway_parser.add_subparsers() @@ -239,7 +251,10 @@ def main(): # Provider CRUD provider_parser = subparsers.add_parser( "provider", - help="Manage model providers: configure and manage large language model providers.", + help=( + "Manage model providers: configure and manage large language model " + "providers." + ), ) provider_subparsers = provider_parser.add_subparsers() @@ -297,7 +312,10 @@ def main(): # Route CRUD route_parser = subparsers.add_parser( "route", - help="Manage routing rules: define and control the routing logic for handling requests.", + help=( + "Manage routing rules: define and control the routing logic for " + "handling requests." + ), ) route_subparsers = route_parser.add_subparsers() @@ -355,7 +373,10 @@ def main(): # Secret CRUD secret_parser = subparsers.add_parser( "secret", - help="Manage API secrets: securely handle and manage API keys and credentials for access control.", + help=( + "Manage API secrets: securely handle and manage API keys and " + "credentials for access control." + ), ) secret_subparsers = secret_parser.add_subparsers() @@ -409,7 +430,10 @@ def main(): # Template CRUD template_parser = subparsers.add_parser( "template", - help="Manage templates: configure and manage templates for sensitive data protection.", + help=( + "Manage templates: configure and manage templates for sensitive " + "data protection." + ), ) template_subparsers = template_parser.add_subparsers() @@ -471,7 +495,7 @@ def main(): args = parser.parse_args() - + if hasattr(args, "func"): # Skip permission check for auth command if args.func != authenticate: @@ -490,9 +514,8 @@ def authenticate(args): print("✅ User is already authenticated!") print("Use --force to re-authenticate and override existing cache.") return - + default_url = "https://dev.highflame.dev/" - print(" O") print(" /|\\") print(" / \\ ========> Welcome to Javelin! 🚀") @@ -500,7 +523,7 @@ def authenticate(args): print("Press Enter to open the default login URL in your browser...") print(f"Default URL: {default_url}") print("Or enter a new URL (leave blank to use the default): ", end="") - + new_url = input().strip() url_to_open = new_url if new_url else default_url @@ -508,14 +531,14 @@ def authenticate(args): redirect_uri = f"http://localhost:{port}" encoded_redirect = urllib.parse.quote(redirect_uri) - + url_to_open = f"{url_to_open}sign-in?localhost_url={encoded_redirect}&cli=1" print(f"\n🚀 Opening {url_to_open} in your browser...") webbrowser.open(url_to_open) print("\n⚡ Waiting for authentication... (Server is running)") - + server_thread.join() if cache_file.exists(): @@ -527,14 +550,15 @@ def authenticate(args): def start_local_server(): # Find an available port port = random.randint(8000, 9000) - + class AuthHandler(http.server.SimpleHTTPRequestHandler): def log_message(self, format, *args): pass + def end_headers(self): - self.send_header('Access-Control-Allow-Origin', '*') - self.send_header('Access-Control-Allow-Methods', 'GET, OPTIONS') - self.send_header('Access-Control-Allow-Headers', 'Content-Type') + self.send_header("Access-Control-Allow-Origin", "*") + self.send_header("Access-Control-Allow-Methods", "GET, OPTIONS") + self.send_header("Access-Control-Allow-Headers", "Content-Type") super().end_headers() def do_OPTIONS(self): @@ -544,20 +568,22 @@ def do_OPTIONS(self): def do_GET(self): query = urllib.parse.urlparse(self.path).query params = urllib.parse.parse_qs(query) - - if 'secrets' in params: - secrets = params['secrets'][0] + + if "secrets" in params: + secrets = params["secrets"][0] store_credentials(secrets) self.send_response(200) - self.send_header('Content-type', 'text/html') + self.send_header("Content-type", "text/html") self.end_headers() - self.wfile.write(b"Authentication successful. You can close this window.") - + self.wfile.write( + b"Authentication successful. You can close this window." + ) + # Shutdown the server threading.Thread(target=self.server.shutdown).start() else: self.send_response(400) - self.send_header('Content-type', 'text/html') + self.send_header("Content-type", "text/html") self.end_headers() self.wfile.write(b"Invalid request. Missing 'secrets' parameter.") @@ -568,7 +594,7 @@ def run_server(): server_thread = threading.Thread(target=run_server) server_thread.start() - + return server_thread, port @@ -576,9 +602,9 @@ def store_credentials(secrets): home_dir = Path.home() javelin_dir = home_dir / ".javelin" javelin_dir.mkdir(exist_ok=True) - + cache_file = javelin_dir / "cache.json" - + try: cache_data = json.loads(secrets) with open(cache_file, "w") as f: @@ -601,4 +627,4 @@ def get_profile(url): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/javelin_sdk/client.py b/javelin_sdk/client.py index 32b5eb6..4d5e4ce 100644 --- a/javelin_sdk/client.py +++ b/javelin_sdk/client.py @@ -1,41 +1,97 @@ +import functools +import inspect +import json +import re +import asyncio from typing import Any, Coroutine, Dict, Optional, Union -from urllib.parse import urljoin +from urllib.parse import unquote, urljoin, urlparse, urlunparse import httpx +from opentelemetry.semconv._incubating.attributes import gen_ai_attributes +from opentelemetry.trace import SpanKind, Status, StatusCode -from javelin_sdk.chat_completions import Chat, Completions -from javelin_sdk.models import ( - HttpMethod, - JavelinConfig, - Request, -) +from javelin_sdk.chat_completions import Chat, Completions, Embeddings +from javelin_sdk.models import HttpMethod, JavelinConfig, Request from javelin_sdk.services.gateway_service import GatewayService +from javelin_sdk.services.modelspec_service import ModelSpecService from javelin_sdk.services.provider_service import ProviderService from javelin_sdk.services.route_service import RouteService from javelin_sdk.services.secret_service import SecretService from javelin_sdk.services.template_service import TemplateService from javelin_sdk.services.trace_service import TraceService from javelin_sdk.services.aispm_service import AISPMService - +from javelin_sdk.services.guardrails_service import GuardrailsService +from javelin_sdk.tracing_setup import configure_span_exporter API_BASEURL = "https://api-dev.javelin.live" API_BASE_PATH = "/v1" API_TIMEOUT = 10 +class JavelinRequestWrapper: + """A wrapper around Botocore's request object to store additional metadata.""" + + def __init__(self, original_request, span): + self.original_request = original_request + self.span = span + + class JavelinClient: + BEDROCK_RUNTIME_OPERATIONS = frozenset( + {"InvokeModel", "InvokeModelWithResponseStream", "Converse", "ConverseStream"} + ) + PROFILE_ARN_PATTERN = re.compile( + r"/model/arn:aws:bedrock:[^:]+:\d+:application-inference-profile/[^/]+" + ) + MODEL_ARN_PATTERN = re.compile( + r"/model/arn:aws:bedrock:[^:]+::foundation-model/[^/]+" + ) + + # Mapping provider_name to well-known gen_ai.system values + GEN_AI_SYSTEM_MAPPING = { + "openai": "openai", + "azureopenai": "az.ai.openai", + "bedrock": "aws.bedrock", + "gemini": "gemini", + "deepseek": "deepseek", + "cohere": "cohere", + "mistral_ai": "mistral_ai", + "anthropic": "anthropic", + "vertex_ai": "vertex_ai", + "perplexity": "perplexity", + "groq": "groq", + "ibm": "ibm.watsonx.ai", + "xai": "xai", + } + + # Mapping method names to well-known operation names + GEN_AI_OPERATION_MAPPING = { + "chat.completions.create": "chat", + "completions.create": "text_completion", + "embeddings.create": "embeddings", + "images.generate": "image_generation", + "images.edit": "image_editing", + "images.create_variation": "image_variation", + } + def __init__(self, config: JavelinConfig) -> None: self.config = config self.base_url = urljoin(config.base_url, config.api_version or "/v1") - self._headers = { - "x-api-key": config.javelin_api_key, - } + + self._headers = {"x-javelin-apikey": config.javelin_api_key} if config.llm_api_key: self._headers["Authorization"] = f"Bearer {config.llm_api_key}" if config.javelin_virtualapikey: self._headers["x-javelin-virtualapikey"] = config.javelin_virtualapikey self._client = None self._aclient = None + self.bedrock_client = None + self.bedrock_runtime_client = None + self.bedrock_session = None + self.default_bedrock_route = None + self.use_default_bedrock_route = False + self.client_is_async = None + self.openai_base_url = None self.gateway_service = GatewayService(self) self.provider_service = ProviderService(self) @@ -43,9 +99,19 @@ def __init__(self, config: JavelinConfig) -> None: self.secret_service = SecretService(self) self.template_service = TemplateService(self) self.trace_service = TraceService(self) + self.modelspec_service = ModelSpecService(self) + self.guardrails_service = GuardrailsService(self) self.chat = Chat(self) self.completions = Completions(self) + self.embeddings = Embeddings(self) + + self.tracer = configure_span_exporter() + + self.patched_clients = set() # Track already patched clients + self.patched_methods = set() # Track already patched methods + + self.original_methods = {} self.aispm = AISPMService(self) @@ -57,7 +123,8 @@ def client(self): # This allows us to exclude x-api-key for AISPM requests self._client = httpx.Client( base_url=self.base_url, - timeout= self.config.timeout if self.config.timeout else API_TIMEOUT, + headers=self._headers, + timeout=self.config.timeout if self.config.timeout else API_TIMEOUT, ) return self._client @@ -90,6 +157,795 @@ def close(self): if self._client: self._client.close() + @staticmethod + def set_span_attribute_if_not_none(span, key, value): + """Helper function to set span attributes only if the value is not None.""" + if value is not None: + span.set_attribute(key, value) + + @staticmethod + def add_event_with_attributes(span, event_name, attributes): + """Helper function to add events only with non-None attributes.""" + filtered_attributes = {k: v for k, v in attributes.items() if v is not None} + if filtered_attributes: # Add event only if there are valid attributes + span.add_event(name=event_name, attributes=filtered_attributes) + + def _setup_client_headers(self, openai_client, route_name): + """Setup client headers and base URL.""" + + self.openai_base_url = openai_client.base_url + + openai_client.base_url = f"{self.base_url}" + + if not hasattr(openai_client, "_custom_headers"): + openai_client._custom_headers = {} + else: + pass + + openai_client._custom_headers.update(self._headers) + + if route_name is not None: + openai_client._custom_headers["x-javelin-route"] = route_name + + # Ensure the client uses the custom headers + if hasattr(openai_client, "default_headers"): + # Filter out None values and openai.Omit objects + filtered_headers = {} + for key, value in openai_client._custom_headers.items(): + if value is not None and not ( + hasattr(value, "__class__") and value.__class__.__name__ == "Omit" + ): + filtered_headers[key] = value + openai_client.default_headers.update(filtered_headers) + elif hasattr(openai_client, "_default_headers"): + # Filter out None values and openai.Omit objects + filtered_headers = {} + for key, value in openai_client._custom_headers.items(): + if value is not None and not ( + hasattr(value, "__class__") and value.__class__.__name__ == "Omit" + ): + filtered_headers[key] = value + openai_client._default_headers.update(filtered_headers) + else: + pass + + def _store_original_methods(self, openai_client, provider_name): + """Store original methods for the provider if not already stored.""" + if provider_name not in self.original_methods: + self.original_methods[provider_name] = { + "chat_completions_create": openai_client.chat.completions.create, + "completions_create": openai_client.completions.create, + "embeddings_create": openai_client.embeddings.create, + "images_generate": openai_client.images.generate, + "images_edit": openai_client.images.edit, + "images_create_variation": openai_client.images.create_variation, + } + + def _create_patched_method(self, method_name, original_method, openai_client): + """Create a patched method with tracing support.""" + if inspect.iscoroutinefunction(original_method): + + async def async_patched_method(*args, **kwargs): + return await self._execute_with_tracing( + original_method, method_name, args, kwargs, openai_client + ) + + return async_patched_method + else: + + def sync_patched_method(*args, **kwargs): + return self._execute_with_tracing( + original_method, method_name, args, kwargs, openai_client + ) + + return sync_patched_method + + def _execute_with_tracing( + self, + original_method, + method_name, + args, + kwargs, + openai_client, + ): + """Execute method with tracing support.""" + model = kwargs.get("model") + + self._setup_custom_headers(openai_client, model) + + operation_name = self.GEN_AI_OPERATION_MAPPING.get(method_name, method_name) + system_name = self.GEN_AI_SYSTEM_MAPPING.get( + self.provider_name, self.provider_name + ) + span_name = f"{operation_name} {model}" + + if self.tracer: + return self._execute_with_tracer( + original_method, + args, + kwargs, + span_name, + system_name, + operation_name, + model, + ) + else: + return self._execute_without_tracer(original_method, args, kwargs) + + def _setup_custom_headers(self, openai_client, model): + """Setup custom headers for the OpenAI client.""" + if model and hasattr(openai_client, "_custom_headers"): + openai_client._custom_headers["x-javelin-model"] = model + + if not hasattr(openai_client, "_custom_headers"): + return + + filtered_headers = self._filter_custom_headers(openai_client._custom_headers) + + if hasattr(openai_client, "default_headers"): + openai_client.default_headers.update(filtered_headers) + elif hasattr(openai_client, "_default_headers"): + openai_client._default_headers.update(filtered_headers) + + def _filter_custom_headers(self, custom_headers): + """Filter out None values and openai.Omit objects from custom headers.""" + filtered_headers = {} + for key, value in custom_headers.items(): + if value is not None and not self._is_omit_object(value): + filtered_headers[key] = value + return filtered_headers + + def _is_omit_object(self, value): + """Check if value is an openai.Omit object.""" + return hasattr(value, "__class__") and value.__class__.__name__ == "Omit" + + def _execute_with_tracer( + self, + original_method, + args, + kwargs, + span_name, + system_name, + operation_name, + model, + ): + """Execute method with tracer enabled.""" + if self.tracer is None: + return self._execute_without_tracer(original_method, args, kwargs) + + with self.tracer.start_as_current_span(span_name, kind=SpanKind.CLIENT) as span: + self._setup_span_attributes( + span, system_name, operation_name, model, kwargs + ) + try: + if inspect.iscoroutinefunction(original_method): + return asyncio.run( + self._async_execution(span, original_method, args, kwargs) + ) + else: + return self._sync_execution(span, original_method, args, kwargs) + except Exception as e: + span.set_status(Status(StatusCode.ERROR, str(e))) + span.set_attribute("is_exception", True) + raise + + def _execute_without_tracer(self, original_method, args, kwargs): + """Execute method without tracer.""" + if inspect.iscoroutinefunction(original_method): + return asyncio.run(original_method(*args, **kwargs)) + else: + return original_method(*args, **kwargs) + + async def _async_execution(self, span, original_method, args, kwargs): + """Execute async method with response capture.""" + response = await original_method(*args, **kwargs) + self._capture_response_details(span, response, kwargs, self.provider_name) + return response + + def _sync_execution(self, span, original_method, args, kwargs): + """Execute sync method with response capture.""" + response = original_method(*args, **kwargs) + self._capture_response_details(span, response, kwargs, self.provider_name) + return response + + def _setup_span_attributes(self, span, system_name, operation_name, model, kwargs): + """Setup span attributes for tracing.""" + span.set_attribute(gen_ai_attributes.GEN_AI_SYSTEM, system_name) + span.set_attribute(gen_ai_attributes.GEN_AI_OPERATION_NAME, operation_name) + span.set_attribute(gen_ai_attributes.GEN_AI_REQUEST_MODEL, model) + + # Request attributes + self.set_span_attribute_if_not_none( + span, + gen_ai_attributes.GEN_AI_REQUEST_MAX_TOKENS, + kwargs.get("max_completion_tokens"), + ) + self.set_span_attribute_if_not_none( + span, + gen_ai_attributes.GEN_AI_REQUEST_PRESENCE_PENALTY, + kwargs.get("presence_penalty"), + ) + self.set_span_attribute_if_not_none( + span, + gen_ai_attributes.GEN_AI_REQUEST_FREQUENCY_PENALTY, + kwargs.get("frequency_penalty"), + ) + self.set_span_attribute_if_not_none( + span, + gen_ai_attributes.GEN_AI_REQUEST_STOP_SEQUENCES, + json.dumps(kwargs.get("stop", [])) if kwargs.get("stop") else None, + ) + self.set_span_attribute_if_not_none( + span, + gen_ai_attributes.GEN_AI_REQUEST_TEMPERATURE, + kwargs.get("temperature"), + ) + self.set_span_attribute_if_not_none( + span, gen_ai_attributes.GEN_AI_REQUEST_TOP_K, kwargs.get("top_k") + ) + self.set_span_attribute_if_not_none( + span, gen_ai_attributes.GEN_AI_REQUEST_TOP_P, kwargs.get("top_p") + ) + + def _capture_response_details(self, span, response, kwargs, system_name): + """Capture response details for tracing.""" + try: + response_data = self._extract_response_data(response) + if response_data is None: + span.set_attribute("javelin.response.body", str(response)) + return + + self._set_basic_response_attributes(span, response_data) + self._set_usage_attributes(span, response_data) + self._add_message_events(span, kwargs, system_name) + self._add_choice_events(span, response_data, system_name) + + except Exception as e: + span.set_attribute("javelin.response.body", str(response)) + span.set_attribute("javelin.error", str(e)) + + def _extract_response_data(self, response): + """Extract response data from various response types.""" + if hasattr(response, "to_dict"): + return self._extract_from_to_dict(response) + elif hasattr(response, "model_dump"): + return self._extract_from_model_dump(response) + elif hasattr(response, "dict"): + return self._extract_from_dict(response) + elif isinstance(response, dict): + return response + elif hasattr(response, "__iter__") and not isinstance( + response, (str, bytes, dict, list) + ): + return self._handle_streaming_response(response) + else: + return self._extract_from_json(response) + + def _extract_from_to_dict(self, response): + """Extract data using to_dict method.""" + try: + response_data = response.to_dict() + return response_data if response_data else None + except Exception: + return None + + def _extract_from_model_dump(self, response): + """Extract data using model_dump method.""" + try: + return response.model_dump() + except Exception: + return None + + def _extract_from_dict(self, response): + """Extract data using dict method.""" + try: + return response.dict() + except Exception: + return None + + def _extract_from_json(self, response): + """Extract data by parsing JSON string.""" + try: + return json.loads(str(response)) + except (TypeError, ValueError): + return None + + def _handle_streaming_response(self, response): + """Handle streaming response data.""" + response_data = { + "object": "thread.message.delta", + "streamed_text": "", + } + + for index, chunk in enumerate(response): + if hasattr(chunk, "to_dict"): + chunk = chunk.to_dict() + + if not isinstance(chunk, dict): + continue + + choices = chunk.get("choices", []) + if not choices: + continue + + delta_dict = choices[0].get("delta", {}) + streamed_text = delta_dict.get("content", "") + response_data["streamed_text"] += streamed_text + + return response_data + + def _set_basic_response_attributes(self, span, response_data): + """Set basic response attributes on span.""" + self.set_span_attribute_if_not_none( + span, gen_ai_attributes.GEN_AI_RESPONSE_MODEL, response_data.get("model") + ) + self.set_span_attribute_if_not_none( + span, gen_ai_attributes.GEN_AI_RESPONSE_ID, response_data.get("id") + ) + self.set_span_attribute_if_not_none( + span, + gen_ai_attributes.GEN_AI_OPENAI_REQUEST_SERVICE_TIER, + response_data.get("service_tier"), + ) + self.set_span_attribute_if_not_none( + span, + gen_ai_attributes.GEN_AI_OPENAI_RESPONSE_SYSTEM_FINGERPRINT, + response_data.get("system_fingerprint"), + ) + + finish_reasons = [ + choice.get("finish_reason") + for choice in response_data.get("choices", []) + if choice.get("finish_reason") + ] + self.set_span_attribute_if_not_none( + span, + gen_ai_attributes.GEN_AI_RESPONSE_FINISH_REASONS, + json.dumps(finish_reasons) if finish_reasons else None, + ) + + def _set_usage_attributes(self, span, response_data): + """Set usage attributes on span.""" + usage = response_data.get("usage", {}) + self.set_span_attribute_if_not_none( + span, + gen_ai_attributes.GEN_AI_USAGE_INPUT_TOKENS, + usage.get("prompt_tokens"), + ) + self.set_span_attribute_if_not_none( + span, + gen_ai_attributes.GEN_AI_USAGE_OUTPUT_TOKENS, + usage.get("completion_tokens"), + ) + + def _add_message_events(self, span, kwargs, system_name): + """Add message events to span.""" + messages = kwargs.get("messages", []) + + system_message = next( + (msg.get("content") for msg in messages if msg.get("role") == "system"), + None, + ) + self.add_event_with_attributes( + span, + "gen_ai.system.message", + {"gen_ai.system": system_name, "content": system_message}, + ) + + user_message = next( + (msg.get("content") for msg in messages if msg.get("role") == "user"), None + ) + self.add_event_with_attributes( + span, + "gen_ai.user.message", + {"gen_ai.system": system_name, "content": user_message}, + ) + + def _add_choice_events(self, span, response_data, system_name): + """Add choice events to span.""" + choices = response_data.get("choices", []) + for index, choice in enumerate(choices): + choice_attributes = {"gen_ai.system": system_name, "index": index} + message = choice.pop("message", {}) + choice.update(message) + + for key, value in choice.items(): + if isinstance(value, (dict, list)): + value = json.dumps(value) + choice_attributes[key] = value if value is not None else None + + self.add_event_with_attributes(span, "gen_ai.choice", choice_attributes) + + def _patch_methods(self, openai_client, provider_name): + """Patch client methods with tracing support.""" + + def get_nested_attr(obj, attr_path): + attrs = attr_path.split(".") + for attr in attrs: + obj = getattr(obj, attr) + return obj + + for method_name in [ + "chat.completions.create", + "completions.create", + "embeddings.create", + ]: + method_ref = get_nested_attr(openai_client, method_name) + method_id = id(method_ref) + + if method_id in self.patched_methods: + continue + + original_method = self.original_methods[provider_name][ + method_name.replace(".", "_") + ] + patched_method = self._create_patched_method( + method_name, original_method, openai_client + ) + + parent_attr, method_attr = method_name.rsplit(".", 1) + parent_obj = get_nested_attr(openai_client, parent_attr) + setattr(parent_obj, method_attr, patched_method) + + self.patched_methods.add(method_id) + + def register_provider( + self, openai_client: Any, provider_name: str, route_name: str = None + ) -> Any: + """ + Generalized function to register OpenAI, Azure OpenAI, and Gemini clients. + + Additionally sets: + - openai_client.base_url to self.base_url + - openai_client._custom_headers to include self._headers + """ + client_id = id(openai_client) + if client_id in self.patched_clients: + return openai_client + + self.patched_clients.add(client_id) + self.provider_name = provider_name # Store for use in helper methods + if provider_name == "azureopenai": + # Add /v1/openai to the base_url if not already present + base_url = self.base_url.rstrip("/") + if not base_url.endswith("openai"): + self.base_url = f"{base_url}/openai" + + self._setup_client_headers(openai_client, route_name) + self._store_original_methods(openai_client, provider_name) + self._patch_methods(openai_client, provider_name) + + return openai_client + + def register_openai(self, openai_client: Any, route_name: str = None) -> Any: + return self.register_provider( + openai_client, provider_name="openai", route_name=route_name + ) + + def register_azureopenai(self, openai_client: Any, route_name: str = None) -> Any: + return self.register_provider( + openai_client, provider_name="azureopenai", route_name=route_name + ) + + def register_gemini(self, openai_client: Any, route_name: str = None) -> Any: + return self.register_provider( + openai_client, provider_name="gemini", route_name=route_name + ) + + def register_deepseek(self, openai_client: Any, route_name: str = None) -> Any: + return self.register_provider( + openai_client, provider_name="deepseek", route_name=route_name + ) + + def _setup_bedrock_clients( + self, bedrock_runtime_client, bedrock_client, bedrock_session + ): + """Setup bedrock clients and validate the runtime client.""" + if bedrock_session is not None: + self.bedrock_session = bedrock_session + self.bedrock_client = bedrock_session.client("bedrock") + self.bedrock_runtime_client = bedrock_session.client("bedrock-runtime") + else: + if bedrock_runtime_client is None: + raise AssertionError("Bedrock Runtime client cannot be None") + + # Store the bedrock client + self.bedrock_client = bedrock_client + self.bedrock_session = bedrock_session + self.bedrock_runtime_client = bedrock_runtime_client + + # Validate bedrock-runtime client type and attributes + if not all( + [ + hasattr(bedrock_runtime_client, "meta"), + hasattr(bedrock_runtime_client.meta, "service_model"), + getattr(bedrock_runtime_client.meta.service_model, "service_name", None) + == "bedrock-runtime", + ] + ): + raise AssertionError( + "Invalid client type. Expected boto3 bedrock-runtime client, got: " + f"{type(bedrock_runtime_client).__name__}" + ) + + def _setup_bedrock_route(self, route_name): + """Setup the default bedrock route.""" + if not route_name: + route_name = "awsbedrock" + + # Store the default bedrock route + if route_name is not None: + self.use_default_bedrock_route = True + self.default_bedrock_route = route_name + + def _create_bedrock_model_functions(self): + """Create cached functions for getting model information.""" + + @functools.lru_cache() + def get_inference_model(inference_profile_identifier: str) -> str | None: + try: + if self.bedrock_client is None: + return None + # Get the inference profile response + response = self.bedrock_client.get_inference_profile( + inferenceProfileIdentifier=inference_profile_identifier + ) + model_identifier = response["models"][0]["modelArn"] + + # Get the foundation model response + foundation_model_response = self.bedrock_client.get_foundation_model( + modelIdentifier=model_identifier + ) + model_id = foundation_model_response["modelDetails"]["modelId"] + return model_id + except Exception: + # Fail silently if the model is not found + return None + + @functools.lru_cache() + def get_foundation_model(model_identifier: str) -> str | None: + try: + if self.bedrock_client is None: + return None + response = self.bedrock_client.get_foundation_model( + modelIdentifier=model_identifier + ) + return response["modelDetails"]["modelId"] + except Exception: + # Fail silently if the model is not found + return None + + return get_inference_model, get_foundation_model + + def _extract_model_id_from_path( + self, path, get_inference_model, get_foundation_model + ): + """Extract model ID from the URL path.""" + model_id = None + + # Check for inference profile ARN + if re.match(self.PROFILE_ARN_PATTERN, path): + match = re.match(self.PROFILE_ARN_PATTERN, path) + if match: + model_id = get_inference_model(match.group(0).replace("/model/", "")) + + # Check for model ARN + elif re.match(self.MODEL_ARN_PATTERN, path): + match = re.match(self.MODEL_ARN_PATTERN, path) + if match: + model_id = get_foundation_model(match.group(0).replace("/model/", "")) + + # If the model ID is not found, try to extract it from the path + if model_id is None: + path = path.replace("/model/", "") + # Get the the last index of / in the path + end_index = path.rfind("/") + path = path[:end_index] + model_id = path.replace("/model/", "") + + return model_id + + def _create_bedrock_request_handlers( + self, get_inference_model, get_foundation_model + ): + """Create request handlers for bedrock operations.""" + + def add_custom_headers(request: Any, **kwargs) -> None: + """Add Javelin headers to each request.""" + request.headers.update(self._headers) + + def override_endpoint_url(request: Any, **kwargs) -> None: + """ + Redirect Bedrock operations to the Javelin endpoint + while preserving path and query. + """ + try: + original_url = urlparse(request.url) + + # Construct the base URL (scheme + netloc) + base_url = f"{original_url.scheme}://{original_url.netloc}" + + # Set the header + request.headers["x-javelin-provider"] = base_url + + if self.use_default_bedrock_route and self.default_bedrock_route: + request.headers["x-javelin-route"] = self.default_bedrock_route + + path = original_url.path + path = unquote(path) + + model_id = self._extract_model_id_from_path( + path, get_inference_model, get_foundation_model + ) + + if model_id: + model_id = re.sub(r"-\d{8}(?=-)", "", model_id) + request.headers["x-javelin-model"] = model_id + + # Update the request URL to use the Javelin endpoint. + parsed_base = urlparse(self.base_url) + updated_url = original_url._replace( + scheme=parsed_base.scheme, + netloc=parsed_base.netloc, + path=f"/v1{original_url.path}", + ) + request.url = urlunparse(updated_url) + + except Exception: + pass + + return add_custom_headers, override_endpoint_url + + def _create_bedrock_tracing_handlers(self): + """Create tracing handlers for bedrock operations.""" + + def bedrock_before_call(**kwargs): + """ + Start a new OTel span and store it in the Botocore context dict + so it can be retrieved in after-call. + """ + if self.tracer is None: + return # If no tracer, skip + + context = kwargs.get("context") + if context is None: + return + + event_name = kwargs.get("event_name", "") + # e.g., "before-call.bedrock-runtime.InvokeModel" + operation_name = event_name.split(".")[-1] if event_name else "Unknown" + + # Create & start the OTel span + span = self.tracer.start_span(operation_name, kind=SpanKind.CLIENT) + + # Store it in the context + context["javelin_request_wrapper"] = JavelinRequestWrapper(None, span) + + def bedrock_after_call(**kwargs): + """ + End the OTel span by retrieving it from Botocore's context dict. + """ + context = kwargs.get("context") + if not context: + return + + wrapper = context.get("javelin_request_wrapper") + if not wrapper: + return + + span = getattr(wrapper, "span", None) + if not span: + return + + # Optionally set status from the HTTP response + http_response = kwargs.get("http_response") + if http_response is not None and hasattr(http_response, "status_code"): + if http_response.status_code >= 400: + span.set_status( + Status( + StatusCode.ERROR, + "HTTP %d" % http_response.status_code, + ) + ) + else: + span.set_status( + Status(StatusCode.OK, "HTTP %d" % http_response.status_code) + ) + + # End the span + span.end() + + return bedrock_before_call, bedrock_after_call + + def _register_bedrock_event_handlers( + self, + add_custom_headers, + override_endpoint_url, + bedrock_before_call, + bedrock_after_call, + ): + """Register event handlers for bedrock operations.""" + if self.bedrock_runtime_client is None: + return + + for op in self.BEDROCK_RUNTIME_OPERATIONS: + event_name_before_send = f"before-send.bedrock-runtime.{op}" + event_name_before_call = f"before-call.bedrock-runtime.{op}" + event_name_after_call = f"after-call.bedrock-runtime.{op}" + events_client = self.bedrock_runtime_client.meta.events + + # Add headers + override endpoint + events_client.register( + event_name_before_send, + add_custom_headers, + ) + events_client.register( + event_name_before_send, + override_endpoint_url, + ) + + # Add OTel instrumentation + events_client.register( + event_name_before_call, + bedrock_before_call, + ) + events_client.register( + event_name_after_call, + bedrock_after_call, + ) + + def register_bedrock( + self, + bedrock_runtime_client: Any, + bedrock_client: Any = None, + bedrock_session: Any = None, + route_name: Optional[str] = None, + ) -> None: + """ + Register an AWS Bedrock Runtime client + for request interception and modification. + + Args: + bedrock_runtime_client: A boto3 bedrock-runtime client instance + bedrock_client: A boto3 bedrock client instance + bedrock_session: A boto3 bedrock session instance + route_name: The name of the route to use for the bedrock client + Returns: + The modified boto3 client with registered event handlers + Raises: + AssertionError: If client is None or not a valid bedrock-runtime client + ValueError: If URL parsing/manipulation fails + + Example: + >>> bedrock = boto3.client('bedrock-runtime') + >>> modified_client = javelin_client.register_bedrock_client(bedrock) + >>> javelin_client.register_bedrock_client(bedrock) + >>> bedrock.invoke_model( + """ + self._setup_bedrock_clients( + bedrock_runtime_client, bedrock_client, bedrock_session + ) + self._setup_bedrock_route(route_name) + + get_inference_model, get_foundation_model = ( + self._create_bedrock_model_functions() + ) + add_custom_headers, override_endpoint_url = ( + self._create_bedrock_request_handlers( + get_inference_model, get_foundation_model + ) + ) + bedrock_before_call, bedrock_after_call = ( + self._create_bedrock_tracing_handlers() + ) + + self._register_bedrock_event_handlers( + add_custom_headers, + override_endpoint_url, + bedrock_before_call, + bedrock_after_call, + ) + def _prepare_request(self, request: Request) -> tuple: url = self._construct_url( gateway_name=request.gateway, @@ -102,15 +958,22 @@ def _prepare_request(self, request: Request) -> tuple: archive=request.archive, query_params=request.query_params, is_transformation_rules=request.is_transformation_rules, + is_model_specs=request.is_model_specs, is_reload=request.is_reload, + univ_model=request.univ_model_config, + guardrail=request.guardrail, + list_guardrails=request.list_guardrails, ) headers = {**self._headers, **(request.headers or {})} - # For AISPM requests: if account_id header is present, remove x-api-key - # AISPM uses account_id-based authentication instead of API key - if request.route.startswith("v1/admin/aispm") and "x-javelin-accountid" in headers: - headers.pop("x-api-key", None) + # For AISPM requests: if account-id auth is used, do not send API key. + if ( + request.route + and request.route.startswith("v1/admin/aispm") + and "x-javelin-accountid" in headers + ): + headers.pop("x-javelin-apikey", None) return url, headers @@ -127,9 +990,6 @@ def _core_send_request( self, client: Union[httpx.Client, httpx.AsyncClient], request: Request ) -> Union[httpx.Response, Coroutine[Any, Any, httpx.Response]]: url, headers = self._prepare_request(request) - - # For httpx.Client, headers passed to request methods override client-level headers - # So we need to ensure we're passing the correct headers if request.method == HttpMethod.GET: return client.get(url, headers=headers) elif request.method == HttpMethod.POST: @@ -153,7 +1013,11 @@ def _construct_url( archive: Optional[str] = "", query_params: Optional[Dict[str, Any]] = None, is_transformation_rules: bool = False, + is_model_specs: bool = False, is_reload: bool = False, + univ_model: Optional[Dict[str, Any]] = None, + guardrail: Optional[str] = None, + list_guardrails: bool = False, ) -> str: # Handle AISPM routes: they use the route directly with base_url if route_name and route_name.startswith("v1/admin/aispm"): @@ -164,218 +1028,458 @@ def _construct_url( return url url_parts = [self.base_url] - - - if query: - url_parts.append("query") - if route_name is not None: - url_parts.append(route_name) - elif gateway_name: - url_parts.extend(["admin", "gateways"]) - if gateway_name != "###": - url_parts.append(gateway_name) - elif provider_name and not secret_name: - if is_reload: - url_parts.extend(["providers"]) - else: - url_parts.extend(["admin", "providers"]) - if provider_name != "###": - url_parts.append(provider_name) - if is_transformation_rules: - url_parts.append("transformation-rules") - elif route_name: - if is_reload: - url_parts.extend(["routes"]) - else: - url_parts.extend(["admin", "routes"]) - if route_name != "###": - url_parts.append(route_name) - elif secret_name: - if is_reload: - url_parts.extend(["secrets"]) - else: - url_parts.extend(["admin", "providers"]) - if provider_name != "###": - url_parts.append(provider_name) - url_parts.append("keyvault") - if secret_name != "###": - url_parts.append(secret_name) - else: - url_parts.append("keys") - elif template_name: - if is_reload: - url_parts.extend(["processors", "dp", "templates"]) - else: - url_parts.extend(["admin", "processors", "dp", "templates"]) - if template_name != "###": - url_parts.append(template_name) - elif trace: - url_parts.extend(["admin", "traces"]) - elif archive: - url_parts.extend(["admin", "archives"]) - if archive != "###": - url_parts.append(archive) - else: - url_parts.extend(["admin", "routes"]) + + # Determine the main URL path based on the primary resource type + main_path = self._get_main_url_path( + gateway_name=gateway_name, + provider_name=provider_name, + route_name=route_name, + secret_name=secret_name, + template_name=template_name, + trace=trace, + query=query, + archive=archive, + is_transformation_rules=is_transformation_rules, + is_model_specs=is_model_specs, + is_reload=is_reload, + guardrail=guardrail, + list_guardrails=list_guardrails, + ) + url_parts.extend(main_path) + + # Add resource-specific path segments + resource_path = self._get_resource_path( + gateway_name=gateway_name, + provider_name=provider_name, + route_name=route_name, + secret_name=secret_name, + template_name=template_name, + archive=archive, + guardrail=guardrail, + query=query, + ) + if resource_path: + url_parts.extend(resource_path) url = "/".join(url_parts) + if univ_model: + endpoint_url = self.construct_endpoint_url(univ_model) + url = urljoin(url, endpoint_url) + if query_params: query_string = "&".join(f"{k}={v}" for k, v in query_params.items()) url += f"?{query_string}" return url + def _get_main_url_path( + self, + gateway_name: Optional[str] = "", + provider_name: Optional[str] = "", + route_name: Optional[str] = "", + secret_name: Optional[str] = "", + template_name: Optional[str] = "", + trace: Optional[str] = "", + query: bool = False, + archive: Optional[str] = "", + is_transformation_rules: bool = False, + is_model_specs: bool = False, + is_reload: bool = False, + guardrail: Optional[str] = None, + list_guardrails: bool = False, + ) -> list: + """Determine the main URL path based on the primary resource type.""" + # Define path strategies based on resource type + path_strategies = [ + (is_model_specs, self._get_model_specs_path), + (query, self._get_query_path), + (gateway_name, self._get_gateway_path), + ( + provider_name and not secret_name, + lambda: self._get_provider_path(is_reload, is_transformation_rules), + ), + (route_name, lambda: self._get_route_path(is_reload)), + (secret_name, lambda: self._get_secret_main_path(is_reload)), + (template_name, lambda: self._get_template_path(is_reload)), + (trace, self._get_trace_path), + (archive, self._get_archive_path), + (guardrail, lambda: self._get_guardrail_path(guardrail)), + (list_guardrails, self._get_list_guardrails_path), + ] + + # Find the first matching strategy and execute it + for condition, strategy in path_strategies: + if condition: + return strategy() + + # Default fallback + return ["admin", "routes"] + + def _get_model_specs_path(self) -> list: + """Get path for model specs.""" + return ["admin", "modelspec"] + + def _get_query_path(self) -> list: + """Get path for queries.""" + return ["query"] + + def _get_gateway_path(self) -> list: + """Get path for gateways.""" + return ["admin", "gateways"] + + def _get_provider_path( + self, is_reload: bool, is_transformation_rules: bool + ) -> list: + """Get path for providers.""" + base_path = ["providers"] if is_reload else ["admin", "providers"] + if is_transformation_rules: + base_path.append("transformation-rules") + return base_path + + def _get_route_path(self, is_reload: bool) -> list: + """Get path for routes.""" + return ["routes"] if is_reload else ["admin", "routes"] + + def _get_secret_main_path(self, is_reload: bool) -> list: + """Get main path for secrets.""" + return ["secrets"] if is_reload else ["admin", "providers"] + + def _get_template_path(self, is_reload: bool) -> list: + """Get path for templates.""" + return ( + ["processors", "dp", "templates"] + if is_reload + else ["admin", "processors", "dp", "templates"] + ) + + def _get_trace_path(self) -> list: + """Get path for traces.""" + return ["admin", "traces"] + + def _get_archive_path(self) -> list: + """Get path for archives.""" + return ["admin", "archives"] + + def _get_guardrail_path(self, guardrail: Optional[str]) -> list: + """Get path for guardrails.""" + if guardrail == "all": + return ["guardrails", "apply"] + else: + return ["guardrail", guardrail, "apply"] + + def _get_list_guardrails_path(self) -> list: + """Get path for listing guardrails.""" + return ["guardrails", "list"] + + def _get_resource_path( + self, + gateway_name: Optional[str] = "", + provider_name: Optional[str] = "", + route_name: Optional[str] = "", + secret_name: Optional[str] = "", + template_name: Optional[str] = "", + archive: Optional[str] = "", + guardrail: Optional[str] = None, + query: bool = False, + ) -> list: + """Get the resource-specific path segments.""" + if query and route_name is not None: + return [route_name] + elif gateway_name and gateway_name != "###": + return [gateway_name] + elif provider_name and provider_name != "###" and not secret_name: + return [provider_name] + elif route_name and route_name != "###": + return [route_name] + elif secret_name: + return self._get_secret_path(provider_name, secret_name) + elif template_name and template_name != "###": + return [template_name] + elif archive and archive != "###": + return [archive] + elif guardrail and guardrail != "all": + return [] # Already handled in main path + else: + return [] + + def _get_secret_path(self, provider_name: Optional[str], secret_name: str) -> list: + """Get the path for secret-related operations.""" + path = [] + if provider_name and provider_name != "###": + path.append(provider_name) + path.append("keyvault") + if secret_name != "###": + path.append(secret_name) + else: + path.append("keys") + return path + # Gateway methods - create_gateway = lambda self, gateway: self.gateway_service.create_gateway(gateway) - acreate_gateway = lambda self, gateway: self.gateway_service.acreate_gateway( - gateway - ) - get_gateway = lambda self, gateway_name: self.gateway_service.get_gateway( - gateway_name - ) - aget_gateway = lambda self, gateway_name: self.gateway_service.aget_gateway( - gateway_name - ) - list_gateways = lambda self: self.gateway_service.list_gateways() - alist_gateways = lambda self: self.gateway_service.alist_gateways() - update_gateway = lambda self, gateway: self.gateway_service.update_gateway(gateway) - aupdate_gateway = lambda self, gateway: self.gateway_service.aupdate_gateway( - gateway - ) - delete_gateway = lambda self, gateway_name: self.gateway_service.delete_gateway( - gateway_name - ) - adelete_gateway = lambda self, gateway_name: self.gateway_service.adelete_gateway( - gateway_name - ) + def create_gateway(self, gateway): + return self.gateway_service.create_gateway(gateway) + + def acreate_gateway(self, gateway): + return self.gateway_service.acreate_gateway(gateway) + + def get_gateway(self, gateway_name): + return self.gateway_service.get_gateway(gateway_name) + + def aget_gateway(self, gateway_name): + return self.gateway_service.aget_gateway(gateway_name) + + def list_gateways(self): + return self.gateway_service.list_gateways() + + def alist_gateways(self): + return self.gateway_service.alist_gateways() + + def update_gateway(self, gateway): + return self.gateway_service.update_gateway(gateway) + + def aupdate_gateway(self, gateway): + return self.gateway_service.aupdate_gateway(gateway) + + def delete_gateway(self, gateway_name): + return self.gateway_service.delete_gateway(gateway_name) + + def adelete_gateway(self, gateway_name): + return self.gateway_service.adelete_gateway(gateway_name) # Provider methods - create_provider = lambda self, provider: self.provider_service.create_provider( - provider - ) - acreate_provider = lambda self, provider: self.provider_service.acreate_provider( - provider - ) - get_provider = lambda self, provider_name: self.provider_service.get_provider( - provider_name - ) - aget_provider = lambda self, provider_name: self.provider_service.aget_provider( - provider_name - ) - list_providers = lambda self: self.provider_service.list_providers() - alist_providers = lambda self: self.provider_service.alist_providers() - update_provider = lambda self, provider: self.provider_service.update_provider( - provider - ) - aupdate_provider = lambda self, provider: self.provider_service.aupdate_provider( - provider - ) - delete_provider = lambda self, provider_name: self.provider_service.delete_provider( - provider_name - ) - adelete_provider = ( - lambda self, provider_name: self.provider_service.adelete_provider( - provider_name + def create_provider(self, provider): + return self.provider_service.create_provider(provider) + + def acreate_provider(self, provider): + return self.provider_service.acreate_provider(provider) + + def get_provider(self, provider_name): + return self.provider_service.get_provider(provider_name) + + def aget_provider(self, provider_name): + return self.provider_service.aget_provider(provider_name) + + def list_providers(self): + return self.provider_service.list_providers() + + def alist_providers(self): + return self.provider_service.alist_providers() + + def update_provider(self, provider): + return self.provider_service.update_provider(provider) + + def aupdate_provider(self, provider): + return self.provider_service.aupdate_provider(provider) + + def delete_provider(self, provider_name): + return self.provider_service.delete_provider(provider_name) + + def adelete_provider(self, provider_name): + return self.provider_service.adelete_provider(provider_name) + + def alist_provider_secrets(self, provider_name): + return self.provider_service.alist_provider_secrets(provider_name) + + def get_transformation_rules(self, provider_name, model_name, endpoint): + return self.provider_service.get_transformation_rules( + provider_name, model_name, endpoint ) - ) - alist_provider_secrets = ( - lambda self, provider_name: self.provider_service.alialist_provider_secrets( - provider_name + + def aget_transformation_rules(self, provider_name, model_name, endpoint): + return self.provider_service.aget_transformation_rules( + provider_name, model_name, endpoint ) - ) - get_transformation_rules = lambda self, provider_name, model_name, endpoint: self.provider_service.get_transformation_rules( - provider_name, model_name, endpoint - ) - aget_transformation_rules = lambda self, provider_name, model_name, endpoint: self.provider_service.aget_transformation_rules( - provider_name, model_name, endpoint - ) + + def get_model_specs(self, provider_url, model_name): + return self.modelspec_service.get_model_specs(provider_url, model_name) + + def aget_model_specs(self, provider_url, model_name): + return self.modelspec_service.aget_model_specs(provider_url, model_name) # Route methods - create_route = lambda self, route: self.route_service.create_route(route) - acreate_route = lambda self, route: self.route_service.acreate_route(route) - get_route = lambda self, route_name: self.route_service.get_route(route_name) - aget_route = lambda self, route_name: self.route_service.aget_route(route_name) - list_routes = lambda self: self.route_service.list_routes() - alist_routes = lambda self: self.route_service.alist_routes() - update_route = lambda self, route: self.route_service.update_route(route) - aupdate_route = lambda self, route: self.route_service.aupdate_route(route) - delete_route = lambda self, route_name: self.route_service.delete_route(route_name) - adelete_route = lambda self, route_name: self.route_service.adelete_route( - route_name - ) - query_route = lambda self, route_name, query_body, headers=None, stream=False, stream_response_path=None: self.route_service.query_route( - route_name=route_name, query_body=query_body, headers=headers, stream=stream, stream_response_path=stream_response_path - ) - aquery_route = lambda self, route_name, query_body, headers=None, stream=False, stream_response_path=None: self.route_service.aquery_route( - route_name, query_body, headers, stream, stream_response_path - ) - query_llama = lambda self, route_name, query_body: self.route_service.query_llama( - route_name, query_body - ) - aquery_llama = lambda self, route_name, query_body: self.route_service.aquery_llama( - route_name, query_body - ) + def create_route(self, route): + return self.route_service.create_route(route) - # Secret methods - create_secret = lambda self, secret: self.secret_service.create_secret(secret) - acreate_secret = lambda self, secret: self.secret_service.acreate_secret(secret) - get_secret = lambda self, secret_name, provider_name: self.secret_service.get_secret(secret_name, provider_name) - aget_secret = lambda self, secret_name, provider_name: self.secret_service.aget_secret(secret_name, provider_name) - list_secrets = lambda self: self.secret_service.list_secrets() - alist_secrets = lambda self: self.secret_service.alist_secrets() - update_secret = lambda self, secret: self.secret_service.update_secret(secret) - aupdate_secret = lambda self, secret: self.secret_service.aupdate_secret(secret) - delete_secret = lambda self, secret_name, provider_name: self.secret_service.delete_secret( - secret_name, provider_name - ) - adelete_secret = lambda self, secret_name, provider_name: self.secret_service.adelete_secret( - secret_name, provider_name - ) + def acreate_route(self, route): + return self.route_service.acreate_route(route) - # Template methods - create_template = lambda self, template: self.template_service.create_template( - template - ) - acreate_template = lambda self, template: self.template_service.acreate_template( - template - ) - get_template = lambda self, template_name: self.template_service.get_template( - template_name - ) - aget_template = lambda self, template_name: self.template_service.aget_template( - template_name - ) - list_templates = lambda self: self.template_service.list_templates() - alist_templates = lambda self: self.template_service.alist_templates() - update_template = lambda self, template: self.template_service.update_template( - template - ) - aupdate_template = lambda self, template: self.template_service.aupdate_template( - template - ) - delete_template = lambda self, template_name: self.template_service.delete_template( - template_name - ) - adelete_template = ( - lambda self, template_name: self.template_service.adelete_template( - template_name + def get_route(self, route_name): + return self.route_service.get_route(route_name) + + def aget_route(self, route_name): + return self.route_service.aget_route(route_name) + + def list_routes(self): + return self.route_service.list_routes() + + def alist_routes(self): + return self.route_service.alist_routes() + + def update_route(self, route): + return self.route_service.update_route(route) + + def delete_route(self, route_name): + return self.route_service.delete_route(route_name) + + def adelete_route(self, route_name): + return self.route_service.adelete_route(route_name) + + def query_route( + self, + route_name, + query_body, + headers=None, + stream=False, + stream_response_path=None, + ): + return self.route_service.query_route( + route_name=route_name, + query_body=query_body, + headers=headers, + stream=stream, + stream_response_path=stream_response_path, ) - ) - reload_data_protection = ( - lambda self, strategy_name: self.template_service.reload_data_protection( - strategy_name + + def aquery_route( + self, + route_name, + query_body, + headers=None, + stream=False, + stream_response_path=None, + ): + return self.route_service.aquery_route( + route_name, query_body, headers, stream, stream_response_path ) - ) - areload_data_protection = ( - lambda self, strategy_name: self.template_service.areload_data_protection( - strategy_name + + def query_unified_endpoint( + self, + provider_name, + endpoint_type, + query_body, + headers=None, + query_params=None, + deployment=None, + model_id=None, + stream_response_path=None, + ): + return self.route_service.query_unified_endpoint( + provider_name, + endpoint_type, + query_body, + headers, + query_params, + deployment, + model_id, + stream_response_path, ) - ) - ## Traces methods - get_traces = lambda self: self.trace_service.get_traces() - aget_traces = lambda self: self.trace_service.aget_traces() + def aquery_unified_endpoint( + self, + provider_name, + endpoint_type, + query_body, + headers=None, + query_params=None, + deployment=None, + model_id=None, + stream_response_path=None, + ): + return self.route_service.aquery_unified_endpoint( + provider_name, + endpoint_type, + query_body, + headers, + query_params, + deployment, + model_id, + stream_response_path, + ) + + # Secret methods + def create_secret(self, secret): + return self.secret_service.create_secret(secret) + + def acreate_secret(self, secret): + return self.secret_service.acreate_secret(secret) + + def get_secret(self, secret_name, provider_name): + return self.secret_service.get_secret(secret_name, provider_name) + + def aget_secret(self, secret_name, provider_name): + return self.secret_service.aget_secret(secret_name, provider_name) + + def list_secrets(self): + return self.secret_service.list_secrets() + + def alist_secrets(self): + return self.secret_service.alist_secrets() + + def update_secret(self, secret): + return self.secret_service.update_secret(secret) + + def aupdate_secret(self, secret): + return self.secret_service.aupdate_secret(secret) + + def delete_secret(self, secret_name, provider_name): + return self.secret_service.delete_secret(secret_name, provider_name) + + def adelete_secret(self, secret_name, provider_name): + return self.secret_service.adelete_secret(secret_name, provider_name) + + # Template methods + def create_template(self, template): + return self.template_service.create_template(template) + + def acreate_template(self, template): + return self.template_service.acreate_template(template) + + def get_template(self, template_name): + return self.template_service.get_template(template_name) + + def aget_template(self, template_name): + return self.template_service.aget_template(template_name) + + def list_templates(self): + return self.template_service.list_templates() + + def alist_templates(self): + return self.template_service.alist_templates() + + def update_template(self, template): + return self.template_service.update_template(template) + + def aupdate_template(self, template): + return self.template_service.aupdate_template(template) + + def delete_template(self, template_name): + return self.template_service.delete_template(template_name) + + def adelete_template(self, template_name): + return self.template_service.adelete_template(template_name) + + def reload_data_protection(self, strategy_name): + return self.template_service.reload_data_protection(strategy_name) + + def areload_data_protection(self, strategy_name): + return self.template_service.areload_data_protection(strategy_name) + + # Guardrails methods + def apply_trustsafety(self, text, config=None): + return self.guardrails_service.apply_trustsafety(text, config) + + def apply_promptinjectiondetection(self, text, config=None): + return self.guardrails_service.apply_promptinjectiondetection(text, config) + + def apply_guardrails(self, text, guardrails): + return self.guardrails_service.apply_guardrails(text, guardrails) + + def list_guardrails(self): + return self.guardrails_service.list_guardrails() + + # Traces methods + def get_traces(self): + return self.trace_service.get_traces() # Archive methods def get_last_n_chronicle_records(self, archive_name: str, n: int) -> Dict[str, Any]: @@ -397,3 +1501,131 @@ async def aget_last_n_chronicle_records( ) response = await self._send_request_async(request) return response + + def _construct_azure_openai_endpoint( + self, + base_url: str, + provider_name: str, + deployment: str, + endpoint_type: Optional[str], + ) -> str: + """Construct Azure OpenAI endpoint URL.""" + if not endpoint_type: + raise ValueError("Endpoint type is required for Azure OpenAI") + + azure_deployment_url = f"{base_url}/{provider_name}/deployments/{deployment}" + + endpoint_mapping = { + "chat": f"{azure_deployment_url}/chat/completions", + "completion": f"{azure_deployment_url}/completions", + "embeddings": f"{azure_deployment_url}/embeddings", + } + + if endpoint_type not in endpoint_mapping: + raise ValueError(f"Invalid Azure OpenAI endpoint type: {endpoint_type}") + + return endpoint_mapping[endpoint_type] + + def _construct_bedrock_endpoint( + self, base_url: str, model_id: str, endpoint_type: Optional[str] + ) -> str: + """Construct Bedrock endpoint URL.""" + if not endpoint_type: + raise ValueError("Endpoint type is required for Bedrock") + + endpoint_mapping = { + "invoke": f"{base_url}/model/{model_id}/invoke", + "converse": f"{base_url}/model/{model_id}/converse", + "invoke_stream": f"{base_url}/model/{model_id}/invoke-with-response-stream", + "converse_stream": f"{base_url}/model/{model_id}/converse-stream", + } + + if endpoint_type not in endpoint_mapping: + raise ValueError(f"Invalid Bedrock endpoint type: {endpoint_type}") + + return endpoint_mapping[endpoint_type] + + def _construct_anthropic_endpoint( + self, base_url: str, endpoint_type: Optional[str] + ) -> str: + """Construct Anthropic endpoint URL.""" + if not endpoint_type: + raise ValueError("Endpoint type is required for Anthropic") + + endpoint_mapping = { + "messages": f"{base_url}/model/messages", + "complete": f"{base_url}/model/complete", + } + + if endpoint_type not in endpoint_mapping: + raise ValueError(f"Invalid Anthropic endpoint type: {endpoint_type}") + + return endpoint_mapping[endpoint_type] + + def _construct_openai_compatible_endpoint( + self, base_url: str, provider_name: str, endpoint_type: Optional[str] + ) -> str: + """Construct OpenAI compatible endpoint URL.""" + if not endpoint_type: + raise ValueError( + "Endpoint type is required for OpenAI compatible endpoints" + ) + + endpoint_mapping = { + "chat": f"{base_url}/{provider_name}/chat/completions", + "completion": f"{base_url}/{provider_name}/completions", + "embeddings": f"{base_url}/{provider_name}/embeddings", + } + + if endpoint_type not in endpoint_mapping: + raise ValueError( + f"Invalid OpenAI compatible endpoint type: {endpoint_type}" + ) + + return endpoint_mapping[endpoint_type] + + def construct_endpoint_url(self, request_model: Dict[str, Any]) -> str: + """ + Constructs the endpoint URL based on the request model. + + :param request_model: The request model containing endpoint details. + :return: The constructed endpoint URL. + """ + provider_name = request_model.get("provider_name") + endpoint_type = request_model.get("endpoint_type") + deployment = request_model.get("deployment") + model_id = request_model.get("model_id") + + if not provider_name: + raise ValueError("Provider name is not specified in the request model.") + + base_url = self.base_url + + # Handle Azure OpenAI endpoints + if provider_name == "azureopenai" and deployment: + return self._construct_azure_openai_endpoint( + base_url, provider_name, deployment, endpoint_type + ) + + # Handle Bedrock endpoints + elif provider_name == "bedrock" and model_id: + return self._construct_bedrock_endpoint(base_url, model_id, endpoint_type) + + # Handle Anthropic endpoints + elif provider_name == "anthropic": + return self._construct_anthropic_endpoint(base_url, endpoint_type) + + # Handle OpenAI compatible endpoints + else: + return self._construct_openai_compatible_endpoint( + base_url, provider_name, endpoint_type + ) + + def set_headers(self, headers: Dict[str, str]) -> None: + """ + Set or update headers for the client. + + Args: + headers (Dict[str, str]): A dictionary of headers to set or update. + """ + self._headers.update(headers) diff --git a/javelin_sdk/models.py b/javelin_sdk/models.py index 6a71f45..5f5f356 100644 --- a/javelin_sdk/models.py +++ b/javelin_sdk/models.py @@ -1,58 +1,66 @@ +from datetime import datetime from enum import Enum, auto from typing import Any, Dict, List, Optional -from datetime import datetime - -from typing import Dict, List, Optional - - - -from pydantic import BaseModel, Field, field_validator - from javelin_sdk.exceptions import UnauthorizedError +from pydantic import BaseModel, Field, field_validator class GatewayConfig(BaseModel): buid: Optional[str] = Field( default=None, - description="Business Unit ID (BUID) uniquely identifies the business unit associated with this gateway configuration", + description=( + "Business Unit ID (BUID) uniquely identifies the business unit " + "associated with this gateway configuration" + ), ) base_url: Optional[str] = Field( default=None, - description="The foundational URL where all API requests are directed. It acts as the root from which endpoint paths are extended", + description=( + "The foundational URL where all API requests are directed. " + "It acts as the root from which endpoint paths are extended" + ), ) api_key: Optional[str] = Field( default=None, - description="The API key used for authenticating requests to the API endpoints specified by the base_url", + description=( + "The API key used for authenticating requests to the API endpoints " + "specified by the base_url" + ), ) organization_id: Optional[str] = Field( default=None, description="Unique identifier of the organization" ) system_namespace: Optional[str] = Field( default=None, - description="A unique namespace within the system to prevent naming conflicts and to organize resources logically", + description=( + "A unique namespace within the system to prevent naming conflicts " + "and to organize resources logically" + ), ) class Gateway(BaseModel): - gateway_id: str = Field( + gateway_id: Optional[str] = Field( default=None, description="Unique identifier for the gateway" ) - name: str = Field(default=None, description="Name of the gateway") - type: str = Field( + name: Optional[str] = Field(default=None, description="Name of the gateway") + type: Optional[str] = Field( default=None, description="The type of this gateway (e.g., development, staging, production)", ) enabled: Optional[bool] = Field( default=True, description="Whether the gateway is enabled" ) - config: GatewayConfig = Field( + config: Optional[GatewayConfig] = Field( default=None, description="Configuration for the gateway" ) class Gateways(BaseModel): - gateways: List[Gateway] = Field(default=[], description="List of gateways") + gateways: List[Gateway] = Field( + default_factory=list, description="List of gateways" + ) class Budget(BaseModel): @@ -96,6 +104,7 @@ class PromptSafety(BaseModel): default=None, description="List of content types" ) + class SecurityFilters(BaseModel): enabled: Optional[bool] = Field( default=None, description="Whether security filters are enabled" @@ -120,51 +129,55 @@ class ContentFilter(BaseModel): ) -class RouteConfig(BaseModel): - rate_limit: Optional[int] = Field( - default=None, description="Rate limit for the route" +class ArchivePolicy(BaseModel): + enabled: Optional[bool] = Field( + default=None, description="Whether archiving is enabled" ) - owner: Optional[str] = Field(default=None, description="Owner of the route") - organization: Optional[str] = Field( - default=None, description="Organization associated with the route" + retention: Optional[int] = Field(default=None, description="Data retention period") + + +class Policy(BaseModel): + dlp: Optional[Dlp] = Field(default=None, description="DLP configuration") + archive: Optional[ArchivePolicy] = Field( + default=None, description="Archive policy configuration" ) - archive: Optional[bool] = Field( - default=None, description="Whether archiving is enabled" + enabled: Optional[bool] = Field( + default=None, description="Whether the policy is enabled" ) + prompt_safety: Optional[PromptSafety] = Field( + default=None, description="Prompt Safety Description" + ) + content_filter: Optional[ContentFilter] = Field( + default=None, description="Content Filter Description" + ) + security_filters: Optional[SecurityFilters] = Field( + default=None, description="Security Filters Description" + ) + + +class RouteConfig(BaseModel): + policy: Optional[Policy] = Field(default=None, description="Policy configuration") retries: Optional[int] = Field( default=None, description="Number of retries for the route" ) - llm_cache: bool = Field(False, description="Whether LLM cache is enabled") - role_to_assume: Optional[str] = Field( - None, description="Role to assume for the route" + rate_limit: Optional[int] = Field( + default=None, description="Rate limit for the route" ) - enable_telemetry: Optional[bool] = Field( - None, description="Whether telemetry is enabled" + unified_endpoint: Optional[bool] = Field( + default=None, description="Whether unified endpoint is enabled" ) - retention: Optional[int] = Field(default=None, description="Data retention period") request_chain: Optional[Dict[str, Any]] = Field( None, description="Request chain configuration" ) response_chain: Optional[Dict[str, Any]] = Field( None, description="Response chain configuration" ) - budget: Optional[Budget] = Field(default=None, description="Budget configuration") - dlp: Optional[Dlp] = Field(default=None, description="DLP configuration") - content_filter: Optional[ContentFilter] = Field( - default=None, description="Content Filter Description" - ) - prompt_safety: Optional[PromptSafety] = Field( - default=None, description="Prompt Safety Description" - ) - security_filters: Optional[SecurityFilters] = Field( - default=None, description="Security Filters Description" - ) class Model(BaseModel): - name: str = Field(default=None, description="Name of the model") - provider: str = Field(default=None, description="Provider of the model") - suffix: str = Field(default=None, description="Suffix for the model") + name: Optional[str] = Field(default=None, description="Name of the model") + provider: Optional[str] = Field(default=None, description="Provider of the model") + suffix: Optional[str] = Field(default=None, description="Suffix for the model") weight: Optional[int] = Field(default=None, description="Weight of the model") virtual_secret_name: Optional[str] = Field(None, description="Virtual secret name") fallback_enabled: Optional[bool] = Field( @@ -174,19 +187,23 @@ class Model(BaseModel): class Route(BaseModel): - name: str = Field(default=None, description="Name of the route") - type: str = Field( + name: Optional[str] = Field(default=None, description="Name of the route") + type: Optional[str] = Field( default=None, description="Type of the route chat, completion, etc" ) enabled: Optional[bool] = Field( default=True, description="Whether the route is enabled" ) - models: List[Model] = Field(default=[], description="List of models for the route") - config: RouteConfig = Field(default=None, description="Configuration for the route") + models: List[Model] = Field( + default_factory=list, description="List of models for the route" + ) + config: Optional[RouteConfig] = Field( + default=None, description="Configuration for the route" + ) class Routes(BaseModel): - routes: List[Route] = Field(default=[], description="List of routes") + routes: List[Route] = Field(default_factory=list, description="List of routes") class ArrayHandling(str, Enum): @@ -219,10 +236,10 @@ class TransformRule(BaseModel): class ModelSpec(BaseModel): input_rules: List[TransformRule] = Field( - default=[], description="Rules for input transformation" + default_factory=list, description="Rules for input transformation" ) output_rules: List[TransformRule] = Field( - default=[], description="Rules for output transformation" + default_factory=list, description="Rules for output transformation" ) response_body_path: str = Field( default="delta.text", description="Path to extract text from streaming response" @@ -240,7 +257,7 @@ class ModelSpec(BaseModel): default={}, description="Output schema for validation" ) supported_features: List[str] = Field( - default=[], description="List of supported features" + default_factory=list, description="List of supported features" ) max_tokens: Optional[int] = Field( default=None, description="Maximum tokens supported" @@ -254,7 +271,7 @@ class ModelSpec(BaseModel): class ProviderConfig(BaseModel): - api_base: str = Field(default=None, description="Base URL of the API") + api_base: Optional[str] = Field(default=None, description="Base URL of the API") api_type: Optional[str] = Field(default=None, description="Type of the API") api_version: Optional[str] = Field(default=None, description="Version of the API") deployment_name: Optional[str] = Field( @@ -272,26 +289,31 @@ class Config: class Provider(BaseModel): - name: str = Field(default=None, description="Name of the Provider") - type: str = Field(default=None, description="Type of the Provider") + name: Optional[str] = Field(default=None, description="Name of the Provider") + type: Optional[str] = Field(default=None, description="Type of the Provider") enabled: Optional[bool] = Field( default=True, description="Whether the provider is enabled" ) vault_enabled: Optional[bool] = Field( default=True, description="Whether the secrets vault is enabled" ) - config: ProviderConfig = Field( + config: Optional[ProviderConfig] = Field( default=None, description="Configuration for the provider" ) - api_keys: Optional[List[Dict[str, Any]]] = Field(default=None, description='API keys associated with the provider') + api_keys: Optional[List[Dict[str, Any]]] = Field( + default=None, description="API keys associated with the provider" + ) + class Providers(BaseModel): - providers: List[Provider] = Field(default=[], description="List of providers") + providers: List[Provider] = Field( + default_factory=list, description="List of providers" + ) class InfoType(BaseModel): - name: str = Field(default=None, description="Name of the infoType") + name: Optional[str] = Field(default=None, description="Name of the infoType") description: Optional[str] = Field( default=None, description="Description of the InfoType" ) @@ -303,15 +325,15 @@ class InfoType(BaseModel): class Transformation(BaseModel): - method: str = Field( + method: Optional[str] = Field( default=None, description="Method of the transformation Mask, Redact, Replace, etc", ) class TemplateConfig(BaseModel): - infoTypes: Optional[List[InfoType]] = Field( - default=[], description="List of InfoTypes" + infoTypes: List[InfoType] = Field( + default_factory=list, description="List of InfoTypes" ) transformation: Optional[Transformation] = Field( default=None, description="Transformation to be used" @@ -331,43 +353,60 @@ class TemplateConfig(BaseModel): class TemplateModel(BaseModel): - name: str = Field(default=None, description="Name of the model") - provider: str = Field(default=None, description="Provider of the model") - suffix: str = Field(default=None, description="Suffix for the model") + name: Optional[str] = Field(default=None, description="Name of the model") + provider: Optional[str] = Field(default=None, description="Provider of the model") + suffix: Optional[str] = Field(default=None, description="Suffix for the model") class Template(BaseModel): - name: str = Field(default=None, description="Name of the Template") - description: str = Field(default=None, description="Description of the Template") - type: str = Field(default=None, description="Type of the Template") + name: Optional[str] = Field(default=None, description="Name of the Template") + description: Optional[str] = Field( + default=None, description="Description of the Template" + ) + type: Optional[str] = Field(default=None, description="Type of the Template") enabled: Optional[bool] = Field( default=True, description="Whether the template is enabled" ) models: List[TemplateModel] = Field( - default=[], description="List of models for the template" + default_factory=list, description="List of models for the template" ) - config: TemplateConfig = Field( + config: Optional[TemplateConfig] = Field( default=None, description="Configuration for the template" ) class Templates(BaseModel): - templates: List[Template] = Field(default=[], description="List of templates") + templates: List[Template] = Field( + default_factory=list, description="List of templates" + ) + + +class SecretType(str, Enum): + AWS = "aws" + KUBERNETES = "kubernetes" class Secret(BaseModel): - api_key: str = Field(default=None, description="Key of the Secret") - api_key_secret_name: str = Field(default=None, description="Name of the Secret") - api_key_secret_key: str = Field(default=None, description="API Key of the Secret") - api_key_secret_key_javelin: str = Field( + api_key: Optional[str] = Field(default=None, description="Key of the Secret") + api_key_secret_name: Optional[str] = Field( + default=None, description="Name of the Secret" + ) + api_key_secret_key: Optional[str] = Field( + default=None, description="API Key of the Secret" + ) + api_key_secret_key_javelin: Optional[str] = Field( default=None, description="Virtual API Key of the Secret" ) - provider_name: str = Field(default=None, description="Provider Name of the Secret") - query_param_key: str = Field( + provider_name: Optional[str] = Field( + default=None, description="Provider Name of the Secret" + ) + query_param_key: Optional[str] = Field( default=None, description="Query Param Key of the Secret" ) - header_key: str = Field(default=None, description="Header Key of the Secret") - group: str = Field(default=None, description="Group of the Secret") + header_key: Optional[str] = Field( + default=None, description="Header Key of the Secret" + ) + group: Optional[str] = Field(default=None, description="Group of the Secret") enabled: Optional[bool] = Field( default=True, description="Whether the secret is enabled" ) @@ -392,7 +431,7 @@ def masked(self): class Secrets(BaseModel): - secrets: List[Secret] = Field(default=[], description="List of secrets") + secrets: List[Secret] = Field(default_factory=list, description="List of secrets") class Message(BaseModel): @@ -439,6 +478,12 @@ class JavelinConfig(BaseModel): default=None, description="API key for the LLM provider" ) api_version: Optional[str] = Field(default=None, description="API version") + default_headers: Optional[Dict[str, str]] = Field( + default=None, description="Default headers" + ) + timeout: Optional[float] = Field( + default=None, description="Request timeout in seconds" + ) @field_validator("javelin_api_key") @classmethod @@ -478,7 +523,11 @@ def __init__( archive: Optional[str] = "", query_params: Optional[Dict[str, Any]] = None, is_transformation_rules: bool = False, + is_model_specs: bool = False, is_reload: bool = False, + univ_model_config: Optional[Dict[str, Any]] = None, + guardrail: Optional[str] = None, + list_guardrails: bool = False, ): self.method = method self.gateway = gateway @@ -493,12 +542,11 @@ def __init__( self.archive = archive self.query_params = query_params self.is_transformation_rules = is_transformation_rules + self.is_model_specs = is_model_specs self.is_reload = is_reload - - -class Message(BaseModel): - role: str - content: str + self.univ_model_config = univ_model_config + self.guardrail = guardrail + self.list_guardrails = list_guardrails class ChatCompletion(BaseModel): @@ -519,19 +567,18 @@ class ModelConfig(BaseModel): class Config: protected_namespaces = () # This resolves the warning - virtual_secret_key: Optional[str] = Field(default=None, description='Virtual secret name') - fallback_enabled: Optional[bool] = Field(default=None, description='Whether fallback is enabled') - suffix: Optional[str] = Field(default=None, description='Suffix for the model') - weight: Optional[int] = Field(default=None, description='Weight of the model') - fallback_codes: Optional[List[int]] = Field(default=None, description='Fallback codes') + virtual_secret_key: Optional[str] = Field( + default=None, description="Virtual secret name" + ) + fallback_enabled: Optional[bool] = Field( + default=None, description="Whether fallback is enabled" + ) + suffix: Optional[str] = Field(default=None, description="Suffix for the model") + weight: Optional[int] = Field(default=None, description="Weight of the model") + fallback_codes: Optional[List[int]] = Field( + default=None, description="Fallback codes" + ) -class JavelinConfig(BaseModel): - base_url: str = Field(default="https://api-dev.javelin.live") - javelin_api_key: str - javelin_virtualapikey: Optional[str] = None - llm_api_key: Optional[str] = None - api_version: Optional[str] = None - timeout: Optional[float] = None class RemoteModelSpec(BaseModel): provider: str @@ -552,7 +599,7 @@ def to_model_spec(self) -> ModelSpec: class EndpointType(str, Enum): UNKNOWN = "unknown" CHAT = "chat" - COMPLETION = "completion" + COMPLETION = "completion" EMBED = "embed" INVOKE = "invoke" CONVERSE = "converse" @@ -561,17 +608,36 @@ class EndpointType(str, Enum): CONVERSE_STREAM = "converse_stream" ALL = "all" +class UnivModelConfig: + def __init__( + self, + provider_name: str, + endpoint_type: str, + deployment: Optional[str] = None, + arn: Optional[str] = None, + api_version: Optional[str] = None, + model_id: Optional[str] = None, + ): + self.provider_name = provider_name + self.endpoint_type = endpoint_type + self.deployment = deployment + self.arn = arn + self.api_version = api_version + self.model_id = model_id + + +# AISPM models -#aispm models class TimeRange(BaseModel): - start_time: str # Change from datetime to str - end_time: str # Change from datetime to str + start_time: str + end_time: str + class BaseResponse(BaseModel): message: Optional[str] = None -# Customer Models + class Customer(BaseModel): name: str description: Optional[str] @@ -579,19 +645,25 @@ class Customer(BaseModel): security_interval: str = "1m" initial_scan: str = "24h" + class CustomerResponse(Customer): status: str - created_at: datetime + created_at: datetime modified_at: datetime -# Cloud Config Models + class BaseCloudConfig(BaseModel): cloud_account_name: str team: str + class AWSConfig(BaseCloudConfig): - role_arn: str - region: Optional[str] = None # Make region optional + # Support either role-based auth or access-key auth (backend-dependent) + role_arn: Optional[str] = None + access_key_id: Optional[str] = None + secret_access_key: Optional[str] = None + region: Optional[str] = None + class AzureConfig(BaseCloudConfig): subscription_id: str @@ -600,18 +672,20 @@ class AzureConfig(BaseCloudConfig): client_secret: str location: str + class GCPConfig(BaseCloudConfig): project_id: str service_account_key: str + class CloudConfigResponse(BaseModel): - name: Optional[str] = Field(None, alias='cloud_account_name') + name: Optional[str] = Field(None, alias="cloud_account_name") provider: str status: str created_at: datetime modified_at: datetime -# Usage Models + class ModelMetrics(BaseModel): latency_avg_ms: float cost_per_request: float @@ -624,6 +698,7 @@ class ModelMetrics(BaseModel): request_count: int token_count: int + class CloudAccountUsage(BaseModel): region_count: int regions: List[str] @@ -631,36 +706,41 @@ class CloudAccountUsage(BaseModel): models: List[str] model_metrics: ModelMetrics + class UsageResponse(BaseModel): - cloud_provider: Dict[str, Any] # Change to allow any structure - time_range: Optional[TimeRange] = None + cloud_provider: Dict[str, Any] + time_range: Optional[TimeRange] = None + -# Alert Models class AlertSeverity(str, Enum): CRITICAL = "CRITICAL" HIGH = "HIGH" - MEDIUM = "MEDIUM" + MEDIUM = "MEDIUM" LOW = "LOW" + class AlertState(str, Enum): ALARM = "ALARM" OK = "OK" INSUFFICIENT_DATA = "INSUFFICIENT_DATA" + class AlertScope(str, Enum): GLOBAL = "GLOBAL" MODEL = "MODEL" REGION = "REGION" + class AlertMetrics(BaseModel): total_alerts: int active_alerts: int resolved_alerts: int - critical_alerts: int + critical_alerts: int high_alerts: int medium_alerts: int low_alerts: int + class Alert(BaseModel): title: str state: AlertState @@ -671,6 +751,7 @@ class Alert(BaseModel): model_id: Optional[str] detected_at: datetime + class CloudProviderAlerts(BaseModel): cloud_account_count: int cloud_accounts: List[str] @@ -681,6 +762,7 @@ class CloudProviderAlerts(BaseModel): alert_metrics: AlertMetrics alerts: List[Alert] + class AlertResponse(BaseModel): cloud_provider: Dict[str, CloudProviderAlerts] time_range: TimeRange From 15d8d32ca4583e5978c9ba57488a7aa7f5fdbc3e Mon Sep 17 00:00:00 2001 From: Darshana Date: Mon, 22 Dec 2025 17:31:46 +0530 Subject: [PATCH 11/11] fix: build and lint errors --- .github/workflows/pr-check.yml | 16 +- javelin_cli/_internal/commands.py | 15 +- javelin_cli/cli.py | 1 - javelin_sdk/client.py | 9 +- javelin_sdk/services/aispm_service.py | 474 ++++++++++++++------------ pyproject.toml | 3 +- 6 files changed, 267 insertions(+), 251 deletions(-) diff --git a/.github/workflows/pr-check.yml b/.github/workflows/pr-check.yml index 4f06d1d..4130e94 100644 --- a/.github/workflows/pr-check.yml +++ b/.github/workflows/pr-check.yml @@ -162,21 +162,21 @@ jobs: with: persist-credentials: false - - name: Setting up the Package Version + - name: Validate Package Version env: PY_VER_FILE: "pyproject.toml" - RELEASE_NAME: "v1.1.1" shell: bash run: |- - export RELEASE_VERSION=$(echo ${{ env.RELEASE_NAME }} | sed 's|^v||g') - if cat ${{ env.PY_VER_FILE }} | grep 'version = "RELEASE_VERSION"' ; then - sed -i "s|^version = \"RELEASE_VERSION\"|version = \"${RELEASE_VERSION}\"|g" ${{ env.PY_VER_FILE }} - cat ${file} - else + if ! grep -q '^version = ' ${{ env.PY_VER_FILE }} ; then + echo "Version entry not found in the ${{ env.PY_VER_FILE }} file...!" + exit 1 + fi + VERSION=$(grep '^version = ' ${{ env.PY_VER_FILE }} | sed 's/.*version = "\(.*\)".*/\1/') + if [ -z "$VERSION" ]; then echo "Version entry format is wrong in the ${{ env.PY_VER_FILE }} file...!" - cat ${file} exit 1 fi + echo "Package version: $VERSION" - name: Setup Python uses: actions/setup-python@v5 diff --git a/javelin_cli/_internal/commands.py b/javelin_cli/_internal/commands.py index b97c464..d7fbf1d 100644 --- a/javelin_cli/_internal/commands.py +++ b/javelin_cli/_internal/commands.py @@ -9,7 +9,6 @@ ) from javelin_sdk.models import ( AWSConfig, - AlertResponse, Gateway, GatewayConfig, JavelinConfig, @@ -24,7 +23,6 @@ Template, TemplateConfig, AzureConfig, - UsageResponse, ) from pydantic import ValidationError @@ -72,7 +70,8 @@ def get_javelin_client_aispm(): role_arn = selected_gateway.get("role_arn") - # Extract account_id from role ARN if still not found (format: arn:aws:iam::ACCOUNT_ID:role/ROLE_NAME) + # Extract account_id from role ARN if still not found + # Format: arn:aws:iam::ACCOUNT_ID:role/ROLE_NAME if role_arn and not account_id: try: parts = role_arn.split(":") @@ -98,8 +97,6 @@ def get_javelin_client_aispm(): client._aispm_userrole = "org:superadmin" return client -) -from pydantic import ValidationError def get_javelin_client(): @@ -204,8 +201,8 @@ def configure_aws(args): client = get_javelin_client_aispm() config = json.loads(args.config) configs = [AWSConfig(**config)] - result = client.aispm.configure_aws(configs) - print(f"AWS configuration created successfully.") + client.aispm.configure_aws(configs) + print("AWS configuration created successfully.") except Exception as e: print(f"Error configuring AWS: {e}") @@ -257,8 +254,8 @@ def configure_azure(args): client = get_javelin_client_aispm() config = json.loads(args.config) configs = [AzureConfig(**config)] - result = client.aispm.configure_azure(configs) - print(f"Azure configuration created successfully.") + client.aispm.configure_azure(configs) + print("Azure configuration created successfully.") except Exception as e: print(f"Error configuring Azure: {e}") diff --git a/javelin_cli/cli.py b/javelin_cli/cli.py index bf8eeb0..f5b984f 100644 --- a/javelin_cli/cli.py +++ b/javelin_cli/cli.py @@ -49,7 +49,6 @@ ) - def check_permissions(): """Check if user has permissions""" home_dir = Path.home() diff --git a/javelin_sdk/client.py b/javelin_sdk/client.py index 4d5e4ce..fba84de 100644 --- a/javelin_sdk/client.py +++ b/javelin_sdk/client.py @@ -115,7 +115,6 @@ def __init__(self, config: JavelinConfig) -> None: self.aispm = AISPMService(self) - @property def client(self): if self._client is None: @@ -966,7 +965,7 @@ def _prepare_request(self, request: Request) -> tuple: ) headers = {**self._headers, **(request.headers or {})} - + # For AISPM requests: if account-id auth is used, do not send API key. if ( request.route @@ -974,15 +973,13 @@ def _prepare_request(self, request: Request) -> tuple: and "x-javelin-accountid" in headers ): headers.pop("x-javelin-apikey", None) - + return url, headers def _send_request_sync(self, request: Request) -> httpx.Response: - response = self._core_send_request(self.client, request) + response = self._core_send_request(self.client, request) return response - - async def _send_request_async(self, request: Request) -> httpx.Response: return await self._core_send_request(self.aclient, request) diff --git a/javelin_sdk/services/aispm_service.py b/javelin_sdk/services/aispm_service.py index 05ed670..7cbf705 100644 --- a/javelin_sdk/services/aispm_service.py +++ b/javelin_sdk/services/aispm_service.py @@ -1,232 +1,256 @@ -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional from httpx import Response import json from javelin_sdk.models import ( - Customer, CustomerResponse, - AWSConfig, AzureConfig, GCPConfig, CloudConfigResponse, - UsageResponse, AlertResponse, TimeRange, - HttpMethod, Request + Customer, + CustomerResponse, + AWSConfig, + AzureConfig, + GCPConfig, + CloudConfigResponse, + UsageResponse, + AlertResponse, + HttpMethod, + Request, ) + class AISPMService: - def __init__(self, client): - self.client = client - - def _get_aispm_headers(self) -> Dict[str, str]: - """Get headers for AISPM requests, including account_id if available.""" - headers = {} - # Check if account_id is stored in client (set by get_javelin_client_aispm) - account_id = getattr(self.client, '_aispm_account_id', None) - if account_id: - headers["x-javelin-accountid"] = account_id - headers["x-javelin-user"] = getattr(self.client, '_aispm_user', "test-user") - headers["x-javelin-userrole"] = getattr(self.client, '_aispm_userrole', "org:superadmin") - return headers - - def _handle_response(self, response: Response) -> None: - if response.status_code >= 400: - try: - error_data = response.json() - # Handle different error response formats - error = error_data.get("error") or error_data.get("message") or str(error_data) - except: - error = f"HTTP {response.status_code}: {response.text}" - raise Exception(f"API error: {error}") - - # Customer Methods - def create_customer(self, customer: Customer) -> CustomerResponse: - request = Request( - method=HttpMethod.POST, - route="v1/admin/aispm/customer", - data=customer.dict(), - headers=self._get_aispm_headers() - ) - response = self.client._send_request_sync(request) - self._handle_response(response) - return CustomerResponse(**response.json()) - - - - def get_customer(self) -> CustomerResponse: - request = Request( - method=HttpMethod.GET, - route="v1/admin/aispm/customer", - headers=self._get_aispm_headers() - ) - response = self.client._send_request_sync(request) - self._handle_response(response) - response_data = response.json() - # Check if response indicates failure (even with 200 status) - if isinstance(response_data, dict) and response_data.get("success") is False: - error_msg = response_data.get("message") or response_data.get("error") or "Request failed" - raise Exception(f"API error: {error_msg}") - return CustomerResponse(**response_data) - - def update_customer(self, customer: Customer) -> CustomerResponse: - request = Request( - method=HttpMethod.PUT, - route="v1/admin/aispm/customer", - data=customer.dict() - ) - response = self.client._send_request_sync(request) - self._handle_response(response) - return CustomerResponse(**response.json()) - - # Cloud Config Methods - def configure_aws(self, configs: List[AWSConfig]) -> List[CloudConfigResponse]: - request = Request( - method=HttpMethod.POST, - route="v1/admin/aispm/config/aws", - data=[config.dict() for config in configs], - headers=self._get_aispm_headers() - ) - response = self.client._send_request_sync(request) - self._handle_response(response) - return [CloudConfigResponse(**config) for config in response.json()] - - def configure_azure(self, configs: List[AzureConfig]) -> List[CloudConfigResponse]: - request = Request( - method=HttpMethod.POST, - route="v1/admin/aispm/config/azure", - data=[config.dict() for config in configs], - headers=self._get_aispm_headers() - ) - response = self.client._send_request_sync(request) - self._handle_response(response) - return [CloudConfigResponse(**config) for config in response.json()] - - - - - def get_aws_configs(self) -> Dict: - """ - Retrieves AWS configurations. - """ - request = Request( - method=HttpMethod.GET, - route="v1/admin/aispm/config/aws", - headers=self._get_aispm_headers() - ) - response = self.client._send_request_sync(request) - self._handle_response(response) - return response.json() - - def configure_gcp(self, configs: List[GCPConfig]) -> List[CloudConfigResponse]: - request = Request( - method=HttpMethod.POST, - route="v1/admin/aispm/config/gcp", - data=[config.dict() for config in configs], - headers=self._get_aispm_headers() - ) - response = self.client._send_request_sync(request) - self._handle_response(response) - return [CloudConfigResponse(**config) for config in response.json()] - - # Usage Methods - def get_usage(self, - provider: Optional[str] = None, - cloud_account: Optional[str] = None, - model: Optional[str] = None, - region: Optional[str] = None) -> UsageResponse: - - route = "v1/admin/aispm/usage" - if provider: - route += f"/{provider}" - if cloud_account: - route += f"/{cloud_account}" - - params = {} - if model: - params["model"] = model - if region: - params["region"] = region - - request = Request( - method=HttpMethod.GET, - route=route, - query_params=params, - headers=self._get_aispm_headers() - ) - response = self.client._send_request_sync(request) - self._handle_response(response) - return UsageResponse(**response.json()) - - # Alert Methods - def get_alerts(self, - provider: Optional[str] = None, - cloud_account: Optional[str] = None, - model: Optional[str] = None, - region: Optional[str] = None) -> AlertResponse: - - route = "v1/admin/aispm/alerts" - if provider: - route += f"/{provider}" - if cloud_account: - route += f"/{cloud_account}" - - params = {} - if model: - params["model"] = model - if region: - params["region"] = region - - request = Request( - method=HttpMethod.GET, - route=route, - query_params=params, - headers=self._get_aispm_headers() - ) - response = self.client._send_request_sync(request) - self._handle_response(response) - return AlertResponse(**response.json()) - - # Helpers - def _validate_provider(self, provider: str) -> None: - valid_providers = ["aws", "azure", "gcp", "openai"] - if provider.lower() not in valid_providers: - raise ValueError(f"Invalid provider. Must be one of: {valid_providers}") - - def _construct_error(self, response: Response) -> Dict: - try: - error = response.json() - return error.get("error", str(response.content)) - except json.JSONDecodeError: - return str(response.content) - - def delete_aws_config(self, name: str) -> None: - """ - Deletes an AWS configuration by name. - - Args: - name (str): The name of the AWS configuration to delete - - Raises: - Exception: If the API request fails - """ - request = Request( - method=HttpMethod.DELETE, - route=f"v1/admin/aispm/config/aws/{name}", - headers=self._get_aispm_headers() - ) - response = self.client._send_request_sync(request) - self._handle_response(response) - - def get_azure_config(self) -> Dict: - """ - Retrieves Azure configurations. - - Returns: - Dict: The Azure configuration data - - Raises: - Exception: If the API request fails - """ - request = Request( - method=HttpMethod.GET, - route="v1/admin/aispm/config/azure", - headers=self._get_aispm_headers() - ) - response = self.client._send_request_sync(request) - self._handle_response(response) - return response.json() \ No newline at end of file + def __init__(self, client): + self.client = client + + def _get_aispm_headers(self) -> Dict[str, str]: + """Get headers for AISPM requests, including account_id if available.""" + headers = {} + # Check if account_id is stored in client (set by get_javelin_client_aispm) + account_id = getattr(self.client, "_aispm_account_id", None) + if account_id: + headers["x-javelin-accountid"] = account_id + headers["x-javelin-user"] = getattr( + self.client, "_aispm_user", "test-user" + ) + headers["x-javelin-userrole"] = getattr( + self.client, "_aispm_userrole", "org:superadmin" + ) + return headers + + def _handle_response(self, response: Response) -> None: + if response.status_code >= 400: + try: + error_data = response.json() + # Handle different error response formats + error = ( + error_data.get("error") + or error_data.get("message") + or str(error_data) + ) + except Exception: + error = f"HTTP {response.status_code}: {response.text}" + raise Exception(f"API error: {error}") + + # Customer Methods + def create_customer(self, customer: Customer) -> CustomerResponse: + request = Request( + method=HttpMethod.POST, + route="v1/admin/aispm/customer", + data=customer.dict(), + headers=self._get_aispm_headers(), + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + return CustomerResponse(**response.json()) + + def get_customer(self) -> CustomerResponse: + request = Request( + method=HttpMethod.GET, + route="v1/admin/aispm/customer", + headers=self._get_aispm_headers(), + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + response_data = response.json() + # Check if response indicates failure (even with 200 status) + if isinstance(response_data, dict) and response_data.get("success") is False: + error_msg = ( + response_data.get("message") + or response_data.get("error") + or "Request failed" + ) + raise Exception(f"API error: {error_msg}") + return CustomerResponse(**response_data) + + def update_customer(self, customer: Customer) -> CustomerResponse: + request = Request( + method=HttpMethod.PUT, + route="v1/admin/aispm/customer", + data=customer.dict(), + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + return CustomerResponse(**response.json()) + + # Cloud Config Methods + def configure_aws( + self, configs: List[AWSConfig] + ) -> List[CloudConfigResponse]: + request = Request( + method=HttpMethod.POST, + route="v1/admin/aispm/config/aws", + data=[config.dict() for config in configs], + headers=self._get_aispm_headers(), + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + return [CloudConfigResponse(**config) for config in response.json()] + + def configure_azure( + self, configs: List[AzureConfig] + ) -> List[CloudConfigResponse]: + request = Request( + method=HttpMethod.POST, + route="v1/admin/aispm/config/azure", + data=[config.dict() for config in configs], + headers=self._get_aispm_headers(), + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + return [CloudConfigResponse(**config) for config in response.json()] + + def get_aws_configs(self) -> Dict: + """ + Retrieves AWS configurations. + """ + request = Request( + method=HttpMethod.GET, + route="v1/admin/aispm/config/aws", + headers=self._get_aispm_headers(), + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + return response.json() + + def configure_gcp( + self, configs: List[GCPConfig] + ) -> List[CloudConfigResponse]: + request = Request( + method=HttpMethod.POST, + route="v1/admin/aispm/config/gcp", + data=[config.dict() for config in configs], + headers=self._get_aispm_headers(), + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + return [CloudConfigResponse(**config) for config in response.json()] + + # Usage Methods + def get_usage( + self, + provider: Optional[str] = None, + cloud_account: Optional[str] = None, + model: Optional[str] = None, + region: Optional[str] = None, + ) -> UsageResponse: + route = "v1/admin/aispm/usage" + if provider: + route += f"/{provider}" + if cloud_account: + route += f"/{cloud_account}" + + params = {} + if model: + params["model"] = model + if region: + params["region"] = region + + request = Request( + method=HttpMethod.GET, + route=route, + query_params=params, + headers=self._get_aispm_headers(), + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + return UsageResponse(**response.json()) + + # Alert Methods + def get_alerts( + self, + provider: Optional[str] = None, + cloud_account: Optional[str] = None, + model: Optional[str] = None, + region: Optional[str] = None, + ) -> AlertResponse: + route = "v1/admin/aispm/alerts" + if provider: + route += f"/{provider}" + if cloud_account: + route += f"/{cloud_account}" + + params = {} + if model: + params["model"] = model + if region: + params["region"] = region + + request = Request( + method=HttpMethod.GET, + route=route, + query_params=params, + headers=self._get_aispm_headers(), + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + return AlertResponse(**response.json()) + + # Helpers + def _validate_provider(self, provider: str) -> None: + valid_providers = ["aws", "azure", "gcp", "openai"] + if provider.lower() not in valid_providers: + raise ValueError( + f"Invalid provider. Must be one of: {valid_providers}" + ) + + def _construct_error(self, response: Response) -> Dict: + try: + error = response.json() + return error.get("error", str(response.content)) + except json.JSONDecodeError: + return str(response.content) + + def delete_aws_config(self, name: str) -> None: + """ + Deletes an AWS configuration by name. + + Args: + name (str): The name of the AWS configuration to delete + + Raises: + Exception: If the API request fails + """ + request = Request( + method=HttpMethod.DELETE, + route=f"v1/admin/aispm/config/aws/{name}", + headers=self._get_aispm_headers(), + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + + def get_azure_config(self) -> Dict: + """ + Retrieves Azure configurations. + + Returns: + Dict: The Azure configuration data + + Raises: + Exception: If the API request fails + """ + request = Request( + method=HttpMethod.GET, + route="v1/admin/aispm/config/azure", + headers=self._get_aispm_headers(), + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + return response.json() diff --git a/pyproject.toml b/pyproject.toml index 13c5ce0..bd7c7f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,12 +5,11 @@ description = "Python client for Javelin" authors = ["Sharath Rajasekar "] readme = "README.md" license = "Apache-2.0" -homepage = "https://dev.highflame.dev/sign-in" +homepage = "https://getjavelin.com" packages = [ { include = "javelin_cli" }, { include = "javelin_sdk" }, ] -homepage = "https://getjavelin.com" [tool.poetry.scripts] javelin = "javelin_cli.cli:main"