diff --git a/.github/workflows/pr-check.yml b/.github/workflows/pr-check.yml index 4f06d1d..4130e94 100644 --- a/.github/workflows/pr-check.yml +++ b/.github/workflows/pr-check.yml @@ -162,21 +162,21 @@ jobs: with: persist-credentials: false - - name: Setting up the Package Version + - name: Validate Package Version env: PY_VER_FILE: "pyproject.toml" - RELEASE_NAME: "v1.1.1" shell: bash run: |- - export RELEASE_VERSION=$(echo ${{ env.RELEASE_NAME }} | sed 's|^v||g') - if cat ${{ env.PY_VER_FILE }} | grep 'version = "RELEASE_VERSION"' ; then - sed -i "s|^version = \"RELEASE_VERSION\"|version = \"${RELEASE_VERSION}\"|g" ${{ env.PY_VER_FILE }} - cat ${file} - else + if ! grep -q '^version = ' ${{ env.PY_VER_FILE }} ; then + echo "Version entry not found in the ${{ env.PY_VER_FILE }} file...!" + exit 1 + fi + VERSION=$(grep '^version = ' ${{ env.PY_VER_FILE }} | sed 's/.*version = "\(.*\)".*/\1/') + if [ -z "$VERSION" ]; then echo "Version entry format is wrong in the ${{ env.PY_VER_FILE }} file...!" - cat ${file} exit 1 fi + echo "Package version: $VERSION" - name: Setup Python uses: actions/setup-python@v5 diff --git a/javelin_cli/_internal/commands.py b/javelin_cli/_internal/commands.py index 22129c3..d7fbf1d 100644 --- a/javelin_cli/_internal/commands.py +++ b/javelin_cli/_internal/commands.py @@ -8,9 +8,11 @@ UnauthorizedError, ) from javelin_sdk.models import ( + AWSConfig, Gateway, GatewayConfig, JavelinConfig, + Customer, Model, Provider, ProviderConfig, @@ -20,10 +22,83 @@ Secrets, Template, TemplateConfig, + AzureConfig, ) from pydantic import ValidationError +def get_javelin_client_aispm(): + # Path to cache.json file + home_dir = Path.home() + json_file_path = home_dir / ".javelin" / "cache.json" + + # Load cache.json + if not json_file_path.exists(): + raise FileNotFoundError(f"Configuration file not found: {json_file_path}") + + with open(json_file_path, "r") as json_file: + cache_data = json.load(json_file) + + # Retrieve the list of gateways + gateways = ( + cache_data.get("memberships", {}) + .get("data", [{}])[0] + .get("organization", {}) + .get("public_metadata", {}) + .get("Gateways", []) + ) + if not gateways: + raise ValueError("No gateways found in the configuration.") + + # Automatically select the first gateway (index 0) + selected_gateway = gateways[0] + base_url = selected_gateway["base_url"] + + # Get organization metadata (where account_id might be stored) + organization = ( + cache_data.get("memberships", {}).get("data", [{}])[0].get("organization", {}) + ) + org_metadata = organization.get("public_metadata", {}) + + # Get account_id from multiple possible locations (in order of preference): + # 1. Gateway's account_id field + # 2. Organization's public_metadata account_id + # 3. Extract from role_arn if provided + account_id = selected_gateway.get("account_id") + if not account_id: + account_id = org_metadata.get("account_id") + + role_arn = selected_gateway.get("role_arn") + + # Extract account_id from role ARN if still not found + # Format: arn:aws:iam::ACCOUNT_ID:role/ROLE_NAME + if role_arn and not account_id: + try: + parts = role_arn.split(":") + if len(parts) >= 5 and parts[2] == "iam": + account_id = parts[4] + except (IndexError, AttributeError): + pass + + javelin_api_key = selected_gateway.get("api_key_value", "placeholder") + + # Initialize and return the JavelinClient + config = JavelinConfig( + base_url=base_url, + javelin_api_key=javelin_api_key, + ) + + client = JavelinClient(config) + + # Store account_id in client for AISPM service to use + if account_id: + client._aispm_account_id = account_id + client._aispm_user = "test-user" + client._aispm_userrole = "org:superadmin" + + return client + + def get_javelin_client(): # Path to cache.json file home_dir = Path.home() @@ -86,6 +161,133 @@ def get_javelin_client(): return JavelinClient(config) +def create_customer(args): + client = get_javelin_client_aispm() + customer = Customer( + name=args.name, + description=args.description, + metrics_interval=args.metrics_interval, + security_interval=args.security_interval, + ) + return client.aispm.create_customer(customer) + + +def get_customer(args): + """ + Gets customer details using the AISPM service. + """ + try: + client = get_javelin_client_aispm() + response = client.aispm.get_customer() + + # Pretty print the response for CLI output + formatted_response = { + "name": response.name, + "description": response.description, + "metrics_interval": response.metrics_interval, + "security_interval": response.security_interval, + "status": response.status, + "created_at": response.created_at.isoformat(), + "modified_at": response.modified_at.isoformat(), + } + + print(json.dumps(formatted_response, indent=2)) + except Exception as e: + print(f"Error getting customer: {e}") + + +def configure_aws(args): + try: + client = get_javelin_client_aispm() + config = json.loads(args.config) + configs = [AWSConfig(**config)] + client.aispm.configure_aws(configs) + print("AWS configuration created successfully.") + except Exception as e: + print(f"Error configuring AWS: {e}") + + +def get_aws_config(args): + """ + Gets AWS configurations using the AISPM service. + """ + try: + client = get_javelin_client_aispm() + response = client.aispm.get_aws_configs() + # Simply print the JSON response + print(json.dumps(response, indent=2)) + + except Exception as e: + print(f"Error getting AWS configurations: {e}") + + +# Add these functions to commands.py + + +def delete_aws_config(args): + """ + Deletes an AWS configuration. + """ + try: + client = get_javelin_client_aispm() + client.aispm.delete_aws_config(args.name) + print(f"AWS configuration '{args.name}' deleted successfully.") + except Exception as e: + print(f"Error deleting AWS config: {e}") + + +def get_azure_config(args): + """ + Gets Azure configurations using the AISPM service. + """ + try: + client = get_javelin_client_aispm() + response = client.aispm.get_azure_config() + # Format and print the response nicely + print(json.dumps(response, indent=2)) + except Exception as e: + print(f"Error getting Azure config: {e}") + + +def configure_azure(args): + try: + client = get_javelin_client_aispm() + config = json.loads(args.config) + configs = [AzureConfig(**config)] + client.aispm.configure_azure(configs) + print("Azure configuration created successfully.") + except Exception as e: + print(f"Error configuring Azure: {e}") + + +def get_usage(args): + try: + client = get_javelin_client_aispm() + usage = client.aispm.get_usage( + provider=args.provider, + cloud_account=args.account, + model=args.model, + region=args.region, + ) + print(json.dumps(usage.dict(), indent=2)) + except Exception as e: + print(f"Error getting usage: {e}") + + +def get_alerts(args): + try: + client = get_javelin_client_aispm() + alerts = client.aispm.get_alerts( + provider=args.provider, + cloud_account=args.account, + model=args.model, + region=args.region, + ) + print(json.dumps(alerts.dict(), indent=2)) + except Exception as e: + print(f"Error getting alerts: {e}") + + def create_gateway(args): try: client = get_javelin_client() diff --git a/javelin_cli/cli.py b/javelin_cli/cli.py index 1a30311..f5b984f 100644 --- a/javelin_cli/cli.py +++ b/javelin_cli/cli.py @@ -37,11 +37,20 @@ update_route, update_secret, update_template, + create_customer, + get_customer, + configure_aws, + configure_azure, + get_usage, + get_alerts, + get_aws_config, + get_azure_config, + delete_aws_config ) def check_permissions(): - """Check if user has superadmin permissions""" + """Check if user has permissions""" home_dir = Path.home() cache_file = home_dir / ".javelin" / "cache.json" @@ -95,6 +104,83 @@ def main(): help="Force re-authentication, overriding existing credentials", ) auth_parser.set_defaults(func=authenticate) + + # AISPM commands + aispm_parser = subparsers.add_parser("aispm", help="Manage AISPM functionality") + aispm_subparsers = aispm_parser.add_subparsers() + + # Customer commands + customer_parser = aispm_subparsers.add_parser("customer", help="Manage customers") + customer_subparsers = customer_parser.add_subparsers() + + customer_create = customer_subparsers.add_parser("create", help="Create customer") + customer_create.add_argument("--name", required=True, help="Customer name") + customer_create.add_argument("--description", help="Customer description") + customer_create.add_argument( + "--metrics-interval", default="5m", help="Metrics interval" + ) + customer_create.add_argument( + "--security-interval", default="1m", help="Security interval" + ) + customer_create.set_defaults(func=create_customer) + + customer_get = customer_subparsers.add_parser("get", help="Get customer details") + customer_get.set_defaults(func=get_customer) + + # Cloud config commands + config_parser = aispm_subparsers.add_parser( + "config", help="Manage cloud configurations" + ) + config_subparsers = config_parser.add_subparsers() + + aws_parser = config_subparsers.add_parser("aws", help="Configure AWS") + azure_parser = config_subparsers.add_parser("azure", help="Configure Azure") + + aws_subparsers = aws_parser.add_subparsers() + aws_get_parser = aws_subparsers.add_parser("get", help="Get AWS configuration") + aws_get_parser.set_defaults(func=get_aws_config) + + aws_create_parser = aws_subparsers.add_parser("create", help="Configure AWS") + aws_create_parser.add_argument( + "--config", type=str, required=True, help="AWS config JSON" + ) + aws_create_parser.set_defaults(func=configure_aws) + + aws_delete_parser = aws_subparsers.add_parser( + "delete", help="Delete AWS configuration" + ) + aws_delete_parser.add_argument( + "--name", type=str, required=True, help="Name of AWS configuration to delete" + ) + aws_delete_parser.set_defaults(func=delete_aws_config) + + azure_subparsers = azure_parser.add_subparsers(dest="azure_command") + azure_get_parser = azure_subparsers.add_parser( + "get", help="Get Azure configuration" + ) + azure_get_parser.set_defaults(func=get_azure_config) + + azure_create_parser = azure_subparsers.add_parser("create", help="Configure Azure") + azure_create_parser.add_argument( + "--config", type=str, required=True, help="Azure config JSON" + ) + azure_create_parser.set_defaults(func=configure_azure) + + # Usage metrics + usage_parser = aispm_subparsers.add_parser("usage", help="Get usage metrics") + usage_parser.add_argument("--provider", help="Cloud provider") + usage_parser.add_argument("--account", help="Cloud account name") + usage_parser.add_argument("--model", help="Model ID") + usage_parser.add_argument("--region", help="Region") + usage_parser.set_defaults(func=get_usage) + + # Alerts + alerts_parser = aispm_subparsers.add_parser("alerts", help="Get alerts") + alerts_parser.add_argument("--provider", help="Cloud provider") + alerts_parser.add_argument("--account", help="Cloud account name") + alerts_parser.add_argument("--model", help="Model ID") + alerts_parser.add_argument("--region", help="Region") + alerts_parser.set_defaults(func=get_alerts) # Gateway CRUD gateway_parser = subparsers.add_parser( "gateway", @@ -415,8 +501,7 @@ def authenticate(args): print("Use --force to re-authenticate and override existing cache.") return - default_url = "https://dev.javelin.live/" - + default_url = "https://dev.highflame.dev/" print(" O") print(" /|\\") print(" / \\ ========> Welcome to Javelin! 🚀") diff --git a/javelin_sdk/client.py b/javelin_sdk/client.py index c6d9315..fba84de 100644 --- a/javelin_sdk/client.py +++ b/javelin_sdk/client.py @@ -19,6 +19,7 @@ from javelin_sdk.services.secret_service import SecretService from javelin_sdk.services.template_service import TemplateService from javelin_sdk.services.trace_service import TraceService +from javelin_sdk.services.aispm_service import AISPMService from javelin_sdk.services.guardrails_service import GuardrailsService from javelin_sdk.tracing_setup import configure_span_exporter @@ -112,9 +113,13 @@ def __init__(self, config: JavelinConfig) -> None: self.original_methods = {} + self.aispm = AISPMService(self) + @property def client(self): if self._client is None: + # Don't set headers at client level - they'll be added per-request + # This allows us to exclude x-api-key for AISPM requests self._client = httpx.Client( base_url=self.base_url, headers=self._headers, @@ -125,8 +130,9 @@ def client(self): @property def aclient(self): if self._aclient is None: + # Don't set headers at client level - they'll be added per-request self._aclient = httpx.AsyncClient( - base_url=self.base_url, headers=self._headers, timeout=API_TIMEOUT + base_url=self.base_url, timeout=API_TIMEOUT ) return self._aclient @@ -957,11 +963,22 @@ def _prepare_request(self, request: Request) -> tuple: guardrail=request.guardrail, list_guardrails=request.list_guardrails, ) + headers = {**self._headers, **(request.headers or {})} + + # For AISPM requests: if account-id auth is used, do not send API key. + if ( + request.route + and request.route.startswith("v1/admin/aispm") + and "x-javelin-accountid" in headers + ): + headers.pop("x-javelin-apikey", None) + return url, headers def _send_request_sync(self, request: Request) -> httpx.Response: - return self._core_send_request(self.client, request) + response = self._core_send_request(self.client, request) + return response async def _send_request_async(self, request: Request) -> httpx.Response: return await self._core_send_request(self.aclient, request) @@ -999,6 +1016,14 @@ def _construct_url( guardrail: Optional[str] = None, list_guardrails: bool = False, ) -> str: + # Handle AISPM routes: they use the route directly with base_url + if route_name and route_name.startswith("v1/admin/aispm"): + url = f"{self.config.base_url.rstrip('/')}/{route_name}" + if query_params: + query_string = "&".join(f"{k}={v}" for k, v in query_params.items()) + url += f"?{query_string}" + return url + url_parts = [self.base_url] # Determine the main URL path based on the primary resource type diff --git a/javelin_sdk/models.py b/javelin_sdk/models.py index 18bec8f..a070d6c 100644 --- a/javelin_sdk/models.py +++ b/javelin_sdk/models.py @@ -1,3 +1,4 @@ +from datetime import datetime from enum import Enum, auto from typing import Any, Dict, List, Optional @@ -624,3 +625,145 @@ def __init__( self.arn = arn self.api_version = api_version self.model_id = model_id + + +# AISPM models + + +class TimeRange(BaseModel): + start_time: str + end_time: str + + +class BaseResponse(BaseModel): + message: Optional[str] = None + + +class Customer(BaseModel): + name: str + description: Optional[str] + metrics_interval: str = "5m" + security_interval: str = "1m" + initial_scan: str = "24h" + + +class CustomerResponse(Customer): + status: str + created_at: datetime + modified_at: datetime + + +class BaseCloudConfig(BaseModel): + cloud_account_name: str + team: str + + +class AWSConfig(BaseCloudConfig): + # Support either role-based auth or access-key auth (backend-dependent) + role_arn: Optional[str] = None + access_key_id: Optional[str] = None + secret_access_key: Optional[str] = None + region: Optional[str] = None + + +class AzureConfig(BaseCloudConfig): + subscription_id: str + tenant_id: str + client_id: str + client_secret: str + location: str + + +class GCPConfig(BaseCloudConfig): + project_id: str + service_account_key: str + + +class CloudConfigResponse(BaseModel): + name: Optional[str] = Field(None, alias="cloud_account_name") + provider: str + status: str + created_at: datetime + modified_at: datetime + + +class ModelMetrics(BaseModel): + latency_avg_ms: float + cost_per_request: float + tokens_per_request: float + attempt_count: int + failure_count: int + success_count: int + success_rate_pct: float + cost_total: float + request_count: int + token_count: int + + +class CloudAccountUsage(BaseModel): + region_count: int + regions: List[str] + model_count: int + models: List[str] + model_metrics: ModelMetrics + + +class UsageResponse(BaseModel): + cloud_provider: Dict[str, Any] + time_range: Optional[TimeRange] = None + + +class AlertSeverity(str, Enum): + CRITICAL = "CRITICAL" + HIGH = "HIGH" + MEDIUM = "MEDIUM" + LOW = "LOW" + + +class AlertState(str, Enum): + ALARM = "ALARM" + OK = "OK" + INSUFFICIENT_DATA = "INSUFFICIENT_DATA" + + +class AlertScope(str, Enum): + GLOBAL = "GLOBAL" + MODEL = "MODEL" + REGION = "REGION" + + +class AlertMetrics(BaseModel): + total_alerts: int + active_alerts: int + resolved_alerts: int + critical_alerts: int + high_alerts: int + medium_alerts: int + low_alerts: int + + +class Alert(BaseModel): + title: str + state: AlertState + state_reason: str + severity: AlertSeverity + scope: AlertScope + region: Optional[str] + model_id: Optional[str] + detected_at: datetime + + +class CloudProviderAlerts(BaseModel): + cloud_account_count: int + cloud_accounts: List[str] + region_count: int + regions: List[str] + model_count: int + models: List[str] + alert_metrics: AlertMetrics + alerts: List[Alert] + + +class AlertResponse(BaseModel): + cloud_provider: Dict[str, CloudProviderAlerts] + time_range: TimeRange diff --git a/javelin_sdk/services/aispm_service.py b/javelin_sdk/services/aispm_service.py new file mode 100644 index 0000000..7cbf705 --- /dev/null +++ b/javelin_sdk/services/aispm_service.py @@ -0,0 +1,256 @@ +from typing import Dict, List, Optional +from httpx import Response +import json + +from javelin_sdk.models import ( + Customer, + CustomerResponse, + AWSConfig, + AzureConfig, + GCPConfig, + CloudConfigResponse, + UsageResponse, + AlertResponse, + HttpMethod, + Request, +) + + +class AISPMService: + def __init__(self, client): + self.client = client + + def _get_aispm_headers(self) -> Dict[str, str]: + """Get headers for AISPM requests, including account_id if available.""" + headers = {} + # Check if account_id is stored in client (set by get_javelin_client_aispm) + account_id = getattr(self.client, "_aispm_account_id", None) + if account_id: + headers["x-javelin-accountid"] = account_id + headers["x-javelin-user"] = getattr( + self.client, "_aispm_user", "test-user" + ) + headers["x-javelin-userrole"] = getattr( + self.client, "_aispm_userrole", "org:superadmin" + ) + return headers + + def _handle_response(self, response: Response) -> None: + if response.status_code >= 400: + try: + error_data = response.json() + # Handle different error response formats + error = ( + error_data.get("error") + or error_data.get("message") + or str(error_data) + ) + except Exception: + error = f"HTTP {response.status_code}: {response.text}" + raise Exception(f"API error: {error}") + + # Customer Methods + def create_customer(self, customer: Customer) -> CustomerResponse: + request = Request( + method=HttpMethod.POST, + route="v1/admin/aispm/customer", + data=customer.dict(), + headers=self._get_aispm_headers(), + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + return CustomerResponse(**response.json()) + + def get_customer(self) -> CustomerResponse: + request = Request( + method=HttpMethod.GET, + route="v1/admin/aispm/customer", + headers=self._get_aispm_headers(), + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + response_data = response.json() + # Check if response indicates failure (even with 200 status) + if isinstance(response_data, dict) and response_data.get("success") is False: + error_msg = ( + response_data.get("message") + or response_data.get("error") + or "Request failed" + ) + raise Exception(f"API error: {error_msg}") + return CustomerResponse(**response_data) + + def update_customer(self, customer: Customer) -> CustomerResponse: + request = Request( + method=HttpMethod.PUT, + route="v1/admin/aispm/customer", + data=customer.dict(), + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + return CustomerResponse(**response.json()) + + # Cloud Config Methods + def configure_aws( + self, configs: List[AWSConfig] + ) -> List[CloudConfigResponse]: + request = Request( + method=HttpMethod.POST, + route="v1/admin/aispm/config/aws", + data=[config.dict() for config in configs], + headers=self._get_aispm_headers(), + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + return [CloudConfigResponse(**config) for config in response.json()] + + def configure_azure( + self, configs: List[AzureConfig] + ) -> List[CloudConfigResponse]: + request = Request( + method=HttpMethod.POST, + route="v1/admin/aispm/config/azure", + data=[config.dict() for config in configs], + headers=self._get_aispm_headers(), + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + return [CloudConfigResponse(**config) for config in response.json()] + + def get_aws_configs(self) -> Dict: + """ + Retrieves AWS configurations. + """ + request = Request( + method=HttpMethod.GET, + route="v1/admin/aispm/config/aws", + headers=self._get_aispm_headers(), + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + return response.json() + + def configure_gcp( + self, configs: List[GCPConfig] + ) -> List[CloudConfigResponse]: + request = Request( + method=HttpMethod.POST, + route="v1/admin/aispm/config/gcp", + data=[config.dict() for config in configs], + headers=self._get_aispm_headers(), + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + return [CloudConfigResponse(**config) for config in response.json()] + + # Usage Methods + def get_usage( + self, + provider: Optional[str] = None, + cloud_account: Optional[str] = None, + model: Optional[str] = None, + region: Optional[str] = None, + ) -> UsageResponse: + route = "v1/admin/aispm/usage" + if provider: + route += f"/{provider}" + if cloud_account: + route += f"/{cloud_account}" + + params = {} + if model: + params["model"] = model + if region: + params["region"] = region + + request = Request( + method=HttpMethod.GET, + route=route, + query_params=params, + headers=self._get_aispm_headers(), + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + return UsageResponse(**response.json()) + + # Alert Methods + def get_alerts( + self, + provider: Optional[str] = None, + cloud_account: Optional[str] = None, + model: Optional[str] = None, + region: Optional[str] = None, + ) -> AlertResponse: + route = "v1/admin/aispm/alerts" + if provider: + route += f"/{provider}" + if cloud_account: + route += f"/{cloud_account}" + + params = {} + if model: + params["model"] = model + if region: + params["region"] = region + + request = Request( + method=HttpMethod.GET, + route=route, + query_params=params, + headers=self._get_aispm_headers(), + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + return AlertResponse(**response.json()) + + # Helpers + def _validate_provider(self, provider: str) -> None: + valid_providers = ["aws", "azure", "gcp", "openai"] + if provider.lower() not in valid_providers: + raise ValueError( + f"Invalid provider. Must be one of: {valid_providers}" + ) + + def _construct_error(self, response: Response) -> Dict: + try: + error = response.json() + return error.get("error", str(response.content)) + except json.JSONDecodeError: + return str(response.content) + + def delete_aws_config(self, name: str) -> None: + """ + Deletes an AWS configuration by name. + + Args: + name (str): The name of the AWS configuration to delete + + Raises: + Exception: If the API request fails + """ + request = Request( + method=HttpMethod.DELETE, + route=f"v1/admin/aispm/config/aws/{name}", + headers=self._get_aispm_headers(), + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + + def get_azure_config(self) -> Dict: + """ + Retrieves Azure configurations. + + Returns: + Dict: The Azure configuration data + + Raises: + Exception: If the API request fails + """ + request = Request( + method=HttpMethod.GET, + route="v1/admin/aispm/config/azure", + headers=self._get_aispm_headers(), + ) + response = self.client._send_request_sync(request) + self._handle_response(response) + return response.json() diff --git a/pyproject.toml b/pyproject.toml index 223e5c9..bd7c7f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,15 +1,15 @@ [tool.poetry] name = "javelin-sdk" -version = "RELEASE_VERSION" +version = "0.2.7-dev" description = "Python client for Javelin" authors = ["Sharath Rajasekar "] readme = "README.md" license = "Apache-2.0" +homepage = "https://getjavelin.com" packages = [ { include = "javelin_cli" }, { include = "javelin_sdk" }, ] -homepage = "https://getjavelin.com" [tool.poetry.scripts] javelin = "javelin_cli.cli:main"