Skip to content

Commit e034024

Browse files
authored
[Cleanup] Use pydantic base class for settings (#68)
* use pydantic BaseSettings * use Config instance instead of class * rename config_instance to settings
1 parent 42e78a0 commit e034024

File tree

10 files changed

+64
-57
lines changed

10 files changed

+64
-57
lines changed

centml/cli/login.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import click
22

3-
from centml.sdk import auth, config
3+
from centml.sdk import auth
4+
from centml.sdk.config import settings
45

56

67
@click.command(help="Login to CentML")
@@ -10,15 +11,15 @@ def login(token_file):
1011
auth.store_centml_cred(token_file)
1112

1213
if auth.load_centml_cred():
13-
click.echo(f"Authenticating with credentials from {config.Config.centml_cred_file}\n")
14+
click.echo(f"Authenticating with credentials from {settings.CENTML_CRED_FILE_PATH}\n")
1415
click.echo("Login successful")
1516
else:
1617
click.echo("Login with CentML authentication token")
1718
click.echo("Usage: centml login TOKEN_FILE\n")
1819
choice = click.confirm("Do you want to download the token?")
1920

2021
if choice:
21-
click.launch(f"{config.Config.centml_web_url}?isCliAuthenticated=true")
22+
click.launch(f"{settings.CENTML_WEB_URL}?isCliAuthenticated=true")
2223
else:
2324
click.echo("Login unsuccessful")
2425

centml/compiler/backend.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import requests
1212
import torch
1313
from torch.fx import GraphModule
14-
from centml.compiler.config import config_instance, CompilationStatus
14+
from centml.compiler.config import settings, CompilationStatus
1515
from centml.compiler.utils import get_backend_compiled_forward_path
1616

1717

@@ -54,13 +54,13 @@ def inputs(self):
5454

5555
def _serialize_model_and_inputs(self):
5656
self.serialized_model_dir = TemporaryDirectory() # pylint: disable=consider-using-with
57-
self.serialized_model_path = os.path.join(self.serialized_model_dir.name, config_instance.SERIALIZED_MODEL_FILE)
58-
self.serialized_input_path = os.path.join(self.serialized_model_dir.name, config_instance.SERIALIZED_INPUT_FILE)
57+
self.serialized_model_path = os.path.join(self.serialized_model_dir.name, settings.SERIALIZED_MODEL_FILE)
58+
self.serialized_input_path = os.path.join(self.serialized_model_dir.name, settings.SERIALIZED_INPUT_FILE)
5959

6060
# torch.save saves a zip file full of pickled files with the model's states.
6161
try:
62-
torch.save(self.module, self.serialized_model_path, pickle_protocol=config_instance.PICKLE_PROTOCOL)
63-
torch.save(self.inputs, self.serialized_input_path, pickle_protocol=config_instance.PICKLE_PROTOCOL)
62+
torch.save(self.module, self.serialized_model_path, pickle_protocol=settings.PICKLE_PROTOCOL)
63+
torch.save(self.inputs, self.serialized_input_path, pickle_protocol=settings.PICKLE_PROTOCOL)
6464
except Exception as e:
6565
raise Exception(f"Failed to save module or inputs with torch.save: {e}") from e
6666

@@ -71,7 +71,7 @@ def _get_model_id(self) -> str:
7171
sha_hash = hashlib.sha256()
7272
with open(self.serialized_model_path, "rb") as serialized_model_file:
7373
# Read in chunks to not load too much into memory
74-
for block in iter(lambda: serialized_model_file.read(config_instance.HASH_CHUNK_SIZE), b""):
74+
for block in iter(lambda: serialized_model_file.read(settings.HASH_CHUNK_SIZE), b""):
7575
sha_hash.update(block)
7676

7777
model_id = sha_hash.hexdigest()
@@ -80,7 +80,7 @@ def _get_model_id(self) -> str:
8080

8181
def _download_model(self, model_id: str):
8282
download_response = requests.get(
83-
url=f"{config_instance.SERVER_URL}/download/{model_id}", timeout=config_instance.TIMEOUT
83+
url=f"{settings.CENTML_SERVER_URL}/download/{model_id}", timeout=settings.TIMEOUT
8484
)
8585
if download_response.status_code != HTTPStatus.OK:
8686
raise Exception(
@@ -104,9 +104,9 @@ def _compile_model(self, model_id: str):
104104

105105
with open(self.serialized_model_path, 'rb') as model_file, open(self.serialized_input_path, 'rb') as input_file:
106106
compile_response = requests.post(
107-
url=f"{config_instance.SERVER_URL}/submit/{model_id}",
107+
url=f"{settings.CENTML_SERVER_URL}/submit/{model_id}",
108108
files={"model": model_file, "inputs": input_file},
109-
timeout=config_instance.TIMEOUT,
109+
timeout=settings.TIMEOUT,
110110
)
111111

112112
if compile_response.status_code != HTTPStatus.OK:
@@ -118,9 +118,7 @@ def _wait_for_status(self, model_id: str) -> bool:
118118
tries = 0
119119
while True:
120120
# get server compilation status
121-
status_response = requests.get(
122-
f"{config_instance.SERVER_URL}/status/{model_id}", timeout=config_instance.TIMEOUT
123-
)
121+
status_response = requests.get(f"{settings.CENTML_SERVER_URL}/status/{model_id}", timeout=settings.TIMEOUT)
124122
if status_response.status_code != HTTPStatus.OK:
125123
raise Exception(
126124
f"Status check: request failed, exception from server:\n{status_response.json().get('detail')}"
@@ -138,10 +136,10 @@ def _wait_for_status(self, model_id: str) -> bool:
138136
else:
139137
tries += 1
140138

141-
if tries > config_instance.MAX_RETRIES:
139+
if tries > settings.MAX_RETRIES:
142140
raise Exception("Waiting for status: compilation failed too many times.\n")
143141

144-
time.sleep(config_instance.COMPILING_SLEEP_TIME)
142+
time.sleep(settings.COMPILING_SLEEP_TIME)
145143

146144
def remote_compilation(self):
147145
self._serialize_model_and_inputs()

centml/compiler/config.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
from enum import Enum
3+
from pydantic_settings import BaseSettings
34

45

56
class CompilationStatus(Enum):
@@ -8,20 +9,19 @@ class CompilationStatus(Enum):
89
DONE = "DONE"
910

1011

11-
class Config:
12+
class Config(BaseSettings):
1213
TIMEOUT: int = 10
1314
MAX_RETRIES: int = 3
1415
COMPILING_SLEEP_TIME: int = 15
1516

16-
CACHE_PATH: str = os.getenv("CENTML_CACHE_DIR", default=os.path.expanduser("~/.cache/centml"))
17+
CENTML_CACHE_DIR: str = os.path.expanduser("~/.cache/centml")
18+
BACKEND_BASE_PATH: str = os.path.join(CENTML_CACHE_DIR, "backend")
19+
SERVER_BASE_PATH: str = os.path.join(CENTML_CACHE_DIR, "server")
1720

18-
SERVER_URL: str = os.getenv("CENTML_SERVER_URL", default="http://0.0.0.0:8090")
19-
20-
BACKEND_BASE_PATH: str = os.path.join(CACHE_PATH, "backend")
21-
SERVER_BASE_PATH: str = os.path.join(CACHE_PATH, "server")
21+
CENTML_SERVER_URL: str = "http://0.0.0.0:8090"
2222

2323
# Use a constant path since torch.save uses the given file name in it's zipfile.
24-
# Thus, a different filename would result in a different hash.
24+
# Using a different filename would result in a different hash.
2525
SERIALIZED_MODEL_FILE: str = "serialized_model.zip"
2626
SERIALIZED_INPUT_FILE: str = "serialized_input.zip"
2727
PICKLE_PROTOCOL: int = 4
@@ -32,4 +32,4 @@ class Config:
3232
MINIMUM_GZIP_SIZE: int = 1000
3333

3434

35-
config_instance = Config()
35+
settings = Config()

centml/compiler/server.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@
1010
from fastapi.middleware.gzip import GZipMiddleware
1111
from centml.compiler.server_compilation import hidet_backend_server
1212
from centml.compiler.utils import dir_cleanup
13-
from centml.compiler.config import config_instance, CompilationStatus
13+
from centml.compiler.config import settings, CompilationStatus
1414
from centml.compiler.utils import get_server_compiled_forward_path
1515

1616
app = FastAPI()
17-
app.add_middleware(GZipMiddleware, minimum_size=config_instance.MINIMUM_GZIP_SIZE) # type: ignore
17+
app.add_middleware(GZipMiddleware, minimum_size=settings.MINIMUM_GZIP_SIZE) # type: ignore
1818

1919

2020
def get_status(model_id: str):
21-
if not os.path.isdir(os.path.join(config_instance.SERVER_BASE_PATH, model_id)):
21+
if not os.path.isdir(os.path.join(settings.SERVER_BASE_PATH, model_id)):
2222
return CompilationStatus.NOT_FOUND
2323

2424
if not os.path.isfile(get_server_compiled_forward_path(model_id)):
@@ -50,7 +50,7 @@ def background_compile(model_id: str, tfx_graph, example_inputs):
5050
# To avoid this, we write to a tmp file and rename it to the correct path after saving.
5151
save_path = get_server_compiled_forward_path(model_id)
5252
tmp_path = save_path + ".tmp"
53-
torch.save(compiled_graph_module, tmp_path, pickle_protocol=config_instance.PICKLE_PROTOCOL)
53+
torch.save(compiled_graph_module, tmp_path, pickle_protocol=settings.PICKLE_PROTOCOL)
5454
os.rename(tmp_path, save_path)
5555
except Exception as e:
5656
logging.getLogger(__name__).exception(f"Saving graph module failed: {e}")
@@ -93,7 +93,7 @@ async def compile_model_handler(model_id: str, model: UploadFile, inputs: Upload
9393
return Response(status_code=200)
9494

9595
# This effectively sets the model's status to COMPILING
96-
os.makedirs(os.path.join(config_instance.SERVER_BASE_PATH, model_id))
96+
os.makedirs(os.path.join(settings.SERVER_BASE_PATH, model_id))
9797

9898
tfx_graph, example_inputs = read_upload_files(model_id, model, inputs)
9999

@@ -110,7 +110,7 @@ async def download_handler(model_id: str):
110110

111111

112112
def run():
113-
parsed = urlparse(config_instance.SERVER_URL)
113+
parsed = urlparse(settings.CENTML_SERVER_URL)
114114
uvicorn.run(app, host=parsed.hostname, port=parsed.port)
115115

116116

centml/compiler/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
import os
22
import shutil
3-
from centml.compiler.config import config_instance
3+
from centml.compiler.config import settings
44

55

66
def get_backend_compiled_forward_path(model_id: str):
7-
os.makedirs(os.path.join(config_instance.BACKEND_BASE_PATH, model_id), exist_ok=True)
8-
return os.path.join(config_instance.BACKEND_BASE_PATH, model_id, "compilation_return.pkl")
7+
os.makedirs(os.path.join(settings.BACKEND_BASE_PATH, model_id), exist_ok=True)
8+
return os.path.join(settings.BACKEND_BASE_PATH, model_id, "compilation_return.pkl")
99

1010

1111
def get_server_compiled_forward_path(model_id: str):
12-
os.makedirs(os.path.join(config_instance.SERVER_BASE_PATH, model_id), exist_ok=True)
13-
return os.path.join(config_instance.SERVER_BASE_PATH, model_id, "compilation_return.pkl")
12+
os.makedirs(os.path.join(settings.SERVER_BASE_PATH, model_id), exist_ok=True)
13+
return os.path.join(settings.SERVER_BASE_PATH, model_id, "compilation_return.pkl")
1414

1515

1616
# This function will delete the storage_path/{model_id} directory
1717
def dir_cleanup(model_id: str):
18-
dir_path = os.path.join(config_instance.SERVER_BASE_PATH, model_id)
18+
dir_path = os.path.join(settings.SERVER_BASE_PATH, model_id)
1919
if not os.path.exists(dir_path):
2020
return # Directory does not exist, return
2121

centml/sdk/api.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
from platform_api_client.models.deployment_status import DeploymentStatus
44

55
from centml.sdk import auth
6-
from centml.sdk.config import Config
6+
from centml.sdk.config import settings
77
from centml.sdk.utils import client_certs
88

99

1010
@contextlib.contextmanager
1111
def get_api():
12-
configuration = platform_api_client.Configuration(host=Config.platformapi_url, access_token=auth.get_centml_token())
12+
configuration = platform_api_client.Configuration(
13+
host=settings.PLATFORM_API_URL, access_token=auth.get_centml_token()
14+
)
1315

1416
with platform_api_client.ApiClient(configuration) as api_client:
1517
api_instance = platform_api_client.EXTERNALApi(api_client)

centml/sdk/auth.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
import requests
66
import jwt
77

8-
from centml.sdk.config import Config
8+
from centml.sdk.config import settings
99

1010

1111
def refresh_centml_token(refresh_token):
12-
api_key = Config.firebase_api_key
12+
api_key = settings.FIREBASE_API_KEY
1313

1414
cred = requests.post(
1515
f"https://securetoken.googleapis.com/v1/token?key={api_key}",
@@ -18,7 +18,7 @@ def refresh_centml_token(refresh_token):
1818
timeout=3,
1919
).json()
2020

21-
with open(Config.centml_cred_file, 'w') as f:
21+
with open(settings.CENTML_CRED_FILE_PATH, 'w') as f:
2222
json.dump(cred, f)
2323

2424
return cred
@@ -27,7 +27,7 @@ def refresh_centml_token(refresh_token):
2727
def store_centml_cred(token_file):
2828
try:
2929
with open(token_file, 'r') as f:
30-
os.makedirs(Config.centml_config_dir, exist_ok=True)
30+
os.makedirs(settings.CENTML_CONFIG_PATH, exist_ok=True)
3131
refresh_token = json.load(f)["refreshToken"]
3232

3333
refresh_centml_token(refresh_token)
@@ -38,8 +38,8 @@ def store_centml_cred(token_file):
3838
def load_centml_cred():
3939
cred = None
4040

41-
if os.path.exists(Config.centml_cred_file):
42-
with open(Config.centml_cred_file, 'r') as f:
41+
if os.path.exists(settings.CENTML_CRED_FILE_PATH):
42+
with open(settings.CENTML_CRED_FILE_PATH, 'r') as f:
4343
cred = json.load(f)
4444

4545
return cred
@@ -60,5 +60,5 @@ def get_centml_token():
6060

6161

6262
def remove_centml_cred():
63-
if os.path.exists(Config.centml_cred_file):
64-
os.remove(Config.centml_cred_file)
63+
if os.path.exists(settings.CENTML_CRED_FILE_PATH):
64+
os.remove(settings.CENTML_CRED_FILE_PATH)

centml/sdk/config.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
import os
2+
from pydantic_settings import BaseSettings
23

34

4-
class Config:
5-
centml_web_url = "https://main.d1tz9z8hgabab9.amplifyapp.com/"
6-
centml_config_dir = os.getenv("CENTML_CONFIG_PATH", default=os.path.expanduser("~/.centml"))
7-
centml_cred_file = centml_config_dir + "/" + os.getenv("CENTML_CRED_FILE", default="credential")
5+
class Config(BaseSettings):
6+
CENTML_WEB_URL: str = "https://main.d1tz9z8hgabab9.amplifyapp.com/"
7+
CENTML_CONFIG_PATH: str = os.path.expanduser("~/.centml")
8+
CENTML_CRED_FILE: str = "credential"
9+
CENTML_CRED_FILE_PATH: str = CENTML_CONFIG_PATH + "/" + CENTML_CRED_FILE
810

9-
platformapi_url = "https://api.centml.org"
11+
PLATFORM_API_URL: str = "https://api.centml.org"
1012

11-
firebase_api_key = "AIzaSyBXSNjruNdtypqUt_CPhB8QNl8Djfh5RXI"
13+
FIREBASE_API_KEY: str = "AIzaSyBXSNjruNdtypqUt_CPhB8QNl8Djfh5RXI"
14+
15+
16+
settings = Config()

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ torch>=2.1.0
22
fastapi>=0.103.0
33
uvicorn>=0.23.0
44
python-multipart>=0.0.6
5+
pydantic-settings==2.0.*
56
Requests==2.32.2
67
tabulate>=0.9.0
78
pyjwt>=2.8.0

tests/test_backend.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from parameterized import parameterized_class
77
from torch.fx import GraphModule
88
from centml.compiler.backend import Runner
9-
from centml.compiler.config import CompilationStatus, config_instance
9+
from centml.compiler.config import CompilationStatus, settings
1010
from .test_helpers import MODEL_SUITE
1111

1212

@@ -132,7 +132,7 @@ def test_invalid_status(self, mock_requests):
132132
mock_requests.get.assert_called_once()
133133
self.assertIn("Status check: request failed, exception from server", str(context.exception))
134134

135-
@patch("centml.compiler.config.Config.COMPILING_SLEEP_TIME", new=0)
135+
@patch("centml.compiler.config.settings.COMPILING_SLEEP_TIME", new=0)
136136
@patch("centml.compiler.backend.Runner._compile_model")
137137
@patch("centml.compiler.backend.requests")
138138
def test_max_tries(self, mock_requests, mock_compile):
@@ -145,10 +145,10 @@ def test_max_tries(self, mock_requests, mock_compile):
145145
with self.assertRaises(Exception) as context:
146146
self.runner._wait_for_status(model_id)
147147

148-
self.assertEqual(mock_compile.call_count, config_instance.MAX_RETRIES + 1)
148+
self.assertEqual(mock_compile.call_count, settings.MAX_RETRIES + 1)
149149
self.assertIn("Waiting for status: compilation failed too many times", str(context.exception))
150150

151-
@patch("centml.compiler.config.Config.COMPILING_SLEEP_TIME", new=0)
151+
@patch("centml.compiler.config.settings.COMPILING_SLEEP_TIME", new=0)
152152
@patch("centml.compiler.backend.requests")
153153
def test_wait_on_compilation(self, mock_requests):
154154
COMPILATION_STEPS = 10

0 commit comments

Comments
 (0)