Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion syncode/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from syncode.infer import Syncode
from syncode.grammar_decoder import SyncodeLogitsProcessor
from syncode.grammar_mask.logits_processor import SyncodeLogitsProcessor
from syncode.parsers.grammars import Grammar
import syncode.common as common

Expand Down
27 changes: 14 additions & 13 deletions syncode/evaluation/code_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,13 @@ def run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samp
# We tokenize the whole thing together since tokenizer just the generated_ids messes up with the
# indentation and removes the initial whitespaces in some cases
raw_completion = syncode.model.tokenizer.decode(generated_ids, skip_special_tokens=True)
grammar_constrainer = syncode.model.logits_processor.grammar_engine

# Post-processing to filter out using stop word
if syncode.model.grammar != None and syncode.model.grammar.name == "python":
completion = CodeEval.postproces_completion_python(syncode.model, i, batch_size, input_ids_cutoff, generated_ids, syncode.model.grammar_decoder, raw_completion, stop_words)
completion = CodeEval.postproces_completion_python(syncode.model, i, batch_size, input_ids_cutoff, generated_ids, grammar_constrainer, raw_completion, stop_words)
elif syncode.model.grammar != None and syncode.model.grammar.name == "go":
completion = CodeEval.postproces_completion_go(syncode.model, i, batch_size, raw_completion, generated_ids, syncode.model.grammar_decoder, input_ids_cutoff)
completion = CodeEval.postproces_completion_go(syncode.model, i, batch_size, raw_completion, generated_ids, grammar_constrainer, input_ids_cutoff)
else: # TODO: handle the case for other grammars
completion = raw_completion

Expand Down Expand Up @@ -121,39 +122,39 @@ def write_results(syncode, out_path, avg_time, functional_result, num_tasks=1):
f.write(f"Averge time taken for each task: {avg_time:.2f}s\n")
f.write("\n")

def postproces_completion_python(hf_model, i, batch_size, input_ids_cutoff, generated_ids, grammar_decoder, raw_completion, stop_words):
def postproces_completion_python(hf_model, i, batch_size, input_ids_cutoff, generated_ids, grammar_constrainer, raw_completion, stop_words):
generated_output = hf_model.tokenizer.decode(generated_ids[input_ids_cutoff:])

if all(stop_word not in generated_output for stop_word in stop_words) and hf_model.tokenizer.eos_token_id != generated_ids[-1] and grammar_decoder is not None:
# Use when the stop word does not exist in the completion and grammar_decoder is used
if all(stop_word not in generated_output for stop_word in stop_words) and hf_model.tokenizer.eos_token_id != generated_ids[-1] and grammar_constrainer is not None:
# Use when the stop word does not exist in the completion and grammar_constrainer is used
function_incomplete = [False for _ in range(batch_size)]
completion = CodeEval.compute_backup_completion(hf_model, grammar_decoder, function_incomplete, i, raw_completion)
completion = CodeEval.compute_backup_completion(hf_model, grammar_constrainer, function_incomplete, i, raw_completion)
else:
completion = raw_completion
return completion

def postproces_completion_go(hf_model, i, batch_size, raw_completion, generated_ids, grammar_decoder, input_ids_cutoff):
def postproces_completion_go(hf_model, i, batch_size, raw_completion, generated_ids, grammar_constrainer, input_ids_cutoff):
if hf_model.mode != "original":
# When the grammar_decoder is used
# When the grammar_constrainer is used
function_incomplete = [False for _ in range(batch_size)]
completion = CodeEval.compute_backup_completion(hf_model, grammar_decoder, function_incomplete, i, raw_completion)
completion = CodeEval.compute_backup_completion(hf_model, grammar_constrainer, function_incomplete, i, raw_completion)

if function_incomplete[i]:
completion += "}"

return completion

def compute_backup_completion(hf_model, grammar_decoder, function_incomplete, i, raw_completion):
if grammar_decoder.function_ends[i] is not None:
fn_ends = sorted(list(set(grammar_decoder.function_ends[i])))
def compute_backup_completion(hf_model, grammar_constrainer, function_incomplete, i, raw_completion):
if grammar_constrainer.function_ends[i] is not None:
fn_ends = sorted(list(set(grammar_constrainer.function_ends[i])))
if len(fn_ends) > 1:
# if the function end is not None, then the last valid state is the function end
last_valid_state = fn_ends[1]
return raw_completion[:last_valid_state]

# otherwise, the last valid state is the last valid state
function_incomplete[i] = True
last_valid_state = grammar_decoder.last_valid_state[i]
last_valid_state = grammar_constrainer.last_valid_state[i]

# Use when the stop word does not exist in the completion
backup_completion = raw_completion[:last_valid_state]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,81 +10,80 @@
import logging
logger = logging.getLogger(__name__)


# Set to True for debugging
DEBUG = True

class SyncodeLogitsProcessor(LogitsProcessor):
class GrammarConstrainer:
"""
This class is used to filter the logits of the model to only allow syntactically valid tokens for Python.

Args:
grammar (str): The grammar to use for parsing e.g. "python".
tokenizer (PreTrainedTokenizer): The tokenizer to use for decoding.
use_cache (bool, optional): Whether to use the cache. Defaults to True.
parse_output_only (bool, optional): Whether to parse the prompt. Defaults to False.
num_samples (int, optional): The number of sequences to generate. Defaults to 1.
dev_mode (bool, optional): Whether to run in development mode. Defaults to False.
parser (str, optional): The parser to use. Defaults to 'lalr'.
mode (str, optional): The mode to use. Defaults to 'grammar_mask'.
Core class for constraining LLM token generation based on formal grammar rules.

This class handles the parsing of generated code, validates its grammatical correctness,
and creates token masks to ensure syntactically valid generations.

The class supports two primary operating modes:

1. `grammar_mask` (Conservative/Overapproximation):
This mode is more permissive and overapproximates the set of acceptable tokens.
It allows a wider range of tokens that might be syntactically valid given the
limited lookahead of the parser. This mode preserves more of the LLM's original
token distribution while still enforcing basic syntactic correctness.

2. `grammar_strict` (Strict/Underapproximation):
This mode is stricter and underapproximates the set of acceptable tokens.
It enforces tighter grammatical constraints and may be more invasive in the
LLM's generation process. It sometimes breaks LLM tokens that would have been
syntactically correct when considered as a whole, potentially affecting the
fluency or accuracy of generation.

Example illustrating the difference:
Consider generating Python code with the partial input: `def calculate`

In `grammar_mask` mode, it might allow tokens like:
- "(num" (combining opening parenthesis and parameter name as one token)

In `grammar_strict` mode, it would force separate tokens:
- "(" followed by "num" (requiring two separate token generations)

For more details on the approximation methods, refer to the SynCode paper:
https://arxiv.org/abs/2403.01632
"""
def __init__(self,
grammar: Grammar,
tokenizer: PreTrainedTokenizer,
use_cache=True,
parse_output_only=True,
num_samples=1,
dev_mode=False,
parser='lalr',
mode='grammar_mask'):

grammar: Grammar,
tokenizer: PreTrainedTokenizer,
byte_tokenizer: ByteTokenizer,
use_cache=True,
parse_output_only=True,
batch_size=1,
dev_mode=False,
parser='lalr',
mode='grammar_mask'):

self.tokenizer = tokenizer
self.byte_tokenizer = ByteTokenizer(tokenizer)

self.byte_tokenizer = byte_tokenizer
self.grammar = grammar
self.dev_mode = dev_mode
self.batch_size = num_samples
self.batch_size = batch_size
self.parse_failed = False

# For backtracking to syntactically valid completions
self.last_valid_state: list = []
self.function_ends: list = []
self.last_valid_state = [0 for _ in range(self.batch_size)]
self.function_ends = [None for _ in range(self.batch_size)]

# We use this when only the LLM output is parsed and not (input+output)
self.parse_output_only = parse_output_only
self.start_from = None
self.start_from = None

# Ignore whitespace tokens
self._ignore_whitespace = self._get_ignore_whitespace(self.grammar)

# Create parser
self.inc_parser: IncrementalParser = create_parser(self.grammar, parser=parser, ignore_whitespace=self._ignore_whitespace)

# Load dfa mask store
# Load dfa mask store with specified mode (grammar_mask or grammar_strict)
self.dfa_mask_store = MaskStore.init_mask_store(
grammar=self.grammar,
tokenizer=self.tokenizer,
use_cache=use_cache,
mode=mode,
mode=mode, # Controls approximation strategy for token masking
)


def _get_ignore_whitespace(self, grammar):
"""
Check if the grammar allows whitespace tokens to be ignored.
"""
base_parser = create_base_parser(grammar)
terminals = base_parser.terminals
ignore_terminals = base_parser.ignore_tokens

import regex
ignore_whitespace = False
for ig_name in ignore_terminals:
for terminal in terminals:
if terminal.name == ig_name:
if regex.match(terminal.pattern.to_regexp(), ' ') is not None:
ignore_whitespace = True # convert to boolean tensor mask. This is useful for fast union operations
return ignore_whitespace


def reset(self):
"""
Expand All @@ -96,6 +95,15 @@ def reset(self):
self.start_from = None
self.inc_parser.reset()

def _set_start_from(self, input_ids):
"""
Sets the starting point for parsing based on whether we're parsing only the output or the full input+output.
"""
if self.start_from is None:
if self.parse_output_only:
self.start_from = input_ids.size(1)
else:
self.start_from = 0

def is_valid(self, input_ids: torch.LongTensor, next_token: torch.LongTensor) -> bool:
"""
Expand Down Expand Up @@ -134,19 +142,26 @@ def is_valid(self, input_ids: torch.LongTensor, next_token: torch.LongTensor) ->
is_valid = self.dfa_mask_store.is_valid_prefix(res)

if is_valid:
self.update_valid_state(partial_code, 0, res)
self._update_valid_state(partial_code, 0, res)

return is_valid

def _set_start_from(self, input_ids):
if self.start_from is None:
if self.parse_output_only:
self.start_from = input_ids.size(1)
else:
self.start_from = 0


def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
def mask_scores(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
"""
Mask scores by zeroing out invalid next tokens based on grammar constraints.

The exact behavior depends on whether we're using grammar_mask mode (conservative/
overapproximation) or grammar_strict mode (strict/underapproximation). In both cases,
tokens that would lead to definitely invalid syntax are masked out by setting their
scores to negative infinity.

Args:
input_ids (torch.LongTensor): The input ids.
scores (torch.FloatTensor): The scores to be masked.

Returns:
torch.FloatTensor: The masked scores.
"""
self._set_start_from(input_ids) # start_from is used for choosing where the parsing should start
partial_codes = self._get_partial_codes(input_ids)

Expand Down Expand Up @@ -188,7 +203,7 @@ def _parse_partial_code(self, idx: int, partial_code: str, remainder_bytes: byte
else:
res.remainder = res.remainder.encode('utf-8')

self.update_valid_state(partial_code, idx, res)
self._update_valid_state(partial_code, idx, res)
except Exception as e:
if self.dev_mode == True:
raise e
Expand All @@ -200,7 +215,6 @@ def _parse_partial_code(self, idx: int, partial_code: str, remainder_bytes: byte
skip = True
return res, skip


def _get_partial_codes(self, input_ids: torch.LongTensor) -> list[(str, bytes)]:
"""
Get the partial codes for the input_ids and return the remainder bytes if the partial code is not a valid UTF-8 string.
Expand All @@ -219,9 +233,8 @@ def _get_partial_codes(self, input_ids: torch.LongTensor) -> list[(str, bytes)]:
)
output.append((partial_code, remainder_bytes))
return output


def update_valid_state(self, partial_code: str, idx: int, r: ParseResult):
def _update_valid_state(self, partial_code: str, idx: int, r: ParseResult):
"""
This a simple heuristic to cut off the generated output at the end of the function.
TODO: Put this under a flag to enable/disable this heuristic.
Expand All @@ -237,7 +250,6 @@ def update_valid_state(self, partial_code: str, idx: int, r: ParseResult):
if accept_seq[0] == '$END' or accept_seq[0] == 'EOF':
self.last_valid_state[idx] = len(partial_code) - len(r.remainder)


@staticmethod
def _bytes_to_string(byte_sequence: bytes) -> tuple[str, bytes]:
"""
Expand All @@ -253,16 +265,6 @@ def _bytes_to_string(byte_sequence: bytes) -> tuple[str, bytes]:
A tuple (string, remainder) where:
- string is the longest valid UTF-8 prefix of the input as a Python string
- remainder is the rest of the bytes that could not be decoded as UTF-8

Examples:
>>> bytes_to_string(b'Hello, world!')
('Hello, world!', b'')
>>> bytes_to_string(b'Hello, \xe2\x82\xac!') # Euro symbol (€) followed by !
('Hello, €!', b'')
>>> bytes_to_string(b'Hello, \xe2\x82!') # Incomplete Euro symbol
('Hello, ', b'\xe2\x82!')
>>> bytes_to_string(b'\xff\xfe') # Invalid UTF-8
('', b'\xff\xfe')
"""
if not isinstance(byte_sequence, bytes):
raise TypeError("Input must be a bytes object")
Expand Down Expand Up @@ -292,3 +294,21 @@ def _bytes_to_string(byte_sequence: bytes) -> tuple[str, bytes]:
return byte_sequence[:valid_end].decode('utf-8'), byte_sequence[valid_end:]
else:
return "", byte_sequence

def _get_ignore_whitespace(self, grammar):
"""
Check if the grammar allows whitespace tokens to be ignored.
"""
base_parser = create_base_parser(grammar)
terminals = base_parser.terminals
ignore_terminals = base_parser.ignore_tokens

import regex
ignore_whitespace = False
for ig_name in ignore_terminals:
for terminal in terminals:
if terminal.name == ig_name:
if regex.match(terminal.pattern.to_regexp(), ' ') is not None:
ignore_whitespace = True # convert to boolean tensor mask. This is useful for fast union operations
return ignore_whitespace

Loading