From 0238a1fedb9651bbb80b3f7699641c32eb3cf0f9 Mon Sep 17 00:00:00 2001 From: crazywriter1 Date: Wed, 1 Apr 2026 21:51:15 +0300 Subject: [PATCH] fix: add key validation and preserve exception chains in workflow models _extract_number() helper validates key existence before access, preventing bare KeyError. read_workflow_wrapper() now catches KeyError separately and uses 'from e' to preserve the original exception chain. Signed-off-by: crazywriter1 --- src/opengradient/workflow_models/utils.py | 15 +++- .../workflow_models/workflow_models.py | 35 ++++++++-- tests/workflow_models_test.py | 70 +++++++++++++++++++ 3 files changed, 112 insertions(+), 8 deletions(-) create mode 100644 tests/workflow_models_test.py diff --git a/src/opengradient/workflow_models/utils.py b/src/opengradient/workflow_models/utils.py index b6db0333..093a37ad 100644 --- a/src/opengradient/workflow_models/utils.py +++ b/src/opengradient/workflow_models/utils.py @@ -1,5 +1,6 @@ """Utility functions for the models module.""" +import logging from typing import Callable from opengradient.client.alpha import Alpha @@ -7,6 +8,8 @@ from .constants import BLOCK_EXPLORER_URL from .types import WorkflowModelOutput +logger = logging.getLogger(__name__) + def create_block_explorer_link_smart_contract(transaction_hash: str) -> str: """Create block explorer link for smart contract.""" @@ -23,10 +26,15 @@ def create_block_explorer_link_transaction(transaction_hash: str) -> str: def read_workflow_wrapper(alpha: Alpha, contract_address: str, format_function: Callable[..., str]) -> WorkflowModelOutput: """ Wrapper function for reading from models through workflows. + Args: alpha (Alpha): The alpha namespace from an initialized OpenGradient client (client.alpha). contract_address (str): Smart contract address of the workflow format_function (Callable): Function for formatting the result returned by read_workflow + + Raises: + KeyError: If the workflow result is missing an expected output key. + RuntimeError: If reading or formatting the workflow result fails. """ try: result = alpha.read_workflow_result(contract_address) @@ -38,5 +46,10 @@ def read_workflow_wrapper(alpha: Alpha, contract_address: str, format_function: result=formatted_result, block_explorer_link=block_explorer_link, ) + except KeyError as e: + raise KeyError( + f"Workflow at {contract_address} is missing expected output key {e}. " + f"Available keys: {list(result.numbers.keys()) if hasattr(result, 'numbers') else 'unknown'}" + ) from e except Exception as e: - raise RuntimeError(f"Error reading from workflow with address {contract_address}: {e!s}") + raise RuntimeError(f"Error reading from workflow with address {contract_address}: {e!s}") from e diff --git a/src/opengradient/workflow_models/workflow_models.py b/src/opengradient/workflow_models/workflow_models.py index cb8b3a3c..614aaa2b 100644 --- a/src/opengradient/workflow_models/workflow_models.py +++ b/src/opengradient/workflow_models/workflow_models.py @@ -15,6 +15,27 @@ from .utils import read_workflow_wrapper +def _extract_number(result, key: str) -> float: + """Extract a numeric value from a workflow result by key. + + Args: + result: The ModelOutput returned by read_workflow_result. + key: The expected key in result.numbers. + + Returns: + The extracted float value. + + Raises: + KeyError: If the key is not present in result.numbers. + """ + if not hasattr(result, "numbers") or key not in result.numbers: + available = list(result.numbers.keys()) if hasattr(result, "numbers") else [] + raise KeyError( + f"Expected key '{key}' not found in workflow output. Available keys: {available}" + ) + return float(result.numbers[key].item()) + + def read_eth_usdt_one_hour_volatility_forecast(alpha: Alpha) -> WorkflowModelOutput: """ Read from the ETH/USDT one hour volatility forecast model workflow on the OpenGradient network. @@ -22,7 +43,7 @@ def read_eth_usdt_one_hour_volatility_forecast(alpha: Alpha) -> WorkflowModelOut More information on this model can be found at https://hub.opengradient.ai/models/OpenGradient/og-1hr-volatility-ethusdt. """ return read_workflow_wrapper( - alpha, contract_address=ETH_USDT_1_HOUR_VOLATILITY_ADDRESS, format_function=lambda x: format(float(x.numbers["Y"].item()), ".10%") + alpha, contract_address=ETH_USDT_1_HOUR_VOLATILITY_ADDRESS, format_function=lambda x: format(_extract_number(x, "Y"), ".10%") ) @@ -35,7 +56,7 @@ def read_btc_1_hour_price_forecast(alpha: Alpha) -> WorkflowModelOutput: return read_workflow_wrapper( alpha, contract_address=BTC_1_HOUR_PRICE_FORECAST_ADDRESS, - format_function=lambda x: format(float(x.numbers["regression_output"].item()), ".10%"), + format_function=lambda x: format(_extract_number(x, "regression_output"), ".10%"), ) @@ -48,7 +69,7 @@ def read_eth_1_hour_price_forecast(alpha: Alpha) -> WorkflowModelOutput: return read_workflow_wrapper( alpha, contract_address=ETH_1_HOUR_PRICE_FORECAST_ADDRESS, - format_function=lambda x: format(float(x.numbers["regression_output"].item()), ".10%"), + format_function=lambda x: format(_extract_number(x, "regression_output"), ".10%"), ) @@ -61,7 +82,7 @@ def read_sol_1_hour_price_forecast(alpha: Alpha) -> WorkflowModelOutput: return read_workflow_wrapper( alpha, contract_address=SOL_1_HOUR_PRICE_FORECAST_ADDRESS, - format_function=lambda x: format(float(x.numbers["regression_output"].item()), ".10%"), + format_function=lambda x: format(_extract_number(x, "regression_output"), ".10%"), ) @@ -74,7 +95,7 @@ def read_sui_1_hour_price_forecast(alpha: Alpha) -> WorkflowModelOutput: return read_workflow_wrapper( alpha, contract_address=SUI_1_HOUR_PRICE_FORECAST_ADDRESS, - format_function=lambda x: format(float(x.numbers["regression_output"].item()), ".10%"), + format_function=lambda x: format(_extract_number(x, "regression_output"), ".10%"), ) @@ -87,7 +108,7 @@ def read_sui_usdt_30_min_price_forecast(alpha: Alpha) -> WorkflowModelOutput: return read_workflow_wrapper( alpha, contract_address=SUI_30_MINUTE_PRICE_FORECAST_ADDRESS, - format_function=lambda x: format(float(x.numbers["destandardized_prediction"].item()), ".10%"), + format_function=lambda x: format(_extract_number(x, "destandardized_prediction"), ".10%"), ) @@ -100,5 +121,5 @@ def read_sui_usdt_6_hour_price_forecast(alpha: Alpha) -> WorkflowModelOutput: return read_workflow_wrapper( alpha, contract_address=SUI_6_HOUR_PRICE_FORECAST_ADDRESS, - format_function=lambda x: format(float(x.numbers["destandardized_prediction"].item()), ".10%"), + format_function=lambda x: format(_extract_number(x, "destandardized_prediction"), ".10%"), ) diff --git a/tests/workflow_models_test.py b/tests/workflow_models_test.py new file mode 100644 index 00000000..7ba7cc07 --- /dev/null +++ b/tests/workflow_models_test.py @@ -0,0 +1,70 @@ +"""Tests for workflow_models error handling and key validation.""" + +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +from opengradient.workflow_models.utils import read_workflow_wrapper +from opengradient.workflow_models.workflow_models import _extract_number + + +class TestExtractNumber: + """Tests for _extract_number helper.""" + + def test_extracts_valid_key(self): + result = MagicMock() + result.numbers = {"Y": np.float32(0.05)} + assert _extract_number(result, "Y") == float(np.float32(0.05)) + + def test_raises_on_missing_key(self): + result = MagicMock() + result.numbers = {"X": np.float32(0.05)} + with pytest.raises(KeyError, match="Expected key 'Y' not found"): + _extract_number(result, "Y") + + def test_raises_on_missing_numbers_attr(self): + result = MagicMock(spec=[]) # no attributes + with pytest.raises(KeyError, match="Expected key 'Y' not found"): + _extract_number(result, "Y") + + def test_reports_available_keys(self): + result = MagicMock() + result.numbers = {"A": np.float32(1.0), "B": np.float32(2.0)} + with pytest.raises(KeyError, match="Available keys:"): + _extract_number(result, "missing") + + +class TestReadWorkflowWrapper: + """Tests for read_workflow_wrapper error handling.""" + + def test_keyerror_preserved_not_converted_to_runtime(self): + """KeyError from missing output key should propagate as KeyError, not RuntimeError.""" + mock_alpha = MagicMock() + mock_alpha.read_workflow_result.return_value = MagicMock() + + def bad_format(result): + raise KeyError("regression_output") + + with pytest.raises(KeyError, match="regression_output"): + read_workflow_wrapper(mock_alpha, "0xabc", bad_format) + + def test_runtime_error_chains_original(self): + """Non-KeyError exceptions should be wrapped in RuntimeError with __cause__ set.""" + mock_alpha = MagicMock() + mock_alpha.read_workflow_result.side_effect = ConnectionError("network down") + + with pytest.raises(RuntimeError, match="Error reading from workflow") as exc_info: + read_workflow_wrapper(mock_alpha, "0xabc", lambda x: str(x)) + + assert exc_info.value.__cause__ is not None + assert isinstance(exc_info.value.__cause__, ConnectionError) + + def test_success_returns_workflow_output(self): + """Successful read returns a WorkflowModelOutput.""" + mock_alpha = MagicMock() + mock_alpha.read_workflow_result.return_value = "mock_result" + + output = read_workflow_wrapper(mock_alpha, "0xabc", lambda x: "formatted") + assert output.result == "formatted" + assert "0xabc" in output.block_explorer_link