diff --git a/src/opengradient/client/_conversions.py b/src/opengradient/client/_conversions.py index 495f663..48cc717 100644 --- a/src/opengradient/client/_conversions.py +++ b/src/opengradient/client/_conversions.py @@ -122,6 +122,9 @@ def convert_to_model_output(event_data: AttributeDict) -> Dict[str, np.ndarray]: We need to reshape each output array using the shape parameter in order to get the array back into its original shape. """ + if not isinstance(event_data, (AttributeDict, dict)): + raise TypeError(f"event_data must be a dict-like object, got {type(event_data).__name__}") + output_dict = {} output = event_data.get("output", {}) diff --git a/src/opengradient/client/alpha.py b/src/opengradient/client/alpha.py index d2957c2..e83b38b 100644 --- a/src/opengradient/client/alpha.py +++ b/src/opengradient/client/alpha.py @@ -56,6 +56,8 @@ def __init__( ): self._blockchain = Web3(Web3.HTTPProvider(rpc_url)) self._wallet_account: LocalAccount = self._blockchain.eth.account.from_key(private_key) + if not Web3.is_address(inference_contract_address): + raise ValueError(f"Invalid Ethereum address for inference_contract_address: {inference_contract_address!r}") self._inference_hub_contract_address = inference_contract_address self._api_url = api_url self._inference_abi: Optional[dict] = None diff --git a/src/opengradient/client/model_hub.py b/src/opengradient/client/model_hub.py index 2ca85ce..c594599 100644 --- a/src/opengradient/client/model_hub.py +++ b/src/opengradient/client/model_hub.py @@ -113,7 +113,7 @@ def create_model(self, model_name: str, model_desc: str, version: str = "1.00") json_response = response.json() created_name = json_response.get("name") if not created_name: - raise Exception(f"Model creation response missing 'name'. Full response: {json_response}") + raise RuntimeError(f"Model creation response missing 'name'. Full response: {json_response}") # Create the initial version for the newly created model. # Pass `version` as release notes (e.g. "1.00") since the server assigns @@ -137,7 +137,7 @@ def create_version(self, model_name: str, notes: str = "", is_major: bool = Fals dict: The server response containing version details. Raises: - Exception: If the version creation fails. + RuntimeError: If the version creation fails or the response is unexpected. """ url = f"https://api.opengradient.ai/api/v0/models/{model_name}/versions" headers = {"Authorization": f"Bearer {self._get_auth_token()}", "Content-Type": "application/json"} @@ -163,12 +163,10 @@ def create_version(self, model_name: str, notes: str = "", is_major: bool = Fals return {"versionString": "Unknown", "note": "Version ID not provided in response"} return {"versionString": version_string} else: - raise Exception(f"Unexpected response type: {type(json_response)}") + raise RuntimeError(f"Unexpected response type: {type(json_response)}") except requests.RequestException as e: - raise Exception(f"Version creation failed: {str(e)}") - except Exception: - raise + raise RuntimeError(f"Version creation failed: {str(e)}") def upload(self, model_path: str, model_name: str, version: str) -> FileUploadResult: """ @@ -207,7 +205,10 @@ def upload(self, model_path: str, model_name: str, version: str) -> FileUploadRe elif response.status_code == 500: raise RuntimeError(f"Internal server error occurred (status_code=500)") else: - error_message = response.json().get("detail", "Unknown error occurred") + try: + error_message = response.json().get("detail", "Unknown error occurred") + except ValueError: + error_message = response.text or "Unknown error occurred" raise RuntimeError(f"Upload failed: {error_message} (status_code={response.status_code})") except requests.RequestException as e: diff --git a/tests/client_test.py b/tests/client_test.py index 6829fc9..a1b5524 100644 --- a/tests/client_test.py +++ b/tests/client_test.py @@ -3,8 +3,10 @@ import pytest +from opengradient.client.alpha import Alpha from opengradient.client.llm import LLM from opengradient.client.model_hub import ModelHub +from opengradient.client._conversions import convert_to_model_output from opengradient.types import ( StreamChunk, x402SettlementMode, @@ -190,3 +192,111 @@ def test_settlement_modes_values(self): assert x402SettlementMode.PRIVATE == "private" assert x402SettlementMode.BATCH_HASHED == "batch" assert x402SettlementMode.INDIVIDUAL_FULL == "individual" + + +# --- Fix #3: response.json() guard in ModelHub.upload --- + + +class TestModelHubUploadErrorHandling: + """upload() must not crash with JSONDecodeError when the server returns HTML.""" + + def _make_hub(self): + with ( + patch("opengradient.client.model_hub._FIREBASE_CONFIG", {"apiKey": "fake"}), + patch("opengradient.client.model_hub.firebase") as mock_firebase, + ): + mock_auth = MagicMock() + mock_auth.sign_in_with_email_and_password.return_value = { + "idToken": "tok", + "email": "u@t.com", + "expiresIn": "3600", + "refreshToken": "rt", + } + mock_firebase.initialize_app.return_value.auth.return_value = mock_auth + return ModelHub(email="u@t.com", password="pw") + + def test_html_error_response_raises_runtime_error_not_json_error(self, tmp_path): + """When the server returns HTML (e.g. a WAF page), upload raises RuntimeError not JSONDecodeError.""" + dummy_file = tmp_path / "model.onnx" + dummy_file.write_bytes(b"dummy") + + hub = self._make_hub() + + html_response = MagicMock() + html_response.status_code = 403 + html_response.json.side_effect = ValueError("No JSON") + html_response.text = "Forbidden" + + with patch("opengradient.client.model_hub.requests.post", return_value=html_response): + with patch("opengradient.client.model_hub.MultipartEncoder"): + with pytest.raises(RuntimeError, match="Upload failed"): + hub.upload(str(dummy_file), "my-model", "1.0") + + def test_json_error_response_raises_runtime_error(self, tmp_path): + """When the server returns JSON error, upload raises RuntimeError with detail message.""" + dummy_file = tmp_path / "model.onnx" + dummy_file.write_bytes(b"dummy") + + hub = self._make_hub() + + json_response = MagicMock() + json_response.status_code = 422 + json_response.json.return_value = {"detail": "Unprocessable entity"} + json_response.text = '{"detail": "Unprocessable entity"}' + + with patch("opengradient.client.model_hub.requests.post", return_value=json_response): + with patch("opengradient.client.model_hub.MultipartEncoder"): + with pytest.raises(RuntimeError, match="Unprocessable entity"): + hub.upload(str(dummy_file), "my-model", "1.0") + + +# --- Fix #5: event_data type guard in convert_to_model_output --- + + +class TestConvertToModelOutputGuard: + """convert_to_model_output must raise TypeError for non-dict input.""" + + def test_none_input_raises_type_error(self): + """Passing None raises TypeError with a clear message.""" + with pytest.raises(TypeError, match="event_data must be a dict-like object"): + convert_to_model_output(None) + + def test_string_input_raises_type_error(self): + """Passing a string raises TypeError.""" + with pytest.raises(TypeError, match="event_data must be a dict-like object"): + convert_to_model_output("not a dict") + + def test_valid_empty_dict_returns_empty(self): + """An empty dict returns an empty output dict without crashing.""" + result = convert_to_model_output({}) + assert result == {} + + +# --- Fix #16: contract address validation in Alpha constructor --- + + +class TestAlphaAddressValidation: + """Alpha constructor must reject invalid Ethereum addresses immediately.""" + + def _make_alpha(self, address): + with patch("opengradient.client.alpha.Web3") as mock_web3_cls: + mock_w3 = MagicMock() + mock_web3_cls.return_value = mock_w3 + mock_web3_cls.HTTPProvider.return_value = MagicMock() + mock_web3_cls.is_address.side_effect = lambda a: a.startswith("0x") and len(a) == 42 + mock_w3.eth.account.from_key.return_value = MagicMock(address="0xDEAD") + return Alpha(private_key="0x" + "a" * 64, inference_contract_address=address) + + def test_invalid_address_raises_value_error(self): + """A clearly wrong address raises ValueError at construction time.""" + with pytest.raises(ValueError, match="Invalid Ethereum address"): + self._make_alpha("not-an-address") + + def test_valid_address_does_not_raise(self): + """A valid checksummed address does not raise.""" + self._make_alpha("0x" + "b" * 40) + + def test_empty_address_raises_value_error(self): + """An empty string raises ValueError.""" + with pytest.raises(ValueError, match="Invalid Ethereum address"): + self._make_alpha("")