1111import requests
1212import torch
1313from torch .fx import GraphModule
14- from centml .compiler .config import config_instance , CompilationStatus
14+ from centml .compiler .config import settings , CompilationStatus
1515from centml .compiler .utils import get_backend_compiled_forward_path
1616
1717
@@ -54,13 +54,13 @@ def inputs(self):
5454
5555 def _serialize_model_and_inputs (self ):
5656 self .serialized_model_dir = TemporaryDirectory () # pylint: disable=consider-using-with
57- self .serialized_model_path = os .path .join (self .serialized_model_dir .name , config_instance .SERIALIZED_MODEL_FILE )
58- self .serialized_input_path = os .path .join (self .serialized_model_dir .name , config_instance .SERIALIZED_INPUT_FILE )
57+ self .serialized_model_path = os .path .join (self .serialized_model_dir .name , settings .SERIALIZED_MODEL_FILE )
58+ self .serialized_input_path = os .path .join (self .serialized_model_dir .name , settings .SERIALIZED_INPUT_FILE )
5959
6060 # torch.save saves a zip file full of pickled files with the model's states.
6161 try :
62- torch .save (self .module , self .serialized_model_path , pickle_protocol = config_instance .PICKLE_PROTOCOL )
63- torch .save (self .inputs , self .serialized_input_path , pickle_protocol = config_instance .PICKLE_PROTOCOL )
62+ torch .save (self .module , self .serialized_model_path , pickle_protocol = settings .PICKLE_PROTOCOL )
63+ torch .save (self .inputs , self .serialized_input_path , pickle_protocol = settings .PICKLE_PROTOCOL )
6464 except Exception as e :
6565 raise Exception (f"Failed to save module or inputs with torch.save: { e } " ) from e
6666
@@ -71,7 +71,7 @@ def _get_model_id(self) -> str:
7171 sha_hash = hashlib .sha256 ()
7272 with open (self .serialized_model_path , "rb" ) as serialized_model_file :
7373 # Read in chunks to not load too much into memory
74- for block in iter (lambda : serialized_model_file .read (config_instance .HASH_CHUNK_SIZE ), b"" ):
74+ for block in iter (lambda : serialized_model_file .read (settings .HASH_CHUNK_SIZE ), b"" ):
7575 sha_hash .update (block )
7676
7777 model_id = sha_hash .hexdigest ()
@@ -80,7 +80,7 @@ def _get_model_id(self) -> str:
8080
8181 def _download_model (self , model_id : str ):
8282 download_response = requests .get (
83- url = f"{ config_instance . SERVER_URL } /download/{ model_id } " , timeout = config_instance .TIMEOUT
83+ url = f"{ settings . CENTML_SERVER_URL } /download/{ model_id } " , timeout = settings .TIMEOUT
8484 )
8585 if download_response .status_code != HTTPStatus .OK :
8686 raise Exception (
@@ -104,9 +104,9 @@ def _compile_model(self, model_id: str):
104104
105105 with open (self .serialized_model_path , 'rb' ) as model_file , open (self .serialized_input_path , 'rb' ) as input_file :
106106 compile_response = requests .post (
107- url = f"{ config_instance . SERVER_URL } /submit/{ model_id } " ,
107+ url = f"{ settings . CENTML_SERVER_URL } /submit/{ model_id } " ,
108108 files = {"model" : model_file , "inputs" : input_file },
109- timeout = config_instance .TIMEOUT ,
109+ timeout = settings .TIMEOUT ,
110110 )
111111
112112 if compile_response .status_code != HTTPStatus .OK :
@@ -118,9 +118,7 @@ def _wait_for_status(self, model_id: str) -> bool:
118118 tries = 0
119119 while True :
120120 # get server compilation status
121- status_response = requests .get (
122- f"{ config_instance .SERVER_URL } /status/{ model_id } " , timeout = config_instance .TIMEOUT
123- )
121+ status_response = requests .get (f"{ settings .CENTML_SERVER_URL } /status/{ model_id } " , timeout = settings .TIMEOUT )
124122 if status_response .status_code != HTTPStatus .OK :
125123 raise Exception (
126124 f"Status check: request failed, exception from server:\n { status_response .json ().get ('detail' )} "
@@ -138,10 +136,10 @@ def _wait_for_status(self, model_id: str) -> bool:
138136 else :
139137 tries += 1
140138
141- if tries > config_instance .MAX_RETRIES :
139+ if tries > settings .MAX_RETRIES :
142140 raise Exception ("Waiting for status: compilation failed too many times.\n " )
143141
144- time .sleep (config_instance .COMPILING_SLEEP_TIME )
142+ time .sleep (settings .COMPILING_SLEEP_TIME )
145143
146144 def remote_compilation (self ):
147145 self ._serialize_model_and_inputs ()
0 commit comments