diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 7e2b0b36..4c54f961 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -31,11 +31,11 @@ jobs: - name: Run Tests run: | python3 -m unittest tests.test_misc - python3 -m unittest tests.test_grammar_go - python3 -m unittest tests.test_grammar_sql - python3 -m unittest tests.test_grammar_python - python3 -m unittest tests.test_grammar_json - python3 -m unittest tests.test_grammar_java + python3 -m unittest tests.parser.test_grammar_go + python3 -m unittest tests.parser.test_grammar_sql + python3 -m unittest tests.parser.test_grammar_python + python3 -m unittest tests.parser.test_grammar_json + python3 -m unittest tests.parser.test_grammar_java python3 -m unittest tests.test_language_model python3 -m unittest tests.test_lr_parser python3 -m unittest tests.test_syncode diff --git a/syncode/evaluation/code_eval.py b/syncode/evaluation/code_eval.py index e6c0820e..661bc515 100644 --- a/syncode/evaluation/code_eval.py +++ b/syncode/evaluation/code_eval.py @@ -36,14 +36,18 @@ def run_code_eval( else: stop_words = None - pbar = tqdm(total=len(problems) * num_samples_per_task) if debug_task_id is None: time1 = time.time() + pbar = tqdm(total=len(problems) * num_samples_per_task) # Run evaluation for all tasks for task_id in list(problems.keys())[:num_tasks]: - outputs.append(CodeEval.run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, pbar, task_id, stop_words=stop_words)) + outputs.append( + CodeEval.run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, task_id, stop_words=stop_words) + ) + pbar.update(num_samples_per_task) + pbar.close() if out_path is not None: write_jsonl(out_path, samples) avg_time = (time.time() - time1) / len(problems) @@ -54,10 +58,12 @@ def run_code_eval( CodeEval.write_results(syncode, out_path, avg_time, functional_result, num_tasks) else: # Debugging a specific task debug_task_id = list(problems.keys())[debug_task_id] - return CodeEval.run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, pbar, debug_task_id, logger=logger, stop_words=stop_words) + return CodeEval.run_eval_for_task( + syncode, num_samples_per_task, format_tabs, problems, samples, debug_task_id, logger=logger, stop_words=stop_words + ) return outputs - def run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, pbar, task_id, logger=common.EmptyLogger(), stop_words=None): + def run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, task_id, logger=common.EmptyLogger(), stop_words=None): """ run evaluation for a specific task """ @@ -96,7 +102,6 @@ def run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samp ) samples += [result] all_completions.append(completion) - pbar.update(num_samples_per_task) # Clear the cache torch.cuda.empty_cache() diff --git a/syncode/grammar_decoder.py b/syncode/grammar_decoder.py index 1be44090..53e05d63 100644 --- a/syncode/grammar_decoder.py +++ b/syncode/grammar_decoder.py @@ -206,9 +206,15 @@ def _get_partial_codes(self, input_ids: torch.LongTensor) -> list[(str, bytes)]: output = [] for idx in range(len(input_ids)): if self.parse_output_only: - partial_code, remainder_bytes = self._bytes_to_string(self.byte_tokenizer.decode(input_ids[idx, self.start_from:].tolist(), skip_special_tokens=True)) + partial_code, remainder_bytes = self._bytes_to_string( + self.byte_tokenizer.decode( + input_ids[idx, self.start_from:].to('cpu', non_blocking=True).tolist(), skip_special_tokens=True) + ) else: - partial_code, remainder_bytes = self._bytes_to_string(self.byte_tokenizer.decode(input_ids[idx].tolist(), skip_special_tokens=True)) + partial_code, remainder_bytes = self._bytes_to_string( + self.byte_tokenizer.decode( + input_ids[idx].to('cpu', non_blocking=True).tolist(), skip_special_tokens=True) + ) output.append((partial_code, remainder_bytes)) return output diff --git a/syncode/language_model.py b/syncode/language_model.py index 16fe99fc..132df3a3 100644 --- a/syncode/language_model.py +++ b/syncode/language_model.py @@ -122,6 +122,12 @@ def generate_grammar_constrained_completion( print("WARNING: Opportunistic mode requires SAMPLE or GREEDY_SEARCH generation mode.") if not batch_size == 1: print("WARNING: Opportunistic mode requires batch_size of 1.") + + # Ensure pad_token_id is set + if 'pad_token_id' not in dir(self.tokenizer): + if self.tokenizer.pad_token_id is None: + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + # Use generate from transformers library for other modes generated_ids = self.model.generate( **inputs, @@ -190,6 +196,7 @@ def _generate( logits_processor = self.model._get_logits_processor(gen_config, token_ids.size(1), token_ids, prefix_allowed_tokens_fn=None, logits_processor=[]) max_tokens = self.gen_args['max_new_tokens']+token_ids.size(1) + self.model.config.pad_token_id = pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id while True: try: diff --git a/syncode/mask_store/fsm_set.py b/syncode/mask_store/fsm_set.py index ce19c123..a7a40453 100644 --- a/syncode/mask_store/fsm_set.py +++ b/syncode/mask_store/fsm_set.py @@ -12,14 +12,22 @@ class JointFSMState: def __init__(self, terminal: str, state_id: int): self.terminal = terminal self.state_id = state_id - self._hash = hash((self.terminal, self.state_id)) # Pre-compute hash on creation - + self._hash = JointFSMState.det_hash(self.terminal, self.state_id) + def __eq__(self, other: 'JointFSMState'): return self.terminal == other.terminal and self.state_id == other.state_id def __hash__(self): return self._hash + @staticmethod + def det_hash(terminal: str, state_id: int): + h = 0 + for char in terminal: + h = (h * 31 + ord(char)) & 0xFFFFFFFF + h = (h * 31 + state_id) & 0xFFFFFFFF + return h + def __repr__(self): return f"({self.terminal}, {self.state_id})" diff --git a/syncode/mask_store/mask_store.py b/syncode/mask_store/mask_store.py index a5a76c81..15ec1ef2 100644 --- a/syncode/mask_store/mask_store.py +++ b/syncode/mask_store/mask_store.py @@ -65,7 +65,7 @@ def __init__(self, followings_terminas_map = None if parse_table is not None: - followings_terminas_map = self._compute_following_terminals_map(terminal_names, parse_table) + followings_terminas_map = self._compute_following_terminals_map(terminal_names, parse_table, ignore_terminals) # Create consume prefix cache self._consume_prefix_cache = {} @@ -105,8 +105,9 @@ def init_mask_store( if use_cache and os.path.exists(fsm_path): try: - mask_store = pickle.load(open(fsm_path, 'rb')) - return mask_store + with open(fsm_path, 'rb') as f: + mask_store = pickle.load(f) + return mask_store except Exception as e: logger.warning(f"Error loading mask store: {e}") @@ -134,7 +135,8 @@ def init_mask_store( def _compute_following_terminals_map( self, terminals: Iterable[str], - parse_table + parse_table, + ignore_terminals: Iterable[str] ) -> defaultdict: """ From terminals, filter out terminals that cannot follow the current terminal @@ -150,9 +152,8 @@ def _compute_following_terminals_map( # We iterate through each cur_terminal: for cur_terminal in terminals: # Add all ignore terminals to the following terminals - for next_terminal in terminals: - if 'IGNORE' in next_terminal: - following_terminals_map[cur_terminal].add(next_terminal) + for next_terminal in ignore_terminals: + following_terminals_map[cur_terminal].add(next_terminal) # We iterate through each parser_state: for _, row in parse_table.states.items(): diff --git a/tests/mask_store/tes b/tests/mask_store/tes deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/test_grammar_go.py b/tests/parser/test_grammar_go.py similarity index 99% rename from tests/test_grammar_go.py rename to tests/parser/test_grammar_go.py index e2e689fb..cca90f85 100644 --- a/tests/test_grammar_go.py +++ b/tests/parser/test_grammar_go.py @@ -1,7 +1,7 @@ import unittest import sys import os -sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../') +sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../..') from syncode.parsers import create_parser from syncode.parsers.grammars.grammar import Grammar from syncode.parse_result import AcceptSequence, RemainderState diff --git a/tests/test_grammar_java.py b/tests/parser/test_grammar_java.py similarity index 99% rename from tests/test_grammar_java.py rename to tests/parser/test_grammar_java.py index d2e0dda8..26c82912 100644 --- a/tests/test_grammar_java.py +++ b/tests/parser/test_grammar_java.py @@ -1,7 +1,7 @@ import unittest import sys import os -sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../') +sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../..') from syncode.parsers import create_parser from syncode.parsers.grammars.grammar import Grammar from syncode.parse_result import AcceptSequence, RemainderState diff --git a/tests/test_grammar_json.py b/tests/parser/test_grammar_json.py similarity index 99% rename from tests/test_grammar_json.py rename to tests/parser/test_grammar_json.py index 33dfdc5a..dc970527 100644 --- a/tests/test_grammar_json.py +++ b/tests/parser/test_grammar_json.py @@ -1,7 +1,7 @@ import unittest import sys import os -sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../') +sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../..') from syncode.parsers import create_parser from syncode.parsers.grammars.grammar import Grammar from syncode.parse_result import AcceptSequence, RemainderState diff --git a/tests/test_grammar_python.py b/tests/parser/test_grammar_python.py similarity index 99% rename from tests/test_grammar_python.py rename to tests/parser/test_grammar_python.py index 6316dafa..fa33e149 100644 --- a/tests/test_grammar_python.py +++ b/tests/parser/test_grammar_python.py @@ -1,6 +1,6 @@ import unittest import sys, os -sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../') +sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../..') from syncode.parsers import create_parser from transformers import ( LlamaTokenizer, diff --git a/tests/test_grammar_sql.py b/tests/parser/test_grammar_sql.py similarity index 98% rename from tests/test_grammar_sql.py rename to tests/parser/test_grammar_sql.py index 9033fe35..0b1bc2aa 100644 --- a/tests/test_grammar_sql.py +++ b/tests/parser/test_grammar_sql.py @@ -1,7 +1,7 @@ import unittest import sys import os -sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../') +sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../..') from syncode.parsers import create_parser from syncode.parsers.grammars.grammar import Grammar from syncode.parse_result import AcceptSequence, RemainderState @@ -70,3 +70,7 @@ def test_sql_parser7(self): r = inc_parser.get_acceptable_next_terminals(partial_code) assert r.remainder == "'%Hey%'" assert r.remainder_state == RemainderState.MAYBE_COMPLETE + +if __name__ == "__main__": + unittest.main() + \ No newline at end of file diff --git a/tests/test_language_model.py b/tests/test_language_model.py index 9882ec39..9cb4ba3e 100644 --- a/tests/test_language_model.py +++ b/tests/test_language_model.py @@ -38,7 +38,8 @@ class TestTokenizer: def __init__(self) -> None: vocab = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', '*', '/', '(', ')', ' ', '\n', '\t', '='] self.vocab = vocab - self.eos_token_id = '' + self.eos_token_id = 1 + self.pad_token_id = 2 def __call__(self, input_batch: list[str], return_tensors="pt") -> BatchEncoding: # This works since we have single character tokens