Skip to content

Commit 80cc0b3

Browse files
authored
Add CENTML_ to start of config vars (#74)
* add CENTML to config vars
1 parent 71a7e13 commit 80cc0b3

File tree

5 files changed

+41
-41
lines changed

5 files changed

+41
-41
lines changed

centml/compiler/backend.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,13 @@ def inputs(self):
5454

5555
def _serialize_model_and_inputs(self):
5656
self.serialized_model_dir = TemporaryDirectory() # pylint: disable=consider-using-with
57-
self.serialized_model_path = os.path.join(self.serialized_model_dir.name, settings.SERIALIZED_MODEL_FILE)
58-
self.serialized_input_path = os.path.join(self.serialized_model_dir.name, settings.SERIALIZED_INPUT_FILE)
57+
self.serialized_model_path = os.path.join(self.serialized_model_dir.name, settings.CENTML_SERIALIZED_MODEL_FILE)
58+
self.serialized_input_path = os.path.join(self.serialized_model_dir.name, settings.CENTML_SERIALIZED_INPUT_FILE)
5959

6060
# torch.save saves a zip file full of pickled files with the model's states.
6161
try:
62-
torch.save(self.module, self.serialized_model_path, pickle_protocol=settings.PICKLE_PROTOCOL)
63-
torch.save(self.inputs, self.serialized_input_path, pickle_protocol=settings.PICKLE_PROTOCOL)
62+
torch.save(self.module, self.serialized_model_path, pickle_protocol=settings.CENTML_PICKLE_PROTOCOL)
63+
torch.save(self.inputs, self.serialized_input_path, pickle_protocol=settings.CENTML_PICKLE_PROTOCOL)
6464
except Exception as e:
6565
raise Exception(f"Failed to save module or inputs with torch.save: {e}") from e
6666

@@ -71,7 +71,7 @@ def _get_model_id(self) -> str:
7171
sha_hash = hashlib.sha256()
7272
with open(self.serialized_model_path, "rb") as serialized_model_file:
7373
# Read in chunks to not load too much into memory
74-
for block in iter(lambda: serialized_model_file.read(settings.HASH_CHUNK_SIZE), b""):
74+
for block in iter(lambda: serialized_model_file.read(settings.CENTML_HASH_CHUNK_SIZE), b""):
7575
sha_hash.update(block)
7676

7777
model_id = sha_hash.hexdigest()
@@ -80,7 +80,7 @@ def _get_model_id(self) -> str:
8080

8181
def _download_model(self, model_id: str):
8282
download_response = requests.get(
83-
url=f"{settings.CENTML_SERVER_URL}/download/{model_id}", timeout=settings.TIMEOUT
83+
url=f"{settings.CENTML_SERVER_URL}/download/{model_id}", timeout=settings.CENTML_COMPILER_TIMEOUT
8484
)
8585
if download_response.status_code != HTTPStatus.OK:
8686
raise Exception(
@@ -106,7 +106,7 @@ def _compile_model(self, model_id: str):
106106
compile_response = requests.post(
107107
url=f"{settings.CENTML_SERVER_URL}/submit/{model_id}",
108108
files={"model": model_file, "inputs": input_file},
109-
timeout=settings.TIMEOUT,
109+
timeout=settings.CENTML_COMPILER_TIMEOUT,
110110
)
111111
if compile_response.status_code != HTTPStatus.OK:
112112
raise Exception(
@@ -120,7 +120,7 @@ def _wait_for_status(self, model_id: str) -> bool:
120120
status = None
121121
try:
122122
status_response = requests.get(
123-
f"{settings.CENTML_SERVER_URL}/status/{model_id}", timeout=settings.TIMEOUT
123+
f"{settings.CENTML_SERVER_URL}/status/{model_id}", timeout=settings.CENTML_COMPILER_TIMEOUT
124124
)
125125
if status_response.status_code != HTTPStatus.OK:
126126
raise Exception(
@@ -144,10 +144,10 @@ def _wait_for_status(self, model_id: str) -> bool:
144144
else:
145145
tries += 1
146146

147-
if tries > settings.MAX_RETRIES:
147+
if tries > settings.CENTML_COMPILER_MAX_RETRIES:
148148
raise Exception("Waiting for status: compilation failed too many times.\n")
149149

150-
time.sleep(settings.COMPILING_SLEEP_TIME)
150+
time.sleep(settings.CENTML_COMPILER_SLEEP_TIME)
151151

152152
def remote_compilation_starter(self):
153153
try:

centml/compiler/config.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,26 +10,26 @@ class CompilationStatus(Enum):
1010

1111

1212
class Config(BaseSettings):
13-
TIMEOUT: int = 10
14-
MAX_RETRIES: int = 3
15-
COMPILING_SLEEP_TIME: int = 15
13+
CENTML_COMPILER_TIMEOUT: int = 10
14+
CENTML_COMPILER_MAX_RETRIES: int = 3
15+
CENTML_COMPILER_SLEEP_TIME: int = 15
1616

17-
CENTML_CACHE_DIR: str = os.path.expanduser("~/.cache/centml")
18-
BACKEND_BASE_PATH: str = os.path.join(CENTML_CACHE_DIR, "backend")
19-
SERVER_BASE_PATH: str = os.path.join(CENTML_CACHE_DIR, "server")
17+
CENTML_BASE_CACHE_DIR: str = os.path.expanduser("~/.cache/centml")
18+
CENTML_BACKEND_BASE_PATH: str = os.path.join(CENTML_BASE_CACHE_DIR, "backend")
19+
CENTML_SERVER_BASE_PATH: str = os.path.join(CENTML_BASE_CACHE_DIR, "server")
2020

2121
CENTML_SERVER_URL: str = "http://0.0.0.0:8090"
2222

2323
# Use a constant path since torch.save uses the given file name in it's zipfile.
2424
# Using a different filename would result in a different hash.
25-
SERIALIZED_MODEL_FILE: str = "serialized_model.zip"
26-
SERIALIZED_INPUT_FILE: str = "serialized_input.zip"
27-
PICKLE_PROTOCOL: int = 4
25+
CENTML_SERIALIZED_MODEL_FILE: str = "serialized_model.zip"
26+
CENTML_SERIALIZED_INPUT_FILE: str = "serialized_input.zip"
27+
CENTML_PICKLE_PROTOCOL: int = 4
2828

29-
HASH_CHUNK_SIZE: int = 4096
29+
CENTML_HASH_CHUNK_SIZE: int = 4096
3030

3131
# If the server response is smaller than this, don't gzip it
32-
MINIMUM_GZIP_SIZE: int = 1000
32+
CENTML_MINIMUM_GZIP_SIZE: int = 1000
3333

3434

3535
settings = Config()

centml/compiler/server.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
from centml.compiler.utils import get_server_compiled_forward_path
1515

1616
app = FastAPI()
17-
app.add_middleware(GZipMiddleware, minimum_size=settings.MINIMUM_GZIP_SIZE) # type: ignore
17+
app.add_middleware(GZipMiddleware, minimum_size=settings.CENTML_MINIMUM_GZIP_SIZE) # type: ignore
1818

1919

2020
def get_status(model_id: str):
21-
if not os.path.isdir(os.path.join(settings.SERVER_BASE_PATH, model_id)):
21+
if not os.path.isdir(os.path.join(settings.CENTML_SERVER_BASE_PATH, model_id)):
2222
return CompilationStatus.NOT_FOUND
2323

2424
if not os.path.isfile(get_server_compiled_forward_path(model_id)):
@@ -50,7 +50,7 @@ def background_compile(model_id: str, tfx_graph, example_inputs):
5050
# To avoid this, we write to a tmp file and rename it to the correct path after saving.
5151
save_path = get_server_compiled_forward_path(model_id)
5252
tmp_path = save_path + ".tmp"
53-
torch.save(compiled_graph_module, tmp_path, pickle_protocol=settings.PICKLE_PROTOCOL)
53+
torch.save(compiled_graph_module, tmp_path, pickle_protocol=settings.CENTML_PICKLE_PROTOCOL)
5454
os.rename(tmp_path, save_path)
5555
except Exception as e:
5656
logging.getLogger(__name__).exception(f"Saving graph module failed: {e}")
@@ -93,7 +93,7 @@ async def compile_model_handler(model_id: str, model: UploadFile, inputs: Upload
9393
return Response(status_code=200)
9494

9595
# This effectively sets the model's status to COMPILING
96-
os.makedirs(os.path.join(settings.SERVER_BASE_PATH, model_id))
96+
os.makedirs(os.path.join(settings.CENTML_SERVER_BASE_PATH, model_id))
9797

9898
tfx_graph, example_inputs = read_upload_files(model_id, model, inputs)
9999

centml/compiler/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,18 @@
44

55

66
def get_backend_compiled_forward_path(model_id: str):
7-
os.makedirs(os.path.join(settings.BACKEND_BASE_PATH, model_id), exist_ok=True)
8-
return os.path.join(settings.BACKEND_BASE_PATH, model_id, "compilation_return.pkl")
7+
os.makedirs(os.path.join(settings.CENTML_BACKEND_BASE_PATH, model_id), exist_ok=True)
8+
return os.path.join(settings.CENTML_BACKEND_BASE_PATH, model_id, "compilation_return.pkl")
99

1010

1111
def get_server_compiled_forward_path(model_id: str):
12-
os.makedirs(os.path.join(settings.SERVER_BASE_PATH, model_id), exist_ok=True)
13-
return os.path.join(settings.SERVER_BASE_PATH, model_id, "compilation_return.pkl")
12+
os.makedirs(os.path.join(settings.CENTML_SERVER_BASE_PATH, model_id), exist_ok=True)
13+
return os.path.join(settings.CENTML_SERVER_BASE_PATH, model_id, "compilation_return.pkl")
1414

1515

1616
# This function will delete the storage_path/{model_id} directory
1717
def dir_cleanup(model_id: str):
18-
dir_path = os.path.join(settings.SERVER_BASE_PATH, model_id)
18+
dir_path = os.path.join(settings.CENTML_SERVER_BASE_PATH, model_id)
1919
if not os.path.exists(dir_path):
2020
return # Directory does not exist, return
2121

tests/test_backend.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def test_successful_download(self, mock_requests, mock_load, mock_open, mock_mak
119119

120120

121121
class TestWaitForStatus(SetUpGraphModule):
122-
@patch("centml.compiler.config.settings.COMPILING_SLEEP_TIME", new=0)
122+
@patch("centml.compiler.config.settings.CENTML_COMPILER_SLEEP_TIME", new=0)
123123
@patch("centml.compiler.backend.requests")
124124
@patch("logging.Logger.exception")
125125
def test_invalid_status(self, mock_logger, mock_requests):
@@ -132,13 +132,13 @@ def test_invalid_status(self, mock_logger, mock_requests):
132132
self.runner._wait_for_status(model_id)
133133

134134
mock_requests.get.assert_called()
135-
assert mock_requests.get.call_count == settings.MAX_RETRIES + 1
136-
assert len(mock_logger.call_args_list) == settings.MAX_RETRIES + 1
135+
assert mock_requests.get.call_count == settings.CENTML_COMPILER_MAX_RETRIES + 1
136+
assert len(mock_logger.call_args_list) == settings.CENTML_COMPILER_MAX_RETRIES + 1
137137
print(mock_logger.call_args_list)
138138
assert mock_logger.call_args_list[0].startswith("Status check failed:")
139139
assert "Waiting for status: compilation failed too many times.\n" == str(context.exception)
140140

141-
@patch("centml.compiler.config.settings.COMPILING_SLEEP_TIME", new=0)
141+
@patch("centml.compiler.config.settings.CENTML_COMPILER_SLEEP_TIME", new=0)
142142
@patch("centml.compiler.backend.requests")
143143
@patch("logging.Logger.exception")
144144
def test_exception_in_status(self, mock_logger, mock_requests):
@@ -150,11 +150,11 @@ def test_exception_in_status(self, mock_logger, mock_requests):
150150
self.runner._wait_for_status(model_id)
151151

152152
mock_requests.get.assert_called()
153-
assert mock_requests.get.call_count == settings.MAX_RETRIES + 1
153+
assert mock_requests.get.call_count == settings.CENTML_COMPILER_MAX_RETRIES + 1
154154
mock_logger.assert_called_with(f"Status check failed:\n{exception_message}")
155155
assert str(context.exception) == "Waiting for status: compilation failed too many times.\n"
156156

157-
@patch("centml.compiler.config.settings.COMPILING_SLEEP_TIME", new=0)
157+
@patch("centml.compiler.config.settings.CENTML_COMPILER_SLEEP_TIME", new=0)
158158
@patch("centml.compiler.backend.Runner._compile_model")
159159
@patch("centml.compiler.backend.requests")
160160
def test_max_tries(self, mock_requests, mock_compile):
@@ -167,10 +167,10 @@ def test_max_tries(self, mock_requests, mock_compile):
167167
with self.assertRaises(Exception) as context:
168168
self.runner._wait_for_status(model_id)
169169

170-
self.assertEqual(mock_compile.call_count, settings.MAX_RETRIES + 1)
170+
self.assertEqual(mock_compile.call_count, settings.CENTML_COMPILER_MAX_RETRIES + 1)
171171
self.assertIn("Waiting for status: compilation failed too many times", str(context.exception))
172172

173-
@patch("centml.compiler.config.settings.COMPILING_SLEEP_TIME", new=0)
173+
@patch("centml.compiler.config.settings.CENTML_COMPILER_SLEEP_TIME", new=0)
174174
@patch("centml.compiler.backend.requests")
175175
def test_wait_on_compilation(self, mock_requests):
176176
# Mock the status check
@@ -186,7 +186,7 @@ def test_wait_on_compilation(self, mock_requests):
186186
# _wait_for_status should return True when compilation DONE
187187
assert self.runner._wait_for_status(model_id)
188188

189-
@patch("centml.compiler.config.settings.COMPILING_SLEEP_TIME", new=0)
189+
@patch("centml.compiler.config.settings.CENTML_COMPILER_SLEEP_TIME", new=0)
190190
@patch("centml.compiler.backend.requests")
191191
@patch("centml.compiler.backend.Runner._compile_model")
192192
@patch("logging.Logger.exception")
@@ -206,10 +206,10 @@ def test_exception_in_compilation(self, mock_logger, mock_compile, mock_requests
206206
self.runner._wait_for_status(model_id)
207207

208208
mock_requests.get.assert_called()
209-
assert mock_requests.get.call_count == settings.MAX_RETRIES + 1
209+
assert mock_requests.get.call_count == settings.CENTML_COMPILER_MAX_RETRIES + 1
210210

211211
mock_compile.assert_called()
212-
assert mock_compile.call_count == settings.MAX_RETRIES + 1
212+
assert mock_compile.call_count == settings.CENTML_COMPILER_MAX_RETRIES + 1
213213

214214
mock_logger.assert_called_with(f"Submitting compilation failed:\n{exception_message}")
215215
assert str(context.exception) == "Waiting for status: compilation failed too many times.\n"

0 commit comments

Comments
 (0)