diff --git a/README.md b/README.md index d6b8c232..44cc50d6 100644 --- a/README.md +++ b/README.md @@ -207,6 +207,7 @@ export HF_ACCESS_TOKEN="your_huggingface_api_key" - `task_id` (int, optional): Problem task id for selecting a problem from a Dataset. +- `device_map` (str, optional): Device map for the model. Defaults to None. - `kwargs`(void, optional): Currently supported `kwargs` are `max_length`, `max_new_tokens`, `min_length`, `min_new_tokens`, `early_stopping`, `do_sample`, `num_beams`, `use_cache`, `temperature`, `top_k`, `top_p`, `num_return_sequences`, `pad_token_id`, and `eos_token_id`. Refer to the [HuggingFace Text Generation Documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation) for more information. @@ -237,6 +238,7 @@ python3 syncode/infer.py --new_mask_store [True, False] --parser ["lr", "lalr"] --task_id [task_id] + --device_map [device_map] ``` diff --git a/syncode/common.py b/syncode/common.py index baf5370f..cbf06662 100644 --- a/syncode/common.py +++ b/syncode/common.py @@ -11,16 +11,22 @@ HF_ACCESS_TOKEN = os.environ['HF_ACCESS_TOKEN'] if 'HF_ACCESS_TOKEN' in os.environ else None -def load_model(model_name, device, quantize): +def load_model(model_name, device, quantize, device_map = None): if model_name == 'test': model = AutoModelForCausalLM.from_pretrained('bigcode/tiny_starcoder_py').to(device) elif model_name == 'test-instruct': model = AutoModelForCausalLM.from_pretrained("rahuldshetty/tiny-starcoder-instruct") else: - if (quantize): - model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True).eval().to(device) + if device_map is not None: + if (quantize): + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True, device_map = device_map).eval() + else: + model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True, device_map = device_map).eval() else: - model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True).eval().to(device) + if (quantize): + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True).eval().to(device) + else: + model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True).eval().to(device) return model def load_tokenizer(model_name): diff --git a/syncode/infer.py b/syncode/infer.py index 8bc68d0a..deaa0bac 100644 --- a/syncode/infer.py +++ b/syncode/infer.py @@ -50,6 +50,7 @@ def __init__( parser: Literal["lr", "lalr"] = "lalr", seed: Optional[int] = None, opp: bool = True, + device_map: Optional[str] = None, **kwargs ): # Check inputs @@ -85,7 +86,7 @@ def __init__( self.grammar = Grammar(grammar) if self._is_grammar_mode() else None # Load model and tokenizer - model = common.load_model(self.model_name, device, quantize) + model = common.load_model(self.model_name, device, quantize, device_map) tokenizer = common.load_tokenizer(self.model_name) # Initialize grammar decoder if needed @@ -259,6 +260,7 @@ def main( parse_output_only: bool = True, prompt_type: str = 'original', format_tabs: bool = False, + device_map: Optional[str] = None, **kwargs ): """Run Syncode with the specified configuration. @@ -309,6 +311,7 @@ def main( seed=seed, opp=opp, parse_output_only=parse_output_only, + device_map=device_map, **kwargs ) diff --git a/syncode/language_model.py b/syncode/language_model.py index 4b1607c2..2db3c11d 100644 --- a/syncode/language_model.py +++ b/syncode/language_model.py @@ -50,7 +50,7 @@ def __init__( self.prompt_template = prompt_template self.model: PreTrainedModel = model self.tokenizer = tokenizer - self.device = device + self.device = self.model.device self.best_of = best_of self._before_prediction_hook = before_prediction_hook self.logits_processor = grammar_decoder