Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/opengradient/client/_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ def convert_to_model_output(event_data: AttributeDict) -> Dict[str, np.ndarray]:
We need to reshape each output array using the shape parameter in order to get the array
back into its original shape.
"""
if not isinstance(event_data, (AttributeDict, dict)):
raise TypeError(f"event_data must be a dict-like object, got {type(event_data).__name__}")

output_dict = {}
output = event_data.get("output", {})

Expand Down
2 changes: 2 additions & 0 deletions src/opengradient/client/alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def __init__(
):
self._blockchain = Web3(Web3.HTTPProvider(rpc_url))
self._wallet_account: LocalAccount = self._blockchain.eth.account.from_key(private_key)
if not Web3.is_address(inference_contract_address):
raise ValueError(f"Invalid Ethereum address for inference_contract_address: {inference_contract_address!r}")
self._inference_hub_contract_address = inference_contract_address
self._api_url = api_url
self._inference_abi: Optional[dict] = None
Expand Down
15 changes: 8 additions & 7 deletions src/opengradient/client/model_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def create_model(self, model_name: str, model_desc: str, version: str = "1.00")
json_response = response.json()
created_name = json_response.get("name")
if not created_name:
raise Exception(f"Model creation response missing 'name'. Full response: {json_response}")
raise RuntimeError(f"Model creation response missing 'name'. Full response: {json_response}")

# Create the initial version for the newly created model.
# Pass `version` as release notes (e.g. "1.00") since the server assigns
Expand All @@ -137,7 +137,7 @@ def create_version(self, model_name: str, notes: str = "", is_major: bool = Fals
dict: The server response containing version details.

Raises:
Exception: If the version creation fails.
RuntimeError: If the version creation fails or the response is unexpected.
"""
url = f"https://api.opengradient.ai/api/v0/models/{model_name}/versions"
headers = {"Authorization": f"Bearer {self._get_auth_token()}", "Content-Type": "application/json"}
Expand All @@ -163,12 +163,10 @@ def create_version(self, model_name: str, notes: str = "", is_major: bool = Fals
return {"versionString": "Unknown", "note": "Version ID not provided in response"}
return {"versionString": version_string}
else:
raise Exception(f"Unexpected response type: {type(json_response)}")
raise RuntimeError(f"Unexpected response type: {type(json_response)}")

except requests.RequestException as e:
raise Exception(f"Version creation failed: {str(e)}")
except Exception:
raise
raise RuntimeError(f"Version creation failed: {str(e)}")

def upload(self, model_path: str, model_name: str, version: str) -> FileUploadResult:
"""
Expand Down Expand Up @@ -207,7 +205,10 @@ def upload(self, model_path: str, model_name: str, version: str) -> FileUploadRe
elif response.status_code == 500:
raise RuntimeError(f"Internal server error occurred (status_code=500)")
else:
error_message = response.json().get("detail", "Unknown error occurred")
try:
error_message = response.json().get("detail", "Unknown error occurred")
except ValueError:
error_message = response.text or "Unknown error occurred"
raise RuntimeError(f"Upload failed: {error_message} (status_code={response.status_code})")

except requests.RequestException as e:
Expand Down
110 changes: 110 additions & 0 deletions tests/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

import pytest

from opengradient.client.alpha import Alpha
from opengradient.client.llm import LLM
from opengradient.client.model_hub import ModelHub
from opengradient.client._conversions import convert_to_model_output
from opengradient.types import (
StreamChunk,
x402SettlementMode,
Expand Down Expand Up @@ -190,3 +192,111 @@ def test_settlement_modes_values(self):
assert x402SettlementMode.PRIVATE == "private"
assert x402SettlementMode.BATCH_HASHED == "batch"
assert x402SettlementMode.INDIVIDUAL_FULL == "individual"


# --- Fix #3: response.json() guard in ModelHub.upload ---


class TestModelHubUploadErrorHandling:
"""upload() must not crash with JSONDecodeError when the server returns HTML."""

def _make_hub(self):
with (
patch("opengradient.client.model_hub._FIREBASE_CONFIG", {"apiKey": "fake"}),
patch("opengradient.client.model_hub.firebase") as mock_firebase,
):
mock_auth = MagicMock()
mock_auth.sign_in_with_email_and_password.return_value = {
"idToken": "tok",
"email": "u@t.com",
"expiresIn": "3600",
"refreshToken": "rt",
}
mock_firebase.initialize_app.return_value.auth.return_value = mock_auth
return ModelHub(email="u@t.com", password="pw")

def test_html_error_response_raises_runtime_error_not_json_error(self, tmp_path):
"""When the server returns HTML (e.g. a WAF page), upload raises RuntimeError not JSONDecodeError."""
dummy_file = tmp_path / "model.onnx"
dummy_file.write_bytes(b"dummy")

hub = self._make_hub()

html_response = MagicMock()
html_response.status_code = 403
html_response.json.side_effect = ValueError("No JSON")
html_response.text = "<html>Forbidden</html>"

with patch("opengradient.client.model_hub.requests.post", return_value=html_response):
with patch("opengradient.client.model_hub.MultipartEncoder"):
with pytest.raises(RuntimeError, match="Upload failed"):
hub.upload(str(dummy_file), "my-model", "1.0")

def test_json_error_response_raises_runtime_error(self, tmp_path):
"""When the server returns JSON error, upload raises RuntimeError with detail message."""
dummy_file = tmp_path / "model.onnx"
dummy_file.write_bytes(b"dummy")

hub = self._make_hub()

json_response = MagicMock()
json_response.status_code = 422
json_response.json.return_value = {"detail": "Unprocessable entity"}
json_response.text = '{"detail": "Unprocessable entity"}'

with patch("opengradient.client.model_hub.requests.post", return_value=json_response):
with patch("opengradient.client.model_hub.MultipartEncoder"):
with pytest.raises(RuntimeError, match="Unprocessable entity"):
hub.upload(str(dummy_file), "my-model", "1.0")


# --- Fix #5: event_data type guard in convert_to_model_output ---


class TestConvertToModelOutputGuard:
"""convert_to_model_output must raise TypeError for non-dict input."""

def test_none_input_raises_type_error(self):
"""Passing None raises TypeError with a clear message."""
with pytest.raises(TypeError, match="event_data must be a dict-like object"):
convert_to_model_output(None)

def test_string_input_raises_type_error(self):
"""Passing a string raises TypeError."""
with pytest.raises(TypeError, match="event_data must be a dict-like object"):
convert_to_model_output("not a dict")

def test_valid_empty_dict_returns_empty(self):
"""An empty dict returns an empty output dict without crashing."""
result = convert_to_model_output({})
assert result == {}


# --- Fix #16: contract address validation in Alpha constructor ---


class TestAlphaAddressValidation:
"""Alpha constructor must reject invalid Ethereum addresses immediately."""

def _make_alpha(self, address):
with patch("opengradient.client.alpha.Web3") as mock_web3_cls:
mock_w3 = MagicMock()
mock_web3_cls.return_value = mock_w3
mock_web3_cls.HTTPProvider.return_value = MagicMock()
mock_web3_cls.is_address.side_effect = lambda a: a.startswith("0x") and len(a) == 42
mock_w3.eth.account.from_key.return_value = MagicMock(address="0xDEAD")
return Alpha(private_key="0x" + "a" * 64, inference_contract_address=address)

def test_invalid_address_raises_value_error(self):
"""A clearly wrong address raises ValueError at construction time."""
with pytest.raises(ValueError, match="Invalid Ethereum address"):
self._make_alpha("not-an-address")

def test_valid_address_does_not_raise(self):
"""A valid checksummed address does not raise."""
self._make_alpha("0x" + "b" * 40)

def test_empty_address_raises_value_error(self):
"""An empty string raises ValueError."""
with pytest.raises(ValueError, match="Invalid Ethereum address"):
self._make_alpha("")