diff --git a/src/opengradient/client/alpha.py b/src/opengradient/client/alpha.py index d2957c2..afb4e59 100644 --- a/src/opengradient/client/alpha.py +++ b/src/opengradient/client/alpha.py @@ -1,465 +1,466 @@ -""" -Alpha Testnet features for OpenGradient SDK. - -This module contains features that are only available on the Alpha Testnet, -including on-chain ONNX model inference, workflow management, and ML model execution. -""" - -import base64 -import json -import urllib.parse -from typing import Dict, List, Optional, Union - -import numpy as np -import requests -from eth_account.account import LocalAccount -from web3 import Web3 -from web3.exceptions import ContractLogicError -from web3.logs import DISCARD - -from ..types import HistoricalInputQuery, InferenceMode, InferenceResult, ModelOutput, SchedulerParams -from ._conversions import convert_array_to_model_output, convert_to_model_input, convert_to_model_output # type: ignore[attr-defined] -from ._utils import get_abi, get_bin, run_with_retry - -DEFAULT_RPC_URL = "https://ogevmdevnet.opengradient.ai" -DEFAULT_API_URL = "https://sdk-devnet.opengradient.ai" -DEFAULT_INFERENCE_CONTRACT_ADDRESS = "0x8383C9bD7462F12Eb996DD02F78234C0421A6FaE" -DEFAULT_SCHEDULER_ADDRESS = "0x7179724De4e7FF9271FA40C0337c7f90C0508eF6" - -# How much time we wait for txn to be included in chain -INFERENCE_TX_TIMEOUT = 120 -REGULAR_TX_TIMEOUT = 30 -HTTP_REQUEST_TIMEOUT = 30 # seconds - -PRECOMPILE_CONTRACT_ADDRESS = "0x00000000000000000000000000000000000000F4" - - -class Alpha: - """ - Alpha Testnet features namespace. - - This class provides access to features that are only available on the Alpha Testnet, - including on-chain ONNX model inference, workflow deployment, and execution. - - Usage: - alpha = og.Alpha(private_key="0x...") - result = alpha.infer(model_cid, InferenceMode.VANILLA, model_input) - result = alpha.new_workflow(model_cid, input_query, input_tensor_name) - """ - - def __init__( - self, - private_key: str, - rpc_url: str = DEFAULT_RPC_URL, - inference_contract_address: str = DEFAULT_INFERENCE_CONTRACT_ADDRESS, - api_url: str = DEFAULT_API_URL, - ): - self._blockchain = Web3(Web3.HTTPProvider(rpc_url)) - self._wallet_account: LocalAccount = self._blockchain.eth.account.from_key(private_key) - self._inference_hub_contract_address = inference_contract_address - self._api_url = api_url - self._inference_abi: Optional[dict] = None - self._precompile_abi: Optional[dict] = None - - @property - def inference_abi(self) -> dict: - if self._inference_abi is None: - self._inference_abi = get_abi("inference.abi") - return self._inference_abi - - @property - def precompile_abi(self) -> dict: - if self._precompile_abi is None: - self._precompile_abi = get_abi("InferencePrecompile.abi") - return self._precompile_abi - - def infer( - self, - model_cid: str, - inference_mode: InferenceMode, - model_input: Dict[str, Union[str, int, float, List, np.ndarray]], - max_retries: Optional[int] = None, - ) -> InferenceResult: - """ - Perform inference on a model. - - Args: - model_cid (str): The unique content identifier for the model from IPFS. - inference_mode (InferenceMode): The inference mode. - model_input (Dict[str, Union[str, int, float, List, np.ndarray]]): The input data for the model. - max_retries (int, optional): Maximum number of retry attempts. Defaults to 5. - - Returns: - InferenceResult (InferenceResult): A dataclass object containing the transaction hash and model output. - transaction_hash (str): Blockchain hash for the transaction - model_output (Dict[str, np.ndarray]): Output of the ONNX model - - Raises: - RuntimeError: If the inference fails. - """ - - def execute_transaction(): - contract = self._blockchain.eth.contract( - address=Web3.to_checksum_address(self._inference_hub_contract_address), abi=self.inference_abi - ) - precompile_contract = self._blockchain.eth.contract( - address=Web3.to_checksum_address(PRECOMPILE_CONTRACT_ADDRESS), abi=self.precompile_abi - ) - - inference_mode_uint8 = inference_mode.value - converted_model_input = convert_to_model_input(model_input) - - run_function = contract.functions.run(model_cid, inference_mode_uint8, converted_model_input) - - tx_hash, tx_receipt = self._send_tx_with_revert_handling(run_function) - parsed_logs = contract.events.InferenceResult().process_receipt(tx_receipt, errors=DISCARD) - if len(parsed_logs) < 1: - raise RuntimeError("InferenceResult event not found in transaction logs") - - # TODO: This should return a ModelOutput class object - model_output = convert_to_model_output(parsed_logs[0]["args"]) - if len(model_output) == 0: - # check inference directly from node - parsed_logs = precompile_contract.events.ModelInferenceEvent().process_receipt(tx_receipt, errors=DISCARD) - inference_id = parsed_logs[0]["args"]["inferenceID"] - inference_result = self._get_inference_result_from_node(inference_id, inference_mode) - model_output = convert_to_model_output(inference_result) - - return InferenceResult(tx_hash.hex(), model_output) - - result: InferenceResult = run_with_retry(execute_transaction, max_retries) - return result - - def _send_tx_with_revert_handling(self, run_function): - """ - Execute a blockchain transaction with revert error. - - Args: - run_function: Function that executes the transaction - - Returns: - tx_hash: Transaction hash - tx_receipt: Transaction receipt - - Raises: - Exception: If transaction fails or gas estimation fails - """ - nonce = self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending") - try: - estimated_gas = run_function.estimate_gas({"from": self._wallet_account.address}) - except ContractLogicError as e: - try: - run_function.call({"from": self._wallet_account.address}) - - except ContractLogicError as call_err: - raise ContractLogicError(f"simulation failed with revert reason: {call_err.args[0]}") - - raise ContractLogicError(f"simulation failed with no revert reason. Reason: {e}") - - gas_limit = int(estimated_gas * 3) - - transaction = run_function.build_transaction( - { - "from": self._wallet_account.address, - "nonce": nonce, - "gas": gas_limit, - "gasPrice": self._blockchain.eth.gas_price, - } - ) - - signed_tx = self._wallet_account.sign_transaction(transaction) # type: ignore[arg-type] - tx_hash = self._blockchain.eth.send_raw_transaction(signed_tx.raw_transaction) - tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash, timeout=INFERENCE_TX_TIMEOUT) - - if tx_receipt["status"] == 0: - try: - run_function.call({"from": self._wallet_account.address}) - - except ContractLogicError as call_err: - raise ContractLogicError(f"Transaction failed with revert reason: {call_err.args[0]}") - - raise ContractLogicError(f"Transaction failed with no revert reason. Receipt: {tx_receipt}") - - return tx_hash, tx_receipt - - def _get_inference_result_from_node(self, inference_id: str, inference_mode: InferenceMode) -> Optional[Dict]: - """ - Get the inference result from node. - - Args: - inference_id (str): Inference id for a inference request - - Returns: - Dict: The inference result as returned by the node - - Raises: - RuntimeError: If the request fails or returns an error - """ - try: - encoded_id = urllib.parse.quote(inference_id, safe="") - url = f"{self._api_url}/artela-network/artela-rollkit/inference/tx/{encoded_id}" - - response = requests.get(url, timeout=HTTP_REQUEST_TIMEOUT) - if response.status_code == 200: - resp = response.json() - inference_result = resp.get("inference_results", {}) - if inference_result: - decoded_bytes = base64.b64decode(inference_result[0]) - decoded_string = decoded_bytes.decode("utf-8") - output = json.loads(decoded_string).get("InferenceResult", {}) - if output is None: - raise RuntimeError("Missing InferenceResult in inference output") - - match inference_mode: - case InferenceMode.VANILLA: - if "VanillaResult" not in output: - raise RuntimeError("Missing VanillaResult in inference output") - if "model_output" not in output["VanillaResult"]: - raise RuntimeError("Missing model_output in VanillaResult") - return {"output": output["VanillaResult"]["model_output"]} - - case InferenceMode.TEE: - if "TeeNodeResult" not in output: - raise RuntimeError("Missing TeeNodeResult in inference output") - if "Response" not in output["TeeNodeResult"]: - raise RuntimeError("Missing Response in TeeNodeResult") - if "VanillaResponse" in output["TeeNodeResult"]["Response"]: - if "model_output" not in output["TeeNodeResult"]["Response"]["VanillaResponse"]: - raise RuntimeError("Missing model_output in VanillaResponse") - return {"output": output["TeeNodeResult"]["Response"]["VanillaResponse"]["model_output"]} - - else: - raise RuntimeError("Missing VanillaResponse in TeeNodeResult Response") - - case InferenceMode.ZKML: - if "ZkmlResult" not in output: - raise RuntimeError("Missing ZkmlResult in inference output") - if "model_output" not in output["ZkmlResult"]: - raise RuntimeError("Missing model_output in ZkmlResult") - return {"output": output["ZkmlResult"]["model_output"]} - - case _: - raise ValueError(f"Invalid inference mode: {inference_mode}") - else: - return None - - else: - raise RuntimeError(f"Failed to get inference result: HTTP {response.status_code}") - - except requests.RequestException as e: - raise RuntimeError(f"Failed to get inference result: {str(e)}") - except (RuntimeError, ValueError): - raise - except Exception as e: - raise RuntimeError(f"Failed to get inference result: {str(e)}") - - def new_workflow( - self, - model_cid: str, - input_query: HistoricalInputQuery, - input_tensor_name: str, - scheduler_params: Optional[SchedulerParams] = None, - ) -> str: - """ - Deploy a new workflow contract with the specified parameters. - - This function deploys a new workflow contract on OpenGradient that connects - an AI model with its required input data. When executed, the workflow will fetch - the specified model, evaluate the input query to get data, and perform inference. - - The workflow can be set to execute manually or automatically via a scheduler. - - Args: - model_cid (str): CID of the model to be executed from the Model Hub - input_query (HistoricalInputQuery): Input definition for the model inference, - will be evaluated at runtime for each inference - input_tensor_name (str): Name of the input tensor expected by the model - scheduler_params (Optional[SchedulerParams]): Scheduler configuration for automated execution: - - frequency: Execution frequency in seconds - - duration_hours: How long the schedule should live for - - Returns: - str: Deployed contract address. If scheduler_params was provided, the workflow - will be automatically executed according to the specified schedule. - - Raises: - Exception: If transaction fails or gas estimation fails - """ - # Get contract ABI and bytecode - abi = get_abi("PriceHistoryInference.abi") - bytecode = get_bin("PriceHistoryInference.bin") - - def deploy_transaction(): - contract = self._blockchain.eth.contract(abi=abi, bytecode=bytecode) - query_tuple = input_query.to_abi_format() - constructor_args = [model_cid, input_tensor_name, query_tuple] - - try: - # Estimate gas needed - estimated_gas = contract.constructor(*constructor_args).estimate_gas({"from": self._wallet_account.address}) - gas_limit = int(estimated_gas * 1.2) - except Exception as e: - print(f"Gas estimation failed: {str(e)}") - gas_limit = 5000000 # Conservative fallback - print(f"Using fallback gas limit: {gas_limit}") - - transaction = contract.constructor(*constructor_args).build_transaction( - { - "from": self._wallet_account.address, - "nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"), - "gas": gas_limit, - "gasPrice": self._blockchain.eth.gas_price, - "chainId": self._blockchain.eth.chain_id, - } - ) - - signed_txn = self._wallet_account.sign_transaction(transaction) - tx_hash = self._blockchain.eth.send_raw_transaction(signed_txn.raw_transaction) - - tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash, timeout=INFERENCE_TX_TIMEOUT) - - if tx_receipt["status"] == 0: - raise Exception(f"Contract deployment failed, transaction hash: {tx_hash.hex()}") - - return tx_receipt.contractAddress - - contract_address: str = run_with_retry(deploy_transaction) - - if scheduler_params: - self._register_with_scheduler(contract_address, scheduler_params) - - return contract_address - - def _register_with_scheduler(self, contract_address: str, scheduler_params: SchedulerParams) -> None: - """ - Register the deployed workflow contract with the scheduler for automated execution. - - Args: - contract_address (str): Address of the deployed workflow contract - scheduler_params (SchedulerParams): Scheduler configuration containing: - - frequency: Execution frequency in seconds - - duration_hours: How long to run in hours - - end_time: Unix timestamp when scheduling should end - - Raises: - Exception: If registration with scheduler fails. The workflow contract will - still be deployed and can be executed manually. - """ - scheduler_abi = get_abi("WorkflowScheduler.abi") - - # Scheduler contract address - scheduler_address = DEFAULT_SCHEDULER_ADDRESS - scheduler_contract = self._blockchain.eth.contract(address=Web3.to_checksum_address(scheduler_address), abi=scheduler_abi) - - try: - # Register the workflow with the scheduler - scheduler_tx = scheduler_contract.functions.registerTask( - contract_address, scheduler_params.end_time, scheduler_params.frequency - ).build_transaction( - { - "from": self._wallet_account.address, - "gas": 300000, - "gasPrice": self._blockchain.eth.gas_price, - "nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"), - "chainId": self._blockchain.eth.chain_id, - } - ) - - signed_scheduler_tx = self._wallet_account.sign_transaction(scheduler_tx) # type: ignore[arg-type] - scheduler_tx_hash = self._blockchain.eth.send_raw_transaction(signed_scheduler_tx.raw_transaction) - self._blockchain.eth.wait_for_transaction_receipt(scheduler_tx_hash, timeout=REGULAR_TX_TIMEOUT) - except Exception as e: - print(f"Error registering contract with scheduler: {str(e)}") - print(" The workflow contract is still deployed and can be executed manually.") - - def read_workflow_result(self, contract_address: str) -> ModelOutput: - """ - Reads the latest inference result from a deployed workflow contract. - - Args: - contract_address (str): Address of the deployed workflow contract - - Returns: - ModelOutput: The inference result from the contract - - Raises: - ContractLogicError: If the transaction fails - Web3Error: If there are issues with the web3 connection or contract interaction - """ - # Get the contract interface - contract = self._blockchain.eth.contract( - address=Web3.to_checksum_address(contract_address), abi=get_abi("PriceHistoryInference.abi") - ) - - # Get the result - result = contract.functions.getInferenceResult().call() - - output: ModelOutput = convert_array_to_model_output(result) - return output - - def run_workflow(self, contract_address: str) -> ModelOutput: - """ - Triggers the run() function on a deployed workflow contract and returns the result. - - Args: - contract_address (str): Address of the deployed workflow contract - - Returns: - ModelOutput: The inference result from the contract - - Raises: - ContractLogicError: If the transaction fails - Web3Error: If there are issues with the web3 connection or contract interaction - """ - # Get the contract interface - contract = self._blockchain.eth.contract( - address=Web3.to_checksum_address(contract_address), abi=get_abi("PriceHistoryInference.abi") - ) - - # Call run() function - nonce = self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending") - - run_function = contract.functions.run() - transaction = run_function.build_transaction( - { - "from": self._wallet_account.address, - "nonce": nonce, - "gas": 30000000, - "gasPrice": self._blockchain.eth.gas_price, - "chainId": self._blockchain.eth.chain_id, - } - ) - - signed_txn = self._wallet_account.sign_transaction(transaction) # type: ignore[arg-type] - tx_hash = self._blockchain.eth.send_raw_transaction(signed_txn.raw_transaction) - tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash, timeout=INFERENCE_TX_TIMEOUT) - - if tx_receipt["status"] == 0: - raise ContractLogicError(f"Run transaction failed. Receipt: {tx_receipt}") - - # Get the inference result from the contract - result = contract.functions.getInferenceResult().call() - - run_output: ModelOutput = convert_array_to_model_output(result) - return run_output - - def read_workflow_history(self, contract_address: str, num_results: int) -> List[ModelOutput]: - """ - Gets historical inference results from a workflow contract. - - Retrieves the specified number of most recent inference results from the contract's - storage, with the most recent result first. - - Args: - contract_address (str): Address of the deployed workflow contract - num_results (int): Number of historical results to retrieve - - Returns: - List[ModelOutput]: List of historical inference results - """ - contract = self._blockchain.eth.contract( - address=Web3.to_checksum_address(contract_address), abi=get_abi("PriceHistoryInference.abi") - ) - - results = contract.functions.getLastInferenceResults(num_results).call() - return [convert_array_to_model_output(result) for result in results] +""" +Alpha Testnet features for OpenGradient SDK. + +This module contains features that are only available on the Alpha Testnet, +including on-chain ONNX model inference, workflow management, and ML model execution. +""" + +import base64 +import json +import logging +import urllib.parse +from typing import Dict, List, Optional, Union + +import numpy as np +import requests +from eth_account.account import LocalAccount +from web3 import Web3 +from web3.exceptions import ContractLogicError +from web3.logs import DISCARD + +from ..types import HistoricalInputQuery, InferenceMode, InferenceResult, ModelOutput, SchedulerParams +from ._conversions import convert_array_to_model_output, convert_to_model_input, convert_to_model_output # type: ignore[attr-defined] +from ._utils import get_abi, get_bin, run_with_retry + +logger = logging.getLogger(__name__) + +DEFAULT_RPC_URL = "https://ogevmdevnet.opengradient.ai" +DEFAULT_API_URL = "https://sdk-devnet.opengradient.ai" +DEFAULT_INFERENCE_CONTRACT_ADDRESS = "0x8383C9bD7462F12Eb996DD02F78234C0421A6FaE" +DEFAULT_SCHEDULER_ADDRESS = "0x7179724De4e7FF9271FA40C0337c7f90C0508eF6" + +# How much time we wait for txn to be included in chain +INFERENCE_TX_TIMEOUT = 120 +REGULAR_TX_TIMEOUT = 30 + +PRECOMPILE_CONTRACT_ADDRESS = "0x00000000000000000000000000000000000000F4" + + +class Alpha: + """ + Alpha Testnet features namespace. + + This class provides access to features that are only available on the Alpha Testnet, + including on-chain ONNX model inference, workflow deployment, and execution. + + Usage: + alpha = og.Alpha(private_key="0x...") + result = alpha.infer(model_cid, InferenceMode.VANILLA, model_input) + result = alpha.new_workflow(model_cid, input_query, input_tensor_name) + """ + + def __init__( + self, + private_key: str, + rpc_url: str = DEFAULT_RPC_URL, + inference_contract_address: str = DEFAULT_INFERENCE_CONTRACT_ADDRESS, + api_url: str = DEFAULT_API_URL, + ): + self._blockchain = Web3(Web3.HTTPProvider(rpc_url)) + self._wallet_account: LocalAccount = self._blockchain.eth.account.from_key(private_key) + self._inference_hub_contract_address = inference_contract_address + self._api_url = api_url + self._inference_abi: Optional[dict] = None + self._precompile_abi: Optional[dict] = None + + @property + def inference_abi(self) -> dict: + if self._inference_abi is None: + self._inference_abi = get_abi("inference.abi") + return self._inference_abi + + @property + def precompile_abi(self) -> dict: + if self._precompile_abi is None: + self._precompile_abi = get_abi("InferencePrecompile.abi") + return self._precompile_abi + + def infer( + self, + model_cid: str, + inference_mode: InferenceMode, + model_input: Dict[str, Union[str, int, float, List, np.ndarray]], + max_retries: Optional[int] = None, + ) -> InferenceResult: + """ + Perform inference on a model. + + Args: + model_cid (str): The unique content identifier for the model from IPFS. + inference_mode (InferenceMode): The inference mode. + model_input (Dict[str, Union[str, int, float, List, np.ndarray]]): The input data for the model. + max_retries (int, optional): Maximum number of retry attempts. Defaults to 5. + + Returns: + InferenceResult (InferenceResult): A dataclass object containing the transaction hash and model output. + transaction_hash (str): Blockchain hash for the transaction + model_output (Dict[str, np.ndarray]): Output of the ONNX model + + Raises: + RuntimeError: If the inference fails. + """ + + def execute_transaction(): + contract = self._blockchain.eth.contract( + address=Web3.to_checksum_address(self._inference_hub_contract_address), abi=self.inference_abi + ) + precompile_contract = self._blockchain.eth.contract( + address=Web3.to_checksum_address(PRECOMPILE_CONTRACT_ADDRESS), abi=self.precompile_abi + ) + + inference_mode_uint8 = inference_mode.value + converted_model_input = convert_to_model_input(model_input) + + run_function = contract.functions.run(model_cid, inference_mode_uint8, converted_model_input) + + tx_hash, tx_receipt = self._send_tx_with_revert_handling(run_function) + parsed_logs = contract.events.InferenceResult().process_receipt(tx_receipt, errors=DISCARD) + if len(parsed_logs) < 1: + raise RuntimeError("InferenceResult event not found in transaction logs") + + # TODO: This should return a ModelOutput class object + model_output = convert_to_model_output(parsed_logs[0]["args"]) + if len(model_output) == 0: + # check inference directly from node + parsed_logs = precompile_contract.events.ModelInferenceEvent().process_receipt(tx_receipt, errors=DISCARD) + inference_id = parsed_logs[0]["args"]["inferenceID"] + inference_result = self._get_inference_result_from_node(inference_id, inference_mode) + model_output = convert_to_model_output(inference_result) + + return InferenceResult(tx_hash.hex(), model_output) + + result: InferenceResult = run_with_retry(execute_transaction, max_retries) + return result + + def _send_tx_with_revert_handling(self, run_function): + """ + Execute a blockchain transaction with revert error. + + Args: + run_function: Function that executes the transaction + + Returns: + tx_hash: Transaction hash + tx_receipt: Transaction receipt + + Raises: + Exception: If transaction fails or gas estimation fails + """ + nonce = self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending") + try: + estimated_gas = run_function.estimate_gas({"from": self._wallet_account.address}) + except ContractLogicError as e: + try: + run_function.call({"from": self._wallet_account.address}) + + except ContractLogicError as call_err: + raise ContractLogicError(f"simulation failed with revert reason: {call_err.args[0]}") + + raise ContractLogicError(f"simulation failed with no revert reason. Reason: {e}") + + gas_limit = int(estimated_gas * 3) + + transaction = run_function.build_transaction( + { + "from": self._wallet_account.address, + "nonce": nonce, + "gas": gas_limit, + "gasPrice": self._blockchain.eth.gas_price, + } + ) + + signed_tx = self._wallet_account.sign_transaction(transaction) # type: ignore[arg-type] + tx_hash = self._blockchain.eth.send_raw_transaction(signed_tx.raw_transaction) + tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash, timeout=INFERENCE_TX_TIMEOUT) + + if tx_receipt["status"] == 0: + try: + run_function.call({"from": self._wallet_account.address}) + + except ContractLogicError as call_err: + raise ContractLogicError(f"Transaction failed with revert reason: {call_err.args[0]}") + + raise ContractLogicError(f"Transaction failed with no revert reason. Receipt: {tx_receipt}") + + return tx_hash, tx_receipt + + def _get_inference_result_from_node(self, inference_id: str, inference_mode: InferenceMode) -> Optional[Dict]: + """ + Get the inference result from node. + + Args: + inference_id (str): Inference id for a inference request + + Returns: + Dict: The inference result as returned by the node + + Raises: + RuntimeError: If the request fails or returns an error + """ + try: + encoded_id = urllib.parse.quote(inference_id, safe="") + url = f"{self._api_url}/artela-network/artela-rollkit/inference/tx/{encoded_id}" + + response = requests.get(url) + if response.status_code == 200: + resp = response.json() + inference_result = resp.get("inference_results", {}) + if inference_result: + decoded_bytes = base64.b64decode(inference_result[0]) + decoded_string = decoded_bytes.decode("utf-8") + output = json.loads(decoded_string).get("InferenceResult", {}) + if output is None: + raise RuntimeError("Missing InferenceResult in inference output") + + match inference_mode: + case InferenceMode.VANILLA: + if "VanillaResult" not in output: + raise RuntimeError("Missing VanillaResult in inference output") + if "model_output" not in output["VanillaResult"]: + raise RuntimeError("Missing model_output in VanillaResult") + return {"output": output["VanillaResult"]["model_output"]} + + case InferenceMode.TEE: + if "TeeNodeResult" not in output: + raise RuntimeError("Missing TeeNodeResult in inference output") + if "Response" not in output["TeeNodeResult"]: + raise RuntimeError("Missing Response in TeeNodeResult") + if "VanillaResponse" in output["TeeNodeResult"]["Response"]: + if "model_output" not in output["TeeNodeResult"]["Response"]["VanillaResponse"]: + raise RuntimeError("Missing model_output in VanillaResponse") + return {"output": output["TeeNodeResult"]["Response"]["VanillaResponse"]["model_output"]} + + else: + raise RuntimeError("Missing VanillaResponse in TeeNodeResult Response") + + case InferenceMode.ZKML: + if "ZkmlResult" not in output: + raise RuntimeError("Missing ZkmlResult in inference output") + if "model_output" not in output["ZkmlResult"]: + raise RuntimeError("Missing model_output in ZkmlResult") + return {"output": output["ZkmlResult"]["model_output"]} + + case _: + raise ValueError(f"Invalid inference mode: {inference_mode}") + else: + return None + + else: + raise RuntimeError(f"Failed to get inference result: HTTP {response.status_code}") + + except requests.RequestException as e: + raise RuntimeError(f"Failed to get inference result: {str(e)}") + except (RuntimeError, ValueError): + raise + except Exception as e: + raise RuntimeError(f"Failed to get inference result: {str(e)}") + + def new_workflow( + self, + model_cid: str, + input_query: HistoricalInputQuery, + input_tensor_name: str, + scheduler_params: Optional[SchedulerParams] = None, + ) -> str: + """ + Deploy a new workflow contract with the specified parameters. + + This function deploys a new workflow contract on OpenGradient that connects + an AI model with its required input data. When executed, the workflow will fetch + the specified model, evaluate the input query to get data, and perform inference. + + The workflow can be set to execute manually or automatically via a scheduler. + + Args: + model_cid (str): CID of the model to be executed from the Model Hub + input_query (HistoricalInputQuery): Input definition for the model inference, + will be evaluated at runtime for each inference + input_tensor_name (str): Name of the input tensor expected by the model + scheduler_params (Optional[SchedulerParams]): Scheduler configuration for automated execution: + - frequency: Execution frequency in seconds + - duration_hours: How long the schedule should live for + + Returns: + str: Deployed contract address. If scheduler_params was provided, the workflow + will be automatically executed according to the specified schedule. + + Raises: + RuntimeError: If the deployment transaction fails. + Gas estimation failure does not raise; a fallback gas limit is used instead. + """ + # Get contract ABI and bytecode + abi = get_abi("PriceHistoryInference.abi") + bytecode = get_bin("PriceHistoryInference.bin") + + def deploy_transaction(): + contract = self._blockchain.eth.contract(abi=abi, bytecode=bytecode) + query_tuple = input_query.to_abi_format() + constructor_args = [model_cid, input_tensor_name, query_tuple] + + try: + # Estimate gas needed + estimated_gas = contract.constructor(*constructor_args).estimate_gas({"from": self._wallet_account.address}) + gas_limit = int(estimated_gas * 1.2) + except Exception as e: + gas_limit = 5000000 # Conservative fallback + logger.warning("Gas estimation failed: %s; using fallback gas limit %d", e, gas_limit) + + transaction = contract.constructor(*constructor_args).build_transaction( + { + "from": self._wallet_account.address, + "nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"), + "gas": gas_limit, + "gasPrice": self._blockchain.eth.gas_price, + "chainId": self._blockchain.eth.chain_id, + } + ) + + signed_txn = self._wallet_account.sign_transaction(transaction) + tx_hash = self._blockchain.eth.send_raw_transaction(signed_txn.raw_transaction) + + tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash, timeout=60) + + if tx_receipt["status"] == 0: + raise RuntimeError(f"Contract deployment failed, transaction hash: {tx_hash.hex()}") + + return tx_receipt.contractAddress + + contract_address: str = run_with_retry(deploy_transaction) + + if scheduler_params: + self._register_with_scheduler(contract_address, scheduler_params) + + return contract_address + + def _register_with_scheduler(self, contract_address: str, scheduler_params: SchedulerParams) -> None: + """ + Register the deployed workflow contract with the scheduler for automated execution. + + Args: + contract_address (str): Address of the deployed workflow contract + scheduler_params (SchedulerParams): Scheduler configuration containing: + - frequency: Execution frequency in seconds + - duration_hours: How long to run in hours + - end_time: Unix timestamp when scheduling should end + + Note: + Scheduler registration failures are logged as warnings and do not raise. + The workflow contract is already deployed and can be executed manually. + """ + scheduler_abi = get_abi("WorkflowScheduler.abi") + + # Scheduler contract address + scheduler_address = DEFAULT_SCHEDULER_ADDRESS + scheduler_contract = self._blockchain.eth.contract(address=Web3.to_checksum_address(scheduler_address), abi=scheduler_abi) + + try: + # Register the workflow with the scheduler + scheduler_tx = scheduler_contract.functions.registerTask( + contract_address, scheduler_params.end_time, scheduler_params.frequency + ).build_transaction( + { + "from": self._wallet_account.address, + "gas": 300000, + "gasPrice": self._blockchain.eth.gas_price, + "nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"), + "chainId": self._blockchain.eth.chain_id, + } + ) + + signed_scheduler_tx = self._wallet_account.sign_transaction(scheduler_tx) # type: ignore[arg-type] + scheduler_tx_hash = self._blockchain.eth.send_raw_transaction(signed_scheduler_tx.raw_transaction) + self._blockchain.eth.wait_for_transaction_receipt(scheduler_tx_hash, timeout=REGULAR_TX_TIMEOUT) + except Exception as e: + logger.warning("Error registering contract with scheduler: %s. The workflow contract is still deployed and can be executed manually.", e) + + def read_workflow_result(self, contract_address: str) -> ModelOutput: + """ + Reads the latest inference result from a deployed workflow contract. + + Args: + contract_address (str): Address of the deployed workflow contract + + Returns: + ModelOutput: The inference result from the contract + + Raises: + ContractLogicError: If the transaction fails + Web3Error: If there are issues with the web3 connection or contract interaction + """ + # Get the contract interface + contract = self._blockchain.eth.contract( + address=Web3.to_checksum_address(contract_address), abi=get_abi("PriceHistoryInference.abi") + ) + + # Get the result + result = contract.functions.getInferenceResult().call() + + output: ModelOutput = convert_array_to_model_output(result) + return output + + def run_workflow(self, contract_address: str) -> ModelOutput: + """ + Triggers the run() function on a deployed workflow contract and returns the result. + + Args: + contract_address (str): Address of the deployed workflow contract + + Returns: + ModelOutput: The inference result from the contract + + Raises: + ContractLogicError: If the transaction fails + Web3Error: If there are issues with the web3 connection or contract interaction + """ + # Get the contract interface + contract = self._blockchain.eth.contract( + address=Web3.to_checksum_address(contract_address), abi=get_abi("PriceHistoryInference.abi") + ) + + # Call run() function + nonce = self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending") + + run_function = contract.functions.run() + transaction = run_function.build_transaction( + { + "from": self._wallet_account.address, + "nonce": nonce, + "gas": 30000000, + "gasPrice": self._blockchain.eth.gas_price, + "chainId": self._blockchain.eth.chain_id, + } + ) + + signed_txn = self._wallet_account.sign_transaction(transaction) # type: ignore[arg-type] + tx_hash = self._blockchain.eth.send_raw_transaction(signed_txn.raw_transaction) + tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash, timeout=INFERENCE_TX_TIMEOUT) + + if tx_receipt["status"] == 0: + raise ContractLogicError(f"Run transaction failed. Receipt: {tx_receipt}") + + # Get the inference result from the contract + result = contract.functions.getInferenceResult().call() + + run_output: ModelOutput = convert_array_to_model_output(result) + return run_output + + def read_workflow_history(self, contract_address: str, num_results: int) -> List[ModelOutput]: + """ + Gets historical inference results from a workflow contract. + + Retrieves the specified number of most recent inference results from the contract's + storage, with the most recent result first. + + Args: + contract_address (str): Address of the deployed workflow contract + num_results (int): Number of historical results to retrieve + + Returns: + List[ModelOutput]: List of historical inference results + """ + contract = self._blockchain.eth.contract( + address=Web3.to_checksum_address(contract_address), abi=get_abi("PriceHistoryInference.abi") + ) + + results = contract.functions.getLastInferenceResults(num_results).call() + return [convert_array_to_model_output(result) for result in results] diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index ed54fd9..dbc2a13 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -1,508 +1,509 @@ -"""LLM chat and completion via TEE-verified execution with x402 payments.""" - -import json -import logging -from dataclasses import dataclass -from typing import AsyncGenerator, Awaitable, Callable, Dict, List, Optional, TypeVar, Union -import httpx -import asyncio - -from eth_account import Account -from eth_account.account import LocalAccount -from x402 import x402Client -from x402.mechanisms.evm import EthAccountSigner -from x402.mechanisms.evm.exact.register import register_exact_evm_client -from x402.mechanisms.evm.upto.register import register_upto_evm_client - -from ..types import TEE_LLM, StreamChoice, StreamChunk, StreamDelta, TextGenerationOutput, x402SettlementMode -from .opg_token import Permit2ApprovalResult, ensure_opg_approval -from .tee_connection import RegistryTEEConnection, StaticTEEConnection, TEEConnectionInterface -from .tee_registry import TEERegistry - -logger = logging.getLogger(__name__) -T = TypeVar("T") - -DEFAULT_RPC_URL = "https://ogevmdevnet.opengradient.ai" -DEFAULT_TEE_REGISTRY_ADDRESS = "0x4e72238852f3c918f4E4e57AeC9280dDB0c80248" - -X402_PROCESSING_HASH_HEADER = "x-processing-hash" -X402_PLACEHOLDER_API_KEY = "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef" -BASE_TESTNET_NETWORK = "eip155:84532" - -_CHAT_ENDPOINT = "/v1/chat/completions" -_COMPLETION_ENDPOINT = "/v1/completions" -_REQUEST_TIMEOUT = 60 - - -@dataclass(frozen=True) -class _ChatParams: - """Bundles the common parameters for chat/completion requests.""" - - model: str - max_tokens: int - temperature: float - stop_sequence: Optional[List[str]] - tools: Optional[List[Dict]] - tool_choice: Optional[str] - x402_settlement_mode: x402SettlementMode - - -class LLM: - """ - LLM inference namespace. - - Provides access to large language model completions and chat via TEE - (Trusted Execution Environment) with x402 payment protocol support. - Supports both streaming and non-streaming responses. - - All request methods (``chat``, ``completion``) are async. - - Before making LLM requests, ensure your wallet has approved sufficient - OPG tokens for Permit2 spending by calling ``ensure_opg_approval``. - - Usage: - # Via on-chain registry (default) - llm = og.LLM(private_key="0x...") - - # Via hardcoded URL (development / self-hosted) - llm = og.LLM.from_url(private_key="0x...", llm_server_url="https://1.2.3.4") - - # Ensure sufficient OPG allowance (only sends tx when below threshold) - llm.ensure_opg_approval(min_allowance=5) - - result = await llm.chat(model=TEE_LLM.CLAUDE_HAIKU_4_5, messages=[...]) - result = await llm.completion(model=TEE_LLM.CLAUDE_HAIKU_4_5, prompt="Hello") - """ - - def __init__( - self, - private_key: str, - rpc_url: str = DEFAULT_RPC_URL, - tee_registry_address: str = DEFAULT_TEE_REGISTRY_ADDRESS, - ): - if not private_key: - raise ValueError("A private key is required to use the LLM client. Pass a valid private_key to the constructor.") - self._wallet_account: LocalAccount = Account.from_key(private_key) - - x402_client = LLM._build_x402_client(private_key) - onchain_registry = TEERegistry(rpc_url=rpc_url, registry_address=tee_registry_address) - self._tee: TEEConnectionInterface = RegistryTEEConnection(x402_client=x402_client, registry=onchain_registry) - - @classmethod - def from_url( - cls, - private_key: str, - llm_server_url: str, - ) -> "LLM": - """**[Dev]** Create an LLM client with a hardcoded TEE endpoint URL. - - Intended for development and self-hosted TEE servers. TLS certificate - verification is disabled because these servers typically use self-signed - certificates. For production use, prefer the default constructor which - resolves TEEs from the on-chain registry. - - Args: - private_key: Ethereum private key for signing x402 payments. - llm_server_url: The TEE endpoint URL (e.g. ``"https://1.2.3.4"``). - """ - instance = cls.__new__(cls) - if not private_key: - raise ValueError("A private key is required to use the LLM client. Pass a valid private_key to the constructor.") - instance._wallet_account = Account.from_key(private_key) - x402_client = cls._build_x402_client(private_key) - instance._tee = StaticTEEConnection(x402_client=x402_client, endpoint=llm_server_url) - return instance - - @staticmethod - def _build_x402_client(private_key: str) -> x402Client: - """Build the x402 payment stack from a private key.""" - account = Account.from_key(private_key) - signer = EthAccountSigner(account) - client = x402Client() - register_exact_evm_client(client, signer, networks=[BASE_TESTNET_NETWORK]) - register_upto_evm_client(client, signer, networks=[BASE_TESTNET_NETWORK]) - return client - - # ── Lifecycle ─────────────────────────────────────────────────────── - - async def close(self) -> None: - """Cancel the background refresh loop and close the HTTP client.""" - await self._tee.close() - - # ── Request helpers ───────────────────────────────────────────────── - - def _headers(self, settlement_mode: x402SettlementMode) -> Dict[str, str]: - return { - "Content-Type": "application/json", - "Authorization": f"Bearer {X402_PLACEHOLDER_API_KEY}", - "X-SETTLEMENT-TYPE": settlement_mode.value, - } - - def _chat_payload(self, params: _ChatParams, messages: List[Dict], stream: bool = False) -> Dict: - payload: Dict = { - "model": params.model, - "messages": messages, - "max_tokens": params.max_tokens, - "temperature": params.temperature, - } - if stream: - payload["stream"] = True - if params.stop_sequence: - payload["stop"] = params.stop_sequence - if params.tools: - payload["tools"] = params.tools - payload["tool_choice"] = params.tool_choice or "auto" - return payload - - async def _call_with_tee_retry( - self, - operation_name: str, - call: Callable[[], Awaitable[T]], - ) -> T: - """Execute *call*; on connection failure, pick a new TEE and retry once. - - Only retries when the request never reached the server (no HTTP response). - Server-side errors (4xx/5xx) are not retried. - """ - self._tee.ensure_refresh_loop() - try: - return await call() - except httpx.HTTPStatusError: - raise - except asyncio.CancelledError: - raise - except Exception as exc: - logger.warning( - "Connection failure during %s; refreshing TEE and retrying once: %s", - operation_name, - exc, - ) - await self._tee.reconnect() - return await call() - - # ── Public API ────────────────────────────────────────────────────── - - def ensure_opg_approval( - self, - min_allowance: float, - approve_amount: Optional[float] = None, - ) -> Permit2ApprovalResult: - """Ensure the Permit2 allowance stays above a minimum threshold. - - Only sends a transaction when the current allowance drops below - ``min_allowance``. When approval is needed, approves ``approve_amount`` - (defaults to ``2 * min_allowance``) to create a buffer that survives - multiple service restarts without re-approving. - - Best for backend servers that call this on startup:: - - llm.ensure_opg_approval(min_allowance=5.0, approve_amount=100.0) - - Args: - min_allowance: The minimum acceptable allowance in OPG. Must be - at least 0.1 OPG. - approve_amount: The amount of OPG to approve when a transaction - is needed. Defaults to ``2 * min_allowance``. Must be - >= ``min_allowance``. - - Returns: - Permit2ApprovalResult: Contains ``allowance_before``, - ``allowance_after``, and ``tx_hash`` (None when no approval - was needed). - - Raises: - ValueError: If ``min_allowance`` is less than 0.1 or - ``approve_amount`` is less than ``min_allowance``. - RuntimeError: If the approval transaction fails. - """ - if min_allowance < 0.1: - raise ValueError("min_allowance must be at least 0.1.") - return ensure_opg_approval(self._wallet_account, min_allowance, approve_amount) - - async def completion( - self, - model: TEE_LLM, - prompt: str, - max_tokens: int = 100, - stop_sequence: Optional[List[str]] = None, - temperature: float = 0.0, - x402_settlement_mode: x402SettlementMode = x402SettlementMode.BATCH_HASHED, - ) -> TextGenerationOutput: - """ - Perform inference on an LLM model using completions via TEE. - - Args: - model (TEE_LLM): The model to use (e.g., TEE_LLM.CLAUDE_HAIKU_4_5). - prompt (str): The input prompt for the LLM. - max_tokens (int): Maximum number of tokens for LLM output. Default is 100. - stop_sequence (List[str], optional): List of stop sequences for LLM. Default is None. - temperature (float): Temperature for LLM inference, between 0 and 1. Default is 0.0. - x402_settlement_mode (x402SettlementMode, optional): Settlement mode for x402 payments. - - PRIVATE: Payment only, no input/output data on-chain (most privacy-preserving). - - BATCH_HASHED: Aggregates inferences into a Merkle tree with input/output hashes and signatures (default, most cost-efficient). - - INDIVIDUAL_FULL: Records input, output, timestamp, and verification on-chain (maximum auditability). - Defaults to BATCH_HASHED. - - Returns: - TextGenerationOutput: Generated text results including: - - Transaction hash ("external" for TEE providers) - - String of completion output - - Payment hash for x402 transactions - - Raises: - RuntimeError: If the inference fails. - """ - model_id = model.split("/")[1] - payload: Dict = { - "model": model_id, - "prompt": prompt, - "max_tokens": max_tokens, - "temperature": temperature, - } - if stop_sequence: - payload["stop"] = stop_sequence - - async def _request() -> TextGenerationOutput: - tee = self._tee.get() - response = await tee.http_client.post( - tee.endpoint + _COMPLETION_ENDPOINT, - json=payload, - headers=self._headers(x402_settlement_mode), - timeout=_REQUEST_TIMEOUT, - ) - response.raise_for_status() - raw_body = await response.aread() - result = json.loads(raw_body.decode()) - return TextGenerationOutput( - transaction_hash="external", - completion_output=result.get("completion"), - tee_signature=result.get("tee_signature"), - tee_timestamp=result.get("tee_timestamp"), - **tee.metadata(), - ) - - try: - return await self._call_with_tee_retry("completion", _request) - except RuntimeError: - raise - except Exception as e: - raise RuntimeError(f"TEE LLM completion failed: {e}") from e - - async def chat( - self, - model: TEE_LLM, - messages: List[Dict], - max_tokens: int = 100, - stop_sequence: Optional[List[str]] = None, - temperature: float = 0.0, - tools: Optional[List[Dict]] = None, - tool_choice: Optional[str] = None, - x402_settlement_mode: x402SettlementMode = x402SettlementMode.BATCH_HASHED, - stream: bool = False, - ) -> Union[TextGenerationOutput, AsyncGenerator[StreamChunk, None]]: - """ - Perform inference on an LLM model using chat via TEE. - - Args: - model (TEE_LLM): The model to use (e.g., TEE_LLM.CLAUDE_HAIKU_4_5). - messages (List[Dict]): The messages that will be passed into the chat. - max_tokens (int): Maximum number of tokens for LLM output. Default is 100. - stop_sequence (List[str], optional): List of stop sequences for LLM. - temperature (float): Temperature for LLM inference, between 0 and 1. - tools (List[dict], optional): Set of tools for function calling. - tool_choice (str, optional): Sets a specific tool to choose. - x402_settlement_mode (x402SettlementMode, optional): Settlement mode for x402 payments. - - PRIVATE: Payment only, no input/output data on-chain (most privacy-preserving). - - BATCH_HASHED: Aggregates inferences into a Merkle tree with input/output hashes and signatures (default, most cost-efficient). - - INDIVIDUAL_FULL: Records input, output, timestamp, and verification on-chain (maximum auditability). - Defaults to BATCH_HASHED. - stream (bool, optional): Whether to stream the response. Default is False. - - Returns: - Union[TextGenerationOutput, AsyncGenerator[StreamChunk, None]]: - - If stream=False: TextGenerationOutput with chat_output, transaction_hash, finish_reason, and payment_hash - - If stream=True: Async generator yielding StreamChunk objects - - Raises: - RuntimeError: If the inference fails. - """ - params = _ChatParams( - model=model.split("/")[1], - max_tokens=max_tokens, - temperature=temperature, - stop_sequence=stop_sequence, - tools=tools, - tool_choice=tool_choice, - x402_settlement_mode=x402_settlement_mode, - ) - - if not stream: - return await self._chat_request(params, messages) - - # The TEE streaming endpoint omits tool call content from SSE events. - # Fall back to non-streaming and emit a single final StreamChunk. - if tools: - return self._chat_tools_as_stream(params, messages) - - return self._chat_stream(params, messages) - - # ── Chat internals ────────────────────────────────────────────────── - - async def _chat_request(self, params: _ChatParams, messages: List[Dict]) -> TextGenerationOutput: - """Non-streaming chat request.""" - payload = self._chat_payload(params, messages) - - async def _request() -> TextGenerationOutput: - tee = self._tee.get() - response = await tee.http_client.post( - tee.endpoint + _CHAT_ENDPOINT, - json=payload, - headers=self._headers(params.x402_settlement_mode), - timeout=_REQUEST_TIMEOUT, - ) - response.raise_for_status() - raw_body = await response.aread() - result = json.loads(raw_body.decode()) - - choices = result.get("choices") - if not choices: - raise RuntimeError(f"Invalid response: 'choices' missing or empty in {result}") - - message = choices[0].get("message", {}) - content = message.get("content") - if isinstance(content, list): - message["content"] = " ".join( - block.get("text", "") for block in content if isinstance(block, dict) and block.get("type") == "text" - ).strip() - - return TextGenerationOutput( - transaction_hash="external", - finish_reason=choices[0].get("finish_reason"), - chat_output=message, - tee_signature=result.get("tee_signature"), - tee_timestamp=result.get("tee_timestamp"), - **tee.metadata(), - ) - - try: - return await self._call_with_tee_retry("chat", _request) - except RuntimeError: - raise - except Exception as e: - raise RuntimeError(f"TEE LLM chat failed: {e}") from e - - async def _chat_tools_as_stream(self, params: _ChatParams, messages: List[Dict]) -> AsyncGenerator[StreamChunk, None]: - """Non-streaming fallback for tool-call requests wrapped as a single StreamChunk.""" - result = await self._chat_request(params, messages) - chat_output = result.chat_output or {} - yield StreamChunk( - choices=[ - StreamChoice( - delta=StreamDelta( - role=chat_output.get("role"), - content=chat_output.get("content"), - tool_calls=chat_output.get("tool_calls"), - ), - index=0, - finish_reason=result.finish_reason, - ) - ], - model=params.model, - is_final=True, - tee_signature=result.tee_signature, - tee_timestamp=result.tee_timestamp, - tee_id=result.tee_id, - tee_endpoint=result.tee_endpoint, - tee_payment_address=result.tee_payment_address, - ) - - async def _chat_stream(self, params: _ChatParams, messages: List[Dict]) -> AsyncGenerator[StreamChunk, None]: - """Async SSE streaming implementation.""" - self._tee.ensure_refresh_loop() - headers = self._headers(params.x402_settlement_mode) - payload = self._chat_payload(params, messages, stream=True) - - chunks_yielded = False - try: - tee = self._tee.get() - async with tee.http_client.stream( - "POST", - tee.endpoint + _CHAT_ENDPOINT, - json=payload, - headers=headers, - timeout=_REQUEST_TIMEOUT, - ) as response: - async for chunk in self._parse_sse_response(response, tee): - chunks_yielded = True - yield chunk - return - except httpx.HTTPStatusError: - raise - except asyncio.CancelledError: - raise - except Exception as exc: - if chunks_yielded: - raise - logger.warning( - "Connection failure during stream setup; refreshing TEE and retrying once: %s", - exc, - ) - - # Only reached if the first attempt failed before yielding any chunks. - # Re-resolve the TEE endpoint from the registry and retry once. - await self._tee.reconnect() - tee = self._tee.get() - - headers = self._headers(params.x402_settlement_mode) - async with tee.http_client.stream( - "POST", - tee.endpoint + _CHAT_ENDPOINT, - json=payload, - headers=headers, - timeout=_REQUEST_TIMEOUT, - ) as response: - async for chunk in self._parse_sse_response(response, tee): - yield chunk - - async def _parse_sse_response(self, response, tee) -> AsyncGenerator[StreamChunk, None]: - """Parse an SSE response stream into StreamChunk objects.""" - status_code = getattr(response, "status_code", None) - if status_code is not None and status_code >= 400: - body = await response.aread() - raise RuntimeError(f"TEE LLM streaming request failed with status {status_code}: {body.decode('utf-8', errors='replace')}") - - buffer = b"" - async for raw_chunk in response.aiter_raw(): - if not raw_chunk: - continue - - buffer += raw_chunk - while b"\n" in buffer: - line_bytes, buffer = buffer.split(b"\n", 1) - line = line_bytes.strip() - if not line: - continue - - try: - decoded = line.decode("utf-8") - except UnicodeDecodeError: - continue - - if not decoded.startswith("data: "): - continue - - data_str = decoded[6:].strip() - if data_str == "[DONE]": - return - - try: - data = json.loads(data_str) - except json.JSONDecodeError: - continue - - chunk = StreamChunk.from_sse_data(data) - if chunk.is_final: - chunk.tee_id = tee.tee_id - chunk.tee_endpoint = tee.endpoint - chunk.tee_payment_address = tee.payment_address - yield chunk +"""LLM chat and completion via TEE-verified execution with x402 payments.""" + +import json +import logging +from dataclasses import dataclass +from typing import AsyncGenerator, Awaitable, Callable, Dict, List, Optional, TypeVar, Union +import httpx +import asyncio + +from eth_account import Account +from eth_account.account import LocalAccount +from x402 import x402Client +from x402.mechanisms.evm import EthAccountSigner +from x402.mechanisms.evm.exact.register import register_exact_evm_client +from x402.mechanisms.evm.upto.register import register_upto_evm_client + +from ..types import TEE_LLM, StreamChoice, StreamChunk, StreamDelta, TextGenerationOutput, x402SettlementMode +from .opg_token import Permit2ApprovalResult, ensure_opg_approval +from .tee_connection import RegistryTEEConnection, StaticTEEConnection, TEEConnectionInterface +from .tee_registry import TEERegistry + +logger = logging.getLogger(__name__) +T = TypeVar("T") + +DEFAULT_RPC_URL = "https://ogevmdevnet.opengradient.ai" +DEFAULT_TEE_REGISTRY_ADDRESS = "0x4e72238852f3c918f4E4e57AeC9280dDB0c80248" + +X402_PROCESSING_HASH_HEADER = "x-processing-hash" +X402_PLACEHOLDER_API_KEY = "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef" +BASE_TESTNET_NETWORK = "eip155:84532" + +_CHAT_ENDPOINT = "/v1/chat/completions" +_COMPLETION_ENDPOINT = "/v1/completions" +_REQUEST_TIMEOUT = 60 + + +@dataclass(frozen=True) +class _ChatParams: + """Bundles the common parameters for chat/completion requests.""" + + model: str + max_tokens: int + temperature: float + stop_sequence: Optional[List[str]] + tools: Optional[List[Dict]] + tool_choice: Optional[str] + x402_settlement_mode: x402SettlementMode + + +class LLM: + """ + LLM inference namespace. + + Provides access to large language model completions and chat via TEE + (Trusted Execution Environment) with x402 payment protocol support. + Supports both streaming and non-streaming responses. + + All request methods (``chat``, ``completion``) are async. + + Before making LLM requests, ensure your wallet has approved sufficient + OPG tokens for Permit2 spending by calling ``ensure_opg_approval``. + + Usage: + # Via on-chain registry (default) + llm = og.LLM(private_key="0x...") + + # Via hardcoded URL (development / self-hosted) + llm = og.LLM.from_url(private_key="0x...", llm_server_url="https://1.2.3.4") + + # Ensure sufficient OPG allowance (only sends tx when below threshold) + llm.ensure_opg_approval(min_allowance=5) + + result = await llm.chat(model=TEE_LLM.CLAUDE_HAIKU_4_5, messages=[...]) + result = await llm.completion(model=TEE_LLM.CLAUDE_HAIKU_4_5, prompt="Hello") + """ + + def __init__( + self, + private_key: str, + rpc_url: str = DEFAULT_RPC_URL, + tee_registry_address: str = DEFAULT_TEE_REGISTRY_ADDRESS, + ): + if not private_key: + raise ValueError("A private key is required to use the LLM client. Pass a valid private_key to the constructor.") + self._wallet_account: LocalAccount = Account.from_key(private_key) + + x402_client = LLM._build_x402_client(private_key) + onchain_registry = TEERegistry(rpc_url=rpc_url, registry_address=tee_registry_address) + self._tee: TEEConnectionInterface = RegistryTEEConnection(x402_client=x402_client, registry=onchain_registry) + + @classmethod + def from_url( + cls, + private_key: str, + llm_server_url: str, + ) -> "LLM": + """**[Dev]** Create an LLM client with a hardcoded TEE endpoint URL. + + Intended for development and self-hosted TEE servers. TLS certificate + verification is disabled because these servers typically use self-signed + certificates. For production use, prefer the default constructor which + resolves TEEs from the on-chain registry. + + Args: + private_key: Ethereum private key for signing x402 payments. + llm_server_url: The TEE endpoint URL (e.g. ``"https://1.2.3.4"``). + """ + instance = cls.__new__(cls) + if not private_key: + raise ValueError("A private key is required to use the LLM client. Pass a valid private_key to the constructor.") + instance._wallet_account = Account.from_key(private_key) + x402_client = cls._build_x402_client(private_key) + instance._tee = StaticTEEConnection(x402_client=x402_client, endpoint=llm_server_url) + return instance + + @staticmethod + def _build_x402_client(private_key: str) -> x402Client: + """Build the x402 payment stack from a private key.""" + account = Account.from_key(private_key) + signer = EthAccountSigner(account) + client = x402Client() + register_exact_evm_client(client, signer, networks=[BASE_TESTNET_NETWORK]) + register_upto_evm_client(client, signer, networks=[BASE_TESTNET_NETWORK]) + return client + + # ── Lifecycle ─────────────────────────────────────────────────────── + + async def close(self) -> None: + """Cancel the background refresh loop and close the HTTP client.""" + await self._tee.close() + + # ── Request helpers ───────────────────────────────────────────────── + + def _headers(self, settlement_mode: x402SettlementMode) -> Dict[str, str]: + return { + "Content-Type": "application/json", + "Authorization": f"Bearer {X402_PLACEHOLDER_API_KEY}", + "X-SETTLEMENT-TYPE": settlement_mode.value, + } + + def _chat_payload(self, params: _ChatParams, messages: List[Dict], stream: bool = False) -> Dict: + payload: Dict = { + "model": params.model, + "messages": messages, + "max_tokens": params.max_tokens, + "temperature": params.temperature, + } + if stream: + payload["stream"] = True + if params.stop_sequence: + payload["stop"] = params.stop_sequence + if params.tools: + payload["tools"] = params.tools + payload["tool_choice"] = params.tool_choice or "auto" + return payload + + async def _call_with_tee_retry( + self, + operation_name: str, + call: Callable[[], Awaitable[T]], + ) -> T: + """Execute *call*; on connection failure, pick a new TEE and retry once. + + Only retries when the request never reached the server (no HTTP response). + Server-side errors (4xx/5xx) are not retried. + """ + self._tee.ensure_refresh_loop() + try: + return await call() + except httpx.HTTPStatusError: + raise + except asyncio.CancelledError: + raise + except Exception as exc: + logger.warning( + "Connection failure during %s; refreshing TEE and retrying once: %s", + operation_name, + exc, + ) + await self._tee.reconnect() + return await call() + + # ── Public API ────────────────────────────────────────────────────── + + def ensure_opg_approval( + self, + min_allowance: float, + approve_amount: Optional[float] = None, + ) -> Permit2ApprovalResult: + """Ensure the Permit2 allowance stays above a minimum threshold. + + Only sends a transaction when the current allowance drops below + ``min_allowance``. When approval is needed, approves ``approve_amount`` + (defaults to ``2 * min_allowance``) to create a buffer that survives + multiple service restarts without re-approving. + + Best for backend servers that call this on startup:: + + llm.ensure_opg_approval(min_allowance=5.0, approve_amount=100.0) + + Args: + min_allowance: The minimum acceptable allowance in OPG. Must be + at least 0.1 OPG. + approve_amount: The amount of OPG to approve when a transaction + is needed. Defaults to ``2 * min_allowance``. Must be + >= ``min_allowance``. + + Returns: + Permit2ApprovalResult: Contains ``allowance_before``, + ``allowance_after``, and ``tx_hash`` (None when no approval + was needed). + + Raises: + ValueError: If ``min_allowance`` is less than 0.1 or + ``approve_amount`` is less than ``min_allowance``. + RuntimeError: If the approval transaction fails. + """ + if min_allowance < 0.1: + raise ValueError("min_allowance must be at least 0.1.") + return ensure_opg_approval(self._wallet_account, min_allowance, approve_amount) + + async def completion( + self, + model: TEE_LLM, + prompt: str, + max_tokens: int = 100, + stop_sequence: Optional[List[str]] = None, + temperature: float = 0.0, + x402_settlement_mode: x402SettlementMode = x402SettlementMode.BATCH_HASHED, + ) -> TextGenerationOutput: + """ + Perform inference on an LLM model using completions via TEE. + + Args: + model (TEE_LLM): The model to use (e.g., TEE_LLM.CLAUDE_HAIKU_4_5). + prompt (str): The input prompt for the LLM. + max_tokens (int): Maximum number of tokens for LLM output. Default is 100. + stop_sequence (List[str], optional): List of stop sequences for LLM. Default is None. + temperature (float): Temperature for LLM inference, between 0 and 1. Default is 0.0. + x402_settlement_mode (x402SettlementMode, optional): Settlement mode for x402 payments. + - PRIVATE: Payment only, no input/output data on-chain (most privacy-preserving). + - BATCH_HASHED: Aggregates inferences into a Merkle tree with input/output hashes and signatures (default, most cost-efficient). + - INDIVIDUAL_FULL: Records input, output, timestamp, and verification on-chain (maximum auditability). + Defaults to BATCH_HASHED. + + Returns: + TextGenerationOutput: Generated text results including: + - Transaction hash ("external" for TEE providers) + - String of completion output + - Payment hash for x402 transactions + + Raises: + RuntimeError: If the inference fails. + """ + model_id = model.split("/")[1] + payload: Dict = { + "model": model_id, + "prompt": prompt, + "max_tokens": max_tokens, + "temperature": temperature, + } + if stop_sequence: + payload["stop"] = stop_sequence + + async def _request() -> TextGenerationOutput: + tee = self._tee.get() + response = await tee.http_client.post( + tee.endpoint + _COMPLETION_ENDPOINT, + json=payload, + headers=self._headers(x402_settlement_mode), + timeout=_REQUEST_TIMEOUT, + ) + response.raise_for_status() + raw_body = await response.aread() + result = json.loads(raw_body.decode()) + return TextGenerationOutput( + transaction_hash="external", + completion_output=result.get("completion"), + tee_signature=result.get("tee_signature"), + tee_timestamp=result.get("tee_timestamp"), + **tee.metadata(), + ) + + try: + return await self._call_with_tee_retry("completion", _request) + except RuntimeError: + raise + except Exception as e: + raise RuntimeError(f"TEE LLM completion failed: {e}") from e + + async def chat( + self, + model: TEE_LLM, + messages: List[Dict], + max_tokens: int = 100, + stop_sequence: Optional[List[str]] = None, + temperature: float = 0.0, + tools: Optional[List[Dict]] = None, + tool_choice: Optional[str] = None, + x402_settlement_mode: x402SettlementMode = x402SettlementMode.BATCH_HASHED, + stream: bool = False, + ) -> Union[TextGenerationOutput, AsyncGenerator[StreamChunk, None]]: + """ + Perform inference on an LLM model using chat via TEE. + + Args: + model (TEE_LLM): The model to use (e.g., TEE_LLM.CLAUDE_HAIKU_4_5). + messages (List[Dict]): The messages that will be passed into the chat. + max_tokens (int): Maximum number of tokens for LLM output. Default is 100. + stop_sequence (List[str], optional): List of stop sequences for LLM. + temperature (float): Temperature for LLM inference, between 0 and 1. + tools (List[dict], optional): Set of tools for function calling. + tool_choice (str, optional): Sets a specific tool to choose. + x402_settlement_mode (x402SettlementMode, optional): Settlement mode for x402 payments. + - PRIVATE: Payment only, no input/output data on-chain (most privacy-preserving). + - BATCH_HASHED: Aggregates inferences into a Merkle tree with input/output hashes and signatures (default, most cost-efficient). + - INDIVIDUAL_FULL: Records input, output, timestamp, and verification on-chain (maximum auditability). + Defaults to BATCH_HASHED. + stream (bool, optional): Whether to stream the response. Default is False. + + Returns: + Union[TextGenerationOutput, AsyncGenerator[StreamChunk, None]]: + - If stream=False: TextGenerationOutput with chat_output, transaction_hash, finish_reason, and payment_hash + - If stream=True: Async generator yielding StreamChunk objects + + Raises: + RuntimeError: If the inference fails. + """ + params = _ChatParams( + model=model.split("/")[1], + max_tokens=max_tokens, + temperature=temperature, + stop_sequence=stop_sequence, + tools=tools, + tool_choice=tool_choice, + x402_settlement_mode=x402_settlement_mode, + ) + + if not stream: + return await self._chat_request(params, messages) + + # The TEE streaming endpoint omits tool call content from SSE events. + # Fall back to non-streaming and emit a single final StreamChunk. + if tools: + return self._chat_tools_as_stream(params, messages) + + return self._chat_stream(params, messages) + + # ── Chat internals ────────────────────────────────────────────────── + + async def _chat_request(self, params: _ChatParams, messages: List[Dict]) -> TextGenerationOutput: + """Non-streaming chat request.""" + payload = self._chat_payload(params, messages) + + async def _request() -> TextGenerationOutput: + tee = self._tee.get() + response = await tee.http_client.post( + tee.endpoint + _CHAT_ENDPOINT, + json=payload, + headers=self._headers(params.x402_settlement_mode), + timeout=_REQUEST_TIMEOUT, + ) + response.raise_for_status() + raw_body = await response.aread() + result = json.loads(raw_body.decode()) + + choices = result.get("choices") + if not isinstance(choices, list) or not choices or not isinstance(choices[0], dict): + raise RuntimeError(f"Invalid response: 'choices' missing or empty or malformed in {result}") + + message = choices[0].get("message", {}) + content = message.get("content") + if isinstance(content, list): + message["content"] = " ".join( + block.get("text", "") for block in content if isinstance(block, dict) and block.get("type") == "text" + ).strip() + + return TextGenerationOutput( + transaction_hash="external", + finish_reason=choices[0].get("finish_reason"), + chat_output=message, + tee_signature=result.get("tee_signature"), + tee_timestamp=result.get("tee_timestamp"), + **tee.metadata(), + ) + + try: + return await self._call_with_tee_retry("chat", _request) + except RuntimeError: + raise + except Exception as e: + raise RuntimeError(f"TEE LLM chat failed: {e}") from e + + async def _chat_tools_as_stream(self, params: _ChatParams, messages: List[Dict]) -> AsyncGenerator[StreamChunk, None]: + """Non-streaming fallback for tool-call requests wrapped as a single StreamChunk.""" + result = await self._chat_request(params, messages) + chat_output = result.chat_output or {} + yield StreamChunk( + choices=[ + StreamChoice( + delta=StreamDelta( + role=chat_output.get("role"), + content=chat_output.get("content"), + tool_calls=chat_output.get("tool_calls"), + ), + index=0, + finish_reason=result.finish_reason, + ) + ], + model=params.model, + is_final=True, + tee_signature=result.tee_signature, + tee_timestamp=result.tee_timestamp, + tee_id=result.tee_id, + tee_endpoint=result.tee_endpoint, + tee_payment_address=result.tee_payment_address, + ) + + async def _chat_stream(self, params: _ChatParams, messages: List[Dict]) -> AsyncGenerator[StreamChunk, None]: + """Async SSE streaming implementation.""" + self._tee.ensure_refresh_loop() + headers = self._headers(params.x402_settlement_mode) + payload = self._chat_payload(params, messages, stream=True) + + chunks_yielded = False + try: + tee = self._tee.get() + async with tee.http_client.stream( + "POST", + tee.endpoint + _CHAT_ENDPOINT, + json=payload, + headers=headers, + timeout=_REQUEST_TIMEOUT, + ) as response: + async for chunk in self._parse_sse_response(response, tee): + chunks_yielded = True + yield chunk + return + except httpx.HTTPStatusError: + raise + except asyncio.CancelledError: + raise + except Exception as exc: + if chunks_yielded: + raise + logger.warning( + "Connection failure during stream setup; refreshing TEE and retrying once: %s", + exc, + ) + + # Only reached if the first attempt failed before yielding any chunks. + # Re-resolve the TEE endpoint from the registry and retry once. + await self._tee.reconnect() + tee = self._tee.get() + + headers = self._headers(params.x402_settlement_mode) + async with tee.http_client.stream( + "POST", + tee.endpoint + _CHAT_ENDPOINT, + json=payload, + headers=headers, + timeout=_REQUEST_TIMEOUT, + ) as response: + async for chunk in self._parse_sse_response(response, tee): + yield chunk + + async def _parse_sse_response(self, response, tee) -> AsyncGenerator[StreamChunk, None]: + """Parse an SSE response stream into StreamChunk objects.""" + status_code = getattr(response, "status_code", None) + if status_code is not None and status_code >= 400: + body = await response.aread() + raise RuntimeError(f"TEE LLM streaming request failed with status {status_code}: {body.decode('utf-8', errors='replace')}") + + buffer = b"" + async for raw_chunk in response.aiter_raw(): + if not raw_chunk: + continue + + buffer += raw_chunk + while b"\n" in buffer: + line_bytes, buffer = buffer.split(b"\n", 1) + line = line_bytes.strip() + if not line: + continue + + try: + decoded = line.decode("utf-8") + except UnicodeDecodeError: + continue + + if not decoded.startswith("data: "): + continue + + data_str = decoded[6:].strip() + if data_str == "[DONE]": + return + + try: + data = json.loads(data_str) + except json.JSONDecodeError: + logger.warning("Skipping malformed SSE JSON: %r", data_str) + continue + + chunk = StreamChunk.from_sse_data(data) + if chunk.is_final: + chunk.tee_id = tee.tee_id + chunk.tee_endpoint = tee.endpoint + chunk.tee_payment_address = tee.payment_address + yield chunk diff --git a/src/opengradient/types.py b/src/opengradient/types.py index a59293f..0c05100 100644 --- a/src/opengradient/types.py +++ b/src/opengradient/types.py @@ -1,554 +1,558 @@ -""" -OpenGradient Specific Types -""" - -import time -from dataclasses import dataclass -from enum import Enum, IntEnum -from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union - -import numpy as np - - -class x402SettlementMode(str, Enum): - """ - Settlement modes for x402 payment protocol transactions. - - These modes control how inference data is recorded on-chain for payment settlement - and auditability. Each mode offers different trade-offs between data completeness, - privacy, and transaction costs. - - Attributes: - PRIVATE: Payment-only settlement. - Only the payment is settled on-chain — no input or output hashes are posted. - Your inference data remains completely off-chain, ensuring maximum privacy. - Suitable when payment settlement is required without any on-chain record of execution. - CLI usage: --settlement-mode private - - BATCH_HASHED: Batch settlement with hashes (default). - Aggregates multiple inferences into a single settlement transaction - using a Merkle tree containing input hashes, output hashes, and signatures. - Most cost-efficient for high-volume applications. - CLI usage: --settlement-mode batch-hashed - - INDIVIDUAL_FULL: Individual settlement with full metadata. - Records input data, output data, timestamp, and verification on-chain. - Provides maximum transparency and auditability. - Higher gas costs due to larger data storage. - CLI usage: --settlement-mode individual-full - - Examples: - >>> from opengradient import x402SettlementMode - >>> mode = x402SettlementMode.PRIVATE - >>> print(mode.value) - 'private' - """ - - PRIVATE = "private" - BATCH_HASHED = "batch" - INDIVIDUAL_FULL = "individual" - - -class CandleOrder(IntEnum): - ASCENDING = 0 - DESCENDING = 1 - - -class CandleType(IntEnum): - HIGH = 0 - LOW = 1 - OPEN = 2 - CLOSE = 3 - VOLUME = 4 - - -@dataclass -class HistoricalInputQuery: - base: str - quote: str - total_candles: int - candle_duration_in_mins: int - order: CandleOrder - candle_types: List[CandleType] - - def to_abi_format(self) -> tuple: - """Convert to format expected by contract ABI""" - return ( - self.base, - self.quote, - self.total_candles, - self.candle_duration_in_mins, - int(self.order), - [int(ct) for ct in self.candle_types], - ) - - -@dataclass -class Number: - value: int - decimals: int - - -@dataclass -class NumberTensor: - """ - A container for numeric tensor data used as input for ONNX models. - - Attributes: - - name: Identifier for this tensor in the model. - - values: List of integer tuples representing the tensor data. - """ - - name: str - values: List[Tuple[int, int]] - - -@dataclass -class StringTensor: - """ - A container for string tensor data used as input for ONNX models. - - Attributes: - - name: Identifier for this tensor in the model. - - values: List of strings representing the tensor data. - """ - - name: str - values: List[str] - - -@dataclass -class ModelInput: - """ - A collection of tensor inputs required for ONNX model inference. - - Attributes: - - numbers: Collection of numeric tensors for the model. - - strings: Collection of string tensors for the model. - """ - - numbers: List[NumberTensor] - strings: List[StringTensor] - - -class InferenceMode(Enum): - """Enum for the different inference modes available for inference (VANILLA, ZKML, TEE)""" - - VANILLA = 0 - ZKML = 1 - TEE = 2 - - -@dataclass -class ModelOutput: - """ - Model output struct based on translations from smart contract. - """ - - numbers: Dict[str, np.ndarray] - strings: Dict[str, np.ndarray] - jsons: Dict[str, np.ndarray] # Converts to JSON dictionary - is_simulation_result: bool - - -@dataclass -class InferenceResult: - """ - Output for ML inference requests. - This class has two fields - transaction_hash (str): Blockchain hash for the transaction - model_output (Dict[str, np.ndarray]): Output of the ONNX model - """ - - transaction_hash: str - model_output: Dict[str, np.ndarray] - - -@dataclass -class StreamDelta: - """ - Represents a delta (incremental change) in a streaming response. - - Attributes: - content: Incremental text content (if any) - role: Message role (appears in first chunk) - tool_calls: Tool call information (if function calling is used) - """ - - content: Optional[str] = None - role: Optional[str] = None - tool_calls: Optional[List[Dict]] = None - - -@dataclass -class StreamChoice: - """ - Represents a choice in a streaming response. - - Attributes: - delta: The incremental changes in this chunk - index: Choice index (usually 0) - finish_reason: Reason for completion (appears in final chunk) - """ - - delta: StreamDelta - index: int = 0 - finish_reason: Optional[str] = None - - -@dataclass -class StreamUsage: - """ - Token usage information for a streaming response. - - Attributes: - prompt_tokens: Number of tokens in the prompt - completion_tokens: Number of tokens in the completion - total_tokens: Total tokens used - """ - - prompt_tokens: int - completion_tokens: int - total_tokens: int - - -@dataclass -class StreamChunk: - """ - Represents a single chunk in a streaming LLM response. - - This follows the OpenAI streaming format but is provider-agnostic. - Each chunk contains incremental data, with the final chunk including - usage information. - - Attributes: - choices: List of streaming choices (usually contains one choice) - model: Model identifier - usage: Token usage information (only in final chunk) - is_final: Whether this is the final chunk (before [DONE]) - tee_signature: RSA-PSS signature over the response, present on the final chunk - tee_timestamp: ISO timestamp from the TEE at signing time, present on the final chunk - tee_id: On-chain TEE registry ID of the enclave that served this request (final chunk only) - tee_endpoint: Endpoint URL of the TEE that served this request (final chunk only) - tee_payment_address: Payment address registered for the TEE (final chunk only) - """ - - choices: List[StreamChoice] - model: str - usage: Optional[StreamUsage] = None - is_final: bool = False - tee_signature: Optional[str] = None - tee_timestamp: Optional[str] = None - tee_id: Optional[str] = None - tee_endpoint: Optional[str] = None - tee_payment_address: Optional[str] = None - - @classmethod - def from_sse_data(cls, data: Dict) -> "StreamChunk": - """ - Parse a StreamChunk from SSE data dictionary. - - Args: - data: Dictionary parsed from SSE data line - - Returns: - StreamChunk instance - """ - choices = [] - for choice_data in data.get("choices", []): - # The TEE proxy sometimes sends SSE events using the non-streaming "message" - # key instead of the standard streaming "delta" key. Fall back gracefully. - delta_data = choice_data.get("delta") or choice_data.get("message") or {} - delta = StreamDelta(content=delta_data.get("content"), role=delta_data.get("role"), tool_calls=delta_data.get("tool_calls")) - choice = StreamChoice(delta=delta, index=choice_data.get("index", 0), finish_reason=choice_data.get("finish_reason")) - choices.append(choice) - - usage = None - if "usage" in data: - usage_data = data["usage"] - usage = StreamUsage( - prompt_tokens=usage_data.get("prompt_tokens", 0), - completion_tokens=usage_data.get("completion_tokens", 0), - total_tokens=usage_data.get("total_tokens", 0), - ) - - is_final = any(c.finish_reason is not None for c in choices) or usage is not None - - return cls( - choices=choices, - model=data.get("model", "unknown"), - usage=usage, - is_final=is_final, - tee_signature=data.get("tee_signature"), - tee_timestamp=data.get("tee_timestamp"), - ) - - -@dataclass -class TextGenerationStream: - """ - Iterator over ``StreamChunk`` objects from a streaming chat response. - - Returned by ``**`opengradient.client.llm`**.LLM.chat`` when - ``stream=True``. Iterate over the stream to receive incremental - chunks as they arrive from the server. - - Each ``StreamChunk`` contains a list of ``StreamChoice`` objects. - Access the incremental text via ``chunk.choices[0].delta.content``. - The final chunk will have ``is_final=True`` and may include - ``usage`` and ``tee_signature`` / ``tee_timestamp`` fields. - - Usage: - stream = client.llm.chat(model=og.TEE_LLM.CLAUDE_HAIKU_4_5, messages=[...], stream=True) - for chunk in stream: - if chunk.choices[0].delta.content: - print(chunk.choices[0].delta.content, end="") - """ - - _iterator: Union[Iterator[str], AsyncIterator[str]] - _is_async: bool = False - - def __iter__(self): - """Iterate over stream chunks.""" - return self - - def __aiter__(self): - """Return async iterator (required for async for loops).""" - if not self._is_async: - raise TypeError("Use __iter__ for sync iterators") - return self - - def __next__(self) -> StreamChunk: - """Get next stream chunk.""" - import json - - while True: - try: - line = next(self._iterator) # type: ignore[arg-type] - except StopIteration: - raise - - if not line or not line.strip(): - continue - - if not line.startswith("data: "): - continue - - data_str = line[6:] # Remove "data: " prefix - - if data_str.strip() == "[DONE]": - raise StopIteration - - try: - data = json.loads(data_str) - return StreamChunk.from_sse_data(data) - except json.JSONDecodeError: - # Skip malformed chunks - continue - - async def __anext__(self) -> StreamChunk: - """Get next stream chunk (async version).""" - import json - - if not self._is_async: - raise TypeError("Use __next__ for sync iterators") - - while True: - try: - line = await self._iterator.__anext__() # type: ignore[union-attr] - except StopAsyncIteration: - raise - - if not line or not line.strip(): - continue - - if not line.startswith("data: "): - continue - - data_str = line[6:] - - if data_str.strip() == "[DONE]": - raise StopAsyncIteration - - try: - data = json.loads(data_str) - return StreamChunk.from_sse_data(data) - except json.JSONDecodeError: - continue - - -@dataclass -class TextGenerationOutput: - """ - Output from a non-streaming ``chat()`` or ``completion()`` call. - - Returned by ``**`opengradient.client.llm`**.LLM.chat`` (when ``stream=False``) - and ``**`opengradient.client.llm`**.LLM.completion``. - - For **chat** requests the response is in ``chat_output``; for - **completion** requests it is in ``completion_output``. Only the - field that matches the request type will be populated. - - Every response includes a ``tee_signature`` and ``tee_timestamp`` - that can be used to cryptographically verify the inference was - performed inside a TEE enclave. - - Attributes: - transaction_hash: Blockchain transaction hash. Set to - ``"external"`` for TEE-routed providers. - finish_reason: Reason the model stopped generating - (e.g. ``"stop"``, ``"tool_call"``, ``"error"``). - Only populated for chat requests. - chat_output: Dictionary with the assistant message returned by - a chat request. Contains ``role``, ``content``, and - optionally ``tool_calls``. - completion_output: Raw text returned by a completion request. - payment_hash: Payment hash for the x402 transaction. - tee_signature: RSA-PSS signature over the response produced - by the TEE enclave. - tee_timestamp: ISO-8601 timestamp from the TEE at signing - time. - """ - - transaction_hash: str - """Blockchain transaction hash. Set to ``"external"`` for TEE-routed providers.""" - - finish_reason: Optional[str] = None - """Reason the model stopped generating (e.g. ``"stop"``, ``"tool_call"``, ``"error"``). Only populated for chat requests.""" - - chat_output: Optional[Dict] = None - """Dictionary with the assistant message returned by a chat request. Contains ``role``, ``content``, and optionally ``tool_calls``.""" - - completion_output: Optional[str] = None - """Raw text returned by a completion request.""" - - payment_hash: Optional[str] = None - """Payment hash for the x402 transaction.""" - - tee_signature: Optional[str] = None - """RSA-PSS signature over the response produced by the TEE enclave.""" - - tee_timestamp: Optional[str] = None - """ISO-8601 timestamp from the TEE at signing time.""" - - tee_id: Optional[str] = None - """On-chain TEE registry ID (keccak256 of the enclave's public key) of the TEE that served this request.""" - - tee_endpoint: Optional[str] = None - """Endpoint URL of the TEE that served this request, as registered on-chain.""" - - tee_payment_address: Optional[str] = None - """Payment address registered for the TEE that served this request.""" - - -@dataclass -class AbiFunction: - name: str - inputs: List[Union[str, "AbiFunction"]] - outputs: List[Union[str, "AbiFunction"]] - state_mutability: str - - -@dataclass -class Abi: - functions: List[AbiFunction] - - @classmethod - def from_json(cls, abi_json): - functions = [] - for item in abi_json: - if item["type"] == "function": - inputs = cls._parse_inputs_outputs(item["inputs"]) - outputs = cls._parse_inputs_outputs(item["outputs"]) - functions.append(AbiFunction(name=item["name"], inputs=inputs, outputs=outputs, state_mutability=item["stateMutability"])) - return cls(functions=functions) - - @staticmethod - def _parse_inputs_outputs(items): - result = [] - for item in items: - if "components" in item: - result.append( - AbiFunction(name=item["name"], inputs=Abi._parse_inputs_outputs(item["components"]), outputs=[], state_mutability="") - ) - else: - result.append(f"{item['name']}:{item['type']}") - return result - - -class TEE_LLM(str, Enum): - """ - Enum for LLM models available for TEE (Trusted Execution Environment) execution. - - TEE mode provides cryptographic verification that inference was performed - correctly in a secure enclave. Use this for applications requiring - auditability and tamper-proof AI inference. - - Usage: - # TEE-verified inference - result = client.llm.chat( - model=og.TEE_LLM.GPT_5, - messages=[{"role": "user", "content": "Hello"}], - ) - """ - - # OpenAI models via TEE - GPT_4_1_2025_04_14 = "openai/gpt-4.1-2025-04-14" - O4_MINI = "openai/o4-mini" - GPT_5 = "openai/gpt-5" - GPT_5_MINI = "openai/gpt-5-mini" - GPT_5_2 = "openai/gpt-5.2" - - # Anthropic models via TEE - CLAUDE_SONNET_4_5 = "anthropic/claude-sonnet-4-5" - CLAUDE_SONNET_4_6 = "anthropic/claude-sonnet-4-6" - CLAUDE_HAIKU_4_5 = "anthropic/claude-haiku-4-5" - CLAUDE_OPUS_4_5 = "anthropic/claude-opus-4-5" - CLAUDE_OPUS_4_6 = "anthropic/claude-opus-4-6" - - # Google models via TEE - GEMINI_2_5_FLASH = "google/gemini-2.5-flash" - GEMINI_2_5_PRO = "google/gemini-2.5-pro" - GEMINI_2_5_FLASH_LITE = "google/gemini-2.5-flash-lite" - GEMINI_3_PRO = "google/gemini-3-pro-preview" - GEMINI_3_FLASH = "google/gemini-3-flash-preview" - - # xAI Grok models via TEE - GROK_4 = "x-ai/grok-4" - GROK_4_FAST = "x-ai/grok-4-fast" - GROK_4_1_FAST = "x-ai/grok-4-1-fast" - GROK_4_1_FAST_NON_REASONING = "x-ai/grok-4-1-fast-non-reasoning" - - -@dataclass -class SchedulerParams: - frequency: int - duration_hours: int - - @property - def end_time(self) -> int: - return int(time.time()) + (self.duration_hours * 60 * 60) - - @staticmethod - def from_dict(data: Optional[Dict[str, int]]) -> Optional["SchedulerParams"]: - if data is None: - return None - return SchedulerParams(frequency=data.get("frequency", 600), duration_hours=data.get("duration_hours", 2)) - - -@dataclass -class ModelRepository: - name: str - initialVersion: str - - -@dataclass -class FileUploadResult: - modelCid: str - size: int +""" +OpenGradient Specific Types +""" + +import logging +import time +from dataclasses import dataclass +from enum import Enum, IntEnum +from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union + +import numpy as np + +logger = logging.getLogger(__name__) + + +class x402SettlementMode(str, Enum): + """ + Settlement modes for x402 payment protocol transactions. + + These modes control how inference data is recorded on-chain for payment settlement + and auditability. Each mode offers different trade-offs between data completeness, + privacy, and transaction costs. + + Attributes: + PRIVATE: Payment-only settlement. + Only the payment is settled on-chain — no input or output hashes are posted. + Your inference data remains completely off-chain, ensuring maximum privacy. + Suitable when payment settlement is required without any on-chain record of execution. + CLI usage: --settlement-mode private + + BATCH_HASHED: Batch settlement with hashes (default). + Aggregates multiple inferences into a single settlement transaction + using a Merkle tree containing input hashes, output hashes, and signatures. + Most cost-efficient for high-volume applications. + CLI usage: --settlement-mode batch-hashed + + INDIVIDUAL_FULL: Individual settlement with full metadata. + Records input data, output data, timestamp, and verification on-chain. + Provides maximum transparency and auditability. + Higher gas costs due to larger data storage. + CLI usage: --settlement-mode individual-full + + Examples: + >>> from opengradient import x402SettlementMode + >>> mode = x402SettlementMode.PRIVATE + >>> print(mode.value) + 'private' + """ + + PRIVATE = "private" + BATCH_HASHED = "batch" + INDIVIDUAL_FULL = "individual" + + +class CandleOrder(IntEnum): + ASCENDING = 0 + DESCENDING = 1 + + +class CandleType(IntEnum): + HIGH = 0 + LOW = 1 + OPEN = 2 + CLOSE = 3 + VOLUME = 4 + + +@dataclass +class HistoricalInputQuery: + base: str + quote: str + total_candles: int + candle_duration_in_mins: int + order: CandleOrder + candle_types: List[CandleType] + + def to_abi_format(self) -> tuple: + """Convert to format expected by contract ABI""" + return ( + self.base, + self.quote, + self.total_candles, + self.candle_duration_in_mins, + int(self.order), + [int(ct) for ct in self.candle_types], + ) + + +@dataclass +class Number: + value: int + decimals: int + + +@dataclass +class NumberTensor: + """ + A container for numeric tensor data used as input for ONNX models. + + Attributes: + + name: Identifier for this tensor in the model. + + values: List of integer tuples representing the tensor data. + """ + + name: str + values: List[Tuple[int, int]] + + +@dataclass +class StringTensor: + """ + A container for string tensor data used as input for ONNX models. + + Attributes: + + name: Identifier for this tensor in the model. + + values: List of strings representing the tensor data. + """ + + name: str + values: List[str] + + +@dataclass +class ModelInput: + """ + A collection of tensor inputs required for ONNX model inference. + + Attributes: + + numbers: Collection of numeric tensors for the model. + + strings: Collection of string tensors for the model. + """ + + numbers: List[NumberTensor] + strings: List[StringTensor] + + +class InferenceMode(Enum): + """Enum for the different inference modes available for inference (VANILLA, ZKML, TEE)""" + + VANILLA = 0 + ZKML = 1 + TEE = 2 + + +@dataclass +class ModelOutput: + """ + Model output struct based on translations from smart contract. + """ + + numbers: Dict[str, np.ndarray] + strings: Dict[str, np.ndarray] + jsons: Dict[str, np.ndarray] # Converts to JSON dictionary + is_simulation_result: bool + + +@dataclass +class InferenceResult: + """ + Output for ML inference requests. + This class has two fields + transaction_hash (str): Blockchain hash for the transaction + model_output (Dict[str, np.ndarray]): Output of the ONNX model + """ + + transaction_hash: str + model_output: Dict[str, np.ndarray] + + +@dataclass +class StreamDelta: + """ + Represents a delta (incremental change) in a streaming response. + + Attributes: + content: Incremental text content (if any) + role: Message role (appears in first chunk) + tool_calls: Tool call information (if function calling is used) + """ + + content: Optional[str] = None + role: Optional[str] = None + tool_calls: Optional[List[Dict]] = None + + +@dataclass +class StreamChoice: + """ + Represents a choice in a streaming response. + + Attributes: + delta: The incremental changes in this chunk + index: Choice index (usually 0) + finish_reason: Reason for completion (appears in final chunk) + """ + + delta: StreamDelta + index: int = 0 + finish_reason: Optional[str] = None + + +@dataclass +class StreamUsage: + """ + Token usage information for a streaming response. + + Attributes: + prompt_tokens: Number of tokens in the prompt + completion_tokens: Number of tokens in the completion + total_tokens: Total tokens used + """ + + prompt_tokens: int + completion_tokens: int + total_tokens: int + + +@dataclass +class StreamChunk: + """ + Represents a single chunk in a streaming LLM response. + + This follows the OpenAI streaming format but is provider-agnostic. + Each chunk contains incremental data, with the final chunk including + usage information. + + Attributes: + choices: List of streaming choices (usually contains one choice) + model: Model identifier + usage: Token usage information (only in final chunk) + is_final: Whether this is the final chunk (before [DONE]) + tee_signature: RSA-PSS signature over the response, present on the final chunk + tee_timestamp: ISO timestamp from the TEE at signing time, present on the final chunk + tee_id: On-chain TEE registry ID of the enclave that served this request (final chunk only) + tee_endpoint: Endpoint URL of the TEE that served this request (final chunk only) + tee_payment_address: Payment address registered for the TEE (final chunk only) + """ + + choices: List[StreamChoice] + model: str + usage: Optional[StreamUsage] = None + is_final: bool = False + tee_signature: Optional[str] = None + tee_timestamp: Optional[str] = None + tee_id: Optional[str] = None + tee_endpoint: Optional[str] = None + tee_payment_address: Optional[str] = None + + @classmethod + def from_sse_data(cls, data: Dict) -> "StreamChunk": + """ + Parse a StreamChunk from SSE data dictionary. + + Args: + data: Dictionary parsed from SSE data line + + Returns: + StreamChunk instance + """ + choices = [] + for choice_data in data.get("choices", []): + # The TEE proxy sometimes sends SSE events using the non-streaming "message" + # key instead of the standard streaming "delta" key. Fall back gracefully. + delta_data = choice_data.get("delta") or choice_data.get("message") or {} + delta = StreamDelta(content=delta_data.get("content"), role=delta_data.get("role"), tool_calls=delta_data.get("tool_calls")) + choice = StreamChoice(delta=delta, index=choice_data.get("index", 0), finish_reason=choice_data.get("finish_reason")) + choices.append(choice) + + usage = None + if "usage" in data: + usage_data = data["usage"] + usage = StreamUsage( + prompt_tokens=usage_data.get("prompt_tokens", 0), + completion_tokens=usage_data.get("completion_tokens", 0), + total_tokens=usage_data.get("total_tokens", 0), + ) + + is_final = any(c.finish_reason is not None for c in choices) or usage is not None + + return cls( + choices=choices, + model=data.get("model", "unknown"), + usage=usage, + is_final=is_final, + tee_signature=data.get("tee_signature"), + tee_timestamp=data.get("tee_timestamp"), + ) + + +@dataclass +class TextGenerationStream: + """ + Iterator over ``StreamChunk`` objects from a streaming chat response. + + Returned by ``**`opengradient.client.llm`**.LLM.chat`` when + ``stream=True``. Iterate over the stream to receive incremental + chunks as they arrive from the server. + + Each ``StreamChunk`` contains a list of ``StreamChoice`` objects. + Access the incremental text via ``chunk.choices[0].delta.content``. + The final chunk will have ``is_final=True`` and may include + ``usage`` and ``tee_signature`` / ``tee_timestamp`` fields. + + Usage: + stream = client.llm.chat(model=og.TEE_LLM.CLAUDE_HAIKU_4_5, messages=[...], stream=True) + for chunk in stream: + if chunk.choices[0].delta.content: + print(chunk.choices[0].delta.content, end="") + """ + + _iterator: Union[Iterator[str], AsyncIterator[str]] + _is_async: bool = False + + def __iter__(self): + """Iterate over stream chunks.""" + return self + + def __aiter__(self): + """Return async iterator (required for async for loops).""" + if not self._is_async: + raise TypeError("Use __iter__ for sync iterators") + return self + + def __next__(self) -> StreamChunk: + """Get next stream chunk.""" + import json + + while True: + try: + line = next(self._iterator) # type: ignore[arg-type] + except StopIteration: + raise + + if not line or not line.strip(): + continue + + if not line.startswith("data: "): + continue + + data_str = line[6:] # Remove "data: " prefix + + if data_str.strip() == "[DONE]": + raise StopIteration + + try: + data = json.loads(data_str) + return StreamChunk.from_sse_data(data) + except json.JSONDecodeError: + logger.warning("Skipping malformed SSE JSON: %r", data_str) + continue + + async def __anext__(self) -> StreamChunk: + """Get next stream chunk (async version).""" + import json + + if not self._is_async: + raise TypeError("Use __next__ for sync iterators") + + while True: + try: + line = await self._iterator.__anext__() # type: ignore[union-attr] + except StopAsyncIteration: + raise + + if not line or not line.strip(): + continue + + if not line.startswith("data: "): + continue + + data_str = line[6:] + + if data_str.strip() == "[DONE]": + raise StopAsyncIteration + + try: + data = json.loads(data_str) + return StreamChunk.from_sse_data(data) + except json.JSONDecodeError: + logger.warning("Skipping malformed SSE JSON: %r", data_str) + continue + + +@dataclass +class TextGenerationOutput: + """ + Output from a non-streaming ``chat()`` or ``completion()`` call. + + Returned by ``**`opengradient.client.llm`**.LLM.chat`` (when ``stream=False``) + and ``**`opengradient.client.llm`**.LLM.completion``. + + For **chat** requests the response is in ``chat_output``; for + **completion** requests it is in ``completion_output``. Only the + field that matches the request type will be populated. + + Every response includes a ``tee_signature`` and ``tee_timestamp`` + that can be used to cryptographically verify the inference was + performed inside a TEE enclave. + + Attributes: + transaction_hash: Blockchain transaction hash. Set to + ``"external"`` for TEE-routed providers. + finish_reason: Reason the model stopped generating + (e.g. ``"stop"``, ``"tool_call"``, ``"error"``). + Only populated for chat requests. + chat_output: Dictionary with the assistant message returned by + a chat request. Contains ``role``, ``content``, and + optionally ``tool_calls``. + completion_output: Raw text returned by a completion request. + payment_hash: Payment hash for the x402 transaction. + tee_signature: RSA-PSS signature over the response produced + by the TEE enclave. + tee_timestamp: ISO-8601 timestamp from the TEE at signing + time. + """ + + transaction_hash: str + """Blockchain transaction hash. Set to ``"external"`` for TEE-routed providers.""" + + finish_reason: Optional[str] = None + """Reason the model stopped generating (e.g. ``"stop"``, ``"tool_call"``, ``"error"``). Only populated for chat requests.""" + + chat_output: Optional[Dict] = None + """Dictionary with the assistant message returned by a chat request. Contains ``role``, ``content``, and optionally ``tool_calls``.""" + + completion_output: Optional[str] = None + """Raw text returned by a completion request.""" + + payment_hash: Optional[str] = None + """Payment hash for the x402 transaction.""" + + tee_signature: Optional[str] = None + """RSA-PSS signature over the response produced by the TEE enclave.""" + + tee_timestamp: Optional[str] = None + """ISO-8601 timestamp from the TEE at signing time.""" + + tee_id: Optional[str] = None + """On-chain TEE registry ID (keccak256 of the enclave's public key) of the TEE that served this request.""" + + tee_endpoint: Optional[str] = None + """Endpoint URL of the TEE that served this request, as registered on-chain.""" + + tee_payment_address: Optional[str] = None + """Payment address registered for the TEE that served this request.""" + + +@dataclass +class AbiFunction: + name: str + inputs: List[Union[str, "AbiFunction"]] + outputs: List[Union[str, "AbiFunction"]] + state_mutability: str + + +@dataclass +class Abi: + functions: List[AbiFunction] + + @classmethod + def from_json(cls, abi_json): + functions = [] + for item in abi_json: + if item["type"] == "function": + inputs = cls._parse_inputs_outputs(item["inputs"]) + outputs = cls._parse_inputs_outputs(item["outputs"]) + functions.append(AbiFunction(name=item["name"], inputs=inputs, outputs=outputs, state_mutability=item["stateMutability"])) + return cls(functions=functions) + + @staticmethod + def _parse_inputs_outputs(items): + result = [] + for item in items: + if "components" in item: + result.append( + AbiFunction(name=item["name"], inputs=Abi._parse_inputs_outputs(item["components"]), outputs=[], state_mutability="") + ) + else: + result.append(f"{item['name']}:{item['type']}") + return result + + +class TEE_LLM(str, Enum): + """ + Enum for LLM models available for TEE (Trusted Execution Environment) execution. + + TEE mode provides cryptographic verification that inference was performed + correctly in a secure enclave. Use this for applications requiring + auditability and tamper-proof AI inference. + + Usage: + # TEE-verified inference + result = client.llm.chat( + model=og.TEE_LLM.GPT_5, + messages=[{"role": "user", "content": "Hello"}], + ) + """ + + # OpenAI models via TEE + GPT_4_1_2025_04_14 = "openai/gpt-4.1-2025-04-14" + O4_MINI = "openai/o4-mini" + GPT_5 = "openai/gpt-5" + GPT_5_MINI = "openai/gpt-5-mini" + GPT_5_2 = "openai/gpt-5.2" + + # Anthropic models via TEE + CLAUDE_SONNET_4_5 = "anthropic/claude-sonnet-4-5" + CLAUDE_SONNET_4_6 = "anthropic/claude-sonnet-4-6" + CLAUDE_HAIKU_4_5 = "anthropic/claude-haiku-4-5" + CLAUDE_OPUS_4_5 = "anthropic/claude-opus-4-5" + CLAUDE_OPUS_4_6 = "anthropic/claude-opus-4-6" + + # Google models via TEE + GEMINI_2_5_FLASH = "google/gemini-2.5-flash" + GEMINI_2_5_PRO = "google/gemini-2.5-pro" + GEMINI_2_5_FLASH_LITE = "google/gemini-2.5-flash-lite" + GEMINI_3_PRO = "google/gemini-3-pro-preview" + GEMINI_3_FLASH = "google/gemini-3-flash-preview" + + # xAI Grok models via TEE + GROK_4 = "x-ai/grok-4" + GROK_4_FAST = "x-ai/grok-4-fast" + GROK_4_1_FAST = "x-ai/grok-4-1-fast" + GROK_4_1_FAST_NON_REASONING = "x-ai/grok-4-1-fast-non-reasoning" + + +@dataclass +class SchedulerParams: + frequency: int + duration_hours: int + + @property + def end_time(self) -> int: + return int(time.time()) + (self.duration_hours * 60 * 60) + + @staticmethod + def from_dict(data: Optional[Dict[str, int]]) -> Optional["SchedulerParams"]: + if data is None: + return None + return SchedulerParams(frequency=data.get("frequency", 600), duration_hours=data.get("duration_hours", 2)) + + +@dataclass +class ModelRepository: + name: str + initialVersion: str + + +@dataclass +class FileUploadResult: + modelCid: str + size: int diff --git a/tests/client_test.py b/tests/client_test.py index 6829fc9..0f1e66b 100644 --- a/tests/client_test.py +++ b/tests/client_test.py @@ -1,192 +1,349 @@ -import json -from unittest.mock import MagicMock, mock_open, patch - -import pytest - -from opengradient.client.llm import LLM -from opengradient.client.model_hub import ModelHub -from opengradient.types import ( - StreamChunk, - x402SettlementMode, -) - -FAKE_PRIVATE_KEY = "0x" + "a" * 64 - -# --- Fixtures --- - - -@pytest.fixture -def mock_tee_registry(): - """Mock the TEE registry so LLM.__init__ doesn't need a live registry.""" - with ( - patch("opengradient.client.llm.TEERegistry") as mock_tee_registry, - patch( - "opengradient.client.tee_connection.build_ssl_context_from_der", - return_value=MagicMock(), - ), - ): - mock_tee = MagicMock() - mock_tee.endpoint = "https://test.tee.server" - mock_tee.tls_cert_der = b"fake-der" - mock_tee.tee_id = "test-tee-id" - mock_tee.payment_address = "0xTestPaymentAddress" - mock_tee_registry.return_value.get_llm_tee.return_value = mock_tee - yield mock_tee_registry - - -@pytest.fixture -def mock_web3(): - """Create a mock Web3 instance for Alpha.""" - with patch("opengradient.client.alpha.Web3") as mock: - mock_instance = MagicMock() - mock.return_value = mock_instance - mock.HTTPProvider.return_value = MagicMock() - - mock_instance.eth.account.from_key.return_value = MagicMock(address="0x1234567890abcdef1234567890abcdef12345678") - mock_instance.eth.get_transaction_count.return_value = 0 - mock_instance.eth.gas_price = 1000000000 - mock_instance.eth.contract.return_value = MagicMock() - - yield mock_instance - - -@pytest.fixture -def mock_abi_files(): - """Mock ABI file reads.""" - inference_abi = [{"type": "function", "name": "run", "inputs": [], "outputs": []}] - precompile_abi = [{"type": "function", "name": "infer", "inputs": [], "outputs": []}] - - def mock_file_open(path, *args, **kwargs): - if "inference.abi" in str(path): - return mock_open(read_data=json.dumps(inference_abi))() - elif "InferencePrecompile.abi" in str(path): - return mock_open(read_data=json.dumps(precompile_abi))() - return mock_open(read_data="{}")() - - with patch("builtins.open", side_effect=mock_file_open): - yield - - -# --- LLM Initialization Tests --- - - -class TestLLMInitialization: - def test_llm_initialization(self, mock_tee_registry): - """Test basic LLM initialization.""" - llm = LLM(private_key=FAKE_PRIVATE_KEY) - assert llm._tee.get().endpoint == "https://test.tee.server" - - def test_llm_initialization_custom_url(self, mock_tee_registry): - """Test LLM initialization with custom server URL.""" - custom_llm_url = "https://custom.llm.server" - llm = LLM.from_url(private_key=FAKE_PRIVATE_KEY, llm_server_url=custom_llm_url) - assert llm._tee.get().endpoint == custom_llm_url - - -# --- ModelHub Authentication Tests --- - - -class TestAuthentication: - def test_login_to_hub_success(self): - """Test successful login to hub.""" - with ( - patch("opengradient.client.model_hub._FIREBASE_CONFIG", {"apiKey": "fake"}), - patch("opengradient.client.model_hub.firebase") as mock_firebase, - ): - mock_auth = MagicMock() - mock_auth.sign_in_with_email_and_password.return_value = { - "idToken": "success_token", - "email": "user@test.com", - } - mock_firebase.initialize_app.return_value.auth.return_value = mock_auth - - hub = ModelHub(email="user@test.com", password="password123") - - mock_auth.sign_in_with_email_and_password.assert_called_once_with("user@test.com", "password123") - assert hub._hub_user["idToken"] == "success_token" - - def test_login_to_hub_failure(self): - """Test login failure raises exception.""" - with ( - patch("opengradient.client.model_hub._FIREBASE_CONFIG", {"apiKey": "fake"}), - patch("opengradient.client.model_hub.firebase") as mock_firebase, - ): - mock_auth = MagicMock() - mock_auth.sign_in_with_email_and_password.side_effect = Exception("Invalid credentials") - mock_firebase.initialize_app.return_value.auth.return_value = mock_auth - - with pytest.raises(Exception, match="Invalid credentials"): - ModelHub(email="user@test.com", password="wrong_password") - - -# --- StreamChunk Tests --- - - -class TestStreamChunk: - def test_from_sse_data_basic(self): - """Test parsing basic SSE data.""" - data = { - "model": "gpt-4o", - "choices": [ - { - "index": 0, - "delta": {"content": "Hello"}, - "finish_reason": None, - } - ], - } - - chunk = StreamChunk.from_sse_data(data) - - assert chunk.model == "gpt-4o" - assert len(chunk.choices) == 1 - assert chunk.choices[0].delta.content == "Hello" - assert not chunk.is_final - - def test_from_sse_data_with_finish_reason(self): - """Test parsing SSE data with finish reason.""" - data = { - "model": "gpt-4o", - "choices": [ - { - "index": 0, - "delta": {}, - "finish_reason": "stop", - } - ], - } - - chunk = StreamChunk.from_sse_data(data) - - assert chunk.is_final - assert chunk.choices[0].finish_reason == "stop" - - def test_from_sse_data_with_usage(self): - """Test parsing SSE data with usage info.""" - data = { - "model": "gpt-4o", - "choices": [], - "usage": { - "prompt_tokens": 10, - "completion_tokens": 20, - "total_tokens": 30, - }, - } - - chunk = StreamChunk.from_sse_data(data) - - assert chunk.usage is not None - assert chunk.usage.prompt_tokens == 10 - assert chunk.usage.total_tokens == 30 - assert chunk.is_final - - -# --- x402 Settlement Mode Tests --- - - -class TestX402SettlementMode: - def test_settlement_modes_values(self): - """Test settlement mode enum values.""" - assert x402SettlementMode.PRIVATE == "private" - assert x402SettlementMode.BATCH_HASHED == "batch" - assert x402SettlementMode.INDIVIDUAL_FULL == "individual" +import json +import logging +from unittest.mock import AsyncMock, MagicMock, mock_open, patch + +import pytest + +from opengradient.client.alpha import Alpha +from opengradient.client.llm import LLM +from opengradient.client.model_hub import ModelHub +from opengradient.types import ( + StreamChunk, + TextGenerationStream, + x402SettlementMode, +) + +FAKE_PRIVATE_KEY = "0x" + "a" * 64 + +# --- Fixtures --- + + +@pytest.fixture +def mock_tee_registry(): + """Mock the TEE registry so LLM.__init__ doesn't need a live registry.""" + with ( + patch("opengradient.client.llm.TEERegistry") as mock_tee_registry, + patch( + "opengradient.client.tee_connection.build_ssl_context_from_der", + return_value=MagicMock(), + ), + ): + mock_tee = MagicMock() + mock_tee.endpoint = "https://test.tee.server" + mock_tee.tls_cert_der = b"fake-der" + mock_tee.tee_id = "test-tee-id" + mock_tee.payment_address = "0xTestPaymentAddress" + mock_tee_registry.return_value.get_llm_tee.return_value = mock_tee + yield mock_tee_registry + + +@pytest.fixture +def mock_web3(): + """Create a mock Web3 instance for Alpha.""" + with patch("opengradient.client.alpha.Web3") as mock: + mock_instance = MagicMock() + mock.return_value = mock_instance + mock.HTTPProvider.return_value = MagicMock() + + mock_instance.eth.account.from_key.return_value = MagicMock(address="0x1234567890abcdef1234567890abcdef12345678") + mock_instance.eth.get_transaction_count.return_value = 0 + mock_instance.eth.gas_price = 1000000000 + mock_instance.eth.contract.return_value = MagicMock() + + yield mock_instance + + +@pytest.fixture +def mock_abi_files(): + """Mock ABI file reads.""" + inference_abi = [{"type": "function", "name": "run", "inputs": [], "outputs": []}] + precompile_abi = [{"type": "function", "name": "infer", "inputs": [], "outputs": []}] + + def mock_file_open(path, *args, **kwargs): + if "inference.abi" in str(path): + return mock_open(read_data=json.dumps(inference_abi))() + elif "InferencePrecompile.abi" in str(path): + return mock_open(read_data=json.dumps(precompile_abi))() + return mock_open(read_data="{}")() + + with patch("builtins.open", side_effect=mock_file_open): + yield + + +# --- LLM Initialization Tests --- + + +class TestLLMInitialization: + def test_llm_initialization(self, mock_tee_registry): + """Test basic LLM initialization.""" + llm = LLM(private_key=FAKE_PRIVATE_KEY) + assert llm._tee.get().endpoint == "https://test.tee.server" + + def test_llm_initialization_custom_url(self, mock_tee_registry): + """Test LLM initialization with custom server URL.""" + custom_llm_url = "https://custom.llm.server" + llm = LLM.from_url(private_key=FAKE_PRIVATE_KEY, llm_server_url=custom_llm_url) + assert llm._tee.get().endpoint == custom_llm_url + + +# --- ModelHub Authentication Tests --- + + +class TestAuthentication: + def test_login_to_hub_success(self): + """Test successful login to hub.""" + with ( + patch("opengradient.client.model_hub._FIREBASE_CONFIG", {"apiKey": "fake"}), + patch("opengradient.client.model_hub.firebase") as mock_firebase, + ): + mock_auth = MagicMock() + mock_auth.sign_in_with_email_and_password.return_value = { + "idToken": "success_token", + "email": "user@test.com", + } + mock_firebase.initialize_app.return_value.auth.return_value = mock_auth + + hub = ModelHub(email="user@test.com", password="password123") + + mock_auth.sign_in_with_email_and_password.assert_called_once_with("user@test.com", "password123") + assert hub._hub_user["idToken"] == "success_token" + + def test_login_to_hub_failure(self): + """Test login failure raises exception.""" + with ( + patch("opengradient.client.model_hub._FIREBASE_CONFIG", {"apiKey": "fake"}), + patch("opengradient.client.model_hub.firebase") as mock_firebase, + ): + mock_auth = MagicMock() + mock_auth.sign_in_with_email_and_password.side_effect = Exception("Invalid credentials") + mock_firebase.initialize_app.return_value.auth.return_value = mock_auth + + with pytest.raises(Exception, match="Invalid credentials"): + ModelHub(email="user@test.com", password="wrong_password") + + +# --- StreamChunk Tests --- + + +class TestStreamChunk: + def test_from_sse_data_basic(self): + """Test parsing basic SSE data.""" + data = { + "model": "gpt-4o", + "choices": [ + { + "index": 0, + "delta": {"content": "Hello"}, + "finish_reason": None, + } + ], + } + + chunk = StreamChunk.from_sse_data(data) + + assert chunk.model == "gpt-4o" + assert len(chunk.choices) == 1 + assert chunk.choices[0].delta.content == "Hello" + assert not chunk.is_final + + def test_from_sse_data_with_finish_reason(self): + """Test parsing SSE data with finish reason.""" + data = { + "model": "gpt-4o", + "choices": [ + { + "index": 0, + "delta": {}, + "finish_reason": "stop", + } + ], + } + + chunk = StreamChunk.from_sse_data(data) + + assert chunk.is_final + assert chunk.choices[0].finish_reason == "stop" + + def test_from_sse_data_with_usage(self): + """Test parsing SSE data with usage info.""" + data = { + "model": "gpt-4o", + "choices": [], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 20, + "total_tokens": 30, + }, + } + + chunk = StreamChunk.from_sse_data(data) + + assert chunk.usage is not None + assert chunk.usage.prompt_tokens == 10 + assert chunk.usage.total_tokens == 30 + assert chunk.is_final + + +# --- x402 Settlement Mode Tests --- + + +class TestX402SettlementMode: + def test_settlement_modes_values(self): + """Test settlement mode enum values.""" + assert x402SettlementMode.PRIVATE == "private" + assert x402SettlementMode.BATCH_HASHED == "batch" + assert x402SettlementMode.INDIVIDUAL_FULL == "individual" + + +# --- Fix #4: choices[] guard --- + + +class TestChoicesGuard: + """_chat_request must raise RuntimeError for any malformed choices value.""" + + @pytest.mark.asyncio + async def test_choices_empty_list_raises(self, mock_tee_registry): + """choices = [] → RuntimeError.""" + llm = LLM(private_key="0x" + "a" * 64) + mock_response = AsyncMock() + mock_response.raise_for_status = MagicMock() + mock_response.aread = AsyncMock(return_value=json.dumps({"choices": []}).encode()) + + mock_tee = MagicMock() + mock_tee.http_client.post = AsyncMock(return_value=mock_response) + mock_tee.metadata.return_value = {"tee_id": None, "tee_endpoint": None, "tee_payment_address": None} + llm._tee.get = MagicMock(return_value=mock_tee) + llm._tee.ensure_refresh_loop = MagicMock() + + from opengradient.types import TEE_LLM + with pytest.raises(RuntimeError, match="choices"): + await llm.chat(model=TEE_LLM.CLAUDE_HAIKU_4_5, messages=[{"role": "user", "content": "hi"}]) + + @pytest.mark.asyncio + async def test_choices_contains_none_raises(self, mock_tee_registry): + """choices = [None] passes the old 'not choices' guard but must now raise RuntimeError.""" + llm = LLM(private_key="0x" + "a" * 64) + mock_response = AsyncMock() + mock_response.raise_for_status = MagicMock() + mock_response.aread = AsyncMock(return_value=json.dumps({"choices": [None]}).encode()) + + mock_tee = MagicMock() + mock_tee.http_client.post = AsyncMock(return_value=mock_response) + mock_tee.metadata.return_value = {"tee_id": None, "tee_endpoint": None, "tee_payment_address": None} + llm._tee.get = MagicMock(return_value=mock_tee) + llm._tee.ensure_refresh_loop = MagicMock() + + from opengradient.types import TEE_LLM + with pytest.raises(RuntimeError, match="choices"): + await llm.chat(model=TEE_LLM.CLAUDE_HAIKU_4_5, messages=[{"role": "user", "content": "hi"}]) + + +# --- Fix #6: SSE JSON logging --- + + +class TestSSEJsonLogging: + """Malformed SSE JSON must emit a warning log, not silently disappear.""" + + def test_malformed_sse_logs_warning_in_sync_stream(self): + """TextGenerationStream.__next__ logs a warning for broken JSON.""" + lines = iter(["data: { broken json \n", "data: [DONE]\n"]) + stream = TextGenerationStream(_iterator=lines, _is_async=False) + + with patch("opengradient.types.logger") as mock_logger: + try: + next(stream) + except StopIteration: + pass + mock_logger.warning.assert_called_once() + assert "Skipping malformed SSE JSON" in mock_logger.warning.call_args[0][0] + + def test_valid_sse_does_not_log_warning(self): + """Well-formed SSE chunks must not trigger any warning.""" + valid_data = json.dumps({ + "model": "test", + "choices": [{"index": 0, "delta": {"content": "hi"}, "finish_reason": None}], + }) + lines = iter([f"data: {valid_data}\n", "data: [DONE]\n"]) + stream = TextGenerationStream(_iterator=lines, _is_async=False) + + with patch("opengradient.types.logger") as mock_logger: + chunk = next(stream) + mock_logger.warning.assert_not_called() + assert chunk.choices[0].delta.content == "hi" + + +# --- Fix #13 & #14: Alpha exception type and logging --- + + +class TestAlphaErrorHandling: + """Alpha.new_workflow raises RuntimeError (not bare Exception) on deployment failure, + and uses logger instead of print() for non-fatal warnings.""" + + @pytest.fixture + def alpha(self): + with patch("opengradient.client.alpha.Web3") as mock_web3_cls: + mock_w3 = MagicMock() + mock_web3_cls.return_value = mock_w3 + mock_web3_cls.HTTPProvider.return_value = MagicMock() + mock_web3_cls.to_checksum_address.side_effect = lambda x: x + mock_w3.eth.account.from_key.return_value = MagicMock(address="0xDEAD") + mock_w3.eth.get_transaction_count.return_value = 1 + mock_w3.eth.gas_price = 1000000000 + mock_w3.eth.chain_id = 1 + yield Alpha(private_key="0x" + "a" * 64) + + def test_deployment_failure_raises_runtime_error(self, alpha): + """Status=0 receipt must raise RuntimeError, not bare Exception.""" + mock_contract = MagicMock() + mock_contract.constructor.return_value.estimate_gas.return_value = 100000 + mock_contract.constructor.return_value.build_transaction.return_value = {} + + fake_receipt = {"status": 0, "transactionHash": b"\xde\xad"} + alpha._blockchain.eth.contract.return_value = mock_contract + alpha._blockchain.eth.send_raw_transaction.return_value = b"\xde\xad" + alpha._blockchain.eth.wait_for_transaction_receipt.return_value = fake_receipt + + from opengradient.types import HistoricalInputQuery, CandleOrder, CandleType + query = HistoricalInputQuery("BTC", "USDT", 10, 60, CandleOrder.DESCENDING, [CandleType.CLOSE]) + + with patch("opengradient.client.alpha.get_abi", return_value=[]): + with patch("opengradient.client.alpha.get_bin", return_value="0x"): + with patch("opengradient.client.alpha.run_with_retry", side_effect=lambda fn, *a, **kw: fn()): + with pytest.raises(RuntimeError, match="Contract deployment failed"): + alpha.new_workflow("Qm123", query, "input") + + def test_gas_estimation_failure_logs_warning(self, alpha): + """Gas estimation failure must call logger.warning, not print().""" + mock_contract = MagicMock() + mock_contract.constructor.return_value.estimate_gas.side_effect = Exception("gas error") + mock_contract.constructor.return_value.build_transaction.return_value = {} + + fake_receipt = MagicMock() + fake_receipt.__getitem__ = lambda self, key: 1 if key == "status" else None + fake_receipt.contractAddress = "0xNEW" + alpha._blockchain.eth.contract.return_value = mock_contract + alpha._blockchain.eth.send_raw_transaction.return_value = b"\xca\xfe" + alpha._blockchain.eth.wait_for_transaction_receipt.return_value = fake_receipt + + from opengradient.types import HistoricalInputQuery, CandleOrder, CandleType + query = HistoricalInputQuery("BTC", "USDT", 10, 60, CandleOrder.DESCENDING, [CandleType.CLOSE]) + + with patch("opengradient.client.alpha.get_abi", return_value=[]): + with patch("opengradient.client.alpha.get_bin", return_value="0x"): + with patch("opengradient.client.alpha.run_with_retry", side_effect=lambda fn, *a, **kw: fn()): + with patch("opengradient.client.alpha.logger") as mock_logger: + alpha.new_workflow("Qm123", query, "input") + mock_logger.warning.assert_called_once() + assert "Gas estimation failed" in mock_logger.warning.call_args[0][0] + + def test_scheduler_failure_logs_warning(self, alpha): + """Scheduler registration failure must call logger.warning, not print().""" + alpha._blockchain.eth.contract.return_value.functions.registerTask.return_value.build_transaction.side_effect = Exception("scheduler error") + alpha._blockchain.eth.get_transaction_count.return_value = 1 + + from opengradient.types import SchedulerParams + with patch("opengradient.client.alpha.get_abi", return_value=[]): + with patch("opengradient.client.alpha.logger") as mock_logger: + alpha._register_with_scheduler("0xCONTRACT", SchedulerParams(frequency=60, duration_hours=1)) + mock_logger.warning.assert_called_once() + assert "scheduler" in mock_logger.warning.call_args[0][0].lower()