Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ jobs:
- name: Run Tests
run: |
python3 -m unittest tests.test_misc
python3 -m unittest tests.test_grammar_go
python3 -m unittest tests.test_grammar_sql
python3 -m unittest tests.test_grammar_python
python3 -m unittest tests.test_grammar_json
python3 -m unittest tests.test_grammar_java
python3 -m unittest tests.parser.test_grammar_go
python3 -m unittest tests.parser.test_grammar_sql
python3 -m unittest tests.parser.test_grammar_python
python3 -m unittest tests.parser.test_grammar_json
python3 -m unittest tests.parser.test_grammar_java
python3 -m unittest tests.test_language_model
python3 -m unittest tests.test_lr_parser
python3 -m unittest tests.test_syncode
Expand Down
15 changes: 10 additions & 5 deletions syncode/evaluation/code_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,18 @@ def run_code_eval(
else:
stop_words = None

pbar = tqdm(total=len(problems) * num_samples_per_task)
if debug_task_id is None:
time1 = time.time()
pbar = tqdm(total=len(problems) * num_samples_per_task)

# Run evaluation for all tasks
for task_id in list(problems.keys())[:num_tasks]:
outputs.append(CodeEval.run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, pbar, task_id, stop_words=stop_words))
outputs.append(
CodeEval.run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, task_id, stop_words=stop_words)
)
pbar.update(num_samples_per_task)

pbar.close()
if out_path is not None: write_jsonl(out_path, samples)

avg_time = (time.time() - time1) / len(problems)
Expand All @@ -54,10 +58,12 @@ def run_code_eval(
CodeEval.write_results(syncode, out_path, avg_time, functional_result, num_tasks)
else: # Debugging a specific task
debug_task_id = list(problems.keys())[debug_task_id]
return CodeEval.run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, pbar, debug_task_id, logger=logger, stop_words=stop_words)
return CodeEval.run_eval_for_task(
syncode, num_samples_per_task, format_tabs, problems, samples, debug_task_id, logger=logger, stop_words=stop_words
)
return outputs

def run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, pbar, task_id, logger=common.EmptyLogger(), stop_words=None):
def run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, task_id, logger=common.EmptyLogger(), stop_words=None):
"""
run evaluation for a specific task
"""
Expand Down Expand Up @@ -96,7 +102,6 @@ def run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samp
)
samples += [result]
all_completions.append(completion)
pbar.update(num_samples_per_task)

# Clear the cache
torch.cuda.empty_cache()
Expand Down
10 changes: 8 additions & 2 deletions syncode/grammar_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,15 @@ def _get_partial_codes(self, input_ids: torch.LongTensor) -> list[(str, bytes)]:
output = []
for idx in range(len(input_ids)):
if self.parse_output_only:
partial_code, remainder_bytes = self._bytes_to_string(self.byte_tokenizer.decode(input_ids[idx, self.start_from:].tolist(), skip_special_tokens=True))
partial_code, remainder_bytes = self._bytes_to_string(
self.byte_tokenizer.decode(
input_ids[idx, self.start_from:].to('cpu', non_blocking=True).tolist(), skip_special_tokens=True)
)
else:
partial_code, remainder_bytes = self._bytes_to_string(self.byte_tokenizer.decode(input_ids[idx].tolist(), skip_special_tokens=True))
partial_code, remainder_bytes = self._bytes_to_string(
self.byte_tokenizer.decode(
input_ids[idx].to('cpu', non_blocking=True).tolist(), skip_special_tokens=True)
)
output.append((partial_code, remainder_bytes))
return output

Expand Down
7 changes: 7 additions & 0 deletions syncode/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ def generate_grammar_constrained_completion(
print("WARNING: Opportunistic mode requires SAMPLE or GREEDY_SEARCH generation mode.")
if not batch_size == 1:
print("WARNING: Opportunistic mode requires batch_size of 1.")

# Ensure pad_token_id is set
if 'pad_token_id' not in dir(self.tokenizer):
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

# Use generate from transformers library for other modes
generated_ids = self.model.generate(
**inputs,
Expand Down Expand Up @@ -190,6 +196,7 @@ def _generate(
logits_processor = self.model._get_logits_processor(gen_config, token_ids.size(1), token_ids, prefix_allowed_tokens_fn=None, logits_processor=[])

max_tokens = self.gen_args['max_new_tokens']+token_ids.size(1)
self.model.config.pad_token_id = pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id

while True:
try:
Expand Down
12 changes: 10 additions & 2 deletions syncode/mask_store/fsm_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,22 @@ class JointFSMState:
def __init__(self, terminal: str, state_id: int):
self.terminal = terminal
self.state_id = state_id
self._hash = hash((self.terminal, self.state_id)) # Pre-compute hash on creation

self._hash = JointFSMState.det_hash(self.terminal, self.state_id)
def __eq__(self, other: 'JointFSMState'):
return self.terminal == other.terminal and self.state_id == other.state_id

def __hash__(self):
return self._hash

@staticmethod
def det_hash(terminal: str, state_id: int):
h = 0
for char in terminal:
h = (h * 31 + ord(char)) & 0xFFFFFFFF
h = (h * 31 + state_id) & 0xFFFFFFFF
return h

def __repr__(self):
return f"({self.terminal}, {self.state_id})"

Expand Down
15 changes: 8 additions & 7 deletions syncode/mask_store/mask_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(self,

followings_terminas_map = None
if parse_table is not None:
followings_terminas_map = self._compute_following_terminals_map(terminal_names, parse_table)
followings_terminas_map = self._compute_following_terminals_map(terminal_names, parse_table, ignore_terminals)

# Create consume prefix cache
self._consume_prefix_cache = {}
Expand Down Expand Up @@ -105,8 +105,9 @@ def init_mask_store(

if use_cache and os.path.exists(fsm_path):
try:
mask_store = pickle.load(open(fsm_path, 'rb'))
return mask_store
with open(fsm_path, 'rb') as f:
mask_store = pickle.load(f)
return mask_store
except Exception as e:
logger.warning(f"Error loading mask store: {e}")

Expand Down Expand Up @@ -134,7 +135,8 @@ def init_mask_store(
def _compute_following_terminals_map(
self,
terminals: Iterable[str],
parse_table
parse_table,
ignore_terminals: Iterable[str]
) -> defaultdict:
"""
From terminals, filter out terminals that cannot follow the current terminal
Expand All @@ -150,9 +152,8 @@ def _compute_following_terminals_map(
# We iterate through each cur_terminal:
for cur_terminal in terminals:
# Add all ignore terminals to the following terminals
for next_terminal in terminals:
if 'IGNORE' in next_terminal:
following_terminals_map[cur_terminal].add(next_terminal)
for next_terminal in ignore_terminals:
following_terminals_map[cur_terminal].add(next_terminal)

# We iterate through each parser_state:
for _, row in parse_table.states.items():
Expand Down
Empty file removed tests/mask_store/tes
Empty file.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
import sys
import os
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../')
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../..')
from syncode.parsers import create_parser
from syncode.parsers.grammars.grammar import Grammar
from syncode.parse_result import AcceptSequence, RemainderState
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
import sys
import os
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../')
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../..')
from syncode.parsers import create_parser
from syncode.parsers.grammars.grammar import Grammar
from syncode.parse_result import AcceptSequence, RemainderState
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
import sys
import os
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../')
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../..')
from syncode.parsers import create_parser
from syncode.parsers.grammars.grammar import Grammar
from syncode.parse_result import AcceptSequence, RemainderState
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
import sys, os
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../')
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../..')
from syncode.parsers import create_parser
from transformers import (
LlamaTokenizer,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
import sys
import os
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../')
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../..')
from syncode.parsers import create_parser
from syncode.parsers.grammars.grammar import Grammar
from syncode.parse_result import AcceptSequence, RemainderState
Expand Down Expand Up @@ -70,3 +70,7 @@ def test_sql_parser7(self):
r = inc_parser.get_acceptable_next_terminals(partial_code)
assert r.remainder == "'%Hey%'"
assert r.remainder_state == RemainderState.MAYBE_COMPLETE

if __name__ == "__main__":
unittest.main()

3 changes: 2 additions & 1 deletion tests/test_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ class TestTokenizer:
def __init__(self) -> None:
vocab = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', '*', '/', '(', ')', ' ', '\n', '\t', '=']
self.vocab = vocab
self.eos_token_id = ''
self.eos_token_id = 1
self.pad_token_id = 2

def __call__(self, input_batch: list[str], return_tensors="pt") -> BatchEncoding:
# This works since we have single character tokens
Expand Down