diff --git a/syncode/__init__.py b/syncode/__init__.py index c2092ba5..bc68afee 100644 --- a/syncode/__init__.py +++ b/syncode/__init__.py @@ -1,5 +1,5 @@ from syncode.infer import Syncode -from syncode.grammar_decoder import SyncodeLogitsProcessor +from syncode.grammar_mask.logits_processor import SyncodeLogitsProcessor from syncode.parsers.grammars import Grammar import syncode.common as common diff --git a/syncode/evaluation/code_eval.py b/syncode/evaluation/code_eval.py index 44bd46fb..4a2ef91f 100644 --- a/syncode/evaluation/code_eval.py +++ b/syncode/evaluation/code_eval.py @@ -86,12 +86,13 @@ def run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samp # We tokenize the whole thing together since tokenizer just the generated_ids messes up with the # indentation and removes the initial whitespaces in some cases raw_completion = syncode.model.tokenizer.decode(generated_ids, skip_special_tokens=True) + grammar_constrainer = syncode.model.logits_processor.grammar_engine # Post-processing to filter out using stop word if syncode.model.grammar != None and syncode.model.grammar.name == "python": - completion = CodeEval.postproces_completion_python(syncode.model, i, batch_size, input_ids_cutoff, generated_ids, syncode.model.grammar_decoder, raw_completion, stop_words) + completion = CodeEval.postproces_completion_python(syncode.model, i, batch_size, input_ids_cutoff, generated_ids, grammar_constrainer, raw_completion, stop_words) elif syncode.model.grammar != None and syncode.model.grammar.name == "go": - completion = CodeEval.postproces_completion_go(syncode.model, i, batch_size, raw_completion, generated_ids, syncode.model.grammar_decoder, input_ids_cutoff) + completion = CodeEval.postproces_completion_go(syncode.model, i, batch_size, raw_completion, generated_ids, grammar_constrainer, input_ids_cutoff) else: # TODO: handle the case for other grammars completion = raw_completion @@ -121,31 +122,31 @@ def write_results(syncode, out_path, avg_time, functional_result, num_tasks=1): f.write(f"Averge time taken for each task: {avg_time:.2f}s\n") f.write("\n") - def postproces_completion_python(hf_model, i, batch_size, input_ids_cutoff, generated_ids, grammar_decoder, raw_completion, stop_words): + def postproces_completion_python(hf_model, i, batch_size, input_ids_cutoff, generated_ids, grammar_constrainer, raw_completion, stop_words): generated_output = hf_model.tokenizer.decode(generated_ids[input_ids_cutoff:]) - if all(stop_word not in generated_output for stop_word in stop_words) and hf_model.tokenizer.eos_token_id != generated_ids[-1] and grammar_decoder is not None: - # Use when the stop word does not exist in the completion and grammar_decoder is used + if all(stop_word not in generated_output for stop_word in stop_words) and hf_model.tokenizer.eos_token_id != generated_ids[-1] and grammar_constrainer is not None: + # Use when the stop word does not exist in the completion and grammar_constrainer is used function_incomplete = [False for _ in range(batch_size)] - completion = CodeEval.compute_backup_completion(hf_model, grammar_decoder, function_incomplete, i, raw_completion) + completion = CodeEval.compute_backup_completion(hf_model, grammar_constrainer, function_incomplete, i, raw_completion) else: completion = raw_completion return completion - def postproces_completion_go(hf_model, i, batch_size, raw_completion, generated_ids, grammar_decoder, input_ids_cutoff): + def postproces_completion_go(hf_model, i, batch_size, raw_completion, generated_ids, grammar_constrainer, input_ids_cutoff): if hf_model.mode != "original": - # When the grammar_decoder is used + # When the grammar_constrainer is used function_incomplete = [False for _ in range(batch_size)] - completion = CodeEval.compute_backup_completion(hf_model, grammar_decoder, function_incomplete, i, raw_completion) + completion = CodeEval.compute_backup_completion(hf_model, grammar_constrainer, function_incomplete, i, raw_completion) if function_incomplete[i]: completion += "}" return completion - def compute_backup_completion(hf_model, grammar_decoder, function_incomplete, i, raw_completion): - if grammar_decoder.function_ends[i] is not None: - fn_ends = sorted(list(set(grammar_decoder.function_ends[i]))) + def compute_backup_completion(hf_model, grammar_constrainer, function_incomplete, i, raw_completion): + if grammar_constrainer.function_ends[i] is not None: + fn_ends = sorted(list(set(grammar_constrainer.function_ends[i]))) if len(fn_ends) > 1: # if the function end is not None, then the last valid state is the function end last_valid_state = fn_ends[1] @@ -153,7 +154,7 @@ def compute_backup_completion(hf_model, grammar_decoder, function_incomplete, i, # otherwise, the last valid state is the last valid state function_incomplete[i] = True - last_valid_state = grammar_decoder.last_valid_state[i] + last_valid_state = grammar_constrainer.last_valid_state[i] # Use when the stop word does not exist in the completion backup_completion = raw_completion[:last_valid_state] diff --git a/syncode/grammar_decoder.py b/syncode/grammar_mask/grammar_constrainer.py similarity index 74% rename from syncode/grammar_decoder.py rename to syncode/grammar_mask/grammar_constrainer.py index a8bdaec2..93e58bc1 100644 --- a/syncode/grammar_decoder.py +++ b/syncode/grammar_mask/grammar_constrainer.py @@ -10,49 +10,65 @@ import logging logger = logging.getLogger(__name__) - -# Set to True for debugging -DEBUG = True - -class SyncodeLogitsProcessor(LogitsProcessor): +class GrammarConstrainer: """ - This class is used to filter the logits of the model to only allow syntactically valid tokens for Python. - - Args: - grammar (str): The grammar to use for parsing e.g. "python". - tokenizer (PreTrainedTokenizer): The tokenizer to use for decoding. - use_cache (bool, optional): Whether to use the cache. Defaults to True. - parse_output_only (bool, optional): Whether to parse the prompt. Defaults to False. - num_samples (int, optional): The number of sequences to generate. Defaults to 1. - dev_mode (bool, optional): Whether to run in development mode. Defaults to False. - parser (str, optional): The parser to use. Defaults to 'lalr'. - mode (str, optional): The mode to use. Defaults to 'grammar_mask'. + Core class for constraining LLM token generation based on formal grammar rules. + + This class handles the parsing of generated code, validates its grammatical correctness, + and creates token masks to ensure syntactically valid generations. + + The class supports two primary operating modes: + + 1. `grammar_mask` (Conservative/Overapproximation): + This mode is more permissive and overapproximates the set of acceptable tokens. + It allows a wider range of tokens that might be syntactically valid given the + limited lookahead of the parser. This mode preserves more of the LLM's original + token distribution while still enforcing basic syntactic correctness. + + 2. `grammar_strict` (Strict/Underapproximation): + This mode is stricter and underapproximates the set of acceptable tokens. + It enforces tighter grammatical constraints and may be more invasive in the + LLM's generation process. It sometimes breaks LLM tokens that would have been + syntactically correct when considered as a whole, potentially affecting the + fluency or accuracy of generation. + + Example illustrating the difference: + Consider generating Python code with the partial input: `def calculate` + + In `grammar_mask` mode, it might allow tokens like: + - "(num" (combining opening parenthesis and parameter name as one token) + + In `grammar_strict` mode, it would force separate tokens: + - "(" followed by "num" (requiring two separate token generations) + + For more details on the approximation methods, refer to the SynCode paper: + https://arxiv.org/abs/2403.01632 """ def __init__(self, - grammar: Grammar, - tokenizer: PreTrainedTokenizer, - use_cache=True, - parse_output_only=True, - num_samples=1, - dev_mode=False, - parser='lalr', - mode='grammar_mask'): - + grammar: Grammar, + tokenizer: PreTrainedTokenizer, + byte_tokenizer: ByteTokenizer, + use_cache=True, + parse_output_only=True, + batch_size=1, + dev_mode=False, + parser='lalr', + mode='grammar_mask'): + self.tokenizer = tokenizer - self.byte_tokenizer = ByteTokenizer(tokenizer) - + self.byte_tokenizer = byte_tokenizer self.grammar = grammar self.dev_mode = dev_mode - self.batch_size = num_samples + self.batch_size = batch_size self.parse_failed = False # For backtracking to syntactically valid completions - self.last_valid_state: list = [] - self.function_ends: list = [] + self.last_valid_state = [0 for _ in range(self.batch_size)] + self.function_ends = [None for _ in range(self.batch_size)] # We use this when only the LLM output is parsed and not (input+output) self.parse_output_only = parse_output_only - self.start_from = None + self.start_from = None # Ignore whitespace tokens self._ignore_whitespace = self._get_ignore_whitespace(self.grammar) @@ -60,31 +76,14 @@ def __init__(self, # Create parser self.inc_parser: IncrementalParser = create_parser(self.grammar, parser=parser, ignore_whitespace=self._ignore_whitespace) - # Load dfa mask store + # Load dfa mask store with specified mode (grammar_mask or grammar_strict) self.dfa_mask_store = MaskStore.init_mask_store( grammar=self.grammar, tokenizer=self.tokenizer, use_cache=use_cache, - mode=mode, + mode=mode, # Controls approximation strategy for token masking ) - - - def _get_ignore_whitespace(self, grammar): - """ - Check if the grammar allows whitespace tokens to be ignored. - """ - base_parser = create_base_parser(grammar) - terminals = base_parser.terminals - ignore_terminals = base_parser.ignore_tokens - - import regex - ignore_whitespace = False - for ig_name in ignore_terminals: - for terminal in terminals: - if terminal.name == ig_name: - if regex.match(terminal.pattern.to_regexp(), ' ') is not None: - ignore_whitespace = True # convert to boolean tensor mask. This is useful for fast union operations - return ignore_whitespace + def reset(self): """ @@ -96,6 +95,15 @@ def reset(self): self.start_from = None self.inc_parser.reset() + def _set_start_from(self, input_ids): + """ + Sets the starting point for parsing based on whether we're parsing only the output or the full input+output. + """ + if self.start_from is None: + if self.parse_output_only: + self.start_from = input_ids.size(1) + else: + self.start_from = 0 def is_valid(self, input_ids: torch.LongTensor, next_token: torch.LongTensor) -> bool: """ @@ -134,19 +142,26 @@ def is_valid(self, input_ids: torch.LongTensor, next_token: torch.LongTensor) -> is_valid = self.dfa_mask_store.is_valid_prefix(res) if is_valid: - self.update_valid_state(partial_code, 0, res) + self._update_valid_state(partial_code, 0, res) return is_valid - def _set_start_from(self, input_ids): - if self.start_from is None: - if self.parse_output_only: - self.start_from = input_ids.size(1) - else: - self.start_from = 0 - - - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + def mask_scores(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + """ + Mask scores by zeroing out invalid next tokens based on grammar constraints. + + The exact behavior depends on whether we're using grammar_mask mode (conservative/ + overapproximation) or grammar_strict mode (strict/underapproximation). In both cases, + tokens that would lead to definitely invalid syntax are masked out by setting their + scores to negative infinity. + + Args: + input_ids (torch.LongTensor): The input ids. + scores (torch.FloatTensor): The scores to be masked. + + Returns: + torch.FloatTensor: The masked scores. + """ self._set_start_from(input_ids) # start_from is used for choosing where the parsing should start partial_codes = self._get_partial_codes(input_ids) @@ -188,7 +203,7 @@ def _parse_partial_code(self, idx: int, partial_code: str, remainder_bytes: byte else: res.remainder = res.remainder.encode('utf-8') - self.update_valid_state(partial_code, idx, res) + self._update_valid_state(partial_code, idx, res) except Exception as e: if self.dev_mode == True: raise e @@ -200,7 +215,6 @@ def _parse_partial_code(self, idx: int, partial_code: str, remainder_bytes: byte skip = True return res, skip - def _get_partial_codes(self, input_ids: torch.LongTensor) -> list[(str, bytes)]: """ Get the partial codes for the input_ids and return the remainder bytes if the partial code is not a valid UTF-8 string. @@ -219,9 +233,8 @@ def _get_partial_codes(self, input_ids: torch.LongTensor) -> list[(str, bytes)]: ) output.append((partial_code, remainder_bytes)) return output - - def update_valid_state(self, partial_code: str, idx: int, r: ParseResult): + def _update_valid_state(self, partial_code: str, idx: int, r: ParseResult): """ This a simple heuristic to cut off the generated output at the end of the function. TODO: Put this under a flag to enable/disable this heuristic. @@ -237,7 +250,6 @@ def update_valid_state(self, partial_code: str, idx: int, r: ParseResult): if accept_seq[0] == '$END' or accept_seq[0] == 'EOF': self.last_valid_state[idx] = len(partial_code) - len(r.remainder) - @staticmethod def _bytes_to_string(byte_sequence: bytes) -> tuple[str, bytes]: """ @@ -253,16 +265,6 @@ def _bytes_to_string(byte_sequence: bytes) -> tuple[str, bytes]: A tuple (string, remainder) where: - string is the longest valid UTF-8 prefix of the input as a Python string - remainder is the rest of the bytes that could not be decoded as UTF-8 - - Examples: - >>> bytes_to_string(b'Hello, world!') - ('Hello, world!', b'') - >>> bytes_to_string(b'Hello, \xe2\x82\xac!') # Euro symbol (€) followed by ! - ('Hello, €!', b'') - >>> bytes_to_string(b'Hello, \xe2\x82!') # Incomplete Euro symbol - ('Hello, ', b'\xe2\x82!') - >>> bytes_to_string(b'\xff\xfe') # Invalid UTF-8 - ('', b'\xff\xfe') """ if not isinstance(byte_sequence, bytes): raise TypeError("Input must be a bytes object") @@ -292,3 +294,21 @@ def _bytes_to_string(byte_sequence: bytes) -> tuple[str, bytes]: return byte_sequence[:valid_end].decode('utf-8'), byte_sequence[valid_end:] else: return "", byte_sequence + + def _get_ignore_whitespace(self, grammar): + """ + Check if the grammar allows whitespace tokens to be ignored. + """ + base_parser = create_base_parser(grammar) + terminals = base_parser.terminals + ignore_terminals = base_parser.ignore_tokens + + import regex + ignore_whitespace = False + for ig_name in ignore_terminals: + for terminal in terminals: + if terminal.name == ig_name: + if regex.match(terminal.pattern.to_regexp(), ' ') is not None: + ignore_whitespace = True # convert to boolean tensor mask. This is useful for fast union operations + return ignore_whitespace + \ No newline at end of file diff --git a/syncode/grammar_mask/logits_processor.py b/syncode/grammar_mask/logits_processor.py new file mode 100644 index 00000000..cfb1c370 --- /dev/null +++ b/syncode/grammar_mask/logits_processor.py @@ -0,0 +1,67 @@ +import torch +from transformers import LogitsProcessor, PreTrainedTokenizer +from syncode.grammar_mask.grammar_constrainer import GrammarConstrainer +from syncode.mask_store.byte_tokenizer import ByteTokenizer +from syncode.parsers.grammars import Grammar +import logging +logger = logging.getLogger(__name__) + + +class SyncodeLogitsProcessor(LogitsProcessor): + """ + This class is used to filter the logits of the model to only allow syntactically valid tokens for Python. + + Args: + grammar (str): The grammar to use for parsing e.g. "python". + tokenizer (PreTrainedTokenizer): The tokenizer to use for decoding. + use_cache (bool, optional): Whether to use the cache. Defaults to True. + parse_output_only (bool, optional): Whether to parse the prompt. Defaults to False. + num_samples (int, optional): The number of sequences to generate. Defaults to 1. + dev_mode (bool, optional): Whether to run in development mode. Defaults to False. + parser (str, optional): The parser to use. Defaults to 'lalr'. + mode (str, optional): The mode to use. Defaults to 'grammar_mask'. + """ + def __init__(self, + grammar: Grammar, + tokenizer: PreTrainedTokenizer, + use_cache=True, + parse_output_only=True, + num_samples=1, + dev_mode=False, + parser='lalr', + mode='grammar_mask'): + + self.tokenizer = tokenizer + self.byte_tokenizer = ByteTokenizer(tokenizer) + + # Create the grammar constrainer that handles most of the logic + self.grammar_engine = GrammarConstrainer( + grammar=grammar, + tokenizer=tokenizer, + byte_tokenizer=self.byte_tokenizer, + use_cache=use_cache, + parse_output_only=parse_output_only, + batch_size=num_samples, + dev_mode=dev_mode, + parser=parser, + mode=mode + ) + + def reset(self): + """ + Resets the decoder state on every new prompt. + """ + self.grammar_engine.reset() + + def is_valid(self, input_ids: torch.LongTensor, next_token: torch.LongTensor) -> bool: + """ + Check if the next token is valid given the input_ids. + """ + return self.grammar_engine.is_valid(input_ids, next_token) + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + """ + Filter the scores to only allow syntactically valid tokens. + """ + return self.grammar_engine.mask_scores(input_ids, scores) + \ No newline at end of file diff --git a/syncode/infer.py b/syncode/infer.py index ba7a994f..8bc68d0a 100644 --- a/syncode/infer.py +++ b/syncode/infer.py @@ -8,7 +8,7 @@ import syncode.common as common from syncode.language_model import HuggingFaceModel -from syncode.grammar_decoder import SyncodeLogitsProcessor +from syncode.grammar_mask.logits_processor import SyncodeLogitsProcessor from syncode.parsers.grammars import Grammar from syncode.dataset import Dataset from syncode.evaluation.code_eval import CodeEval diff --git a/syncode/language_model.py b/syncode/language_model.py index 132df3a3..8b687b8f 100644 --- a/syncode/language_model.py +++ b/syncode/language_model.py @@ -2,7 +2,7 @@ import time import torch import syncode.common as common -from syncode.grammar_decoder import SyncodeLogitsProcessor +from syncode.grammar_mask.logits_processor import SyncodeLogitsProcessor from transformers import LogitsProcessorList, StoppingCriteriaList, StoppingCriteria from syncode.parsers.grammars import Grammar from syncode.utils.generation import filter_code, fix_indents @@ -53,8 +53,8 @@ def __init__( self.device = device self.best_of = best_of self._before_prediction_hook = before_prediction_hook - self.grammar_decoder = grammar_decoder - self.grammar_processor: Iterable = LogitsProcessorList([self.grammar_decoder]) if self.grammar_decoder is not None else None + self.logits_processor = grammar_decoder + self.grammar_processor: Iterable = LogitsProcessorList([self.logits_processor]) if self.logits_processor is not None else None self.mode = mode self.grammar = grammar @@ -88,8 +88,8 @@ def generate_grammar_constrained_completion( inputs = self.get_tokenized_input(prompt, batch_size) # Reset the grammar decoder - if self.grammar_decoder is not None: - self.grammar_decoder.reset() + if self.logits_processor is not None: + self.logits_processor.reset() input_ids_cutoff = inputs.input_ids.size(dim=1) @@ -112,7 +112,7 @@ def generate_grammar_constrained_completion( inputs, gen_config, gen_mode, - grammar_decoder=self.grammar_decoder, + grammar_decoder=self.logits_processor, stop_criteria=stop_criteria, debug=debug ) diff --git a/syncode/process_eval.py b/syncode/process_eval.py deleted file mode 100644 index 2f77e2ce..00000000 --- a/syncode/process_eval.py +++ /dev/null @@ -1,60 +0,0 @@ -from human_eval.data import read_problems, write_jsonl, stream_jsonl -import glob -from tqdm import tqdm -from transformers import PreTrainedTokenizer -import argparse - -parser = argparse.ArgumentParser() - -# Inputs -parser.add_argument("--path", type=str, help="") -parser.add_argument("--out_path", type=str, help="") -parser.add_argument("--add_prompt", action="store_true", help="") - -args = parser.parse_args() - -files = sorted(glob.glob(args.path + "/*.jsonl")) -print("{} files in {}".format(len(files), args.path)) - -problems = read_problems() - -output = [] -a = 0 -for code_file in tqdm(files, total=len(files)): - codes = [c for c in stream_jsonl(code_file)] - if args.add_prompt: - for code in codes: - task_id = code["task_id"] - prompt = problems[task_id]["prompt"] - completion = code["completion"] - completion = completion.replace("\r", "") - if "```python" in completion: - def_line = completion.index("```python") - completion = completion[def_line:].strip() - completion = completion.replace("```python", "") - # print(completion) - try: - next_line = completion.index("```") - completion = completion[:next_line].strip() - except: - a += 1 - print(completion) - print("================\n") - # print(completion) - if '__name__ == "__main__"' in completion: - next_line = completion.index('if __name__ == "__main__":') - completion = completion[:next_line].strip() - # print(completion) - - if "# Example usage" in completion: - # print(completion) - next_line = completion.index("# Example usage") - completion = completion[:next_line].strip() - - code["completion"] = completion - - output += codes - -print("save to {}".format(args.out_path)) -write_jsonl(args.out_path, output) -print(a) diff --git a/tests/test_misc.py b/tests/test_misc.py index 54ce28db..73c5a138 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -5,7 +5,7 @@ import torch from syncode.mask_store.mask_store import MaskStore -from syncode.grammar_decoder import SyncodeLogitsProcessor +from syncode.grammar_mask.logits_processor import SyncodeLogitsProcessor # Adjusting the path so the modules can be imported correctly sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../')