Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/opengradient/workflow_models/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""Utility functions for the models module."""

import logging
from typing import Callable

from opengradient.client.alpha import Alpha

from .constants import BLOCK_EXPLORER_URL
from .types import WorkflowModelOutput

logger = logging.getLogger(__name__)


def create_block_explorer_link_smart_contract(transaction_hash: str) -> str:
"""Create block explorer link for smart contract."""
Expand All @@ -23,10 +26,15 @@ def create_block_explorer_link_transaction(transaction_hash: str) -> str:
def read_workflow_wrapper(alpha: Alpha, contract_address: str, format_function: Callable[..., str]) -> WorkflowModelOutput:
"""
Wrapper function for reading from models through workflows.

Args:
alpha (Alpha): The alpha namespace from an initialized OpenGradient client (client.alpha).
contract_address (str): Smart contract address of the workflow
format_function (Callable): Function for formatting the result returned by read_workflow

Raises:
KeyError: If the workflow result is missing an expected output key.
RuntimeError: If reading or formatting the workflow result fails.
"""
try:
result = alpha.read_workflow_result(contract_address)
Expand All @@ -38,5 +46,10 @@ def read_workflow_wrapper(alpha: Alpha, contract_address: str, format_function:
result=formatted_result,
block_explorer_link=block_explorer_link,
)
except KeyError as e:
raise KeyError(
f"Workflow at {contract_address} is missing expected output key {e}. "
f"Available keys: {list(result.numbers.keys()) if hasattr(result, 'numbers') else 'unknown'}"
) from e
except Exception as e:
raise RuntimeError(f"Error reading from workflow with address {contract_address}: {e!s}")
raise RuntimeError(f"Error reading from workflow with address {contract_address}: {e!s}") from e
35 changes: 28 additions & 7 deletions src/opengradient/workflow_models/workflow_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,35 @@
from .utils import read_workflow_wrapper


def _extract_number(result, key: str) -> float:
"""Extract a numeric value from a workflow result by key.

Args:
result: The ModelOutput returned by read_workflow_result.
key: The expected key in result.numbers.

Returns:
The extracted float value.

Raises:
KeyError: If the key is not present in result.numbers.
"""
if not hasattr(result, "numbers") or key not in result.numbers:
available = list(result.numbers.keys()) if hasattr(result, "numbers") else []
raise KeyError(
f"Expected key '{key}' not found in workflow output. Available keys: {available}"
)
return float(result.numbers[key].item())


def read_eth_usdt_one_hour_volatility_forecast(alpha: Alpha) -> WorkflowModelOutput:
"""
Read from the ETH/USDT one hour volatility forecast model workflow on the OpenGradient network.

More information on this model can be found at https://hub.opengradient.ai/models/OpenGradient/og-1hr-volatility-ethusdt.
"""
return read_workflow_wrapper(
alpha, contract_address=ETH_USDT_1_HOUR_VOLATILITY_ADDRESS, format_function=lambda x: format(float(x.numbers["Y"].item()), ".10%")
alpha, contract_address=ETH_USDT_1_HOUR_VOLATILITY_ADDRESS, format_function=lambda x: format(_extract_number(x, "Y"), ".10%")
)


Expand All @@ -35,7 +56,7 @@ def read_btc_1_hour_price_forecast(alpha: Alpha) -> WorkflowModelOutput:
return read_workflow_wrapper(
alpha,
contract_address=BTC_1_HOUR_PRICE_FORECAST_ADDRESS,
format_function=lambda x: format(float(x.numbers["regression_output"].item()), ".10%"),
format_function=lambda x: format(_extract_number(x, "regression_output"), ".10%"),
)


Expand All @@ -48,7 +69,7 @@ def read_eth_1_hour_price_forecast(alpha: Alpha) -> WorkflowModelOutput:
return read_workflow_wrapper(
alpha,
contract_address=ETH_1_HOUR_PRICE_FORECAST_ADDRESS,
format_function=lambda x: format(float(x.numbers["regression_output"].item()), ".10%"),
format_function=lambda x: format(_extract_number(x, "regression_output"), ".10%"),
)


Expand All @@ -61,7 +82,7 @@ def read_sol_1_hour_price_forecast(alpha: Alpha) -> WorkflowModelOutput:
return read_workflow_wrapper(
alpha,
contract_address=SOL_1_HOUR_PRICE_FORECAST_ADDRESS,
format_function=lambda x: format(float(x.numbers["regression_output"].item()), ".10%"),
format_function=lambda x: format(_extract_number(x, "regression_output"), ".10%"),
)


Expand All @@ -74,7 +95,7 @@ def read_sui_1_hour_price_forecast(alpha: Alpha) -> WorkflowModelOutput:
return read_workflow_wrapper(
alpha,
contract_address=SUI_1_HOUR_PRICE_FORECAST_ADDRESS,
format_function=lambda x: format(float(x.numbers["regression_output"].item()), ".10%"),
format_function=lambda x: format(_extract_number(x, "regression_output"), ".10%"),
)


Expand All @@ -87,7 +108,7 @@ def read_sui_usdt_30_min_price_forecast(alpha: Alpha) -> WorkflowModelOutput:
return read_workflow_wrapper(
alpha,
contract_address=SUI_30_MINUTE_PRICE_FORECAST_ADDRESS,
format_function=lambda x: format(float(x.numbers["destandardized_prediction"].item()), ".10%"),
format_function=lambda x: format(_extract_number(x, "destandardized_prediction"), ".10%"),
)


Expand All @@ -100,5 +121,5 @@ def read_sui_usdt_6_hour_price_forecast(alpha: Alpha) -> WorkflowModelOutput:
return read_workflow_wrapper(
alpha,
contract_address=SUI_6_HOUR_PRICE_FORECAST_ADDRESS,
format_function=lambda x: format(float(x.numbers["destandardized_prediction"].item()), ".10%"),
format_function=lambda x: format(_extract_number(x, "destandardized_prediction"), ".10%"),
)
70 changes: 70 additions & 0 deletions tests/workflow_models_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Tests for workflow_models error handling and key validation."""

from unittest.mock import MagicMock, patch

import numpy as np
import pytest

from opengradient.workflow_models.utils import read_workflow_wrapper
from opengradient.workflow_models.workflow_models import _extract_number


class TestExtractNumber:
"""Tests for _extract_number helper."""

def test_extracts_valid_key(self):
result = MagicMock()
result.numbers = {"Y": np.float32(0.05)}
assert _extract_number(result, "Y") == float(np.float32(0.05))

def test_raises_on_missing_key(self):
result = MagicMock()
result.numbers = {"X": np.float32(0.05)}
with pytest.raises(KeyError, match="Expected key 'Y' not found"):
_extract_number(result, "Y")

def test_raises_on_missing_numbers_attr(self):
result = MagicMock(spec=[]) # no attributes
with pytest.raises(KeyError, match="Expected key 'Y' not found"):
_extract_number(result, "Y")

def test_reports_available_keys(self):
result = MagicMock()
result.numbers = {"A": np.float32(1.0), "B": np.float32(2.0)}
with pytest.raises(KeyError, match="Available keys:"):
_extract_number(result, "missing")


class TestReadWorkflowWrapper:
"""Tests for read_workflow_wrapper error handling."""

def test_keyerror_preserved_not_converted_to_runtime(self):
"""KeyError from missing output key should propagate as KeyError, not RuntimeError."""
mock_alpha = MagicMock()
mock_alpha.read_workflow_result.return_value = MagicMock()

def bad_format(result):
raise KeyError("regression_output")

with pytest.raises(KeyError, match="regression_output"):
read_workflow_wrapper(mock_alpha, "0xabc", bad_format)

def test_runtime_error_chains_original(self):
"""Non-KeyError exceptions should be wrapped in RuntimeError with __cause__ set."""
mock_alpha = MagicMock()
mock_alpha.read_workflow_result.side_effect = ConnectionError("network down")

with pytest.raises(RuntimeError, match="Error reading from workflow") as exc_info:
read_workflow_wrapper(mock_alpha, "0xabc", lambda x: str(x))

assert exc_info.value.__cause__ is not None
assert isinstance(exc_info.value.__cause__, ConnectionError)

def test_success_returns_workflow_output(self):
"""Successful read returns a WorkflowModelOutput."""
mock_alpha = MagicMock()
mock_alpha.read_workflow_result.return_value = "mock_result"

output = read_workflow_wrapper(mock_alpha, "0xabc", lambda x: "formatted")
assert output.result == "formatted"
assert "0xabc" in output.block_explorer_link