This repository was archived by the owner on Jan 24, 2024. It is now read-only.
File tree Expand file tree Collapse file tree 3 files changed +30
-2
lines changed Expand file tree Collapse file tree 3 files changed +30
-2
lines changed Original file line number Diff line number Diff line change @@ -21,6 +21,9 @@ def is_true(value):
2121MODEL_REVISION = os .getenv ("MODEL_REVISION" , "" )
2222MODEL_CACHE_DIR = os .getenv ("MODEL_CACHE_DIR" , "models" )
2323MODEL_LOAD_IN_8BIT = is_true (os .getenv ("MODEL_LOAD_IN_8BIT" , "" ))
24+ MODEL_LOAD_IN_4BIT = is_true (os .getenv ("MODEL_LOAD_IN_4BIT" , "" ))
25+ MODEL_4BIT_QUANT_TYPE = os .getenv ("MODEL_4BIT_QUANT_TYPE" , "fp4" )
26+ MODEL_4BIT_DOUBLE_QUANT = is_true (os .getenv ("MODEL_4BIT_DOUBLE_QUANT" , "" ))
2427MODEL_LOCAL_FILES_ONLY = is_true (os .getenv ("MODEL_LOCAL_FILES_ONLY" , "" ))
2528MODEL_TRUST_REMOTE_CODE = is_true (os .getenv ("MODEL_TRUST_REMOTE_CODE" , "" ))
2629MODEL_HALF_PRECISION = is_true (os .getenv ("MODEL_HALF_PRECISION" , "" ))
Original file line number Diff line number Diff line change 2020from . import MODEL_REVISION
2121from . import MODEL_CACHE_DIR
2222from . import MODEL_LOAD_IN_8BIT
23+ from . import MODEL_LOAD_IN_4BIT
24+ from . import MODEL_4BIT_QUANT_TYPE
25+ from . import MODEL_4BIT_DOUBLE_QUANT
2326from . import MODEL_LOCAL_FILES_ONLY
2427from . import MODEL_TRUST_REMOTE_CODE
2528from . import MODEL_HALF_PRECISION
4245 revision = MODEL_REVISION ,
4346 cache_dir = MODEL_CACHE_DIR ,
4447 load_in_8bit = MODEL_LOAD_IN_8BIT ,
48+ load_in_4bit = MODEL_LOAD_IN_4BIT ,
49+ quant_type = MODEL_4BIT_QUANT_TYPE ,
50+ double_quant = MODEL_4BIT_DOUBLE_QUANT ,
4551 local_files_only = MODEL_LOCAL_FILES_ONLY ,
4652 trust_remote_code = MODEL_TRUST_REMOTE_CODE ,
4753 half_precision = MODEL_HALF_PRECISION ,
Original file line number Diff line number Diff line change 1212 MinNewTokensLengthLogitsProcessor ,
1313 TemperatureLogitsWarper ,
1414 TopPLogitsWarper ,
15+ BitsAndBytesConfig
1516)
1617
1718from .choice import map_choice
@@ -302,6 +303,9 @@ def load_model(
302303 revision = None ,
303304 cache_dir = None ,
304305 load_in_8bit = False ,
306+ load_in_4bit = False ,
307+ quant_type = "fp4" ,
308+ double_quant = False ,
305309 local_files_only = False ,
306310 trust_remote_code = False ,
307311 half_precision = False ,
@@ -319,12 +323,27 @@ def load_model(
319323
320324 # Set device mapping and quantization options if CUDA is available.
321325 if torch .cuda .is_available ():
326+ # Set quantization options if specified.
327+ quant_config = None
328+ if load_in_8bit and load_in_4bit :
329+ raise ValueError ("Only one of load_in_8bit and load_in_4bit can be True" )
330+ if load_in_8bit :
331+ quant_config = BitsAndBytesConfig (
332+ load_in_8bit = True ,
333+ )
334+ elif load_in_4bit :
335+ quant_config = BitsAndBytesConfig (
336+ load_in_4bit = True ,
337+ bnb_4bit_quant_type = quant_type ,
338+ bnb_4bit_use_double_quant = double_quant ,
339+ bnb_4bit_compute_dtype = torch .bfloat16 ,
340+ )
322341 kwargs = kwargs .copy ()
323342 kwargs ["device_map" ] = "auto"
324- kwargs ["load_in_8bit " ] = load_in_8bit
343+ kwargs ["quantization_config " ] = quant_config
325344
326345 # Cast all parameters to float16 if quantization is enabled.
327- if half_precision or load_in_8bit :
346+ if half_precision or load_in_8bit or load_in_4bit :
328347 kwargs ["torch_dtype" ] = torch .float16
329348
330349 # Support both decoder-only and encoder-decoder models.
You can’t perform that action at this time.
0 commit comments