diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 4c54f961..7b02d490 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -27,7 +27,7 @@ jobs: uses: actions/cache@v3 with: path: /home/runner/work/syncode/syncode/cache/mask_stores/ - key: files-${{ hashFiles('syncode/parsers/grammars/python_grammar.lark', 'syncode/dfa_mask_store.py') }} + key: files-${{ hashFiles('syncode/parsers/grammars/python.lark', 'syncode/dfa_mask_store.py') }} - name: Run Tests run: | python3 -m unittest tests.test_misc diff --git a/README.md b/README.md index e6a480cd..4fa8dd97 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,7 @@ SynCode depends on HuggingFace [transformers](https://github.com/huggingface/tra | SynCode version | Required transformers version | Python version | | -------------- | ----------------------------- | -------------- | -| `v0.4.10` (latest) | `v4.44.0` | 3.6 - 3.12 | +| `v0.4.11` (latest) | `v4.51.0` | 3.6 - 3.12 | **Note:** Python 3.13 is not currently supported due to dependency constraints. diff --git a/pyproject.toml b/pyproject.toml index 2dfadc7f..89b68576 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "syncode" -version="0.4.10" +version="0.4.11" requires-python = ">=3.6,<3.13" description = "Grammar-guided code generation tool" readme = "README.md" @@ -24,7 +24,7 @@ dependencies = [ "regex==2023.8.8", "torch", "tqdm", - "transformers==4.44.0", + "transformers==4.51.0", "datasets", "jsonschema", ] diff --git a/requirements.txt b/requirements.txt index 6a98aea6..aca1dee8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,9 @@ +accelerate fire interegular regex==2023.8.8 torch tqdm -transformers==4.44.0; python_version < "3.13" +transformers==4.51.0; python_version < "3.13" datasets jsonschema diff --git a/setup.py b/setup.py index fcc423f5..9a78553d 100644 --- a/setup.py +++ b/setup.py @@ -11,14 +11,14 @@ "regex==2023.8.8", "torch", "tqdm", - "transformers==4.44.0", + "transformers==4.51.0", "datasets", "jsonschema" ] setuptools.setup( name="syncode", - version="0.4.10", + version="0.4.11", author="Shubham Ugare", author_email="shubhamugare@gmail.com", description="This package provides the tool for grammar augmented LLM generation.", diff --git a/syncode/common.py b/syncode/common.py index cbf06662..e480265b 100644 --- a/syncode/common.py +++ b/syncode/common.py @@ -12,21 +12,32 @@ def load_model(model_name, device, quantize, device_map = None): + torch_dtype = torch.bfloat16 if quantize else "auto" + device_map = device_map if device_map is not None else "auto" + + attn_implementation = None + if "gemma-3" in model_name: + # This is due to the gemma-3 issue with SDPA implementation + # https://github.com/google-deepmind/gemma/issues/169 + attn_implementation = "eager" + logging.info("Using slower \"eager\" attention implementation for gemma-3 due to issue with SDPA implementation") + if model_name == 'test': model = AutoModelForCausalLM.from_pretrained('bigcode/tiny_starcoder_py').to(device) elif model_name == 'test-instruct': model = AutoModelForCausalLM.from_pretrained("rahuldshetty/tiny-starcoder-instruct") else: if device_map is not None: - if (quantize): - model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True, device_map = device_map).eval() - else: - model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True, device_map = device_map).eval() - else: - if (quantize): - model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True).eval().to(device) - else: - model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True).eval().to(device) + logging.info(f"Loading model {model_name} with device:{device}, device_map:{device_map}, torch_dtype:{torch_dtype}") + model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch_dtype, + cache_dir=HF_CACHE, + token=HF_ACCESS_TOKEN, + trust_remote_code=True, + device_map = device_map, + attn_implementation=attn_implementation + ).eval() return model def load_tokenizer(model_name): @@ -35,7 +46,12 @@ def load_tokenizer(model_name): elif model_name == 'test-instruct': tokenizer = AutoTokenizer.from_pretrained("rahuldshetty/tiny-starcoder-instruct") else: - tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + model_name, + cache_dir=HF_CACHE, + token=HF_ACCESS_TOKEN, + trust_remote_code=True + ) return tokenizer def get_output_path(model_name, grammar, dataset, num_samples, mode): diff --git a/syncode/evaluation/json_eval.py b/syncode/evaluation/json_eval.py index bf540379..bd73e10e 100644 --- a/syncode/evaluation/json_eval.py +++ b/syncode/evaluation/json_eval.py @@ -72,7 +72,10 @@ def run_eval_for_task(syncode, num_samples_per_task, problem, samples, pbar, tas else: problem["prompt"][0]['content'] = f"{problem['prompt'][0]['content']}\nOnly output JSON.\nJSON:\n" - prompt = syncode.model.tokenizer.apply_chat_template(problem["prompt"], tokenize = False) + if syncode.model.tokenizer.chat_template is not None: + prompt = syncode.model.tokenizer.apply_chat_template(problem["prompt"], tokenize = False) + else: + prompt = problem["prompt"][0]['content'] batch_completions = syncode.model.generate_grammar_constrained_completion(prompt, num_samples_per_task) for completion_id, completion in enumerate(batch_completions): diff --git a/syncode/language_model.py b/syncode/language_model.py index 2db3c11d..e1cac7ff 100644 --- a/syncode/language_model.py +++ b/syncode/language_model.py @@ -1,15 +1,19 @@ -from ast import Tuple +from ast import Dict, Tuple import time import torch import syncode.common as common from syncode.grammar_mask.logits_processor import SyncodeLogitsProcessor from transformers import LogitsProcessorList, StoppingCriteriaList, StoppingCriteria, PreTrainedModel from syncode.parsers.grammars import Grammar -from syncode.utils.generation import filter_code, fix_indents -from typing import Callable, Iterable, Union +from typing import Any, Callable, Iterable, Union from transformers.generation.utils import GenerationMode from transformers.generation.configuration_utils import GenerationConfig - +from transformers.generation.logits_process import ( + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, +) +from transformers.cache_utils import Cache class KeywordsStoppingCriteria(StoppingCriteria): ''' @@ -172,8 +176,10 @@ def get_tokenized_input(self, prompt: Union[str, list], batch_size: int): raise ValueError("Prompt should be either a string or a list! It is currently of type: "+str(type(prompt))) input_batch = [prompt_str for _ in range(batch_size)] - inputs = self.tokenizer(input_batch, return_tensors="pt").to(self.model.device) - + inputs = self.tokenizer( + input_batch, + return_tensors="pt", + ).to(self.model.device) return inputs @torch.inference_mode() @@ -189,34 +195,50 @@ def _generate( """ We support greedy search and sampling for batch size 1 otherwise we use the generate function from transformers library. """ - token_ids, attention_mask, past_key_values = inputs['input_ids'], inputs['attention_mask'], None - + + # Get the input ids and attention mask + token_ids = inputs['input_ids'] + model_kwargs = {} + model_kwargs['attention_mask'] = inputs['attention_mask'] + model_kwargs['use_cache'] = True + model_kwargs = self._get_initial_cache_position(token_ids, model_kwargs) + # This does not include grammar decoder - self.model._prepare_special_tokens(gen_config, False, device=self.device) + self.model._prepare_special_tokens(gen_config, True, device=self.device) # Add logits processor for generation parameters such as top_k, top_p, temperature, etc. - logits_processor = self.model._get_logits_warper(gen_config, self.device) + logits_processor = self._get_logits_processors(gen_config) max_tokens = self.gen_args['max_new_tokens']+token_ids.size(1) - self.model.config.pad_token_id = pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id + self.model.config.pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id + + # Prepare the cache. (This is copied from the transformers generation_utils.py) + # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `gen_config`. + # - different models have a different cache name expected by the model (default = "past_key_values") + # - `max_length`, prepared above, is used to determine the maximum cache length + max_cache_length = max_tokens-1 + self.model._prepare_cache_for_generation( + gen_config, + model_kwargs, + assistant_model=None, + batch_size=token_ids.shape[0], + max_cache_length=max_cache_length, + device=self.device + ) while True: + model_inputs = self.model.prepare_inputs_for_generation(token_ids, **model_kwargs) try: - if past_key_values: # Get the last token if kv is cached for all previous tokens - input_ids = token_ids[..., -1].unsqueeze(-1) - else: - input_ids = token_ids - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values - ) + outputs = self.model(**model_inputs, return_dict=True) except IndexError as e: raise ValueError(f"The input length exceeds the context length of the model. {e}") - next_token_scores, past_key_values = outputs.logits[:, -1, :], outputs.past_key_values + model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs) + # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration + # (the clone itself is always small) + next_token_scores = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=token_ids.device) + if grammar_decoder is not None: next_token = self._get_next_token(gen_mode, token_ids, logits_processor, next_token_scores) is_valid = grammar_decoder.is_valid(token_ids, next_token) @@ -240,12 +262,6 @@ def _generate( if finish_generation or next_token == self.tokenizer.eos_token_id or token_ids.size(1) >= max_tokens: break - # Update attention mask - attention_mask = torch.cat([attention_mask, torch.ones((attention_mask.size(0), 1), dtype=attention_mask.dtype).to(self.device)], dim=-1) - - if debug: - grammar_decoder.print_debug() - return token_ids def _get_next_token(self, gen_mode, token_ids, logits_processor, next_token_scores): @@ -258,20 +274,20 @@ def _get_next_token(self, gen_mode, token_ids, logits_processor, next_token_scor return next_token def _get_generation_mode( - self, generation_config: GenerationConfig + self, gen_config: GenerationConfig ) -> GenerationMode: """ Returns the generation mode triggered by a [`GenerationConfig`] instance. """ - if generation_config.constraints is not None or generation_config.force_words_ids is not None: + if gen_config.constraints is not None or gen_config.force_words_ids is not None: generation_mode = GenerationMode.CONSTRAINED_BEAM_SEARCH - elif generation_config.num_beams == 1: - if generation_config.do_sample is False: + elif gen_config.num_beams == 1: + if gen_config.do_sample is False: if ( - generation_config.top_k is not None - and generation_config.top_k > 1 - and generation_config.penalty_alpha is not None - and generation_config.penalty_alpha > 0 + gen_config.top_k is not None + and gen_config.top_k > 1 + and gen_config.penalty_alpha is not None + and gen_config.penalty_alpha > 0 ): generation_mode = GenerationMode.CONTRASTIVE_SEARCH else: @@ -279,9 +295,9 @@ def _get_generation_mode( else: generation_mode = GenerationMode.SAMPLE else: - if generation_config.num_beam_groups > 1: + if gen_config.num_beam_groups > 1: generation_mode = GenerationMode.GROUP_BEAM_SEARCH - elif generation_config.do_sample is True: + elif gen_config.do_sample is True: generation_mode = GenerationMode.BEAM_SAMPLE else: generation_mode = GenerationMode.BEAM_SEARCH @@ -289,3 +305,109 @@ def _get_generation_mode( def tokenize(self, s: str) -> 'Iterable[int]': return self.tokenizer.encode(s, add_special_tokens=False) + + def _get_logits_processors(self, gen_config: GenerationConfig) -> LogitsProcessorList: + """ + Returns a [`~transformers.generation.LogitsProcessorList`] with the appropriate [`LogitsProcessor`]s to use for + generation. + """ + processors = LogitsProcessorList() + if gen_config.do_sample: + # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a + # better score (i.e. keep len(list(gen_config._eos_token_tensor)) + 1) + if gen_config.num_beams > 1: + if isinstance(gen_config._eos_token_tensor, list): + min_tokens_to_keep = len(gen_config._eos_token_tensor) + 1 + elif isinstance(gen_config._eos_token_tensor, torch.Tensor): + min_tokens_to_keep = gen_config._eos_token_tensor.shape[0] + 1 + else: + min_tokens_to_keep = 2 + else: + min_tokens_to_keep = 1 + + # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files + # all samplers can be found in `generation_utils_samplers.py` + if gen_config.temperature is not None and gen_config.temperature != 1.0: + processors.append(TemperatureLogitsWarper(gen_config.temperature)) + if gen_config.top_k is not None and gen_config.top_k != 0: + processors.append( + TopKLogitsWarper(top_k=gen_config.top_k, min_tokens_to_keep=min_tokens_to_keep) + ) + if gen_config.top_p is not None and gen_config.top_p < 1.0: + processors.append( + TopPLogitsWarper(top_p=gen_config.top_p, min_tokens_to_keep=min_tokens_to_keep) + ) + return processors + + def _update_model_kwargs_for_generation( + self, + outputs, + model_kwargs: dict[str, Any], + ) -> dict[str, Any]: + # Variable names used to hold the cache at generation time + ALL_CACHE_NAMES = [ + "past_key_values", # default + "cache_params", # mamba-based models + "state", # rwkv + "mems", # xlnet + "past_buckets_states", # reformer + ] + + # update past_key_values keeping its naming used in model code + for possible_cache_name in ALL_CACHE_NAMES: + if possible_cache_name in outputs: + if possible_cache_name in ("past_buckets_states", "mems"): + cache_name = "past_key_values" + else: + cache_name = possible_cache_name + model_kwargs[cache_name] = getattr(outputs, possible_cache_name) + break + + # update token_type_ids with last value + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) + + # assuming is_encoder_decoder = False + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + + if model_kwargs.get("use_cache", True): + model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 # num_new_tokens = 1 + else: + past_positions = model_kwargs.pop("cache_position") + new_positions = torch.arange( + past_positions[-1] + 1, past_positions[-1] + 2, dtype=past_positions.dtype + ).to(past_positions.device) + model_kwargs["cache_position"] = torch.cat((past_positions, new_positions)) + + return model_kwargs + + def _get_initial_cache_position(self, input_ids, model_kwargs): + """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length""" + # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange` + if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder: + cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1 + elif "decoder_inputs_embeds" in model_kwargs and self.config.is_encoder_decoder: + cache_position = ( + torch.ones_like(model_kwargs["decoder_inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1 + ) + else: + cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1 + + past_length = 0 + if model_kwargs.get("past_key_values") is not None: + cache = model_kwargs["past_key_values"] + past_length = 0 + if not isinstance(cache, Cache): + past_length = cache[0][0].shape[2] + elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None: + past_length = cache.get_seq_length() + + cache_position = cache_position[past_length:] + + model_kwargs["cache_position"] = cache_position + return model_kwargs \ No newline at end of file diff --git a/syncode/parsers/grammars/c_grammar.lark b/syncode/parsers/grammars/c.lark similarity index 100% rename from syncode/parsers/grammars/c_grammar.lark rename to syncode/parsers/grammars/c.lark diff --git a/syncode/parsers/grammars/calc_grammar.lark b/syncode/parsers/grammars/calc.lark similarity index 100% rename from syncode/parsers/grammars/calc_grammar.lark rename to syncode/parsers/grammars/calc.lark diff --git a/syncode/parsers/grammars/go_grammar.lark b/syncode/parsers/grammars/go.lark similarity index 100% rename from syncode/parsers/grammars/go_grammar.lark rename to syncode/parsers/grammars/go.lark diff --git a/syncode/parsers/grammars/grammar.py b/syncode/parsers/grammars/grammar.py index 4fdb9e9b..76d13fce 100644 --- a/syncode/parsers/grammars/grammar.py +++ b/syncode/parsers/grammars/grammar.py @@ -13,8 +13,8 @@ def __init__(self, name): self.ebnf = None grammar_filename = None assert name is not None, 'Grammar name not provided in grammar mode!' - if name in ['python', 'go', 'sql', 'tiny', 'calc', 'json', 'c', 'java', 'prover9']: - grammar_filename = f'{os.path.dirname(__file__)}/{name}_grammar.lark' + if name in ['python', 'go', 'sql', 'tiny', 'calc', 'json', 'c', 'java', 'prover9', 'invariants']: + grammar_filename = f'{os.path.dirname(__file__)}/{name}.lark' elif name.endswith('.lark'): if os.path.exists(name): # In this case we assume that the user provides the full path to the grammar file diff --git a/syncode/parsers/grammars/java_grammar.lark b/syncode/parsers/grammars/java.lark similarity index 100% rename from syncode/parsers/grammars/java_grammar.lark rename to syncode/parsers/grammars/java.lark diff --git a/syncode/parsers/grammars/json_grammar.lark b/syncode/parsers/grammars/json.lark similarity index 100% rename from syncode/parsers/grammars/json_grammar.lark rename to syncode/parsers/grammars/json.lark diff --git a/syncode/parsers/grammars/prover9_grammar.lark b/syncode/parsers/grammars/prover9.lark similarity index 100% rename from syncode/parsers/grammars/prover9_grammar.lark rename to syncode/parsers/grammars/prover9.lark diff --git a/syncode/parsers/grammars/python_grammar.lark b/syncode/parsers/grammars/python.lark similarity index 100% rename from syncode/parsers/grammars/python_grammar.lark rename to syncode/parsers/grammars/python.lark diff --git a/syncode/parsers/grammars/sql_grammar.lark b/syncode/parsers/grammars/sql.lark similarity index 100% rename from syncode/parsers/grammars/sql_grammar.lark rename to syncode/parsers/grammars/sql.lark diff --git a/syncode/parsers/grammars/tiny_grammar.lark b/syncode/parsers/grammars/tiny.lark similarity index 100% rename from syncode/parsers/grammars/tiny_grammar.lark rename to syncode/parsers/grammars/tiny.lark diff --git a/tests/test_language_model.py b/tests/test_language_model.py index 9cb4ba3e..266a7bbd 100644 --- a/tests/test_language_model.py +++ b/tests/test_language_model.py @@ -56,40 +56,17 @@ def get_vocab(self) -> Dict[str, int]: return {v: i for i, v in enumerate(self.vocab)} class TestHuggingFaceModel(unittest.TestCase): - def test_generate_grammar_constrained_completion(self): - torch.manual_seed(0) - model = TestModel() - tokenizer = TestTokenizer() - logger = common.EmptyLogger() - lm = HuggingFaceModel(model, Grammar('calc'), tokenizer, mode='original', max_new_tokens=15, device='cpu') - prompt = "113 + 235 + 17" - output = lm.generate_grammar_constrained_completion(prompt, 1) - self.assertEqual(len(output[0]), 15, "The output length does not match the expected value.") - - def test_generate_grammar_constrained_completion2(self): - torch.manual_seed(0) - model = TestModel() - tokenizer = TestTokenizer() - logger = common.EmptyLogger() - lm = HuggingFaceModel(model, Grammar('calc'), tokenizer, mode='original', max_new_tokens=15, device='cpu') - prompt = "113 + 235 + 17" - output = lm.generate_grammar_constrained_completion(prompt, 2) - self.assertEqual(len(output[0]), 15, "The output length does not match the expected value.") - self.assertEqual(len(output[1]), 15, "The output length does not match the expected value.") - - @unittest.skip("Only for local testing") def test_stop_word(self): torch.manual_seed(0) - syncode = Syncode(model="microsoft/phi-2", mode='original') + syncode = Syncode(model="microsoft/phi-2", mode='original', device='cpu') prompt = "Generate a json for the country nigeria.\n```json\n" stop_words = ["```"] output = syncode.infer(prompt, stop_words=stop_words)[0] assert output.endswith('```') - @unittest.skip("Only for local testing") def test_stop_word2(self): torch.manual_seed(0) - syncode = Syncode(model="microsoft/phi-2", mode='original') + syncode = Syncode(model="microsoft/phi-2", mode='original', device='cpu') prompt = "def add(a, b):\n" stop_words = ["\n\n"] output = syncode.infer(prompt, stop_words=stop_words)[0] diff --git a/tests/test_lr_parser.py b/tests/test_lr_parser.py index 3b88d32b..ecdd0cf0 100644 --- a/tests/test_lr_parser.py +++ b/tests/test_lr_parser.py @@ -10,8 +10,8 @@ class ParserTests(unittest.TestCase): @classmethod def setUpClass(cls): # Common setup that runs once before all tests - cls.tiny_grammar = Grammar('syncode/parsers/grammars/tiny_grammar.lark') - cls.calc_grammar = Grammar('syncode/parsers/grammars/calc_grammar.lark') + cls.tiny_grammar = Grammar('tiny') + cls.calc_grammar = Grammar('calc') cls.python_grammar = Grammar('python') def test_tiny(self):