diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 488c713d..37a12bed 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -44,6 +44,11 @@ jobs: - name: Run tests run: make test + - name: Run LLM integration tests + env: + PRIVATE_KEY: ${{ secrets.PRIVATE_KEY }} + run: make llm_integrationtest + release: needs: [check, test] runs-on: ubuntu-latest diff --git a/README.md b/README.md index b93fcefb..bf990bc2 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ OpenGradient enables developers to build AI applications with verifiable executi pip install opengradient ``` -**Note**: Windows users should temporarily enable WSL during installation (fix in progress). +**Note**: > **Windows users:** See the [Windows Installation Guide](./WINDOWS_INSTALL.md) for step-by-step setup instructions. ## Network Architecture diff --git a/WINDOWS_INSTALL.md b/WINDOWS_INSTALL.md new file mode 100644 index 00000000..ac57d32f --- /dev/null +++ b/WINDOWS_INSTALL.md @@ -0,0 +1,36 @@ +# Windows Installation Guide + +The `opengradient` package requires a C compiler +to build its native dependencies. Windows does not +have one by default. + +## Step 1 — Enable WSL + +Open PowerShell as Administrator and run: + + wsl --install + +Restart your PC when prompted. + +## Step 2 — Install Python and uv inside WSL + +Open the Ubuntu app and run: + + sudo apt update && sudo apt install -y python3 curl + curl -LsSf https://astral.sh/uv/install.sh | sh + source $HOME/.local/bin/env + +## Step 3 — Install SDK + + uv add opengradient + +## Step 4 — Verify + + uv run python3 -c "import opengradient; print('Ready!')" + +## Common Errors + +- Visual C++ 14.0 required → Use WSL instead +- wsl: command not found → Update Windows 10 to Build 19041+ +- WSL stuck → Enable Virtualization in BIOS +- uv: command not found → Run: source $HOME/.local/bin/env diff --git a/docs/opengradient/client/llm.md b/docs/opengradient/client/llm.md index 65aed9f0..40615b73 100644 --- a/docs/opengradient/client/llm.md +++ b/docs/opengradient/client/llm.md @@ -200,8 +200,8 @@ a transaction. Otherwise, sends an ERC-20 approve transaction. **Arguments** -* **`opg_amount`**: Minimum number of OPG tokens required (e.g. ``0.05`` - for 0.05 OPG). Must be at least 0.05 OPG. +* **`opg_amount`**: Minimum number of OPG tokens required (e.g. ``0.1`` + for 0.1 OPG). Must be at least 0.1 OPG. **Returns** @@ -211,5 +211,5 @@ Permit2ApprovalResult: Contains ``allowance_before``, **Raises** -* **`ValueError`**: If the OPG amount is less than 0.05. +* **`ValueError`**: If the OPG amount is less than 0.1. * **`RuntimeError`**: If the approval transaction fails. \ No newline at end of file diff --git a/docs/opengradient/index.md b/docs/opengradient/index.md index 45b84a2d..32abc92f 100644 --- a/docs/opengradient/index.md +++ b/docs/opengradient/index.md @@ -6,7 +6,7 @@ opengradient # Package opengradient -**Version: 0.9.0** +**Version: 0.9.2** OpenGradient Python SDK for decentralized AI inference with end-to-end verification. diff --git a/integrationtest/llm/test_llm_chat.py b/integrationtest/llm/test_llm_chat.py index 8cc9f371..632436ac 100644 --- a/integrationtest/llm/test_llm_chat.py +++ b/integrationtest/llm/test_llm_chat.py @@ -23,7 +23,7 @@ ] # Amount of OPG tokens to fund the test account with -OPG_FUND_AMOUNT = 0.05 +OPG_FUND_AMOUNT = 0.1 # Amount of ETH to fund the test account with (for gas) ETH_FUND_AMOUNT = 0.0001 diff --git a/pyproject.toml b/pyproject.toml index 1f515651..40190299 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "opengradient" -version = "0.9.1" +version = "0.9.3" description = "Python SDK for OpenGradient decentralized model management & inference services" authors = [{name = "OpenGradient", email = "adam@vannalabs.ai"}] readme = "README.md" @@ -27,7 +27,7 @@ dependencies = [ "langchain>=0.3.7", "openai>=1.58.1", "pydantic>=2.9.2", - "og-x402==0.0.1.dev2" + "og-x402==0.0.1.dev4" ] [project.optional-dependencies] diff --git a/src/opengradient/client/_conversions.py b/src/opengradient/client/_conversions.py index 495f663b..8c04426f 100644 --- a/src/opengradient/client/_conversions.py +++ b/src/opengradient/client/_conversions.py @@ -36,11 +36,32 @@ def convert_to_fixed_point(number: float) -> Tuple[int, int]: return value, decimals +def convert_fixed_point_to_python(value: int, decimals: int) -> np.float32: + """ + Converts a fixed-point representation back to a NumPy float32. + + This function is intentionally type-stable and always returns np.float32, + regardless of the value of `decimals`. Callers that require integer + semantics should perform an explicit cast (e.g., int(...)) based on + their own dtype metadata or application logic. + + Args: + value: The integer significand stored on-chain. + decimals: The scale factor exponent (value / 10**decimals). + + Returns: + np.float32 corresponding to `value / 10**decimals`. + """ + return np.float32(Decimal(value) / (10 ** Decimal(decimals))) + + def convert_to_float32(value: int, decimals: int) -> np.float32: """ - Converts fixed point back into floating point + Deprecated: use convert_fixed_point_to_python() instead. - Returns an np.float32 type + Kept for backwards compatibility. New callers should use + convert_fixed_point_to_python which is type-stable and always + returns np.float32. """ return np.float32(Decimal(value) / (10 ** Decimal(decimals))) @@ -131,10 +152,11 @@ def convert_to_model_output(event_data: AttributeDict) -> Dict[str, np.ndarray]: name = tensor.get("name") shape = tensor.get("shape") values = [] - # Convert from fixed point back into np.float32 + # Use convert_fixed_point_to_python so integer tensors (decimals==0) + # come back as int instead of np.float32 (fixes issue #103). for v in tensor.get("values", []): if isinstance(v, (AttributeDict, dict)): - values.append(convert_to_float32(value=int(v.get("value")), decimals=int(v.get("decimals")))) + values.append(convert_fixed_point_to_python(value=int(v.get("value")), decimals=int(v.get("decimals")))) else: logging.warning(f"Unexpected number type: {type(v)}") output_dict[name] = np.array(values).reshape(shape) @@ -183,10 +205,11 @@ def convert_array_to_model_output(array_data: List) -> ModelOutput: values = tensor[1] shape = tensor[2] - # Convert from fixed point into np.float32 + # Use convert_fixed_point_to_python so integer tensors (decimals==0) + # come back as int instead of np.float32 (fixes issue #103). converted_values = [] for value in values: - converted_values.append(convert_to_float32(value=value[0], decimals=value[1])) + converted_values.append(convert_fixed_point_to_python(value=value[0], decimals=value[1])) number_data[name] = np.array(converted_values).reshape(shape) diff --git a/src/opengradient/client/_utils.py b/src/opengradient/client/_utils.py index 5e2938ad..d9d8436d 100644 --- a/src/opengradient/client/_utils.py +++ b/src/opengradient/client/_utils.py @@ -49,6 +49,9 @@ def run_with_retry( """ effective_retries = max_retries if max_retries is not None else DEFAULT_MAX_RETRY + if effective_retries < 1: + raise ValueError(f"max_retries must be at least 1, got {effective_retries}") + for attempt in range(effective_retries): try: return txn_function() @@ -62,3 +65,5 @@ def run_with_retry( continue raise + + raise RuntimeError(f"run_with_retry exhausted {effective_retries} attempts without returning or raising") diff --git a/src/opengradient/client/alpha.py b/src/opengradient/client/alpha.py index a4e633ec..094efaa3 100644 --- a/src/opengradient/client/alpha.py +++ b/src/opengradient/client/alpha.py @@ -119,9 +119,19 @@ def execute_transaction(): model_output = convert_to_model_output(parsed_logs[0]["args"]) if len(model_output) == 0: # check inference directly from node - parsed_logs = precompile_contract.events.ModelInferenceEvent().process_receipt(tx_receipt, errors=DISCARD) - inference_id = parsed_logs[0]["args"]["inferenceID"] + precompile_logs = precompile_contract.events.ModelInferenceEvent().process_receipt(tx_receipt, errors=DISCARD) + if not precompile_logs: + raise RuntimeError( + "ModelInferenceEvent not found in transaction logs. " + "Cannot fall back to node-side inference result." + ) + inference_id = precompile_logs[0]["args"]["inferenceID"] inference_result = self._get_inference_result_from_node(inference_id, inference_mode) + if inference_result is None: + raise RuntimeError( + f"Inference node returned no result for inference ID {inference_id!r}. " + "The result may not be available yet — retry after a short delay." + ) model_output = convert_to_model_output(inference_result) return InferenceResult(tx_hash.hex(), model_output) @@ -315,7 +325,7 @@ def deploy_transaction(): signed_txn = self._wallet_account.sign_transaction(transaction) tx_hash = self._blockchain.eth.send_raw_transaction(signed_txn.raw_transaction) - tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash, timeout=60) + tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash, timeout=INFERENCE_TX_TIMEOUT) if tx_receipt["status"] == 0: raise Exception(f"Contract deployment failed, transaction hash: {tx_hash.hex()}") @@ -419,11 +429,30 @@ def run_workflow(self, contract_address: str) -> ModelOutput: nonce = self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending") run_function = contract.functions.run() + + # Estimate gas instead of using a hardcoded 30M limit, which is wasteful + # and may exceed the block gas limit on some networks. + try: + estimated_gas = run_function.estimate_gas({"from": self._wallet_account.address}) + gas_limit = int(estimated_gas * 3) + except ContractLogicError as exc: + # Estimation failed due to a contract revert — simulate the call to + # surface the revert reason and avoid sending a transaction that will fail. + try: + run_function.call({"from": self._wallet_account.address}) + except ContractLogicError as call_exc: + # Re-raise the detailed revert reason from the simulated call. + raise call_exc + # If the simulated call somehow doesn't raise, re-raise the original error. + raise exc + except Exception: + gas_limit = 30000000 # Conservative fallback for transient/RPC estimation errors + transaction = run_function.build_transaction( { "from": self._wallet_account.address, "nonce": nonce, - "gas": 30000000, + "gas": gas_limit, "gasPrice": self._blockchain.eth.gas_price, "chainId": self._blockchain.eth.chain_id, } diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index eecfa148..af383861 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -1,10 +1,11 @@ """LLM chat and completion via TEE-verified execution with x402 payments.""" -import json +import json as _json import logging import ssl from dataclasses import dataclass -from typing import AsyncGenerator, Dict, List, Optional, Union +from typing import AsyncGenerator, Awaitable, Callable, Dict, List, Optional, TypeVar, Union +import httpx from eth_account import Account from eth_account.account import LocalAccount @@ -19,6 +20,7 @@ from .tee_registry import TEERegistry, build_ssl_context_from_der logger = logging.getLogger(__name__) +T = TypeVar("T") DEFAULT_RPC_URL = "https://ogevmdevnet.opengradient.ai" DEFAULT_TEE_REGISTRY_ADDRESS = "0x4e72238852f3c918f4E4e57AeC9280dDB0c80248" @@ -31,6 +33,12 @@ _COMPLETION_ENDPOINT = "/v1/completions" _REQUEST_TIMEOUT = 60 +_402_HINT = ( + "Payment required (HTTP 402): your wallet may have insufficient OPG token allowance. " + "Call llm.ensure_opg_approval(opg_amount=) to approve Permit2 spending " + "before making requests. Minimum amount is 0.05 OPG." +) + @dataclass class _ChatParams: @@ -94,32 +102,44 @@ def __init__( llm_server_url: Optional[str] = None, ): self._wallet_account: LocalAccount = Account.from_key(private_key) + self._rpc_url = rpc_url + self._tee_registry_address = tee_registry_address + self._llm_server_url = llm_server_url + + # x402 payment stack (created once, reused across TEE refreshes) + signer = EthAccountSigner(self._wallet_account) + self._x402_client = x402Client() + register_exact_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) + register_upto_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) + self._connect_tee() + + # ── TEE resolution and connection ─────────────────────────────────────────── + + def _connect_tee(self) -> None: + """Resolve TEE from registry and create a secure HTTP client for it.""" endpoint, tls_cert_der, tee_id, tee_payment_address = self._resolve_tee( - llm_server_url, - rpc_url, - tee_registry_address, + self._llm_server_url, + self._rpc_url, + self._tee_registry_address, ) - self._tee_id = tee_id self._tee_endpoint = endpoint self._tee_payment_address = tee_payment_address ssl_ctx = build_ssl_context_from_der(tls_cert_der) if tls_cert_der else None - # When connecting directly via llm_server_url, skip cert verification — - # self-hosted TEE servers commonly use self-signed certificates. - verify_ssl = llm_server_url is None - self._tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx if ssl_ctx else verify_ssl - - # x402 client and signer - signer = EthAccountSigner(self._wallet_account) - self._x402_client = x402Client() - register_exact_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) - register_upto_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) - # httpx.AsyncClient subclass - construction is sync, connections open lazily + self._tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx if ssl_ctx else (self._llm_server_url is None) self._http_client = x402HttpxClient(self._x402_client, verify=self._tls_verify) - # ── TEE resolution ────────────────────────────────────────────────── + async def _refresh_tee(self) -> None: + """Re-resolve TEE from the registry and rebuild the HTTP client.""" + old_http_client = self._http_client + self._connect_tee() + try: + await old_http_client.aclose() + except Exception: + logger.debug("Failed to close previous HTTP client during TEE refresh.", exc_info=True) + @staticmethod def _resolve_tee( @@ -188,6 +208,29 @@ def _tee_metadata(self) -> Dict: tee_payment_address=self._tee_payment_address, ) + async def _call_with_tee_retry( + self, + operation_name: str, + call: Callable[[], Awaitable[T]], + ) -> T: + """Execute *call*; on connection failure, pick a new TEE and retry once. + + Only retries when the request never reached the server (no HTTP response). + Server-side errors (4xx/5xx) are not retried. + """ + try: + return await call() + except httpx.HTTPStatusError: + raise + except Exception as exc: + logger.warning( + "Connection failure during %s; refreshing TEE and retrying once: %s", + operation_name, + exc, + ) + await self._refresh_tee() + return await call() + # ── Public API ────────────────────────────────────────────────────── def ensure_opg_approval(self, opg_amount: float) -> Permit2ApprovalResult: @@ -248,7 +291,6 @@ async def completion( RuntimeError: If the inference fails. """ model_id = model.split("/")[1] - headers = self._headers(x402_settlement_mode) payload: Dict = { "model": model_id, "prompt": prompt, @@ -258,16 +300,16 @@ async def completion( if stop_sequence: payload["stop"] = stop_sequence - try: + async def _request() -> TextGenerationOutput: response = await self._http_client.post( self._tee_endpoint + _COMPLETION_ENDPOINT, json=payload, - headers=headers, + headers=self._headers(x402_settlement_mode), timeout=_REQUEST_TIMEOUT, ) response.raise_for_status() raw_body = await response.aread() - result = json.loads(raw_body.decode()) + result = _json.loads(raw_body.decode()) return TextGenerationOutput( transaction_hash="external", completion_output=result.get("completion"), @@ -275,8 +317,15 @@ async def completion( tee_timestamp=result.get("tee_timestamp"), **self._tee_metadata(), ) + + try: + return await self._call_with_tee_retry("completion", _request) except RuntimeError: raise + except httpx.HTTPStatusError as e: + if e.response.status_code == 402: + raise RuntimeError(_402_HINT) from e + raise RuntimeError(f"TEE LLM completion failed: {e}") from e except Exception as e: raise RuntimeError(f"TEE LLM completion failed: {e}") from e @@ -342,19 +391,18 @@ async def chat( async def _chat_request(self, params: _ChatParams, messages: List[Dict]) -> TextGenerationOutput: """Non-streaming chat request.""" - headers = self._headers(params.x402_settlement_mode) payload = self._chat_payload(params, messages) - try: + async def _request() -> TextGenerationOutput: response = await self._http_client.post( self._tee_endpoint + _CHAT_ENDPOINT, json=payload, - headers=headers, + headers=self._headers(params.x402_settlement_mode), timeout=_REQUEST_TIMEOUT, ) response.raise_for_status() raw_body = await response.aread() - result = json.loads(raw_body.decode()) + result = _json.loads(raw_body.decode()) choices = result.get("choices") if not choices: @@ -375,8 +423,17 @@ async def _chat_request(self, params: _ChatParams, messages: List[Dict]) -> Text tee_timestamp=result.get("tee_timestamp"), **self._tee_metadata(), ) + + try: + return await self._call_with_tee_retry("chat", _request) except RuntimeError: raise + except httpx.HTTPStatusError as e: + # Provide an actionable error message for the very common 402 case + # (issue #188 — users see a cryptic RuntimeError instead of guidance). + if e.response.status_code == 402: + raise RuntimeError(_402_HINT) from e + raise RuntimeError(f"TEE LLM chat failed: {e}") from e except Exception as e: raise RuntimeError(f"TEE LLM chat failed: {e}") from e @@ -410,6 +467,33 @@ async def _chat_stream(self, params: _ChatParams, messages: List[Dict]) -> Async headers = self._headers(params.x402_settlement_mode) payload = self._chat_payload(params, messages, stream=True) + chunks_yielded = False + try: + async with self._http_client.stream( + "POST", + self._tee_endpoint + _CHAT_ENDPOINT, + json=payload, + headers=headers, + timeout=_REQUEST_TIMEOUT, + ) as response: + async for chunk in self._parse_sse_response(response): + chunks_yielded = True + yield chunk + return + except httpx.HTTPStatusError: + raise + except Exception as exc: + if chunks_yielded: + raise + logger.warning( + "Connection failure during stream setup; refreshing TEE and retrying once: %s", + exc, + ) + + # Only reached if the first attempt failed before yielding any chunks. + # Re-resolve the TEE endpoint from the registry and retry once. + await self._refresh_tee() + headers = self._headers(params.x402_settlement_mode) async with self._http_client.stream( "POST", self._tee_endpoint + _CHAT_ENDPOINT, @@ -423,6 +507,8 @@ async def _chat_stream(self, params: _ChatParams, messages: List[Dict]) -> Async async def _parse_sse_response(self, response) -> AsyncGenerator[StreamChunk, None]: """Parse an SSE response stream into StreamChunk objects.""" status_code = getattr(response, "status_code", None) + if status_code is not None and status_code == 402: + raise RuntimeError(_402_HINT) if status_code is not None and status_code >= 400: body = await response.aread() raise RuntimeError(f"TEE LLM streaming request failed with status {status_code}: {body.decode('utf-8', errors='replace')}") @@ -452,8 +538,8 @@ async def _parse_sse_response(self, response) -> AsyncGenerator[StreamChunk, Non return try: - data = json.loads(data_str) - except json.JSONDecodeError: + data = _json.loads(data_str) + except _json.JSONDecodeError: continue chunk = StreamChunk.from_sse_data(data) diff --git a/src/opengradient/client/opg_token.py b/src/opengradient/client/opg_token.py index d86d9de9..af80783a 100644 --- a/src/opengradient/client/opg_token.py +++ b/src/opengradient/client/opg_token.py @@ -82,8 +82,8 @@ def ensure_opg_approval(wallet_account: LocalAccount, opg_amount: float) -> Perm allowance_before = token.functions.allowance(owner, spender).call() - # Only approve if the allowance is less than 10% of the requested amount - if allowance_before >= amount_base * 0.1: + # Only approve if the allowance is less than 50% of the requested amount + if allowance_before >= amount_base * 0.5: return Permit2ApprovalResult( allowance_before=allowance_before, allowance_after=allowance_before, @@ -124,7 +124,6 @@ def ensure_opg_approval(wallet_account: LocalAccount, opg_amount: float) -> Perm ) time.sleep(ALLOWANCE_POLL_INTERVAL) - return Permit2ApprovalResult( allowance_before=allowance_before, allowance_after=allowance_after, diff --git a/src/opengradient/client/tee_registry.py b/src/opengradient/client/tee_registry.py index 9ad3cfd7..b791a383 100644 --- a/src/opengradient/client/tee_registry.py +++ b/src/opengradient/client/tee_registry.py @@ -1,6 +1,7 @@ """TEE Registry client for fetching verified TEE endpoints and TLS certificates.""" import logging +import random import ssl from dataclasses import dataclass from typing import List, NamedTuple, Optional @@ -109,17 +110,32 @@ def get_active_tees_by_type(self, tee_type: int) -> List[TEEEndpoint]: def get_llm_tee(self) -> Optional[TEEEndpoint]: """ - Return the first active LLM proxy TEE from the registry. + Return a randomly selected active LLM proxy TEE from the registry. + + Randomizing the selection distributes load across all healthy TEEs and + avoids repeatedly routing to the same TEE when it starts failing + (addresses issue #200 — improve TEE selection/retry logic). Returns: - TEEEndpoint for an active LLM proxy TEE, or None if none are available. + TEEEndpoint for a randomly chosen active LLM proxy TEE, or None if + none are available. """ tees = self.get_active_tees_by_type(TEE_TYPE_LLM_PROXY) if not tees: logger.warning("No active LLM proxy TEEs found in registry") return None - return tees[0] + # Randomly select from all active TEEs to distribute load and improve + # resilience — if one TEE is failing, successive LLM() constructions + # will eventually land on a healthy one. + selected = random.choice(tees) + logger.debug( + "Selected TEE %s (endpoint=%s) from %d active LLM proxy TEE(s)", + selected.tee_id, + selected.endpoint, + len(tees), + ) + return selected def build_ssl_context_from_der(der_cert: bytes) -> ssl.SSLContext: diff --git a/stresstest/infer.py b/stresstest/infer.py deleted file mode 100644 index 6a72cefa..00000000 --- a/stresstest/infer.py +++ /dev/null @@ -1,58 +0,0 @@ -import argparse -import statistics - -from utils import stress_test_wrapper - -import opengradient as og - -# Number of requests to run serially -NUM_REQUESTS = 10_000 -MODEL = "QmbUqS93oc4JTLMHwpVxsE39mhNxy6hpf6Py3r9oANr8aZ" - - -def main(private_key: str): - alpha = og.Alpha(private_key=private_key) - - def run_inference(input_data: dict): - alpha.infer(MODEL, og.InferenceMode.VANILLA, input_data) - - latencies, failures = stress_test_wrapper(run_inference, num_requests=NUM_REQUESTS) - - # Calculate and print statistics - total_requests = NUM_REQUESTS - success_rate = (len(latencies) / total_requests) * 100 if total_requests > 0 else 0 - - if latencies: - avg_latency = statistics.mean(latencies) - median_latency = statistics.median(latencies) - min_latency = min(latencies) - max_latency = max(latencies) - p95_latency = statistics.quantiles(latencies, n=20)[18] # 95th percentile - else: - avg_latency = median_latency = min_latency = max_latency = p95_latency = 0 - - print("\nOpenGradient Inference Stress Test Results:") - print(f"Using model '{MODEL}'") - print("=" * 20 + "\n") - print(f"Total Requests: {total_requests}") - print(f"Successful Requests: {len(latencies)}") - print(f"Failed Requests: {failures}") - print(f"Success Rate: {success_rate:.2f}%\n") - print(f"Average Latency: {avg_latency:.4f} seconds") - print(f"Median Latency: {median_latency:.4f} seconds") - print(f"Min Latency: {min_latency:.4f} seconds") - print(f"Max Latency: {max_latency:.4f} seconds") - print(f"95th Percentile Latency: {p95_latency:.4f} seconds") - - if failures > 0: - print("\n🛑 WARNING: TEST FAILED") - else: - print("\n✅ NO FAILURES") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Run inference stress test") - parser.add_argument("private_key", help="Private key for inference") - args = parser.parse_args() - - main(args.private_key) diff --git a/stresstest/llm.py b/stresstest/llm.py index 7c2811db..f9652c19 100644 --- a/stresstest/llm.py +++ b/stresstest/llm.py @@ -1,4 +1,5 @@ import argparse +import asyncio import statistics from utils import stress_test_wrapper @@ -7,16 +8,18 @@ # Number of requests to run serially NUM_REQUESTS = 100 -MODEL = "anthropic/claude-haiku-4-5" +MODEL = og.TEE_LLM.GEMINI_2_5_FLASH -def main(private_key: str): +async def main(private_key: str): llm = og.LLM(private_key=private_key) + llm.ensure_opg_approval(opg_amount=0.1) - def run_prompt(prompt: str): - llm.completion(MODEL, prompt, max_tokens=50) + async def run_prompt(prompt: str): + messages = [{"role": "user", "content": prompt}] + await llm.chat(MODEL, messages=messages, max_tokens=50, x402_settlement_mode=og.x402SettlementMode.INDIVIDUAL_FULL) - latencies, failures = stress_test_wrapper(run_prompt, num_requests=NUM_REQUESTS, is_llm=True) + latencies, failures = await stress_test_wrapper(run_prompt, num_requests=NUM_REQUESTS) # Calculate and print statistics total_requests = NUM_REQUESTS @@ -55,4 +58,4 @@ def run_prompt(prompt: str): parser.add_argument("private_key", help="Private key for inference") args = parser.parse_args() - main(args.private_key) + asyncio.run(main(args.private_key)) diff --git a/stresstest/utils.py b/stresstest/utils.py index b7e22512..8ac7b064 100644 --- a/stresstest/utils.py +++ b/stresstest/utils.py @@ -1,17 +1,7 @@ import random import time import uuid -from typing import Callable, List, Tuple - - -def generate_unique_input(request_id: int) -> dict: - """Generate a unique input for testing.""" - num_input1 = [random.uniform(0, 10) for _ in range(3)] - num_input2 = random.randint(1, 20) - str_input1 = [random.choice(["hello", "world", "ONNX", "test"]) for _ in range(2)] - str_input2 = f"Request {request_id}: {str(uuid.uuid4())[:8]}" - - return {"num_input1": num_input1, "num_input2": num_input2, "str_input1": str_input1, "str_input2": str_input2} +from typing import Callable, Coroutine, List, Tuple def generate_unique_prompt(request_id: int) -> str: @@ -26,14 +16,13 @@ def generate_unique_prompt(request_id: int) -> str: return f"Request {request_id}: Tell me a {adjective} fact about {topic}. Keep it short. Unique ID: {unique_id}" -def stress_test_wrapper(infer_function: Callable, num_requests: int, is_llm: bool = False) -> Tuple[List[float], int]: +async def stress_test_wrapper(infer_function: Callable[..., Coroutine], num_requests: int) -> Tuple[List[float], int]: """ - Wrapper function to stress test the inference. + Async wrapper function to stress test the inference. Args: - infer_function (Callable): The inference function to test. + infer_function (Callable): An async inference function to test. num_requests (int): Number of requests to send. - is_llm (bool): Whether the test is for an LLM model. Default is False. Returns: Tuple[List[float], int]: List of latencies for each request and the number of failures. @@ -42,15 +31,12 @@ def stress_test_wrapper(infer_function: Callable, num_requests: int, is_llm: boo failures = 0 for i in range(num_requests): - if is_llm: - input_data = generate_unique_prompt(i) - else: - input_data = generate_unique_input(i) + input_data = generate_unique_prompt(i) start_time = time.time() try: - _ = infer_function(input_data) + _ = await infer_function(input_data) end_time = time.time() latency = end_time - start_time latencies.append(latency) diff --git a/tests/llm_test.py b/tests/llm_test.py index bb845a75..90c46621 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -5,9 +5,10 @@ """ import json +import ssl from contextlib import asynccontextmanager from typing import List -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest @@ -31,6 +32,8 @@ def __init__(self, *_args, **_kwargs): self._response_body: bytes = b"{}" self._post_calls: List[dict] = [] self._stream_response = None + self._error_on_next: BaseException | None = None + self._stream_error_on_next: BaseException | None = None def set_response(self, status_code: int, body: dict) -> None: self._response_status = status_code @@ -43,16 +46,34 @@ def set_stream_response(self, status_code: int, chunks: List[bytes]) -> None: def post_calls(self) -> List[dict]: return self._post_calls + def fail_next_post(self, exc: BaseException) -> None: + """Make the next post() call raise *exc*, then revert to normal.""" + self._error_on_next = exc + + def fail_next_stream(self, exc: BaseException) -> None: + """Make the next stream() call raise *exc*, then revert to normal.""" + self._stream_error_on_next = exc + async def post(self, url: str, *, json=None, headers=None, timeout=None) -> "_FakeResponse": self._post_calls.append({"url": url, "json": json, "headers": headers, "timeout": timeout}) + if self._error_on_next is not None: + exc, self._error_on_next = self._error_on_next, None + raise exc resp = _FakeResponse(self._response_status, self._response_body) if self._response_status >= 400: - resp.raise_for_status = MagicMock(side_effect=httpx.HTTPStatusError("error", request=MagicMock(), response=MagicMock())) + mock_response = MagicMock() + mock_response.status_code = self._response_status + resp.raise_for_status = MagicMock( + side_effect=httpx.HTTPStatusError("error", request=MagicMock(), response=mock_response) + ) return resp @asynccontextmanager async def stream(self, method: str, url: str, *, json=None, headers=None, timeout=None): self._post_calls.append({"method": method, "url": url, "json": json, "headers": headers, "timeout": timeout}) + if self._stream_error_on_next is not None: + exc, self._stream_error_on_next = self._stream_error_on_next, None + raise exc yield self._stream_response async def aclose(self): @@ -209,6 +230,14 @@ async def test_http_error_raises_opengradient_error(self, fake_http): with pytest.raises(RuntimeError, match="TEE LLM completion failed"): await llm.completion(model=TEE_LLM.GPT_5, prompt="Hi") + async def test_http_402_raises_hint(self, fake_http): + """HTTP 402 from the TEE must surface the actionable _402_HINT message.""" + fake_http.set_response(402, {}) + llm = _make_llm() + + with pytest.raises(RuntimeError, match=r"Payment required \(HTTP 402\)"): + await llm.completion(model=TEE_LLM.GPT_5, prompt="Hi") + # ── Chat (non-streaming) tests ─────────────────────────────────────── @@ -361,6 +390,14 @@ async def test_http_error_raises_opengradient_error(self, fake_http): with pytest.raises(RuntimeError, match="TEE LLM chat failed"): await llm.chat(model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hi"}]) + async def test_http_402_raises_hint(self, fake_http): + """HTTP 402 from the TEE must surface the actionable _402_HINT message.""" + fake_http.set_response(402, {}) + llm = _make_llm() + + with pytest.raises(RuntimeError, match=r"Payment required \(HTTP 402\)"): + await llm.chat(model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hi"}]) + # ── Streaming tests ────────────────────────────────────────────────── @@ -440,6 +477,20 @@ async def test_stream_error_raises(self, fake_http): with pytest.raises(RuntimeError, match="streaming request failed"): _ = [chunk async for chunk in gen] + async def test_stream_402_raises_hint(self, fake_http): + """HTTP 402 during streaming must surface the actionable _402_HINT message.""" + fake_http.set_stream_response(402, [b"Payment Required"]) + llm = _make_llm() + + gen = await llm.chat( + model=TEE_LLM.GPT_5, + messages=[{"role": "user", "content": "Hi"}], + stream=True, + ) + + with pytest.raises(RuntimeError, match=r"Payment required \(HTTP 402\)"): + _ = [chunk async for chunk in gen] + async def test_tools_with_stream_falls_back_to_single_chunk(self, fake_http): """When tools + stream=True, LLM falls back to non-streaming and yields one chunk.""" tools = [{"type": "function", "function": {"name": "f"}}] @@ -535,3 +586,237 @@ def test_registry_success(self): assert cert == b"cert-bytes" assert tee_id == "tee-42" assert pay_addr == "0xPay" + + +# ── TEE retry tests (non-streaming) ────────────────────────────────── + + +@pytest.mark.asyncio +class TestTeeRetryCompletion: + async def test_retries_on_connection_error_and_succeeds(self, fake_http): + """First call hits connection error → refresh TEE → second call succeeds.""" + fake_http.set_response(200, {"completion": "retried ok", "tee_signature": "s", "tee_timestamp": "t"}) + fake_http.fail_next_post(ConnectionError("connection refused")) + llm = _make_llm() + + result = await llm.completion(model=TEE_LLM.GPT_5, prompt="Hi") + + assert result.completion_output == "retried ok" + assert len(fake_http.post_calls) == 2 + + async def test_http_status_error_not_retried(self, fake_http): + """A server-side error (HTTP 500) should not trigger a TEE retry.""" + fake_http.set_response(500, {"error": "boom"}) + llm = _make_llm() + + with pytest.raises(RuntimeError, match="TEE LLM completion failed"): + await llm.completion(model=TEE_LLM.GPT_5, prompt="Hi") + assert len(fake_http.post_calls) == 1 + + async def test_second_failure_propagates(self, fake_http): + """If the retry also fails, the error should propagate.""" + call_count = 0 + + async def always_fail(*args, **kwargs): + nonlocal call_count + call_count += 1 + raise ConnectionError("still broken") + + fake_http.post = always_fail + llm = _make_llm() + + with pytest.raises(RuntimeError, match="TEE LLM completion failed"): + await llm.completion(model=TEE_LLM.GPT_5, prompt="Hi") + assert call_count == 2 + + +@pytest.mark.asyncio +class TestTeeRetryChat: + async def test_retries_on_connection_error_and_succeeds(self, fake_http): + fake_http.set_response( + 200, + {"choices": [{"message": {"role": "assistant", "content": "retry ok"}, "finish_reason": "stop"}]}, + ) + fake_http.fail_next_post(OSError("network unreachable")) + llm = _make_llm() + + result = await llm.chat(model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hi"}]) + + assert result.chat_output["content"] == "retry ok" + assert len(fake_http.post_calls) == 2 + + async def test_http_status_error_not_retried(self, fake_http): + fake_http.set_response(500, {"error": "internal"}) + llm = _make_llm() + + with pytest.raises(RuntimeError, match="TEE LLM chat failed"): + await llm.chat(model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hi"}]) + assert len(fake_http.post_calls) == 1 + + +# ── TEE retry tests (streaming) ────────────────────────────────────── + + +@pytest.mark.asyncio +class TestTeeRetryStreaming: + async def test_retries_stream_on_connection_error_before_chunks(self, fake_http): + """Connection failure during stream setup (no chunks yielded) → retry succeeds.""" + fake_http.set_stream_response( + 200, + [ + b'data: {"model":"m","choices":[{"index":0,"delta":{"content":"ok"},"finish_reason":"stop"}]}\n\n', + b"data: [DONE]\n\n", + ], + ) + fake_http.fail_next_stream(ConnectionError("reset by peer")) + llm = _make_llm() + + gen = await llm.chat( + model=TEE_LLM.GPT_5, + messages=[{"role": "user", "content": "Hi"}], + stream=True, + ) + chunks = [c async for c in gen] + + assert len(chunks) == 1 + assert chunks[0].choices[0].delta.content == "ok" + assert len(fake_http.post_calls) == 2 + + async def test_no_retry_after_chunks_yielded(self, fake_http): + """Failure AFTER chunks were yielded must raise, not retry.""" + + class _FailMidStream: + def __init__(self): + self.status_code = 200 + + async def aiter_raw(self): + yield b'data: {"model":"m","choices":[{"index":0,"delta":{"content":"partial"},"finish_reason":null}]}\n\n' + raise ConnectionError("mid-stream disconnect") + + async def aread(self) -> bytes: + return b"" + + fake_http._stream_response = _FailMidStream() + llm = _make_llm() + + gen = await llm.chat( + model=TEE_LLM.GPT_5, + messages=[{"role": "user", "content": "Hi"}], + stream=True, + ) + + with pytest.raises(ConnectionError): + _ = [c async for c in gen] + + assert len(fake_http.post_calls) == 1 + + +# ── _refresh_tee tests ───────────────────────────────────── + + +@pytest.mark.asyncio +class TestRefreshTeeAndReset: + async def test_replaces_http_client(self): + """After refresh, the http client should be a new instance.""" + clients_created = [] + + def make_client(*args, **kwargs): + c = FakeHTTPClient() + clients_created.append(c) + return c + + with ( + patch(_PATCHES["x402_httpx"], side_effect=make_client), + patch(_PATCHES["x402_client"]), + patch(_PATCHES["signer"]), + patch(_PATCHES["register_exact"]), + patch(_PATCHES["register_upto"]), + ): + llm = _make_llm() + old_client = llm._http_client + + await llm._refresh_tee() + + assert llm._http_client is not old_client + assert len(clients_created) == 2 # init + refresh + + async def test_closes_old_client(self, fake_http): + llm = _make_llm() + old_client = llm._http_client + old_client.aclose = AsyncMock() + + await llm._refresh_tee() + + old_client.aclose.assert_awaited_once() + + async def test_close_failure_is_swallowed(self, fake_http): + llm = _make_llm() + old_client = llm._http_client + old_client.aclose = AsyncMock(side_effect=OSError("already closed")) + + # Should not raise + await llm._refresh_tee() + + +# ── TEE cert rotation (crash + re-register) tests ──────────────────── + + +@pytest.mark.asyncio +class TestTeeCertRotation: + """Simulate a TEE crashing and a new one registering at the same IP + with a different ephemeral TLS certificate. The old cert is now + invalid, so the first request fails with SSLCertVerificationError. + The retry should re-resolve from the registry (getting the new cert) + and succeed.""" + + async def test_ssl_verification_failure_triggers_tee_refresh_completion(self, fake_http): + fake_http.set_response(200, {"completion": "ok after refresh", "tee_signature": "s", "tee_timestamp": "t"}) + fake_http.fail_next_post(ssl.SSLCertVerificationError("certificate verify failed")) + llm = _make_llm() + + with patch.object(llm, "_connect_tee", wraps=llm._connect_tee) as spy: + result = await llm.completion(model=TEE_LLM.GPT_5, prompt="Hi") + + # _connect_tee was called once during the retry (refresh) + spy.assert_called_once() + assert result.completion_output == "ok after refresh" + assert len(fake_http.post_calls) == 2 + + async def test_ssl_verification_failure_triggers_tee_refresh_chat(self, fake_http): + fake_http.set_response( + 200, + {"choices": [{"message": {"role": "assistant", "content": "ok after refresh"}, "finish_reason": "stop"}]}, + ) + fake_http.fail_next_post(ssl.SSLCertVerificationError("certificate verify failed")) + llm = _make_llm() + + with patch.object(llm, "_connect_tee", wraps=llm._connect_tee) as spy: + result = await llm.chat(model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hi"}]) + + spy.assert_called_once() + assert result.chat_output["content"] == "ok after refresh" + assert len(fake_http.post_calls) == 2 + + async def test_ssl_verification_failure_triggers_tee_refresh_streaming(self, fake_http): + fake_http.set_stream_response( + 200, + [ + b'data: {"model":"m","choices":[{"index":0,"delta":{"content":"ok"},"finish_reason":"stop"}]}\n\n', + b"data: [DONE]\n\n", + ], + ) + fake_http.fail_next_stream(ssl.SSLCertVerificationError("certificate verify failed")) + llm = _make_llm() + + with patch.object(llm, "_connect_tee", wraps=llm._connect_tee) as spy: + gen = await llm.chat( + model=TEE_LLM.GPT_5, + messages=[{"role": "user", "content": "Hi"}], + stream=True, + ) + chunks = [c async for c in gen] + + spy.assert_called_once() + assert len(chunks) == 1 + assert chunks[0].choices[0].delta.content == "ok" + assert len(fake_http.post_calls) == 2 diff --git a/tests/opg_token_test.py b/tests/opg_token_test.py index f2b1255f..0e44a072 100644 --- a/tests/opg_token_test.py +++ b/tests/opg_token_test.py @@ -77,6 +77,7 @@ def test_zero_amount_with_zero_allowance_skips(self, mock_wallet, mock_web3): assert result.tx_hash is None + class TestEnsureOpgApprovalSendsTx: """Cases where allowance is insufficient and a transaction is sent.""" @@ -181,6 +182,7 @@ def test_waits_for_allowance_update_after_receipt(self, mock_wallet, mock_web3, assert result.allowance_after == amount_base assert result.tx_hash == "0xconfirmed" + class TestEnsureOpgApprovalErrors: """Error handling paths.""" diff --git a/tests/tee_registry_test.py b/tests/tee_registry_test.py index 189ad26c..a577d43b 100644 --- a/tests/tee_registry_test.py +++ b/tests/tee_registry_test.py @@ -1,6 +1,7 @@ import os import ssl import sys +import random from unittest.mock import MagicMock, patch import pytest @@ -148,19 +149,36 @@ def test_validator_type_label(self, mock_contract): class TestGetLlmTee: - def test_returns_first_active_tee(self, mock_contract): + def test_returns_active_tee_from_pool(self, mock_contract): + """get_llm_tee uses random.choice; patch it to make the test deterministic.""" registry, contract = mock_contract - contract.functions.getActiveTEEs.return_value.call.return_value = [ + tee_infos = [ _make_tee_info(endpoint="https://tee-1.example.com"), _make_tee_info(endpoint="https://tee-2.example.com"), ] + contract.functions.getActiveTEEs.return_value.call.return_value = tee_infos - result = registry.get_llm_tee() + with patch("src.opengradient.client.tee_registry.random.choice", side_effect=lambda seq: seq[0]): + result = registry.get_llm_tee() assert result is not None assert result.endpoint == "https://tee-1.example.com" + def test_returns_any_active_tee(self, mock_contract): + """Without patching random.choice, result must still be one of the active TEEs.""" + registry, contract = mock_contract + + endpoints = {"https://tee-1.example.com", "https://tee-2.example.com"} + contract.functions.getActiveTEEs.return_value.call.return_value = [ + _make_tee_info(endpoint=ep) for ep in sorted(endpoints) + ] + + result = registry.get_llm_tee() + + assert result is not None + assert result.endpoint in endpoints + def test_returns_none_when_no_tees(self, mock_contract): registry, contract = mock_contract diff --git a/uv.lock b/uv.lock index 94346b2c..6d0327a6 100644 --- a/uv.lock +++ b/uv.lock @@ -1835,15 +1835,15 @@ wheels = [ [[package]] name = "og-x402" -version = "0.0.1.dev2" +version = "0.0.1.dev4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pydantic" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1e/75/40c43cd44aa394e68acc98f8d5b8376f3a5e3b9eddf55b1c0c34616c340b/og_x402-0.0.1.dev2.tar.gz", hash = "sha256:bf5d4484ece5a371358a336fcc79fe5678be611044c55ade45c4be9d19d7691b", size = 899662, upload-time = "2026-03-17T06:35:36.587Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b1/5b/46a55d93d9da5535ff2bb28d48d5766c9108d9e16546cb9c7a65cde0fb11/og_x402-0.0.1.dev4.tar.gz", hash = "sha256:2d8a71b2f4222284e65d45e2d122faafe3bdb33c4fae77903f9665d29e517a97", size = 900109, upload-time = "2026-03-23T15:10:37.144Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/4c/79/8c7543c2e647508e04ad0983e9a3a7b861f388ec591ccdc42c69a3128d42/og_x402-0.0.1.dev2-py3-none-any.whl", hash = "sha256:65e7d3bbb3c7f51e51dad974f6c405a230693816f72d874cf0d6d705a8eec271", size = 952331, upload-time = "2026-03-17T06:35:34.695Z" }, + { url = "https://files.pythonhosted.org/packages/45/da/5e0be4b8415a6c557a94991367c6124998df3ba014bceb76b595ef48c8c7/og_x402-0.0.1.dev4-py3-none-any.whl", hash = "sha256:c329ceb4fe7cc4195fa5bf9c769f5c571b61c8333b33fd0fe204a2ab377d8366", size = 952662, upload-time = "2026-03-23T15:10:35.21Z" }, ] [[package]] @@ -1867,7 +1867,7 @@ wheels = [ [[package]] name = "opengradient" -version = "0.9.0" +version = "0.9.3" source = { editable = "." } dependencies = [ { name = "click" }, @@ -1907,7 +1907,7 @@ requires-dist = [ { name = "langgraph", marker = "extra == 'dev'" }, { name = "mypy", marker = "extra == 'dev'" }, { name = "numpy", specifier = ">=1.26.4" }, - { name = "og-x402", specifier = "==0.0.1.dev2" }, + { name = "og-x402", specifier = "==0.0.1.dev4" }, { name = "openai", specifier = ">=1.58.1" }, { name = "pdoc3", marker = "extra == 'dev'", specifier = "==0.10.0" }, { name = "pydantic", specifier = ">=2.9.2" },