diff --git a/syncode/mask_store/byte_fsm.py b/syncode/mask_store/byte_fsm.py index c675470e..c50fa265 100644 --- a/syncode/mask_store/byte_fsm.py +++ b/syncode/mask_store/byte_fsm.py @@ -50,7 +50,7 @@ def _build_byte_fsm(self, regex_fsm): self.transitions = {} # Create a mapping from byte values to category numbers - self.byte_to_category = {} + self.byte_to_category = {} # Extract the mapping from the regex FSM's alphabet and build our byte-level alphabet for char, category in regex_fsm.alphabet.items(): @@ -84,8 +84,8 @@ def _build_byte_fsm(self, regex_fsm): # Copy the transitions from the regex FSM to our byte FSM for state, category_transitions in regex_fsm.map.items(): for category, target in category_transitions.items(): - self.transitions[state][category] = target - + self.transitions[state][category] = target + # Handle multi-byte Unicode characters separately # This is needed because a multi-byte character might need special handling for char, category in regex_fsm.alphabet.items(): @@ -95,30 +95,42 @@ def _build_byte_fsm(self, regex_fsm): char_bytes = char.encode('utf-8') if len(char_bytes) <= 1: continue - + + # Add an explicit dead state for invalid transitions + dead_state = f"DEAD" + if dead_state not in self.transitions: + self.transitions[dead_state] = {} + # For multi-byte characters, we need to add special transitions # Make a copy of states to avoid modifying the dictionary during iteration states_to_process = list(self.transitions.keys()) for state in states_to_process: if category in self.transitions[state]: target = self.transitions[state][category] + else: + target = dead_state - # Create intermediate states for the multi-byte character - current = state - for i, byte in enumerate(char_bytes): - if byte not in self.alphabet: - # Add the byte to the alphabet with a new category - byte_category = f"{byte}_{i}" - self.byte_to_category[byte] = byte_category + # Create intermediate states for the multi-byte character + current = state + for i, byte in enumerate(char_bytes): + if byte not in self.alphabet: + # Add the byte to the alphabet with a new category + byte_category = f"{byte}_{i}" + self.byte_to_category[byte] = byte_category - if i < len(char_bytes) - 1: + if i < len(char_bytes) - 1: + if byte_category not in self.transitions[current]: + # Create a new state for this byte next_state = f"{current}_{byte}_{i}_{char}" if next_state not in self.transitions: self.transitions[next_state] = {} self.transitions[current][byte_category] = next_state current = next_state else: - self.transitions[current][byte_category] = target + # Transition already exists + current = self.transitions[current][byte_category] + else: + self.transitions[current][byte_category] = target @lru_cache(maxsize=100000) def _get_category(self, byte_val: int) -> Any: diff --git a/syncode/mask_store/fsm_set.py b/syncode/mask_store/fsm_set.py index a7a40453..40e59779 100644 --- a/syncode/mask_store/fsm_set.py +++ b/syncode/mask_store/fsm_set.py @@ -1,6 +1,6 @@ import time import interegular -from typing import Any, Optional, Tuple, Iterable, Dict +from typing import Any, Optional, Tuple, Iterable, Dict, Union from syncode.mask_store.byte_fsm import ByteFSM import logging logger = logging.getLogger(__name__) @@ -21,11 +21,20 @@ def __hash__(self): return self._hash @staticmethod - def det_hash(terminal: str, state_id: int): + def det_hash(terminal: str, state_id: Union[str, int]): h = 0 for char in terminal: h = (h * 31 + ord(char)) & 0xFFFFFFFF - h = (h * 31 + state_id) & 0xFFFFFFFF + + # Handle state_id based on its type + if isinstance(state_id, str): + # If state_id is a string, hash each character + for char in state_id: + h = (h * 31 + ord(char)) & 0xFFFFFFFF + else: + # If state_id is an integer, hash it directly + h = (h * 31 + state_id) & 0xFFFFFFFF + return h def __repr__(self): diff --git a/tests/mask_store/test_byte_fsm.py b/tests/mask_store/test_byte_fsm.py index ec75917a..42acd499 100644 --- a/tests/mask_store/test_byte_fsm.py +++ b/tests/mask_store/test_byte_fsm.py @@ -120,6 +120,9 @@ def test_consume_prefix(self): ("user@example.net", (False, None)), ("user@", (True, b"")), # Live state ("invalid", (True, b"")) # Live state for [a-z]+ + ]), + ('"[^"”“]+"', [ + ('\"key”', (False, None)), ]) ] diff --git a/tests/mask_store/test_mask_store_go.py b/tests/mask_store/test_mask_store_go.py deleted file mode 100644 index a6364804..00000000 --- a/tests/mask_store/test_mask_store_go.py +++ /dev/null @@ -1,48 +0,0 @@ -import sys -import os -import time -import unittest -sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../..') -import syncode.common as common -from syncode.parsers.incremental_parser import ParseResult -from syncode.parse_result import AcceptSequence, RemainderState -from syncode.mask_store.mask_store import MaskStore -from syncode.parsers.grammars.grammar import Grammar -from tests.test_utils import CustomAssertMixin - - -# Initialize these outside the test class if they're shared across tests -model = 'Qwen/Qwen2.5-1.5B-Instruct' -tokenizer = common.load_tokenizer(model) -mask_store = MaskStore.init_mask_store(grammar=Grammar('go'), tokenizer=tokenizer, use_cache=False, mode='grammar_mask') - -class TestDFAMask(unittest.TestCase, CustomAssertMixin): - def test_dfa_mask(self): - r = ParseResult({AcceptSequence(['DECIMAL_LIT', 'PLUS'])}, b'1', RemainderState.MAYBE_COMPLETE) - mask_store.get_accept_mask(r, get_list=True) - result_list = mask_store.get_accept_mask(r, get_list=True) - for token in [' +', ' +=', ' ++']: - self.assertInWithLimit(token, result_list, f"{token} not found in result list") - - def test_dfa_mask2(self): - r = ParseResult({AcceptSequence(['EOS'])}, b'\n // 1.', RemainderState.MAYBE_COMPLETE) - result_list = mask_store.get_accept_mask(r, get_list=True) - self.assertTrue(len(result_list) > 32000, "Result list is smaller than expected") - - def test_dfa_mask3(self): - r = ParseResult({AcceptSequence(['__ANON_14'])}, b'', RemainderState.COMPLETE) - result_list = mask_store.get_accept_mask(r, get_list=True) - # Uncomment the following line if you want to assert presence of specific tokens - self.assertInWithLimit(":=", result_list, ":= not found in result list") - - def test_dfa_mask4(self): - r = ParseResult({AcceptSequence(['__IGNORE_0'])}, b'', RemainderState.COMPLETE) - self.assertInWithLimit("\t", mask_store.get_accept_mask(r, get_list=True), "Tab character not found in result list") - - def test_dfa_mask5(self): - r = ParseResult({AcceptSequence(['LBRACE', '__IGNORE_0'])}, b'{', RemainderState.MAYBE_COMPLETE) - self.assertInWithLimit("\t", mask_store.get_accept_mask(r, get_list=True), "Tab character not found in result list") - - def test_dfa_mask6(self): - r = ParseResult({AcceptSequence(['NAME'])}, b'for', RemainderState.MAYBE_COMPLETE) - self.assertInWithLimit(" {", mask_store.get_accept_mask(r, get_list=True), "Opening brace not found in result list") diff --git a/tests/mask_store/test_mask_store_misc.py b/tests/mask_store/test_mask_store_misc.py new file mode 100644 index 00000000..54438a20 --- /dev/null +++ b/tests/mask_store/test_mask_store_misc.py @@ -0,0 +1,107 @@ +import sys +import os +import time +import unittest +sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../..') +import syncode.common as common +from syncode.parsers.incremental_parser import ParseResult +from syncode.parse_result import AcceptSequence, RemainderState +from syncode.mask_store.mask_store import MaskStore +from syncode.parsers.grammars.grammar import Grammar +from tests.test_utils import CustomAssertMixin + + +class TestMaskGo(unittest.TestCase, CustomAssertMixin): + def setUp(self): + model = 'Qwen/Qwen2.5-1.5B-Instruct' + tokenizer = common.load_tokenizer(model) + self.mask_store = MaskStore.init_mask_store(grammar=Grammar('go'), tokenizer=tokenizer, use_cache=False, mode='grammar_mask') + return super().setUp() + + def test_mask(self): + r = ParseResult({AcceptSequence(['DECIMAL_LIT', 'PLUS'])}, b'1', RemainderState.MAYBE_COMPLETE) + self.mask_store.get_accept_mask(r, get_list=True) + result_list = self.mask_store.get_accept_mask(r, get_list=True) + for token in [' +', ' +=', ' ++']: + self.assertInWithLimit(token, result_list, f"{token} not found in result list") + + def test_mask2(self): + r = ParseResult({AcceptSequence(['EOS'])}, b'\n // 1.', RemainderState.MAYBE_COMPLETE) + result_list = self.mask_store.get_accept_mask(r, get_list=True) + self.assertTrue(len(result_list) > 32000, "Result list is smaller than expected") + + def test_mask3(self): + r = ParseResult({AcceptSequence(['__ANON_14'])}, b'', RemainderState.COMPLETE) + result_list = self.mask_store.get_accept_mask(r, get_list=True) + # Uncomment the following line if you want to assert presence of specific tokens + self.assertInWithLimit(":=", result_list, ":= not found in result list") + + def test_mask4(self): + r = ParseResult({AcceptSequence(['__IGNORE_0'])}, b'', RemainderState.COMPLETE) + self.assertInWithLimit("\t", self.mask_store.get_accept_mask(r, get_list=True), "Tab character not found in result list") + + def test_mask5(self): + r = ParseResult({AcceptSequence(['LBRACE', '__IGNORE_0'])}, b'{', RemainderState.MAYBE_COMPLETE) + self.assertInWithLimit("\t", self.mask_store.get_accept_mask(r, get_list=True), "Tab character not found in result list") + + def test_mask6(self): + r = ParseResult({AcceptSequence(['NAME'])}, b'for', RemainderState.MAYBE_COMPLETE) + self.assertInWithLimit(" {", self.mask_store.get_accept_mask(r, get_list=True), "Opening brace not found in result list") + + +class TestMaskJSON(unittest.TestCase, CustomAssertMixin): + def setUp(self): + model = 'google/gemma-2-2b-it' + tokenizer = common.load_tokenizer(model) + + custom_json_grammar = f""" + ?start: start_value + ?start_value: object + | array + + ?value: object + | array + | EMPTY_STRING + | NONEMPTY_STRING + | SIGNED_NUMBER -> number + | "true" -> true + | "false" -> false + | "null" -> null + + array : "[" [value ("," value)*] "]" + object : "{" [pair ("," pair)*] "}" + pair : NONEMPTY_STRING ":" value + + NONEMPTY_STRING: /\"[^"”“]+\"/ + EMPTY_STRING: /\"\"/ + + DIGIT: "0".."9" + HEXDIGIT: "a".."f"|"A".."F"|DIGIT + INT: DIGIT+ + SIGNED_INT: ["+"|"-"] INT + DECIMAL: INT "." INT? | "." INT + + + _EXP: ("e"|"E") SIGNED_INT + FLOAT: INT _EXP | DECIMAL _EXP? + NUMBER: FLOAT | INT + SIGNED_NUMBER: ["+"|"-"] NUMBER + WS: /[ \t\f\r\n]/+ + + %ignore WS + """ + self.mask_store = MaskStore.init_mask_store(grammar=Grammar(custom_json_grammar), tokenizer=tokenizer, use_cache=False, mode='grammar_mask') + return super().setUp() + + def test_mask(self): + r = ParseResult({AcceptSequence(['NONEMPTY_STRING'])}, b'"key', RemainderState.INCOMPLETE) + result_list = self.mask_store.get_accept_mask(r, get_list=True) + self.assertInWithLimit('"', result_list, '" not found in result list') + self.assertNotIn('”', result_list) + self.assertNotIn('“', result_list) + + +if __name__ == '__main__': + # Run JSON tests + suite = unittest.TestLoader().loadTestsFromTestCase(TestMaskJSON) + unittest.TextTestRunner().run(suite) \ No newline at end of file diff --git a/tests/mask_store/test_mask_store_python.py b/tests/mask_store/test_mask_store_python.py index 050cb0b5..b064fc88 100644 --- a/tests/mask_store/test_mask_store_python.py +++ b/tests/mask_store/test_mask_store_python.py @@ -19,7 +19,7 @@ class TestDFAMaskLlama(unittest.TestCase, CustomAssertMixin): mask_store = MaskStore.init_mask_store( grammar=Grammar('python'), tokenizer=tokenizer, - use_cache=True, + use_cache=False, indent=True, mode="grammar_strict" ) diff --git a/tests/parser/test_grammar_json.py b/tests/parser/test_grammar_json.py index dc970527..28f7f208 100644 --- a/tests/parser/test_grammar_json.py +++ b/tests/parser/test_grammar_json.py @@ -25,4 +25,13 @@ def test_json_parser2(self): r = inc_parser.get_acceptable_next_terminals(partial_code) assert r.remainder == '' assert r.remainder_state == RemainderState.COMPLETE - \ No newline at end of file + + def test_json_parser3(self): + # Tests when the last incomplete word is unparsed + inc_parser.reset() + partial_code = '{\n "key' + r = inc_parser.get_acceptable_next_terminals(partial_code) + assert AcceptSequence(['NONEMPTY_STRING']) in r.accept_sequences + +if __name__ == '__main__': + unittest.main() \ No newline at end of file