Skip to content

Commit 74293cc

Browse files
authored
Better exception catching (#69)
* Catch server exceptions * Catch status and in-thread exceptions * Add unit tests
1 parent e36bde0 commit 74293cc

File tree

2 files changed

+79
-14
lines changed

2 files changed

+79
-14
lines changed

centml/compiler/backend.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,16 @@ def __init__(self, module: GraphModule, inputs: List[torch.Tensor]):
2424
self._inputs: List[torch.Tensor] = inputs
2525
self.compiled_forward_function: Optional[Callable[[torch.Tensor], tuple]] = None
2626
self.lock = th.Lock()
27-
self.child_thread = th.Thread(target=self.remote_compilation)
27+
self.child_thread = th.Thread(target=self.remote_compilation_starter)
2828

2929
self.serialized_model_dir: Optional[TemporaryDirectory] = None
3030
self.serialized_model_path: Optional[str] = None
3131
self.serialized_input_path: Optional[str] = None
3232

3333
try:
3434
self.child_thread.start()
35-
except Exception:
36-
logging.getLogger(__name__).exception("Remote compilation failed with the following exception: \n")
35+
except Exception as e:
36+
logging.getLogger(__name__).exception(f"Failed to start compilation thread\n{e}")
3737

3838
@property
3939
def module(self) -> Optional[GraphModule]:
@@ -108,7 +108,6 @@ def _compile_model(self, model_id: str):
108108
files={"model": model_file, "inputs": input_file},
109109
timeout=settings.TIMEOUT,
110110
)
111-
112111
if compile_response.status_code != HTTPStatus.OK:
113112
raise Exception(
114113
f"Compile model: request failed, exception from server:\n{compile_response.json().get('detail')}\n"
@@ -118,21 +117,30 @@ def _wait_for_status(self, model_id: str) -> bool:
118117
tries = 0
119118
while True:
120119
# get server compilation status
121-
status_response = requests.get(f"{settings.CENTML_SERVER_URL}/status/{model_id}", timeout=settings.TIMEOUT)
122-
if status_response.status_code != HTTPStatus.OK:
123-
raise Exception(
124-
f"Status check: request failed, exception from server:\n{status_response.json().get('detail')}"
120+
status = None
121+
try:
122+
status_response = requests.get(
123+
f"{settings.CENTML_SERVER_URL}/status/{model_id}", timeout=settings.TIMEOUT
125124
)
126-
status = status_response.json().get("status")
125+
if status_response.status_code != HTTPStatus.OK:
126+
raise Exception(
127+
f"Status check: request failed, exception from server:\n{status_response.json().get('detail')}"
128+
)
129+
status = status_response.json().get("status")
130+
except Exception as e:
131+
logging.getLogger(__name__).exception(f"Status check failed:\n{e}")
127132

128133
if status == CompilationStatus.DONE.value:
129134
return True
130135
elif status == CompilationStatus.COMPILING.value:
131136
pass
132137
elif status == CompilationStatus.NOT_FOUND.value:
133-
tries += 1
134138
logging.info("Submitting model to server for compilation.")
135-
self._compile_model(model_id)
139+
try:
140+
self._compile_model(model_id)
141+
except Exception as e:
142+
logging.getLogger(__name__).exception(f"Submitting compilation failed:\n{e}")
143+
tries += 1
136144
else:
137145
tries += 1
138146

@@ -141,6 +149,12 @@ def _wait_for_status(self, model_id: str) -> bool:
141149

142150
time.sleep(settings.COMPILING_SLEEP_TIME)
143151

152+
def remote_compilation_starter(self):
153+
try:
154+
self.remote_compilation()
155+
except Exception as e:
156+
logging.getLogger(__name__).exception(f"Compilation thread failed:\n{e}")
157+
144158
def remote_compilation(self):
145159
self._serialize_model_and_inputs()
146160

tests/test_backend.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,10 @@ def test_successful_download(self, mock_requests, mock_load, mock_open, mock_mak
119119

120120

121121
class TestWaitForStatus(SetUpGraphModule):
122+
@patch("centml.compiler.config.settings.COMPILING_SLEEP_TIME", new=0)
122123
@patch("centml.compiler.backend.requests")
123-
def test_invalid_status(self, mock_requests):
124+
@patch("logging.Logger.exception")
125+
def test_invalid_status(self, mock_logger, mock_requests):
124126
mock_response = MagicMock()
125127
mock_response.status_code = HTTPStatus.BAD_REQUEST
126128
mock_requests.get.return_value = mock_response
@@ -129,8 +131,28 @@ def test_invalid_status(self, mock_requests):
129131
with self.assertRaises(Exception) as context:
130132
self.runner._wait_for_status(model_id)
131133

132-
mock_requests.get.assert_called_once()
133-
self.assertIn("Status check: request failed, exception from server", str(context.exception))
134+
mock_requests.get.assert_called()
135+
assert mock_requests.get.call_count == settings.MAX_RETRIES + 1
136+
assert len(mock_logger.call_args_list) == settings.MAX_RETRIES + 1
137+
print(mock_logger.call_args_list)
138+
assert mock_logger.call_args_list[0].startswith("Status check failed:")
139+
assert "Waiting for status: compilation failed too many times.\n" == str(context.exception)
140+
141+
@patch("centml.compiler.config.settings.COMPILING_SLEEP_TIME", new=0)
142+
@patch("centml.compiler.backend.requests")
143+
@patch("logging.Logger.exception")
144+
def test_exception_in_status(self, mock_logger, mock_requests):
145+
exception_message = "Exiting early"
146+
mock_requests.get.side_effect = Exception(exception_message)
147+
148+
model_id = "exception_in_status"
149+
with self.assertRaises(Exception) as context:
150+
self.runner._wait_for_status(model_id)
151+
152+
mock_requests.get.assert_called()
153+
assert mock_requests.get.call_count == settings.MAX_RETRIES + 1
154+
mock_logger.assert_called_with(f"Status check failed:\n{exception_message}")
155+
assert str(context.exception) == "Waiting for status: compilation failed too many times.\n"
134156

135157
@patch("centml.compiler.config.settings.COMPILING_SLEEP_TIME", new=0)
136158
@patch("centml.compiler.backend.Runner._compile_model")
@@ -151,6 +173,7 @@ def test_max_tries(self, mock_requests, mock_compile):
151173
@patch("centml.compiler.config.settings.COMPILING_SLEEP_TIME", new=0)
152174
@patch("centml.compiler.backend.requests")
153175
def test_wait_on_compilation(self, mock_requests):
176+
# Mock the status check
154177
COMPILATION_STEPS = 10
155178
mock_response = MagicMock()
156179
mock_response.status_code = HTTPStatus.OK
@@ -163,6 +186,34 @@ def test_wait_on_compilation(self, mock_requests):
163186
# _wait_for_status should return True when compilation DONE
164187
assert self.runner._wait_for_status(model_id)
165188

189+
@patch("centml.compiler.config.settings.COMPILING_SLEEP_TIME", new=0)
190+
@patch("centml.compiler.backend.requests")
191+
@patch("centml.compiler.backend.Runner._compile_model")
192+
@patch("logging.Logger.exception")
193+
def test_exception_in_compilation(self, mock_logger, mock_compile, mock_requests):
194+
# Mock the status check
195+
mock_response = MagicMock()
196+
mock_response.status_code = HTTPStatus.OK
197+
mock_response.json.return_value = {"status": CompilationStatus.NOT_FOUND.value}
198+
mock_requests.get.return_value = mock_response
199+
200+
# Mock the compile model function
201+
exception_message = "Exiting early"
202+
mock_compile.side_effect = Exception(exception_message)
203+
204+
model_id = "exception_in_compilation"
205+
with self.assertRaises(Exception) as context:
206+
self.runner._wait_for_status(model_id)
207+
208+
mock_requests.get.assert_called()
209+
assert mock_requests.get.call_count == settings.MAX_RETRIES + 1
210+
211+
mock_compile.assert_called()
212+
assert mock_compile.call_count == settings.MAX_RETRIES + 1
213+
214+
mock_logger.assert_called_with(f"Submitting compilation failed:\n{exception_message}")
215+
assert str(context.exception) == "Waiting for status: compilation failed too many times.\n"
216+
166217
@patch("centml.compiler.backend.requests")
167218
def test_compilation_done(self, mock_requests):
168219
mock_response = MagicMock()

0 commit comments

Comments
 (0)