diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml
index 64a9de88..7e2b0b36 100644
--- a/.github/workflows/run_tests.yml
+++ b/.github/workflows/run_tests.yml
@@ -23,7 +23,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
- - name: Cache DFA mask store
+ - name: Cache mask store
uses: actions/cache@v3
with:
path: /home/runner/work/syncode/syncode/cache/mask_stores/
@@ -39,3 +39,7 @@ jobs:
python3 -m unittest tests.test_language_model
python3 -m unittest tests.test_lr_parser
python3 -m unittest tests.test_syncode
+ python3 -m unittest tests.mask_store.test_byte_fsm
+ python3 -m unittest tests.mask_store.test_fsm_set
+ python3 -m unittest tests.mask_store.test_byte_tokenizer
+ python3 -m unittest tests.mask_store.test_lookup_table
diff --git a/.gitignore b/.gitignore
index 17ea9642..59b4e7a3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -6,4 +6,5 @@ syncode/core/__pycache__
.vscode/
tmp*
cache/
-.ipynb_checkpoints/
\ No newline at end of file
+.ipynb_checkpoints/
+*.prof
diff --git a/notebooks/tests/builtin_grammar.ipynb b/notebooks/tests/builtin_grammar.ipynb
index f69d0543..8c3bc18d 100644
--- a/notebooks/tests/builtin_grammar.ipynb
+++ b/notebooks/tests/builtin_grammar.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -17,7 +17,6 @@
"\n",
"device = 'cuda'\n",
"model_name = \"meta-llama/Llama-3.2-1B-Instruct\"\n",
- "# model_name = \"meta-llama/Llama-3.1-8B-Instruct\"\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True).eval().to(device)\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True)"
diff --git a/notebooks/tests/lexer_ambiguity.ipynb b/notebooks/tests/lexer_ambiguity.ipynb
index 57688863..7be12b72 100644
--- a/notebooks/tests/lexer_ambiguity.ipynb
+++ b/notebooks/tests/lexer_ambiguity.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -36,7 +36,6 @@
"source": [
"from syncode.infer import Syncode\n",
"\n",
- "# Load the unconstrained original model\n",
"model_name = \"microsoft/Phi-3-mini-4k-instruct\"\n",
"\n",
"trying = \"\"\" \n",
diff --git a/notebooks/tests/non_ascii.ipynb b/notebooks/tests/non_ascii.ipynb
new file mode 100644
index 00000000..07d3b6d2
--- /dev/null
+++ b/notebooks/tests/non_ascii.ipynb
@@ -0,0 +1,122 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/shubham/codex/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n",
+ "Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 10.73it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Creating DFA mask store for PreTrainedTokenizerFast and custom, may take more than 10 minutes. Caching at /home/shubham/syncode/cache/mask_stores/PreTrainedTokenizerFast/grammar_strict_4470738745_128000.pkl.\n",
+ "Ignore whitespace tokens is False\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 16/16 [00:03<00:00, 4.32it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Time taken to create mask store: 4.161165714263916 seconds\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "from syncode.infer import Syncode\n",
+ "\n",
+ "grammar = r\"\"\"\n",
+ " start: \"∀∃∀∃∀\" \n",
+ " \"\"\"\n",
+ "\n",
+ "syn_llm = Syncode(model=\"meta-llama/Llama-3.1-8B-Instruct\", grammar=grammar, new_mask_store=True)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
+ "Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n",
+ "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
+ "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Syncode augmented LLM output:\n",
+ "∀∃∀∃∀\n",
+ "\n"
+ ]
+ },
+ {
+ "ename": "",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
+ "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
+ "\u001b[1;31mClick here for more info. \n",
+ "\u001b[1;31mView Jupyter log for further details."
+ ]
+ }
+ ],
+ "source": [
+ "p = \"You are an expert in writing print something random.\"\n",
+ " \n",
+ "output = syn_llm.infer(p)[0]\n",
+ "print(f\"Syncode augmented LLM output:\\n{output}\\n\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "codex",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/syncode/__init__.py b/syncode/__init__.py
index 4f86c4e3..cb3d12b5 100644
--- a/syncode/__init__.py
+++ b/syncode/__init__.py
@@ -1,3 +1,6 @@
from syncode.infer import Syncode
from grammar_decoder import SyncodeLogitsProcessor
from parsers.grammars import Grammar
+import common
+
+common.setup_logging()
diff --git a/syncode/common.py b/syncode/common.py
index 2be53afc..baf5370f 100644
--- a/syncode/common.py
+++ b/syncode/common.py
@@ -1,4 +1,6 @@
+import logging
import os
+import sys
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
@@ -8,23 +10,6 @@
SYNCODE_CACHE = os.environ['SYNCODE_CACHE'] if 'SYNCODE_CACHE' in os.environ else 'cache/'
HF_ACCESS_TOKEN = os.environ['HF_ACCESS_TOKEN'] if 'HF_ACCESS_TOKEN' in os.environ else None
-def get_vocab_from_tokenizer(tokenizer):
- # self.vocab is a list of readable token strings (e.g., ' hello' and '\n')
- # sorted by their token IDs (so self.vocab[0] is the first token, etc).
- vocab = [v for k, v in
- sorted([(t_id, tokenizer.decode([t_id]))
- for _, t_id in tokenizer.get_vocab().items()])]
-
- # HACK: Is there a better way to know if a token has a prefix space?
- if 'Llama' in tokenizer.__class__.__name__:
- for i in range(len(vocab)):
- t = vocab[i]
- if 2*len(t) != len(tokenizer.decode([i, i], add_special_tokens=False)):
- vocab[i] = ' ' + t
- if t == '':
- vocab[i] = ' '
-
- return vocab
def load_model(model_name, device, quantize):
if model_name == 'test':
@@ -53,6 +38,46 @@ def get_output_path(model_name, grammar, dataset, num_samples, mode):
os.makedirs(out_dir, exist_ok=True)
return out_dir,out_path
+# This is the setup for Python logging
+def setup_logging(level=None):
+ """
+ Configure the root logger for both application and test usage.
+
+ This function is safe to call multiple times - it will only configure
+ logging once to avoid duplicate handlers.
+
+ Args:
+ level: Override the logging level. If None, uses the LOG_LEVEL
+ environment variable or defaults to INFO.
+
+ Returns:
+ The root logger
+ """
+ # Determine the logging level
+ if level is None:
+ # Get level from environment or default to INFO
+ level_name = os.environ.get('LOG_LEVEL', 'INFO')
+ level = getattr(logging, level_name.upper(), logging.INFO)
+
+ # Get the root logger
+ root_logger = logging.getLogger()
+
+ # Clear any existing handlers to avoid duplicates
+ for handler in root_logger.handlers[:]:
+ root_logger.removeHandler(handler)
+
+ # Set the logging level
+ root_logger.setLevel(level)
+
+ # Create a stdout handler
+ handler = logging.StreamHandler(sys.stdout)
+ formatter = logging.Formatter('[%(asctime)s-%(name)s] - %(message)s')
+ handler.setFormatter(formatter)
+ root_logger.addHandler(handler)
+
+ return root_logger
+
+
class Logger:
"""
Logger class for logging the output of the model
diff --git a/syncode/dfa_mask_store.py b/syncode/dfa_mask_store.py
deleted file mode 100644
index 08968b05..00000000
--- a/syncode/dfa_mask_store.py
+++ /dev/null
@@ -1,581 +0,0 @@
-from collections import defaultdict
-import copy, os, pickle
-import time
-import interegular
-import torch
-import regex
-import syncode.common as common
-from tqdm import tqdm
-from syncode.parsers import create_base_parser
-from syncode.larkm.lexer import TerminalDef
-from syncode.parse_result import IndentationConstraint, RemainderState, ParseResult
-from syncode.parsers.grammars.grammar import Grammar
-from typing import Any, Optional, Tuple, Iterable, Dict
-
-
-class DFAState:
- """
- Represents the state of the DFA. It consists of the current terminal and the DFA state for the current terminal.
- """
- def __init__(self, terminal, state_id):
- self.terminal = terminal
- self.state_id = state_id
-
- def __eq__(self, other: 'DFAState'):
- return self.terminal == other.terminal and self.state_id == other.state_id
-
- def __hash__(self):
- return hash((self.terminal, self.state_id))
-
- def __repr__(self):
- return f"({self.terminal}, {self.state_id})"
-
-
-class DFAs:
- """
- Stores the DFAs for each terminal and provides the method to consume the input string and get the DFA state.
- """
- def __init__(self, terminals: Iterable[TerminalDef], simplifications: Dict[str, str] = {}):
- self._terminals_to_dfa: Dict[str, interegular.FSM] = {}
- self.anything_else = interegular.fsm.anything_else # This is special character used for the
- self._simplifications: Dict[str, str] = simplifications
-
- for terminal in terminals:
- if terminal.name in simplifications:
- terminal_regex = simplifications[terminal.name]
- else:
- terminal_regex = terminal.pattern.to_regexp()
- # We store the DFA for each terminal (with name as the key) in the dictionary
- self._terminals_to_dfa[terminal.name] = interegular.parse_pattern(terminal_regex).to_fsm()
-
- def states(self):
- return [DFAState(terminal_name, state_id) for terminal_name, dfa in self._terminals_to_dfa.items() for state_id in dfa.states]
-
- def initial(self, terminal: str):
- return DFAState(terminal, self._terminals_to_dfa[terminal].initial)
-
- def compute_dfa_states(self, input_str: str) -> Iterable[DFAState]:
- """
- consume input_str and get the list of pairs of (terminal, dfa_state). This denotes our current DFA state.
-
- NOTE: The returned DFA state is always a live state
- """
- dfa_states = []
- for (terminal, dfa) in self._terminals_to_dfa.items():
- state_id = self._consume_input(dfa, input_str)
- if state_id is not None:
- dfa_states.append(DFAState(terminal, state_id))
- return dfa_states
-
- def _consume_input(self, dfa: interegular.FSM, input_str: str) -> Optional[int]:
- """
- Conumses the input string and returns the final state if it is live, otherwise returns None
- """
- state_id = dfa.initial
- for i, symbol in enumerate(input_str):
- if not symbol in dfa.alphabet:
- if not self.anything_else in dfa.alphabet:
- return None
- symbol = self.anything_else
-
- # Missing transition = transition to dead state
- if not (state_id in dfa.map and dfa.alphabet[symbol] in dfa.map[state_id]):
- return None
- state_id = dfa.map[state_id][dfa.alphabet[symbol]]
- return state_id
-
- def is_final(self, dfa_state: DFAState) -> bool:
- """
- Returns True if the dfa state is a final state
- """
- return dfa_state.state_id in self._terminals_to_dfa[dfa_state.terminal].finals
-
- def consume_prefix(self, dfa_state: DFAState, input_str: str) -> Tuple[bool, Optional[str]]:
- """
- Consume longest prefix of input_str that is accepted by dfa and return the remainder.
- If we reach a dead state, return (False, None).
- If the consumption ends at any live state that is not an accept state, return (True, '').
- If we reach a final state, return (True, remainder).
- """
- cur_state_id = dfa_state.state_id
- dfa: interegular.FSM = self._terminals_to_dfa[dfa_state.terminal]
-
- longest_accept_index = -1
-
- if cur_state_id in dfa.finals:
- longest_accept_index = 0
-
- for i, symbol in enumerate(input_str):
- if not symbol in dfa.alphabet:
- if not self.anything_else in dfa.alphabet:
- cur_state_id = None
- break
- symbol = self.anything_else
-
- # Missing transition = transition to dead state
- if not (cur_state_id in dfa.map and dfa.alphabet[symbol] in dfa.map[cur_state_id]):
- cur_state_id = None
- break
-
- cur_state_id = dfa.map[cur_state_id][dfa.alphabet[symbol]]
-
- if cur_state_id in dfa.finals:
- longest_accept_index = i+1
-
- if longest_accept_index != -1: # reached accept state at some point
- return (True, input_str[longest_accept_index:])
- elif cur_state_id != None and dfa.islive(cur_state_id): # if state is a live state
- return (True, '')
-
- # if we never reach a final state and reach a dead state at some point
- return (False, None)
-
-class LookupTable:
- """
- Stores the overapproximate tokens
- """
- def __init__(self, vocab: Iterable[str], special_token_ids: Iterable[int], indentation=False, mode='grammar_mask'):
- self._dfa_state_and_next_terminal_to_tokens: defaultdict = defaultdict(list)
- self._overapprox_lookup: Dict[DFAState, Any] = {}
- self._exact_lookup: dict = {}
- self._mode = mode
- self._vocab: Iterable[str] = vocab
- self.indentation = indentation
-
- self._default_mask = self._get_default_mask(special_token_ids)
- if indentation:
- self._whitespace_tokens_map: defaultdict = defaultdict(list)
- self._indentation_to_tokens_map: defaultdict = defaultdict(list)
- self._create_indentation_to_tokens_map()
-
- def incomplete_case_lookup(self, dfa_state: DFAState) -> Any:
- assert isinstance(dfa_state, DFAState)
- if self._mode == 'grammar_mask':
- return self._overapprox_lookup[dfa_state]
- elif self._mode == 'grammar_strict':
- if dfa_state in self._exact_lookup:
- return self._exact_lookup[dfa_state]
- else:
- return self._overapprox_lookup[dfa_state]
- raise ValueError(f"Invalid mode: {self._mode}")
-
- def store_overapprox_lookup(self, dfa_state: DFAState, mask: torch.Tensor):
- assert isinstance(dfa_state, DFAState)
- if dfa_state not in self._overapprox_lookup:
- self._overapprox_lookup[dfa_state] = self._get_default_mask()
- self._overapprox_lookup[dfa_state] |= mask
-
- def complete_case_lookup(self, dfa_state: DFAState) -> Any:
- assert isinstance(dfa_state, DFAState)
- return self._exact_lookup[dfa_state]
-
- def add_exact_lookup(self, dfa_state: DFAState, token):
- assert isinstance(dfa_state, DFAState)
- if dfa_state not in self._exact_lookup:
- self._exact_lookup[dfa_state] = []
- self._exact_lookup[dfa_state].append(token)
-
- def dfa_state_and_next_terminal_to_tokens(self, dfa_state: DFAState, next_terminal) -> torch.Tensor:
- assert isinstance(dfa_state, DFAState)
- return self._dfa_state_and_next_terminal_to_tokens[(dfa_state, next_terminal)]
-
- def dfa_state_and_next_terminal_to_tokens_store(self, dfa_state: DFAState, next_terminal, mask: torch.Tensor):
- assert isinstance(dfa_state, DFAState)
- self._dfa_state_and_next_terminal_to_tokens[(dfa_state, next_terminal)] = mask
-
- def dfa_state_and_next_terminal_to_tokens_add(self, dfa_state: DFAState, next_terminal, token):
- assert isinstance(dfa_state, DFAState)
- self._dfa_state_and_next_terminal_to_tokens[(dfa_state, next_terminal)].append(token)
-
- def _list_to_mask(self, tokens_idx_list) -> torch.Tensor:
- indices = torch.tensor(tokens_idx_list)
- tokens_mask = self._get_default_mask()
- tokens_mask[indices] = 1
- return tokens_mask
-
- def convert_lookups_from_list_to_mask(self):
- """
- Converts the lookups from list of tokens to boolean tensor mask
- """
- for key, val in self._dfa_state_and_next_terminal_to_tokens.items():
- m = self._list_to_mask(val)
- self._dfa_state_and_next_terminal_to_tokens[key] = m
- (dfa_state, _) = key
- self.store_overapprox_lookup(dfa_state, m)
-
- for key, val in self._exact_lookup.items():
- self._exact_lookup[key] = self._list_to_mask(val)
-
- # TODO: move this logic to the lookup table
- if self.indentation:
- for key, val in self._whitespace_tokens_map.items():
- self._whitespace_tokens_map[key] = self._list_to_mask(val)
- for key, val in self._indentation_to_tokens_map.items():
- self._indentation_to_tokens_map[key] = self._list_to_mask(val)
-
- def _get_default_mask(self, special_token_ids=None) -> torch.Tensor:
- if special_token_ids is not None:
- mask = torch.zeros(len(self._vocab), dtype=torch.bool)
- else:
- mask = copy.deepcopy(self._default_mask)
- return mask
-
- def _create_indentation_to_tokens_map(self):
- """
- We create a mapping from indentation to overapproximate tokens. This is useful for the indentation rule.
- """
- for token_idx, token in enumerate(self._vocab):
- full_match, indent = self._get_indent_type(token)
- if full_match:
- self._whitespace_tokens_map[indent].append(token_idx)
- else:
- self._indentation_to_tokens_map[indent].append(token_idx)
-
- def _get_indent_type(self, s: str) -> Tuple[bool, int]:
- m = regex.match(r'[\t ]+', s, partial=True)
- full_match = False
- if m != None:
- start, end = m.start(), m.end()
- if end == len(s):
- full_match = True
- return full_match, s[start: end].count(' ') + 4*s[start: end].count('\t')
- return False, 0
-
- def get_indentation_tokens(self, indent_constraint: IndentationConstraint, get_list=False):
- """
- Returns the tokens mask for the indentation constraint
- """
- out_mask = self._get_default_mask()
- if indent_constraint.greater_than_indent_val is not None:
- for indent in self._indentation_to_tokens_map.keys():
- if indent > indent_constraint.greater_than_indent_val:
- out_mask |= self._indentation_to_tokens_map[indent]
-
- for indent in self._whitespace_tokens_map.keys(): # We are ok with any num of whitespace
- out_mask |= self._whitespace_tokens_map[indent]
-
- elif indent_constraint.accept_indents is not None:
- for indent in indent_constraint.accept_indents:
- if indent in self._indentation_to_tokens_map:
- out_mask |= self._indentation_to_tokens_map[indent]
-
- max_acceptable_indent = max(indent_constraint.accept_indents)
- for indent in self._whitespace_tokens_map.keys(): # We are ok with num whitespace <= largest accepted indent
- if indent <= max_acceptable_indent:
- out_mask |= self._whitespace_tokens_map[indent]
-
- if get_list: # This is useful for testing
- return self._get_tokens_list(out_mask)
- return out_mask
-
- def _get_tokens_list(self, token_mask) -> Iterable[str]:
- return [self._vocab[idx.item()] for idx in torch.where(token_mask == True)[0]]
-
-
-class DFAMaskStore:
- """
- We build an DFA that consists of DFAs for each terminal. We simulate the DFA by consuming the input string for each terminal DFA.
-
- There are 3 possible cases for the remainder string:
-
- 1. COMPLETE: In this case, the last token is complete (and cannot be further extended) and we know the type of next terminal. Thus, we need to compute all tokens such that consuming the token leads to a live state for the terminal DFA or it reaches a final state for the terminal DFA.
-
- 2. INCOMPLETE: In this case, the remainder is incomplete and does not match any terminal regex. Thus, we need to compute all tokens such that consuming the token leads to a live state for the current terminal DFA or again it reaches a final state for the current terminal DFA at some point.
-
- 3. MAYBE_COMPLETE: In this case the remainder matches a type of terminal. It may happen that we add to the same matched part of the remainder. In that case, there are two possibilities. i) the matched terminal type does not change and thus we can use the next terminal set computed by assuming that. ii) the matched terminal type changes and then we do not know the next terminal set. Thus, we need to compute all tokens such that consuming the token leads to a live state for the current terminal DFA or again it reaches a final state for the current terminal DFA at some point.
- """
- def __init__(self,
- terminals: Iterable[TerminalDef],
- vocab: Iterable[str],
- simplifications: dict={},
- special_token_ids: Iterable=[],
- indentation: bool=True,
- mode='grammar_strict', # 'grammar_strict' or 'grammar_mask'
- ignore_terminals: Iterable[str]=[],
- parse_table=None
- ):
- self._vocab = vocab
- self.special_token_ids = special_token_ids
- self._mode = mode
- self._dfas = DFAs(terminals, simplifications)
-
- # Check if whitespace is in ignore terminals
- self._ignore_whitespace = self.set_ignore_whitespace(terminals, ignore_terminals)
- print(f"Ignore whitespace tokens is {self._ignore_whitespace}", flush=True)
-
- # Iterate through each pair of DFA state and next terminals and store the overapproximate tokens
- self._lookup_table = LookupTable(vocab, special_token_ids, indentation=indentation, mode=mode)
- terminal_names = [terminal.name for terminal in terminals]
-
- followings_terminas_map = None
- if parse_table is not None:
- followings_terminas_map = self._compute_following_terminals_map(terminal_names, parse_table)
- self._store_overapproximate_tokens(terminal_names, vocab, followings_terminas_map)
-
- self.indentation = indentation
-
- # NOTE: This should be called at the end of the constructor
- self._lookup_table.convert_lookups_from_list_to_mask()
-
- def set_ignore_whitespace(self, terminals: Iterable[TerminalDef], ignore_terminals: Iterable[str]) -> bool:
- ignore_whitespace = False
- for ig_name in ignore_terminals:
- for terminal in terminals:
- if terminal.name == ig_name:
- if regex.match(terminal.pattern.to_regexp(), ' ') is not None:
- ignore_whitespace = True # convert to boolean tensor mask. This is useful for fast union operations
- return ignore_whitespace
-
- @staticmethod
- def load_dfa_mask_store(
- grammar: Grammar,
- tokenizer,
- use_cache=True,
- logger=None,
- mode='grammar_strict',
- parse_table=None
- ):
- '''
- Loads the dfa for the given language and tokenizer. If the dfa is not cached, it is created and cached.
- '''
- tokenizer_name = type(tokenizer).__name__
- dfa_dir = common.SYNCODE_CACHE + 'mask_stores/' + tokenizer_name + '/'
- grammar_hash = grammar.hash()
-
- # TODO: Hashing using the tokenizer vocab size, this may be problmatic if we have two fine-tuned models with same tokenizer, same vocab size but different vocab
- dfa_path = f'{dfa_dir}{mode}_{grammar_hash}_{tokenizer.vocab_size}.pkl'
-
- if use_cache and os.path.exists(dfa_path):
- try:
- mask_store = pickle.load(open(dfa_path, 'rb'))
- return mask_store
- except: # If we cannot load the file, we will create the dfa from scratch
- pass
-
- print(f"Creating DFA mask store for {tokenizer_name} and {grammar}, may take more than 10 minutes. Caching at {os.path.abspath(dfa_path)}.", flush=True)
- vocab = common.get_vocab_from_tokenizer(tokenizer)
-
- base_parser = create_base_parser(grammar)
-
- simplifications = grammar.simplifications()
- os.makedirs(dfa_dir, exist_ok=True)
-
- start_time = time.time()
- mask_store = DFAMaskStore(
- base_parser.terminals,
- vocab,
- simplifications=simplifications,
- special_token_ids=[tokenizer.eos_token_id],
- mode=mode,
- ignore_terminals=base_parser.ignore_tokens,
- parse_table=parse_table
- )
- print(f"Time taken to create DFA mask store: {time.time() - start_time} seconds", flush=True)
-
- pickle.dump(mask_store, open(dfa_path, 'wb'))
- return mask_store
-
- def _get_default_mask(self) -> torch.Tensor:
- mask = torch.zeros(len(self._vocab), dtype=torch.bool)
- return mask
-
- def _compute_following_terminals_map(self, terminals: Iterable[str], parse_table) -> defaultdict:
- """
- From terminals, filter out terminals that cannot follow the current terminal
- according to the grammar.
-
- If in the parsing table Action[cur_terminal, parser_state] = 'shift, new_parser_state' then next terminals
- are the terminals that are legal in new_parser_state.
- """
- following_terminals_map = defaultdict(set)
- terminals_set = set(terminals)
-
- # We iterate through each cur_terminal:
- for cur_terminal in terminals:
- # We iterate through each parser_state:
- for _, row in parse_table.states.items():
- if cur_terminal in row:
- action = row[cur_terminal]
- # -> If we see a shift action to new_parser_state
- if str(action[0]) == 'Shift':
- new_parser_state = action[1]
- for next_terminal in parse_table.states[new_parser_state]:
- # Lark parse_table stores non-terminals and terminals together
- if next_terminal in terminals_set:
- # -> -> we add the terminals that are legal in new_parser_state
- following_terminals_map[cur_terminal].add(next_terminal)
-
- return following_terminals_map
-
-
- def _store_overapproximate_tokens(self, terminals: Iterable[str], vocab: Iterable[str], followings_terminas_map: dict=None):
- """
- Stores the overapproximate tokens for each dfa state and next terminals
- """
- all_dfa_states = self._dfas.states()
- pbar = tqdm(total=len(all_dfa_states))
-
- for dfa_state in all_dfa_states:
- for token_idx, token in enumerate(vocab):
- is_special_token = token_idx in self.special_token_ids
-
- if is_special_token:
- if self._dfas.is_final(dfa_state):
- self._lookup_table.dfa_state_and_next_terminal_to_tokens_add(
- dfa_state, '$END', token_idx)
- else:
- if followings_terminas_map is not None and dfa_state.terminal in followings_terminas_map:
- following_terminals = followings_terminas_map[dfa_state.terminal]
- else:
- following_terminals = terminals
-
- self._process_regular_tokens(following_terminals, dfa_state, token_idx, token)
-
- pbar.update(1)
-
- def _process_regular_tokens(self, terminals, dfa_state, token_idx, token):
- remainder = token.replace('\t', ' ')
-
- is_valid, remainder = self._dfas.consume_prefix(dfa_state, remainder)
- if is_valid:
- if remainder == '':
- # We reached a live state for the current terminal, thus we add the token in all overapproximate sets of next terminals
- for next_terminal in terminals:
- self._lookup_table.dfa_state_and_next_terminal_to_tokens_add(dfa_state, next_terminal, token_idx)
- else:
- remainder = self.remove_left_whitespace(dfa_state, remainder)
-
- # We reached the final state while consuming the token, thus we conusme the remainder with all next terminals
- for next_terminal in terminals:
- initial_state = self._dfas.initial(next_terminal)
- is_valid, remainder_new = self._dfas.consume_prefix(initial_state, remainder)
- if self._mode == 'grammar_mask':
- if is_valid: # In the non-strict mode we overapproximate
- # We reached a live state for the next terminal, thus we add the
- # token in the overapproximate sets of next terminals
- self._lookup_table.dfa_state_and_next_terminal_to_tokens_add(dfa_state, next_terminal, token_idx)
- elif self._mode == 'grammar_strict':
- if is_valid and remainder_new == '':
- # We reached a live state for the next terminal and the remainder
- # is empty, thus we add the token in the exact set of next terminals
- self._lookup_table.dfa_state_and_next_terminal_to_tokens_add(dfa_state, next_terminal, token_idx)
- else:
- raise ValueError(f"Invalid mode: {self._mode}")
-
- # For COMPLETE case:
- remainder = token
- remainder = self.remove_left_whitespace(dfa_state, remainder)
-
- is_valid, remainder = self._dfas.consume_prefix(dfa_state, remainder)
- if is_valid and remainder == '':
- self._lookup_table.add_exact_lookup(dfa_state, token_idx)
-
- def remove_left_whitespace(self, dfa_state, remainder):
- """
- Ignore left space at the start of the terminal. This only helps the efficiency
- e.g. without this say if the model wants to generate ' def' then syncode will force it to generate ' ' and 'def' seperately
- """
- if self._dfas.initial(dfa_state.terminal) == dfa_state and remainder.startswith(' ') and self._ignore_whitespace:
-
- remainder = remainder[1:]
- return remainder
-
-
- def _lookup_next_tokens_for_dfa_state(self, dfa_state: DFAState, next_terminal) -> torch.Tensor:
- tokens = self._lookup_table.dfa_state_and_next_terminal_to_tokens(dfa_state, next_terminal)
- if tokens == []:
- # TODO: This is a hack. Do something better
- return self._get_default_mask()
- return tokens
-
- def _lookup_next_tokens(self, dfa_states: Iterable[DFAState], r: ParseResult) -> torch.Tensor:
- overapprox_token_ids = self._get_default_mask()
-
- # Case when the final string may be incomplete
- for dfa_state in dfa_states:
- for accept_sequence in r.accept_sequences:
- if dfa_state.terminal != accept_sequence[0]:
- continue
-
- if r.remainder_state == RemainderState.COMPLETE:
- assert len(accept_sequence) == 1 # Since we only store length 1 accept sequences in this case
- overapprox_token_ids |= self._lookup_table.complete_case_lookup(dfa_state)
-
- if r.remainder_state == RemainderState.INCOMPLETE:
- overapprox_token_ids |= self._lookup_table.incomplete_case_lookup(dfa_state)
-
- if r.remainder_state == RemainderState.MAYBE_COMPLETE:
- if len(accept_sequence) == 1:
- overapprox_token_ids |= self._lookup_table.complete_case_lookup(dfa_state)
- elif len(accept_sequence) == 2:
- overapprox_token_ids |= self._lookup_next_tokens_for_dfa_state(dfa_state, accept_sequence[1])
- elif len(accept_sequence) == 3:
- # If the DFA state is a final state we can jump to the start of next terminal
- if self._dfas.is_final(dfa_state):
- ignore_init_state = self._dfas.initial(accept_sequence[1])
- overapprox_token_ids |= self._lookup_next_tokens_for_dfa_state(ignore_init_state, accept_sequence[2])
- else:
- raise ValueError(f"Invalid accept sequence: {accept_sequence}")
- return overapprox_token_ids
-
- def get_dfa_states(self, r: ParseResult) -> Iterable[DFAState]:
- """
- Returns the DFA state for the current partial code
- """
- cur_incomplete_string = r.remainder
- if cur_incomplete_string is None:
- return []
-
- cur_dfa_states = self._dfas.compute_dfa_states(cur_incomplete_string)
- return cur_dfa_states
-
- def get_accept_mask(
- self,
- r: ParseResult,
- get_list=False,
- logger: common.Logger=common.EmptyLogger()
- ) -> torch.Tensor:
- """
- Returns the mask for the acceptable tokens for the current partial code
-
- Args:
- r (ParseResult): The parse result
- get_list (bool, optional): If True, returns the list of tokens instead of the mask. Defaults to False.
- logger (common.Logger, optional): The logger. Defaults to common.EmptyLogger().
- """
- cur_incomplete_string = r.remainder
- if cur_incomplete_string is None:
- return torch.ones(len(self._vocab), dtype=torch.bool)
-
- cur_dfa_states = self._dfas.compute_dfa_states(cur_incomplete_string)
- accept_token_mask = self._lookup_next_tokens(cur_dfa_states, r)
-
- if self.indentation and r.next_ac_indents is not None:
- indent_ac_token = self._lookup_table.get_indentation_tokens(r.next_ac_indents)
- accept_token_mask &= indent_ac_token
-
- if get_list: # This is useful for testing
- return self._get_tokens_list(accept_token_mask)
- return accept_token_mask
-
- def is_valid_prefix(self, r: ParseResult) -> bool:
- """
- Check if r.remainder is a valid prefix for accept sequences in r
- """
- cur_incomplete_string = r.remainder
-
- cur_dfa_states = self._dfas.compute_dfa_states(cur_incomplete_string)
- for accept_sequence in r.accept_sequences:
- for dfa_state in cur_dfa_states:
- if dfa_state.terminal == accept_sequence[0]:
- return True
- return False
-
- def _list_to_mask(self, tokens_idx_list) -> torch.Tensor:
- indices = torch.tensor(tokens_idx_list)
- tokens_mask = self._get_default_mask()
- tokens_mask[indices] = 1
- return tokens_mask
-
- def _get_tokens_list(self, token_mask) -> Iterable[str]:
- return [self._vocab[idx.item()] for idx in torch.where(token_mask == True)[0]]
diff --git a/syncode/grammar_decoder.py b/syncode/grammar_decoder.py
index 3a55b011..1be44090 100644
--- a/syncode/grammar_decoder.py
+++ b/syncode/grammar_decoder.py
@@ -1,11 +1,15 @@
import torch
import syncode.common as common
from transformers import LogitsProcessor, PreTrainedTokenizer
+from syncode.mask_store.byte_tokenizer import ByteTokenizer
from syncode.parse_result import AcceptSequence, RemainderState
from syncode.parsers.incremental_parser import IncrementalParser, ParseResult
from syncode.parsers import create_parser, create_base_parser
-from syncode.dfa_mask_store import DFAMaskStore
+from syncode.mask_store.mask_store import MaskStore
from syncode.parsers.grammars import Grammar
+import logging
+logger = logging.getLogger(__name__)
+
# Set to True for debugging
DEBUG = True
@@ -17,15 +21,16 @@ class SyncodeLogitsProcessor(LogitsProcessor):
Args:
grammar (str): The grammar to use for parsing e.g. "python".
tokenizer (PreTrainedTokenizer): The tokenizer to use for decoding.
- logger (common.Logger): The logger to use for logging.
use_cache (bool, optional): Whether to use the cache. Defaults to True.
parse_output_only (bool, optional): Whether to parse the prompt. Defaults to False.
+ num_samples (int, optional): The number of sequences to generate. Defaults to 1.
dev_mode (bool, optional): Whether to run in development mode. Defaults to False.
+ parser (str, optional): The parser to use. Defaults to 'lalr'.
+ mode (str, optional): The mode to use. Defaults to 'grammar_mask'.
"""
def __init__(self,
grammar: Grammar,
tokenizer: PreTrainedTokenizer,
- logger: common.Logger=common.EmptyLogger(),
use_cache=True,
parse_output_only=True,
num_samples=1,
@@ -34,8 +39,9 @@ def __init__(self,
mode='grammar_mask'):
self.tokenizer = tokenizer
+ self.byte_tokenizer = ByteTokenizer(tokenizer)
+
self.grammar = grammar
- self.logger = logger
self.dev_mode = dev_mode
self.batch_size = num_samples
self.parse_failed = False
@@ -52,23 +58,17 @@ def __init__(self,
self._ignore_whitespace = self._get_ignore_whitespace(self.grammar)
# Create parser
- self.inc_parser: IncrementalParser = create_parser(self.grammar, logger=self.logger, parser=parser, ignore_whitespace=self._ignore_whitespace)
+ self.inc_parser: IncrementalParser = create_parser(self.grammar, parser=parser, ignore_whitespace=self._ignore_whitespace)
# Load dfa mask store
- self.dfa_mask_store = DFAMaskStore.load_dfa_mask_store(
+ self.dfa_mask_store = MaskStore.init_mask_store(
grammar=self.grammar,
tokenizer=self.tokenizer,
use_cache=use_cache,
- logger=self.logger,
mode=mode,
- parse_table=self.inc_parser.base_parser.parser.parser._parse_table,
)
-
+
- def _log_current_status(self, partial_code, r: ParseResult):
- self.logger.log_code('Partial code', partial_code)
- self.logger.log(repr(r))
-
def _get_ignore_whitespace(self, grammar):
"""
Check if the grammar allows whitespace tokens to be ignored.
@@ -109,60 +109,53 @@ def is_valid(self, input_ids: torch.LongTensor, next_token: torch.LongTensor) ->
bool: True if the next token is valid, False otherwise.
"""
assert len(input_ids) == 1, "Only one input is supported for now."
- input_ids = torch.cat((input_ids, next_token.unsqueeze(0)), dim=-1)
- partial_code = self._get_partial_codes(input_ids)[0]
+ self._set_start_from(input_ids)
- try:
- r = self.inc_parser.get_acceptable_next_terminals(partial_code)
- except Exception as e:
- self.logger.log(f"Exception while parsing:\n {e}")
- return False
+ input_ids = torch.cat((input_ids, next_token.unsqueeze(0)), dim=-1)
+ partial_code, remainder_bytes = self._get_partial_codes(input_ids)[0]
+
+ res, skip = self._parse_partial_code(
+ idx=0,
+ partial_code=partial_code,
+ remainder_bytes=remainder_bytes,
+ accepted_generation=False
+ )
+
+ if skip: return False
if input_ids[0, -1] == self.tokenizer.eos_token_id:
# Do not allow the model to generate EOS token until $END in the grammar is reached
- return AcceptSequence(['$END']) in r.accept_sequences
+ return AcceptSequence(['$END']) in res.accept_sequences
- if r.remainder_state == RemainderState.COMPLETE or r.remainder_state == RemainderState.MAYBE_COMPLETE:
+ if res.remainder_state == RemainderState.COMPLETE or res.remainder_state == RemainderState.MAYBE_COMPLETE:
is_valid = True
# Check if the remainder is a valid prefix for the last terminal
- is_valid = self.dfa_mask_store.is_valid_prefix(r)
+ is_valid = self.dfa_mask_store.is_valid_prefix(res)
if is_valid:
- self.update_valid_state(partial_code, 0, r)
+ self.update_valid_state(partial_code, 0, res)
return is_valid
-
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
- # start_from is used for choosing where the parsing should start
+ def _set_start_from(self, input_ids):
if self.start_from is None:
if self.parse_output_only:
self.start_from = input_ids.size(1)
else:
self.start_from = 0
+
+
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+ self._set_start_from(input_ids) # start_from is used for choosing where the parsing should start
partial_codes = self._get_partial_codes(input_ids)
- for idx, partial_code in enumerate(partial_codes):
+ for idx, (partial_code, remainder_bytes) in enumerate(partial_codes):
## Parsing
- try: # returns the accept sequences that are currently accepted.
- r = self.inc_parser.get_acceptable_next_terminals(partial_code)
- self.update_valid_state(partial_code, idx, r)
- except Exception as e:
- if self.dev_mode == True:
- raise e
- elif self.parse_failed == False:
- self.parse_failed = True
- print("-"*50)
- print(f"Parsing failed! Falling back to unconstrained decoding.\nException: {e}\nPartial code: {partial_code}\nParsed lexical tokens: {self.inc_parser.parsed_lexer_tokens}")
- print("-"*50)
- continue # Skip altering the scores for this batch
-
- accept_mask = self.dfa_mask_store.get_accept_mask(r, logger=self.logger)
+ res, skip = self._parse_partial_code(idx, partial_code, remainder_bytes, accepted_generation=True)
+ if skip: continue
- if DEBUG:
- self._log_current_status(partial_code, r)
- greedy_token = self.tokenizer.decode(scores[idx].argmax(dim=-1))
+ accept_mask = self.dfa_mask_store.get_accept_mask(res)
if torch.sum(accept_mask) != 0: # If there are acceptable tokens for the current partial code
if len(scores[idx]) > len(accept_mask):
@@ -172,17 +165,53 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
accept_mask = accept_mask[: len(scores[idx])]
scores[idx] = scores[idx].masked_fill(~accept_mask.to(scores.device), -float("inf"))
else: # Otherwise, report the error and mask no tokens
- self.logger.log('No acceptable tokens for the current partial code!')
- self._log_current_status(partial_code, r)
-
- # For debugging - remove later
- if DEBUG: self._debug_greedy(scores, idx, partial_code, r, greedy_token)
+ logger.debug('No acceptable tokens for the current partial code!')
+ logger.debug(repr(res))
return scores
- def _get_partial_codes(self, input_ids: torch.LongTensor):
- partial_codes = self.tokenizer.batch_decode(input_ids[:, self.start_from:], skip_special_tokens=True)
- return partial_codes
+ def _parse_partial_code(self, idx: int, partial_code: str, remainder_bytes: bytes, accepted_generation=True) -> tuple[ParseResult, bool]:
+ """
+ Parse the partial code and return the result.
+ """
+ skip = False
+ res = None
+
+ try:
+ res = self.inc_parser.get_acceptable_next_terminals(partial_code)
+
+ if len(remainder_bytes) > 0:
+ res.remainder_state = RemainderState.INCOMPLETE
+ res.remainder = res.remainder.encode('utf-8') + remainder_bytes
+ else:
+ res.remainder = res.remainder.encode('utf-8')
+
+ self.update_valid_state(partial_code, idx, res)
+ except Exception as e:
+ if self.dev_mode == True:
+ raise e
+ elif self.parse_failed == False and accepted_generation:
+ self.parse_failed = True
+ print("-"*50)
+ print(f"Parsing failed! Falling back to unconstrained decoding.\nException: {e}\nPartial code: {partial_code}\nParsed lexical tokens: {self.inc_parser.parsed_lexer_tokens}")
+ print("-"*50)
+ skip = True
+ return res, skip
+
+
+ def _get_partial_codes(self, input_ids: torch.LongTensor) -> list[(str, bytes)]:
+ """
+ Get the partial codes for the input_ids and return the remainder bytes if the partial code is not a valid UTF-8 string.
+ """
+ output = []
+ for idx in range(len(input_ids)):
+ if self.parse_output_only:
+ partial_code, remainder_bytes = self._bytes_to_string(self.byte_tokenizer.decode(input_ids[idx, self.start_from:].tolist(), skip_special_tokens=True))
+ else:
+ partial_code, remainder_bytes = self._bytes_to_string(self.byte_tokenizer.decode(input_ids[idx].tolist(), skip_special_tokens=True))
+ output.append((partial_code, remainder_bytes))
+ return output
+
def update_valid_state(self, partial_code: str, idx: int, r: ParseResult):
"""
@@ -200,25 +229,58 @@ def update_valid_state(self, partial_code: str, idx: int, r: ParseResult):
if accept_seq[0] == '$END' or accept_seq[0] == 'EOF':
self.last_valid_state[idx] = len(partial_code) - len(r.remainder)
- def _debug_greedy(self, scores, idx, partial_code, r, greedy_token):
- greedy_grammar_token = self.tokenizer.decode(scores[idx].argmax(dim=-1))
- if greedy_token != greedy_grammar_token:
- self._log_greedy_difference(greedy_grammar_token, partial_code, r, greedy_token)
-
- def _log_greedy_difference(self, greedy_grammar_token, partial_code, r, greedy_token):
- self.logger.log_check(f"Greedy token and greedy grammar-based token do not match!")
- self.logger.log(f"Greedy token: {repr(greedy_token)}")
- self.logger.log(f"Greedy grammar-based token: {repr(greedy_grammar_token)}")
- self._log_current_status(partial_code, r)
- def print_debug(self):
- print('-'*50)
- print('Parsed terminals:')
-
- name_to_pattern = {}
- for term in self.inc_parser.base_parser.terminals:
- name_to_pattern[term.name] = term.pattern
-
- for token in self.inc_parser.parsed_lexer_tokens:
- print(f"(type: {name_to_pattern[token.type]} | value: '{token.value}')")
- print('-'*50)
+ @staticmethod
+ def _bytes_to_string(byte_sequence: bytes) -> tuple[str, bytes]:
+ """
+ Convert a byte sequence into a UTF-8 string plus a remainder that is not valid UTF-8.
+
+ This function finds the longest valid UTF-8 prefix of the input byte sequence,
+ converts it to a Python string, and returns any remaining bytes that couldn't be decoded.
+
+ Args:
+ byte_sequence: The input byte sequence
+
+ Returns:
+ A tuple (string, remainder) where:
+ - string is the longest valid UTF-8 prefix of the input as a Python string
+ - remainder is the rest of the bytes that could not be decoded as UTF-8
+
+ Examples:
+ >>> bytes_to_string(b'Hello, world!')
+ ('Hello, world!', b'')
+ >>> bytes_to_string(b'Hello, \xe2\x82\xac!') # Euro symbol (€) followed by !
+ ('Hello, €!', b'')
+ >>> bytes_to_string(b'Hello, \xe2\x82!') # Incomplete Euro symbol
+ ('Hello, ', b'\xe2\x82!')
+ >>> bytes_to_string(b'\xff\xfe') # Invalid UTF-8
+ ('', b'\xff\xfe')
+ """
+ if not isinstance(byte_sequence, bytes):
+ raise TypeError("Input must be a bytes object")
+
+ if not byte_sequence:
+ return "", b""
+
+ # Try to decode the entire sequence first - common case optimization
+ try:
+ return byte_sequence.decode('utf-8'), b""
+ except UnicodeDecodeError:
+ pass
+
+ # Find the longest valid prefix by incrementally checking each additional byte
+ valid_end = 0
+
+ while valid_end < len(byte_sequence):
+ try:
+ # Try to decode up to the current position
+ byte_sequence[:valid_end+1].decode('utf-8')
+ valid_end += 1
+ except UnicodeDecodeError:
+ break
+
+ # Return the valid prefix and remainder
+ if valid_end > 0:
+ return byte_sequence[:valid_end].decode('utf-8'), byte_sequence[valid_end:]
+ else:
+ return "", byte_sequence
diff --git a/syncode/infer.py b/syncode/infer.py
index e970e759..f72da8cd 100644
--- a/syncode/infer.py
+++ b/syncode/infer.py
@@ -17,7 +17,7 @@
def compile_and_run(model, mode="grammar_strict", quantize=True, device="cuda", grammar=None, dataset="input", num_few_shot=0, dev_mode=False, log_level=1, new_mask_store=False, parser="lalr", num_tasks=None, task_id=None, seed=None, opp=True, debug=False, **kwargs):
- syncode = Syncode(model, mode=mode, quantize=quantize, device=device, grammar=grammar, dev_mode=dev_mode, log_level=log_level, new_mask_store=new_mask_store, parser=parser, seed=seed, opp=opp, **kwargs)
+ syncode = Syncode(model, mode=mode, quantize=quantize, device=device, grammar=grammar, dev_mode=dev_mode, new_mask_store=new_mask_store, parser=parser, seed=seed, opp=opp, **kwargs)
if dataset == "input":
syncode.infer(debug=debug)
@@ -56,8 +56,6 @@ class Syncode:
new_mask_store (bool, optional): Use new DFA mask store. Defaults to False.
dev_mode (bool, optional): Development mode. Defaults to False.
-
- log_level (int, optional): Log level. Defaults to 2. 0 for no logs, 1 for minimal logs, 2 for all logs including time.
opp (bool, optional): Whether to use opportunistic generation. Defaults to True.
"""
@@ -70,7 +68,6 @@ def __init__(
grammar: Optional[str] = None,
parse_output_only: bool = True,
dev_mode: bool = False,
- log_level: int = 1,
new_mask_store: bool = False,
parser: Literal["lr", "lalr"] = "lalr",
seed: Optional[int] = None,
@@ -91,7 +88,6 @@ def __init__(
self.num_samples = kwargs.get('num_return_sequences', 1)
self.new_mask_store = new_mask_store
self.parser = parser
- self.log_level = log_level
# Set seed
if seed is not None:
diff --git a/syncode/language_model.py b/syncode/language_model.py
index 1d55fe16..16fe99fc 100644
--- a/syncode/language_model.py
+++ b/syncode/language_model.py
@@ -58,7 +58,6 @@ def __init__(
self.mode = mode
self.grammar = grammar
- self.vocab = common.get_vocab_from_tokenizer(self.tokenizer)
self.gen_args = kwargs
self.opp = opp
diff --git a/syncode/mask_store/__init__.py b/syncode/mask_store/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/syncode/mask_store/byte_fsm.py b/syncode/mask_store/byte_fsm.py
new file mode 100644
index 00000000..c675470e
--- /dev/null
+++ b/syncode/mask_store/byte_fsm.py
@@ -0,0 +1,356 @@
+from functools import lru_cache
+import interegular
+from typing import Tuple, Optional, Any, Union
+
+class ByteFSM:
+ """
+ A finite state machine that operates on bytes rather than characters.
+ """
+
+ def __init__(self, pattern: str):
+ """
+ Initialize a ByteFSM from a regular expression pattern.
+
+ Args:
+ pattern: A regular expression pattern string
+ """
+ self.pattern = pattern
+
+ # Parse the regex pattern and create the character FSM
+ regex_fsm = interegular.parse_pattern(pattern).to_fsm()
+
+ # Store FSM properties
+ self.initial = regex_fsm.initial
+ self.finals = set(regex_fsm.finals)
+
+ # Get the special "anything_else" marker from interegular
+ self.anything_else = interegular.fsm.anything_else
+
+ # Create a byte-level alphabet for our FSM
+ self.alphabet = {}
+
+ # Create transitions dictionary for the byte FSM
+ self.transitions = {}
+ self.byte_to_category = {}
+
+ # Cache for live states to avoid recomputing
+ self._live_states_cache = set()
+
+ # Build the byte-level transitions from the character-level FSM
+ self._build_byte_fsm(regex_fsm)
+
+ def _build_byte_fsm(self, regex_fsm):
+ """
+ Build a byte-level FSM from the character-level FSM.
+
+ This method handles the conversion from character transitions to byte transitions,
+ properly handling the alphabet categories.
+ """
+ # Create a new transitions dictionary
+ self.transitions = {}
+
+ # Create a mapping from byte values to category numbers
+ self.byte_to_category = {}
+
+ # Extract the mapping from the regex FSM's alphabet and build our byte-level alphabet
+ for char, category in regex_fsm.alphabet.items():
+ if char == self.anything_else:
+ # Keep track of the anything_else category, but don't add to byte mappings
+ self.alphabet[self.anything_else] = category
+ continue
+
+ if isinstance(char, str):
+ # Handle string characters
+ char_bytes = char.encode('utf-8')
+ if len(char_bytes) == 1:
+ # Single byte character - add to our alphabet and mapping
+ byte_val = char_bytes[0]
+ self.alphabet[byte_val] = category
+ self.byte_to_category[byte_val] = category
+ else:
+ # Multi-byte characters are handled separately
+ # For now, add the full character to our alphabet
+ self.alphabet[char] = category
+ elif isinstance(char, int):
+ # Handle integer (byte) characters
+ self.alphabet[char] = category
+ self.byte_to_category[char] = category
+
+ # Make sure all regex FSM states are in our transitions
+ for state in regex_fsm.map:
+ if state not in self.transitions:
+ self.transitions[state] = {}
+
+ # Copy the transitions from the regex FSM to our byte FSM
+ for state, category_transitions in regex_fsm.map.items():
+ for category, target in category_transitions.items():
+ self.transitions[state][category] = target
+
+ # Handle multi-byte Unicode characters separately
+ # This is needed because a multi-byte character might need special handling
+ for char, category in regex_fsm.alphabet.items():
+ if char == self.anything_else or not isinstance(char, str):
+ continue
+
+ char_bytes = char.encode('utf-8')
+ if len(char_bytes) <= 1:
+ continue
+
+ # For multi-byte characters, we need to add special transitions
+ # Make a copy of states to avoid modifying the dictionary during iteration
+ states_to_process = list(self.transitions.keys())
+ for state in states_to_process:
+ if category in self.transitions[state]:
+ target = self.transitions[state][category]
+
+ # Create intermediate states for the multi-byte character
+ current = state
+ for i, byte in enumerate(char_bytes):
+ if byte not in self.alphabet:
+ # Add the byte to the alphabet with a new category
+ byte_category = f"{byte}_{i}"
+ self.byte_to_category[byte] = byte_category
+
+ if i < len(char_bytes) - 1:
+ next_state = f"{current}_{byte}_{i}_{char}"
+ if next_state not in self.transitions:
+ self.transitions[next_state] = {}
+ self.transitions[current][byte_category] = next_state
+ current = next_state
+ else:
+ self.transitions[current][byte_category] = target
+
+ @lru_cache(maxsize=100000)
+ def _get_category(self, byte_val: int) -> Any:
+ """
+ Get the category for a byte value.
+
+ Args:
+ byte_val: The byte value
+
+ Returns:
+ The category for the byte value, or the 'anything_else' category if not found
+ """
+ # Check if we have a specific mapping for this byte
+ if byte_val in self.byte_to_category:
+ return self.byte_to_category[byte_val]
+
+ # If not, return the 'anything_else' category if it exists
+ if self.anything_else in self.alphabet:
+ return self.alphabet[self.anything_else]
+
+ # If there's no 'anything_else', return None (no transition)
+ return None
+
+ @property
+ def num_states(self) -> int:
+ """Returns the number of states in the FSM."""
+ return len(self.transitions)
+
+ @property
+ def alphabet_size(self) -> int:
+ """Returns the size of the alphabet (unique categories) used in transitions."""
+ categories = set()
+ for state_transitions in self.transitions.values():
+ for category in state_transitions:
+ if category is not None: # Skip epsilon transitions
+ categories.add(category)
+ return len(categories)
+
+ @property
+ def num_transitions(self) -> int:
+ """Returns the total number of transitions in the FSM."""
+ return sum(len(transitions) for transitions in self.transitions.values())
+
+ def get_next_state(self, current: Any, byte_val: int) -> Optional[Any]:
+ """
+ Get the next state based on the current state and the input byte value.
+
+ Args:
+ current: The current state
+ byte_val: The input byte value
+
+ Returns:
+ The next state or None if no transition is defined
+ """
+ if current is None or current not in self.transitions:
+ return None
+
+ # If not, get the category for this byte and check if there's a transition on that category
+ category = self._get_category(byte_val)
+ if category is not None and category in self.transitions[current]:
+ return self.transitions[current][category]
+
+ return None
+
+ def accepts(self, data: Union[str, bytes]) -> bool:
+ """
+ Check if the FSM accepts the given input data.
+
+ Args:
+ data: The input string or bytes
+
+ Returns:
+ True if the FSM accepts the input, False otherwise
+ """
+ # Convert string to bytes if needed
+ if isinstance(data, str):
+ data = data.encode('utf-8')
+
+ # Start from the initial state
+ current = self.initial
+
+ # Process each byte
+ for byte in data:
+ # print(current, byte)
+ current = self.get_next_state(current, byte)
+ # print(current)
+ if current is None:
+ return False
+
+ # Check if the final state is an accepting state
+ return current in self.finals
+
+ def try_consume_all(self, input_bytes: bytes) -> Optional[Any]:
+ """
+ Try to consume all input bytes and return the final state reached.
+
+ Args:
+ input_bytes: The input bytes to consume
+
+ Returns:
+ The final state reached after consuming all bytes if successful and in a final state,
+ otherwise None if any transition is invalid or if not in a final state.
+ """
+ if not input_bytes:
+ # For empty input, check if the initial state is final
+ return self.initial if self.initial in self.finals else None
+
+ # Start from the initial state
+ current = self.initial
+
+ # Process each byte
+ for byte in input_bytes:
+ current = self.get_next_state(current, byte)
+ if current is None:
+ return None
+
+ # Return the final state only if it's an accepting state
+ return current if current in self.finals else None
+
+ def islive(self, state: Any) -> bool:
+ """
+ Check if a state is "live", meaning it can potentially reach a final state.
+
+ Args:
+ state: The state to check
+
+ Returns:
+ True if the state is live, False otherwise
+ """
+ # Check cache first
+ if state in self._live_states_cache:
+ return True
+
+ # Final states are always live
+ if state in self.finals:
+ self._live_states_cache.add(state)
+ return True
+
+ # Simple BFS to see if we can reach a final state from this state
+ visited = set()
+ queue = [state]
+
+ while queue:
+ current = queue.pop(0)
+
+ if current in self.finals or current in self._live_states_cache:
+ # Update cache for all states in the path
+ self._live_states_cache.add(state)
+ return True
+
+ if current in visited:
+ continue
+
+ visited.add(current)
+
+ # Add all reachable states to the queue
+ if current in self.transitions:
+ for symbol, next_state in self.transitions[current].items():
+ if next_state not in visited:
+ queue.append(next_state)
+
+ return False
+
+ def consume_prefix(self, data: Union[str, bytes], current_state: Optional[Any] = None) -> Tuple[bool, Optional[bytes]]:
+ """
+ Consume longest prefix of data starting from current_state that is accepted by the FSM and return the remainder.
+
+ Args:
+ data: The input string or bytes
+ current_state: The state to start from (defaults to initial state if None)
+
+ Returns:
+ A tuple (success, remainder) where:
+ - success is True if an accept state was reached or if we ended in a live state
+ - remainder is the remaining bytes after the consumed prefix, or None if no valid prefix was found
+ """
+ # Convert to bytes if not already - only happens once per call
+ if isinstance(data, str):
+ data = data.encode('utf-8')
+
+ # Use the provided state or the initial state
+ cur_state = self.initial if current_state is None else current_state
+
+ # Pre-check if we're already in a final state
+ longest_accept_index = 0 if cur_state in self.finals else -1
+
+ # Cache membership checking methods
+ is_final = self.finals.__contains__ # Direct method access is faster than 'in'
+
+ # Pre-compute length to avoid repeated len() calls
+ data_len = len(data)
+ if data_len == 0:
+ # Early return for empty data - fixed conditional return
+ if cur_state is not None and self.islive(cur_state):
+ return True, b""
+ else:
+ return False, None
+
+ # Main byte processing loop
+ i = 0
+ while i < data_len:
+ byte = data[i]
+
+ # Get transitions for current state only once
+ state_transitions = self.transitions.get(cur_state, {})
+ if not state_transitions: # No transitions - dead state
+ break
+
+ # Direct byte transition - most common case first
+ if byte in state_transitions:
+ cur_state = state_transitions[byte]
+ else:
+ # Only get category if needed - reduces _get_category calls
+ category = self._get_category(byte)
+ if category is not None and category in state_transitions:
+ cur_state = state_transitions[category]
+ else:
+ # No valid transition - we've reached a "dead" state
+ cur_state = None
+ break
+
+ # Check if we're in a final state - using cached method
+ if is_final(cur_state):
+ longest_accept_index = i + 1
+
+ i += 1
+
+ if longest_accept_index != -1: # Reached accept state at some point
+ return True, data[longest_accept_index:]
+ elif cur_state is not None and self.islive(cur_state):
+ # Reached a live state but never an accept state
+ return True, b""
+ else:
+ # Never reached a final state or ended in a dead state
+ return False, None
diff --git a/syncode/mask_store/byte_tokenizer.py b/syncode/mask_store/byte_tokenizer.py
new file mode 100644
index 00000000..a443e917
--- /dev/null
+++ b/syncode/mask_store/byte_tokenizer.py
@@ -0,0 +1,361 @@
+"""bytetokenizerr.py: A flexible wrapper to handle different HuggingFace tokenizer types.
+
+This module provides a ByteTokenizer class that can adapt to different types of tokenizers:
+- RAW: Tokens in original form without processing (e.g., tiktoken)
+- BYTE_FALLBACK: Tokens encoded with byte-fallback conversion (e.g., Llama-2)
+- BYTE_LEVEL: Tokens encoded with byte-to-unicode conversion (e.g., GPT-2, Llama-3)
+
+The ByteTokenizer allows working with these tokenizers in a consistent byte-level manner.
+"""
+from transformers import AutoTokenizer
+from functools import reduce, cache
+from enum import Enum
+import re
+import dataclasses
+from typing import List, Optional, Dict, Tuple, Union, Any
+import time
+
+
+class VocabType(Enum):
+ """The type of vocabulary used by the tokenizer."""
+ RAW = 0
+ BYTE_FALLBACK = 1
+ BYTE_LEVEL = 2
+
+
+def bytes_to_unicode():
+ """
+ Returns a mapping between utf-8 bytes and unicode strings.
+ Used for byte-level BPE tokenization (GPT-2 style).
+
+ This makes the tokens representable in a standard text editor by mapping
+ bytes to printable unicode characters.
+ """
+ bs = (
+ list(range(ord("!"), ord("~") + 1))
+ + list(range(ord("¡"), ord("¬") + 1))
+ + list(range(ord("®"), ord("ÿ") + 1))
+ )
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+@cache
+def enbyte_bytelevel(token: str) -> bytes:
+ """Turn a byte-level BPE token into the corresponding bytes.
+
+ Example:
+ --------
+ >>> enbyte_bytelevel('âĪ')
+ b'\xe2\x88'
+ """
+ dict_bytes = {v: k for k, v in bytes_to_unicode().items()}
+ # Replace non-ASCII special tokens with ASCII for compatibility
+ token = token.replace('|', '|').replace('▁', '_')
+ try:
+ return bytes([dict_bytes[c] for c in token])
+ except KeyError as e:
+ # For characters not in the mapping, attempt a direct encoding
+ # This handles special tokens not covered by the mapping
+ return token.encode('utf-8')
+
+
+@cache
+def enbyte_bytefallback(token: str) -> bytes:
+ """Turn a byte-fallback token into the corresponding bytes.
+
+ Example:
+ --------
+ >>> enbyte_bytefallback('<0x1B>')
+ b'\x1b'
+ >>> enbyte_bytefallback('▁apple')
+ b' apple'
+ """
+ # Handle byte fallback format like <0x1B>
+ if re.match(r'<0x[0-9A-F]{2}>', token):
+ byte_value = int(token[3:5], 16)
+ return bytes([byte_value])
+
+ # Handle space prefix format - Gemma/Llama style with '▁'
+ if token.startswith('▁'):
+ # If it's just a single or multiple '▁', it's just spaces (common in indentation)
+ if set(token) == {'▁'}:
+ return b' ' * len(token)
+ # Otherwise it's a space followed by content
+ return b' ' + token[1:].encode('utf-8')
+
+ return token.encode('utf-8')
+
+
+@cache
+def enbyte_raw(token: bytes) -> bytes:
+ """Turn a raw token directly into bytes.
+
+ Example:
+ --------
+ >>> enbyte_raw(b'hello')
+ b'hello'
+ """
+ return token
+
+
+def debyte_bytelevel(array: bytes) -> list[str]:
+ """Turn bytes into a list of corresponding code points for byte-level BPE.
+
+ Example:
+ --------
+ >>> debyte_bytelevel(b'\xe2\x88')
+ ['â', 'Ī']
+ """
+ byte_dict = bytes_to_unicode()
+ return [byte_dict[b] for b in array]
+
+
+def detect_vocab_type(tokenizer):
+ """
+ Detect the vocabulary type of a tokenizer.
+
+ Returns:
+ --------
+ VocabType: The detected vocabulary type
+ """
+ vocab = tokenizer.get_vocab()
+
+ # Check for byte fallback pattern (e.g., <0x0A> tokens)
+ if any(token.startswith('<0x') and token.endswith('>') for token in vocab):
+ return VocabType.BYTE_FALLBACK
+
+ # Check for tiktoken-based tokenizers
+ if hasattr(tokenizer, 'tokenizer') and 'tiktoken' in str(type(tokenizer.tokenizer)):
+ return VocabType.RAW
+
+ # Check filename pattern for tiktoken tokenizers
+ if (hasattr(tokenizer, 'vocab_files_names') and
+ 'vocab_file' in tokenizer.vocab_files_names and
+ 'tiktoken' in tokenizer.vocab_files_names['vocab_file']):
+ return VocabType.RAW
+
+ # Look for the "Ġ" character which is common in byte-level BPE tokenizers
+ if any(('Ġ' in token or '▁' in token) for token in vocab):
+ return VocabType.BYTE_LEVEL
+
+ # Default to RAW type if no specific patterns are detected
+ return VocabType.RAW
+
+
+class ByteTokenizer:
+ """A class to convert tokenizers of different types to work at the byte level."""
+
+ def __init__(self, tokenizer, vocab_type=None):
+ self.tokenizer = tokenizer
+
+ # Detect vocab type if not provided
+ if vocab_type is None:
+ self.vocab_type = detect_vocab_type(tokenizer)
+ else:
+ self.vocab_type = vocab_type
+
+ # Select appropriate encoding function based on vocab type
+ if self.vocab_type == VocabType.BYTE_LEVEL:
+ self.enbyte_fn = enbyte_bytelevel
+ elif self.vocab_type == VocabType.BYTE_FALLBACK:
+ self.enbyte_fn = enbyte_bytefallback
+ else: # RAW
+ self.enbyte_fn = enbyte_raw
+
+ # Build vocabulary mappings
+ self.vocab = tokenizer.get_vocab()
+ self.byte_vocab = {}
+ self.vocab_byte = {}
+
+ # Create mappings for all vocabulary items
+ for token, token_id in self.vocab.items():
+ try:
+ byte_token = self.enbyte_fn(token)
+ self.byte_vocab[byte_token] = token_id
+ self.vocab_byte[token_id] = byte_token
+ except Exception as e:
+ # Skip problematic tokens but log them
+ print(f"Warning: Could not convert token '{token}' to bytes: {e}")
+
+ # Cache special token IDs as a set for faster lookups
+ self.special_token_ids = set(getattr(tokenizer, "all_special_ids", []))
+
+ @classmethod
+ def from_pretrained(cls, model_id, vocab_type=None):
+ """
+ Create a ByteTokenizer from a pre-trained model ID.
+
+ Parameters:
+ -----------
+ model_id: str
+ The HuggingFace model ID
+ vocab_type: VocabType, optional
+ The vocabulary type to use, if known
+
+ Returns:
+ --------
+ ByteTokenizer: A new ByteTokenizer instance
+ """
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
+ return cls(tokenizer, vocab_type)
+
+ def decode(self, token_ids: list[int], skip_special_tokens: bool = False) -> bytes:
+ """
+ Decode token_ids to bytes.
+
+ Parameters:
+ -----------
+ token_ids: list[int]
+ List of token IDs to decode
+ skip_special_tokens: bool, default=False
+ Whether to skip special tokens in the decoded output
+
+ Returns:
+ --------
+ bytes: The decoded bytes
+
+ Examples:
+ ---------
+ >>> tok = ByteTokenizer.from_pretrained('gpt2')
+ >>> tok.decode([19526, 254, 25001, 121, 28938, 245, 171, 120, 253]).decode('utf-8')
+ '你好吗?'
+ """
+ if not token_ids:
+ return b''
+
+ # Use a mutable bytearray for faster concatenation
+ result = bytearray()
+ vocab_byte = self.vocab_byte
+
+ # Fast path: no special token handling needed
+ if not skip_special_tokens or not self.special_token_ids:
+ for token_id in token_ids:
+ try:
+ # Try/except is faster than 'in' check for dictionary access
+ result.extend(vocab_byte[token_id])
+ except KeyError:
+ # Fall back to tokenizer for unknown tokens
+ text = self.tokenizer.decode([token_id])
+ result.extend(text.encode('utf-8'))
+ return bytes(result)
+
+ # Path with special token handling
+ special_token_ids = self.special_token_ids
+ for token_id in token_ids:
+ if token_id in special_token_ids:
+ continue
+
+ try:
+ result.extend(vocab_byte[token_id])
+ except KeyError:
+ text = self.tokenizer.decode([token_id])
+ result.extend(text.encode('utf-8'))
+
+ return bytes(result)
+
+ def encode(self, byte_text: bytes) -> list[int]:
+ """
+ Encode bytes to token_ids.
+
+ Parameters:
+ -----------
+ byte_text: bytes
+ Bytes to encode
+
+ Returns:
+ --------
+ list[int]: List of token IDs
+
+ Examples:
+ ---------
+ >>> tok = ByteTokenizer.from_pretrained('gpt2')
+ >>> tok.encode('你好吗?'.encode())
+ [19526, 254, 25001, 121, 28938, 245, 171, 120, 253]
+ """
+ # For byte-level tokenizers, we use a greedy algorithm
+ input_ids = []
+ original_byte_text = byte_text
+
+ # Handle RAW tokenizers differently
+ if self.vocab_type == VocabType.RAW:
+ # For RAW tokenizers, we can decode the bytes and use the original tokenizer
+ try:
+ text = byte_text.decode('utf-8')
+ return self.tokenizer.encode(text, add_special_tokens=False)
+ except UnicodeDecodeError:
+ # If we can't decode as UTF-8, fall back to our byte-level logic
+ pass
+
+ # Greedy tokenization for byte-level and byte-fallback tokenizers
+ while byte_text:
+ matched = False
+ # Try largest prefix first
+ for i in range(len(byte_text), 0, -1):
+ prefix = byte_text[:i]
+ if prefix in self.byte_vocab:
+ input_ids.append(self.byte_vocab[prefix])
+ byte_text = byte_text[i:]
+ matched = True
+ break
+
+ # If no match found for any prefix, add the first byte as an unknown token
+ # and continue with the rest
+ if not matched:
+ # Try to get the unknown token ID from the tokenizer
+ unk_token_id = self.tokenizer.unk_token_id
+ if unk_token_id is None:
+ # If no explicit unknown token, use a default
+ unk_token_id = 0
+
+ input_ids.append(unk_token_id)
+ byte_text = byte_text[1:]
+
+ return input_ids
+
+ def encode_batch(self, batch_byte_text: list[bytes]) -> list[list[int]]:
+ """
+ Encode a batch of bytes to token_ids.
+
+ Parameters:
+ -----------
+ batch_byte_text: list[bytes]
+ List of bytes sequences to encode
+
+ Returns:
+ --------
+ list[list[int]]: List of lists of token IDs
+ """
+ return [self.encode(text) for text in batch_byte_text]
+
+ def batched_decode(self, token_id_batches: List[List[int]], skip_special_tokens: bool = False) -> List[bytes]:
+ """
+ Decode multiple batches of token IDs.
+
+ Parameters:
+ -----------
+ token_id_batches: List[List[int]]
+ A list of batches of token IDs to decode
+
+ skip_special_tokens: bool, default=False
+ Whether to skip special tokens in the decoded output
+
+ Returns:
+ --------
+ List[bytes]: List of decoded byte sequences
+ """
+ # Pre-allocate result list for better performance
+ results = [None] * len(token_id_batches)
+
+ # Process all batches
+ for i, token_ids in enumerate(token_id_batches):
+ results[i] = self.decode(token_ids, skip_special_tokens)
+
+ return results
diff --git a/syncode/mask_store/fsm_set.py b/syncode/mask_store/fsm_set.py
new file mode 100644
index 00000000..ce19c123
--- /dev/null
+++ b/syncode/mask_store/fsm_set.py
@@ -0,0 +1,136 @@
+import time
+import interegular
+from typing import Any, Optional, Tuple, Iterable, Dict
+from syncode.mask_store.byte_fsm import ByteFSM
+import logging
+logger = logging.getLogger(__name__)
+
+class JointFSMState:
+ """
+ Represents the state of the FSM. It consists of the current terminal and the FSM state for the current terminal.
+ """
+ def __init__(self, terminal: str, state_id: int):
+ self.terminal = terminal
+ self.state_id = state_id
+ self._hash = hash((self.terminal, self.state_id)) # Pre-compute hash on creation
+
+ def __eq__(self, other: 'JointFSMState'):
+ return self.terminal == other.terminal and self.state_id == other.state_id
+
+ def __hash__(self):
+ return self._hash
+
+ def __repr__(self):
+ return f"({self.terminal}, {self.state_id})"
+
+
+class FSMSet:
+ """
+ Stores the FSM for each terminal and provides the method to consume the input string and get the FSM state.
+ Uses external ByteFSM for regex matching.
+ """
+ def __init__(self, terminals: Iterable['MockTerminalDef'], simplifications: Dict[str, str] = {}):
+ start_time = time.time()
+ self._terminals_to_byte_fsm: Dict[str, ByteFSM] = {} # Store ByteFSM instances
+ self.anything_else = interegular.fsm.anything_else
+ self._simplifications: Dict[str, str] = simplifications
+
+ # Initialize cache for initial states
+ self._initial_state_cache = {}
+ cnt_states = 0
+
+ for terminal in terminals:
+ if terminal.name in simplifications:
+ terminal_regex = simplifications[terminal.name]
+ else:
+ terminal_regex = terminal.pattern.to_regexp()
+
+ # Create a ByteFSM for each terminal pattern
+ # This handles the regex pattern matching
+ byte_fsm = ByteFSM(terminal_regex)
+ self._terminals_to_byte_fsm[terminal.name] = byte_fsm
+ cnt_states += len(byte_fsm.transitions)
+ logger.info(f"{len(terminals)} FSMs with {cnt_states} states initialized in {time.time() - start_time:.2f} seconds")
+
+ def states(self):
+ """Returns all possible DFA states for all terminals."""
+ states = []
+ for terminal_name, byte_fsm in self._terminals_to_byte_fsm.items():
+ # We need to get states from the ByteFSM's transitions dictionary
+ for state_id in byte_fsm.transitions:
+ states.append(JointFSMState(terminal_name, state_id))
+ return states
+
+ def initial(self, terminal: str):
+ """Get the initial state for a specific terminal (optimized with caching)."""
+ # Check if we've already computed this initial state
+ if terminal not in self._initial_state_cache:
+ # Only create the JointFSMState object once per terminal
+ self._initial_state_cache[terminal] = JointFSMState(
+ terminal,
+ self._terminals_to_byte_fsm[terminal].initial
+ )
+
+ # Return the cached version
+ return self._initial_state_cache[terminal]
+
+ def compute_fsm_states(self, input_bytes: bytes) -> Iterable[JointFSMState]:
+ """
+ Consume input_bytes and get the list of pairs of (terminal, state_id).
+ This denotes our current DFA state.
+
+ For each terminal's ByteFSM, attempts to consume all input bytes
+ and returns state after consumption. A terminal is included only if
+ its FSM can successfully process the entire input.
+
+ There is no requirement for the final state to be an accepting state.
+
+ Args:
+ input_bytes: The input bytes to consume
+
+ Returns:
+ A list of JointFSMState objects, each containing a terminal and its
+ corresponding state_id after consuming all input bytes.
+ """
+ dfa_states = []
+
+ for terminal, byte_fsm in self._terminals_to_byte_fsm.items():
+ # Start from the initial state
+ current_state = byte_fsm.initial
+
+ # Process input byte by byte
+ valid_transition = True
+ for byte_val in input_bytes:
+ next_state = byte_fsm.get_next_state(current_state, byte_val)
+ if next_state is None:
+ valid_transition = False
+ break
+ current_state = next_state
+
+ # If we were able to process all bytes, add the terminal and final state
+ if valid_transition:
+ dfa_states.append(JointFSMState(terminal, current_state))
+ # Special case for empty input
+ elif not input_bytes:
+ dfa_states.append(JointFSMState(terminal, byte_fsm.initial))
+
+ return dfa_states
+
+ def is_final(self, dfa_state: JointFSMState) -> bool:
+ """
+ Returns True if the dfa state is a final state
+ """
+ byte_fsm = self._terminals_to_byte_fsm[dfa_state.terminal]
+ return dfa_state.state_id in byte_fsm.finals
+
+ def consume_prefix(self, fsm_state: JointFSMState, input_bytes: bytes) -> Tuple[bool, Optional[bytes]]:
+ """
+ Consume longest prefix of input_bytes that is accepted by dfa and return the remainder.
+ """
+ terminal = fsm_state.terminal
+ current_state = fsm_state.state_id
+ byte_fsm = self._terminals_to_byte_fsm[terminal]
+
+ success, remainder = byte_fsm.consume_prefix(input_bytes, current_state)
+ return success, remainder
+
\ No newline at end of file
diff --git a/syncode/mask_store/lookup_table.py b/syncode/mask_store/lookup_table.py
new file mode 100644
index 00000000..dc30b394
--- /dev/null
+++ b/syncode/mask_store/lookup_table.py
@@ -0,0 +1,182 @@
+from collections import defaultdict
+import copy
+import torch
+import regex
+from syncode.mask_store.mask_store import JointFSMState
+from syncode.parse_result import IndentationConstraint
+from typing import Any, Tuple, Iterable, Dict, Union
+import logging
+logger = logging.getLogger(__name__)
+
+
+class LookupTable:
+ """
+ Stores the overapproximate tokens
+ """
+ def __init__(
+ self,
+ vocab: Iterable[str],
+ eos_token_id: int,
+ special_token_ids: Iterable[int],
+ indent=False,
+ mode='grammar_mask'
+ ):
+ self._fsm_state_and_next_terminal_to_tokens: defaultdict = defaultdict(list)
+ self._overapprox_lookup: Dict[JointFSMState, Any] = {}
+ self._exact_lookup: dict = {}
+ self._mode = mode
+ self._vocab: Iterable[str] = vocab
+ self.indent = indent
+
+ # In the default mask, add all tokens that are special tokens except the EOS token
+ self._default_mask: torch.IntTensor = torch.zeros(len(vocab), dtype=torch.bool)
+ for token_id in special_token_ids:
+ if token_id != eos_token_id:
+ self._default_mask[token_id] = 1
+
+ if indent:
+ logger.info("Indentation mode enabled")
+ self._whitespace_tokens_map: defaultdict = defaultdict(list)
+ self._indentation_to_tokens_map: defaultdict = defaultdict(list)
+ self._create_indentation_to_tokens_map()
+
+ def incomplete_case_lookup(self, fsm_state: JointFSMState) -> Any:
+ assert isinstance(fsm_state, JointFSMState)
+ if self._mode == 'grammar_mask':
+ return self._overapprox_lookup[fsm_state]
+ elif self._mode == 'grammar_strict':
+ if fsm_state in self._exact_lookup:
+ return self._exact_lookup[fsm_state]
+ else:
+ return self._overapprox_lookup[fsm_state]
+ raise ValueError(f"Invalid mode: {self._mode}")
+
+ def store_overapprox_lookup(self, fsm_state: JointFSMState, mask: torch.Tensor):
+ assert isinstance(fsm_state, JointFSMState)
+ if fsm_state not in self._overapprox_lookup:
+ self._overapprox_lookup[fsm_state] = self._get_default_mask()
+ self._overapprox_lookup[fsm_state] |= mask
+
+ def complete_case_lookup(self, fsm_state: JointFSMState) -> Any:
+ assert isinstance(fsm_state, JointFSMState)
+ return self._exact_lookup[fsm_state]
+
+ def add_exact_lookup(self, fsm_state: JointFSMState, token):
+ assert isinstance(fsm_state, JointFSMState)
+ if fsm_state not in self._exact_lookup:
+ self._exact_lookup[fsm_state] = []
+ self._exact_lookup[fsm_state].append(token)
+
+ def fsm_state_and_next_terminal_to_tokens(self, fsm_state: JointFSMState, next_terminal) -> torch.Tensor:
+ assert isinstance(fsm_state, JointFSMState)
+ return self._fsm_state_and_next_terminal_to_tokens[(fsm_state, next_terminal)]
+
+ def fsm_state_and_next_terminal_to_tokens_store(self, fsm_state: JointFSMState, next_terminal, mask: torch.Tensor):
+ assert isinstance(fsm_state, JointFSMState)
+ self._fsm_state_and_next_terminal_to_tokens[(fsm_state, next_terminal)] = mask
+
+ def fsm_state_and_next_terminal_to_tokens_add(self, fsm_state: JointFSMState, next_terminal, token):
+ assert isinstance(fsm_state, JointFSMState)
+ self._fsm_state_and_next_terminal_to_tokens[(fsm_state, next_terminal)].append(token)
+
+ def _list_to_mask(self, tokens_idx_list) -> torch.Tensor:
+ indices = torch.tensor(tokens_idx_list)
+ tokens_mask = self._get_default_mask()
+ tokens_mask[indices] = 1
+ return tokens_mask
+
+ def convert_lookups_from_list_to_mask(self):
+ """
+ Converts the lookups from list of tokens to boolean tensor mask
+ """
+ for key, val in self._fsm_state_and_next_terminal_to_tokens.items():
+ m = self._list_to_mask(val)
+ self._fsm_state_and_next_terminal_to_tokens[key] = m
+ (fsm_state, _) = key
+ self.store_overapprox_lookup(fsm_state, m)
+
+ for key, val in self._exact_lookup.items():
+ self._exact_lookup[key] = self._list_to_mask(val)
+
+ # TODO: move this logic to the lookup table
+ if self.indent:
+ for key, val in self._whitespace_tokens_map.items():
+ self._whitespace_tokens_map[key] = self._list_to_mask(val)
+ for key, val in self._indentation_to_tokens_map.items():
+ self._indentation_to_tokens_map[key] = self._list_to_mask(val)
+
+ def _get_default_mask(self) -> torch.Tensor:
+ return self._default_mask.clone()
+
+ def _create_indentation_to_tokens_map(self):
+ """
+ We create a mapping from indentation to overapproximate tokens. This is useful for the indentation rule.
+ """
+ for token_idx, token in enumerate(self._vocab):
+ full_match, indent = self._get_indent_type(token)
+ if full_match:
+ self._whitespace_tokens_map[indent].append(token_idx)
+ else:
+ self._indentation_to_tokens_map[indent].append(token_idx)
+
+ def _get_indent_type(self, s: Union[str, bytes]) -> Tuple[bool, int]:
+ """
+ Determine the indentation type and level from a string or bytes input.
+
+ Args:
+ s (Union[str, bytes]): The input string or bytes to analyze
+
+ Returns:
+ Tuple[bool, int]: A tuple containing:
+ - bool: Whether the input is entirely whitespace
+ - int: The indentation level (spaces + 4*tabs)
+ """
+ # Convert bytes to string if needed
+ if isinstance(s, bytes):
+ try:
+ s_str = s.decode('utf-8')
+ except UnicodeDecodeError:
+ # Handle decode errors by returning default values
+ return False, 0
+ else:
+ s_str = s
+
+ m = regex.match(r'[\t ]+', s_str, partial=True)
+ full_match = False
+ if m != None:
+ start, end = m.start(), m.end()
+ if end == len(s_str):
+ full_match = True
+ return full_match, s_str[start: end].count(' ') + 4*s_str[start: end].count('\t')
+ return False, 0
+
+ def get_indentation_tokens(self, indent_constraint: IndentationConstraint, get_list=False):
+ """
+ Returns the tokens mask for the indentation constraint
+ """
+ out_mask = self._get_default_mask()
+ if indent_constraint.greater_than_indent_val is not None:
+ for indent in self._indentation_to_tokens_map.keys():
+ if indent > indent_constraint.greater_than_indent_val:
+ out_mask |= self._indentation_to_tokens_map[indent]
+
+ for indent in self._whitespace_tokens_map.keys(): # We are ok with any num of whitespace
+ out_mask |= self._whitespace_tokens_map[indent]
+
+ elif indent_constraint.accept_indents is not None:
+ for indent in indent_constraint.accept_indents:
+ if indent in self._indentation_to_tokens_map:
+ out_mask |= self._indentation_to_tokens_map[indent]
+
+ max_acceptable_indent = max(indent_constraint.accept_indents)
+ for indent in self._whitespace_tokens_map.keys(): # We are ok with num whitespace <= largest accepted indent
+ if indent <= max_acceptable_indent:
+ out_mask |= self._whitespace_tokens_map[indent]
+
+ if get_list: # This is useful for testing
+ return self._get_tokens_list(out_mask)
+ return out_mask
+
+ def _get_tokens_list(self, token_mask) -> Iterable[str]:
+ return [self._vocab[idx.item()] for idx in torch.where(token_mask == True)[0]]
+
\ No newline at end of file
diff --git a/syncode/mask_store/mask_store.py b/syncode/mask_store/mask_store.py
new file mode 100644
index 00000000..a5a76c81
--- /dev/null
+++ b/syncode/mask_store/mask_store.py
@@ -0,0 +1,418 @@
+from collections import defaultdict
+import os, pickle
+import time
+import torch
+import regex
+import syncode.common as common
+from tqdm import tqdm
+from syncode.mask_store.byte_tokenizer import ByteTokenizer
+from syncode.mask_store.fsm_set import JointFSMState, FSMSet
+from syncode.mask_store.lookup_table import LookupTable
+from syncode.parsers import create_base_parser
+from syncode.larkm.lexer import TerminalDef
+from syncode.parse_result import RemainderState, ParseResult
+from syncode.parsers.grammars.grammar import Grammar
+from typing import Iterable, Union
+from transformers import PreTrainedTokenizer
+import logging
+logger = logging.getLogger(__name__)
+
+
+class MaskStore:
+ """
+ We build an DFA that consists of DFAs for each terminal. We simulate the DFA by consuming the input string for each terminal DFA.
+
+ There are 3 possible cases for the remainder string:
+
+ 1. COMPLETE: In this case, the last token is complete (and cannot be further extended) and we know the type of next terminal. Thus, we need to compute all tokens such that consuming the token leads to a live state for the terminal DFA or it reaches a final state for the terminal DFA.
+
+ 2. INCOMPLETE: In this case, the remainder is incomplete and does not match any terminal regex. Thus, we need to compute all tokens such that consuming the token leads to a live state for the current terminal DFA or again it reaches a final state for the current terminal DFA at some point.
+
+ 3. MAYBE_COMPLETE: In this case the remainder matches a type of terminal. It may happen that we add to the same matched part of the remainder. In that case, there are two possibilities. i) the matched terminal type does not change and thus we can use the next terminal set computed by assuming that. ii) the matched terminal type changes and then we do not know the next terminal set. Thus, we need to compute all tokens such that consuming the token leads to a live state for the current terminal DFA or again it reaches a final state for the current terminal DFA at some point.
+ """
+ def __init__(self,
+ terminals: Iterable[TerminalDef],
+ tokenizer: PreTrainedTokenizer,
+ simplifications: dict={},
+ indent: bool=False,
+ mode='grammar_strict', # 'grammar_strict' or 'grammar_mask'
+ ignore_terminals: Iterable[str]=[],
+ parse_table=None
+ ):
+ self._vocab = MaskStore._get_vocab_from_tokenizer(tokenizer)
+ self._mode = mode
+
+ # Tokenizer for byte tokens
+ self.byte_tokenizer = ByteTokenizer(tokenizer)
+ self.special_token_ids = tokenizer.all_special_ids
+ self.eos_token_id = tokenizer.eos_token_id
+
+ # Create the FSMs for each terminal
+ self._fsms = FSMSet(terminals, simplifications)
+
+ # Check if whitespace is in ignore terminals
+ self._ignore_whitespace = self.set_ignore_whitespace(terminals, ignore_terminals)
+ logger.info(f"Ignore whitespace is {self._ignore_whitespace}")
+
+ # Iterate through each pair of DFA state and next terminals and store the overapproximate tokens
+ self._lookup_table = LookupTable(
+ self._vocab,
+ eos_token_id=self.eos_token_id,
+ special_token_ids=self.special_token_ids,
+ indent=indent, mode=mode
+ )
+ terminal_names = [terminal.name for terminal in terminals]
+
+ followings_terminas_map = None
+ if parse_table is not None:
+ followings_terminas_map = self._compute_following_terminals_map(terminal_names, parse_table)
+
+ # Create consume prefix cache
+ self._consume_prefix_cache = {}
+ self._store_token_masks(terminal_names, len(self._vocab), followings_terminas_map)
+
+ self.indentation = indent
+
+ # NOTE: This should be called at the end of the constructor
+ self._lookup_table.convert_lookups_from_list_to_mask()
+
+ def set_ignore_whitespace(self, terminals: Iterable[TerminalDef], ignore_terminals: Iterable[str]) -> bool:
+ ignore_whitespace = False
+ for ig_name in ignore_terminals:
+ for terminal in terminals:
+ if terminal.name == ig_name:
+ if regex.match(terminal.pattern.to_regexp(), ' ') is not None:
+ ignore_whitespace = True # convert to boolean tensor mask. This is useful for fast union operations
+ return ignore_whitespace
+
+ @staticmethod
+ def init_mask_store(
+ grammar: Grammar,
+ tokenizer,
+ use_cache=True,
+ mode='grammar_strict',
+ indent=False
+ ):
+ '''
+ Loads the fsm for the given language and tokenizer. If the fsm is not cached, it is created and cached.
+ '''
+ tokenizer_name = type(tokenizer).__name__
+ fsm_dir = common.SYNCODE_CACHE + 'mask_stores/' + tokenizer_name + '/'
+ grammar_hash = grammar.hash()
+
+ # TODO: Hashing using the tokenizer vocab size, this may be problmatic if we have two fine-tuned models with same tokenizer, same vocab size but different vocab
+ fsm_path = f'{fsm_dir}{mode}_{grammar_hash}_{tokenizer.vocab_size}.pkl'
+
+ if use_cache and os.path.exists(fsm_path):
+ try:
+ mask_store = pickle.load(open(fsm_path, 'rb'))
+ return mask_store
+ except Exception as e:
+ logger.warning(f"Error loading mask store: {e}")
+
+ logger.info(f"Using cache: {use_cache} and fsm path {fsm_path} exist: {os.path.exists(fsm_path)}")
+ logger.info(f"Creating mask store for {tokenizer_name} and {grammar}, may take more than 10 minutes. Caching at {os.path.abspath(fsm_path)}.")
+ base_parser = create_base_parser(grammar)
+ simplifications = grammar.simplifications()
+ os.makedirs(fsm_dir, exist_ok=True)
+
+ start_time = time.time()
+ mask_store = MaskStore(
+ base_parser.terminals,
+ tokenizer,
+ simplifications=simplifications,
+ mode=mode,
+ ignore_terminals=base_parser.ignore_tokens,
+ parse_table=base_parser.parser.parser._parse_table,
+ indent=indent
+ )
+ logger.info(f"Time taken to create mask store: {time.time() - start_time:.2f} seconds")
+
+ pickle.dump(mask_store, open(fsm_path, 'wb'))
+ return mask_store
+
+ def _compute_following_terminals_map(
+ self,
+ terminals: Iterable[str],
+ parse_table
+ ) -> defaultdict:
+ """
+ From terminals, filter out terminals that cannot follow the current terminal
+ according to the grammar.
+
+ If in the parsing table Action[cur_terminal, parser_state] = 'shift, new_parser_state' then next terminals
+ are the terminals that are legal in new_parser_state.
+ """
+ following_terminals_map = defaultdict(set)
+ terminals_set = set(terminals)
+ cnt_seq_two_terms = 0
+
+ # We iterate through each cur_terminal:
+ for cur_terminal in terminals:
+ # Add all ignore terminals to the following terminals
+ for next_terminal in terminals:
+ if 'IGNORE' in next_terminal:
+ following_terminals_map[cur_terminal].add(next_terminal)
+
+ # We iterate through each parser_state:
+ for _, row in parse_table.states.items():
+ if cur_terminal in row:
+ action = row[cur_terminal]
+ # -> If we see a shift action to new_parser_state
+ if str(action[0]) == 'Shift':
+ new_parser_state = action[1]
+ for next_terminal in parse_table.states[new_parser_state]:
+ # Lark parse_table stores non-terminals and terminals together
+ if next_terminal in terminals_set:
+ # -> -> we add the terminals that are legal in new_parser_state
+ following_terminals_map[cur_terminal].add(next_terminal)
+ cnt_seq_two_terms += len(following_terminals_map[cur_terminal])
+
+ logger.info(f"Number of 2 length terminal sequences reduced from {len(terminals)*len(terminals)} to {cnt_seq_two_terms}")
+ return following_terminals_map
+
+
+ def _store_token_masks(self, terminals: Iterable[str], vocab_size: int, followings_terminas_map: dict=None):
+ """
+ Stores the token masks for each fsm state and next terminals
+ """
+ all_fsm_states = self._fsms.states()
+ pbar = tqdm(total=len(all_fsm_states))
+
+ for fsm_state in all_fsm_states:
+ # Get the next terminals for the current fsm state
+ if followings_terminas_map is not None and fsm_state.terminal in followings_terminas_map:
+ following_terminals = followings_terminas_map[fsm_state.terminal]
+ else:
+ following_terminals = terminals
+
+ # For each token, we check if it is a valid token for the current fsm state
+ for token_idx in range(vocab_size):
+ # If the token is EOS token, we add it to the final state with the terminal '$END'
+ if token_idx == self.eos_token_id:
+ # Add 2 length terminal sequences where the second terminal is '$END'
+ if self._fsms.is_final(fsm_state):
+ self._lookup_table.fsm_state_and_next_terminal_to_tokens_add(
+ fsm_state, '$END', token_idx)
+ else:
+ self._process_regular_tokens(
+ following_terminals, fsm_state, token_idx
+ )
+ pbar.update(1)
+ pbar.close()
+
+
+ def _process_regular_tokens(self, terminals, fsm_state: JointFSMState, token_idx: int):
+ token_bytes = self.byte_tokenizer.decode([token_idx])
+
+ # For COMPLETE case:
+ self._process_complete_case(fsm_state, token_idx, token_bytes)
+
+ # For INCOMPLETE case:
+ # Replace \t with 4 spaces
+ remainder = token_bytes.replace(b'\t', b' ')
+
+ is_valid, remainder = self._fsms.consume_prefix(fsm_state, remainder)
+ if is_valid:
+ if len(remainder) == 0:
+ # We reached a live state for the current terminal, thus we add the token in all sets of next terminals
+ for next_terminal in terminals:
+ self._lookup_table.fsm_state_and_next_terminal_to_tokens_add(
+ fsm_state, next_terminal, token_idx)
+ else:
+ remainder = self._remove_left_whitespace(fsm_state, remainder)
+
+ # We reached the final state while consuming the token, thus we conusme the remainder with all next terminals
+ for next_terminal in terminals:
+ initial_state = self._fsms.initial(next_terminal)
+
+ if (initial_state, remainder) not in self._consume_prefix_cache:
+ # We use the cache to speed up the process only for the initial state
+ is_valid, remainder_new = self._fsms.consume_prefix(initial_state, remainder)
+ self._consume_prefix_cache[(initial_state, remainder)] = (is_valid, remainder_new)
+ else:
+ is_valid, remainder_new = self._consume_prefix_cache[(initial_state, remainder)]
+
+ if self._mode == 'grammar_mask':
+ if is_valid: # In the non-strict mode we overapproximate
+ # We reached a live state for the next terminal, thus we add the
+ # token in the overapproximate sets of next terminals
+ self._lookup_table.fsm_state_and_next_terminal_to_tokens_add(
+ fsm_state, next_terminal, token_idx)
+ elif self._mode == 'grammar_strict':
+ if is_valid and len(remainder_new) == 0:
+ # We reached a live state for the next terminal and the remainder
+ # is empty, thus we add the token in the exact set of next terminals
+ self._lookup_table.fsm_state_and_next_terminal_to_tokens_add(fsm_state, next_terminal, token_idx)
+ else:
+ raise ValueError(f"Invalid mode: {self._mode}")
+
+
+ def _process_complete_case(self, fsm_state, token_idx, token_bytes):
+ remainder = token_bytes.replace(b'\t', b' ')
+ remainder = self._remove_left_whitespace(fsm_state, remainder)
+
+ is_valid, remainder = self._fsms.consume_prefix(fsm_state, remainder)
+ if is_valid and len(remainder) == 0:
+ self._lookup_table.add_exact_lookup(fsm_state, token_idx)
+
+ def _remove_left_whitespace(
+ self,
+ fsm_state,
+ remainder: Union[str, bytes]
+ ) -> Union[str, bytes]:
+ """
+ Ignore all left whitespace at the start of the terminal. This improves efficiency
+ e.g. without this say if the model wants to generate ' def' then syncode will force it to generate ' ' and 'def' separately
+ """
+ if (self._fsms.initial(fsm_state.terminal) == fsm_state or self._fsms.is_final(fsm_state)) and self._ignore_whitespace:
+ if isinstance(remainder, bytes):
+ # For bytes, use lstrip() to remove all leading whitespace
+ remainder = remainder.lstrip()
+ elif isinstance(remainder, str):
+ # For strings, use lstrip() to remove all leading whitespace
+ remainder = remainder.lstrip()
+ return remainder
+
+ def _lookup_next_tokens_for_fsm_state(self, fsm_state: JointFSMState, next_terminal) -> torch.Tensor:
+ tokens = self._lookup_table.fsm_state_and_next_terminal_to_tokens(fsm_state, next_terminal)
+ if tokens == []:
+ # TODO: This is a hack. Do something better
+ return self._lookup_table._get_default_mask()
+ return tokens
+
+ @staticmethod
+ def _get_vocab_from_tokenizer(tokenizer, byte_string=False) -> Iterable[str]:
+ """
+ self.vocab is a list of readable token strings (e.g., ' hello' and '\n')
+ sorted by their token IDs (so self.vocab[0] is the first token, etc).
+
+ if `byte_string` is True, then the vocab is returned as byte strings.
+ """
+ vocab = [v for k, v in
+ sorted([(t_id, tokenizer.decode([t_id]))
+ for _, t_id in tokenizer.get_vocab().items()])]
+
+ # HACK: Is there a better way to know if a token has a prefix space?
+ if 'Llama' in tokenizer.__class__.__name__:
+ for i in range(len(vocab)):
+ t = vocab[i]
+ if 2*len(t) != len(tokenizer.decode([i, i], add_special_tokens=False)):
+ vocab[i] = ' ' + t
+ if t == '':
+ vocab[i] = ' '
+
+ if byte_string:
+ vocab = [t.encode('utf-8') for t in vocab]
+
+ return vocab
+
+
+ def _lookup_next_tokens(
+ self,
+ fsm_states: Iterable[JointFSMState],
+ remainder_state: RemainderState,
+ accept_sequences: Iterable
+ ) -> torch.Tensor:
+ """
+ Lookup the next tokens for the current fsm states and remainder state and accept sequences.
+ """
+ accept_token_mask = self._lookup_table._get_default_mask()
+
+ # Case when the final string may be incomplete
+ for fsm_state in fsm_states:
+ for accept_sequence in accept_sequences:
+ if accept_sequence[0] == '$END':
+ accept_token_mask[self.eos_token_id] = 1
+
+ if fsm_state.terminal != accept_sequence[0]:
+ continue
+
+ if remainder_state == RemainderState.COMPLETE:
+ assert len(accept_sequence) == 1 # Since we only store length 1 accept sequences in this case
+ accept_token_mask |= self._lookup_table.complete_case_lookup(fsm_state)
+
+ if remainder_state == RemainderState.INCOMPLETE:
+ accept_token_mask |= self._lookup_table.incomplete_case_lookup(fsm_state)
+
+ if remainder_state == RemainderState.MAYBE_COMPLETE:
+ if len(accept_sequence) == 1:
+ # mode='grammar_strict': incomplete_case_lookup is the same as complete_case_lookup
+ # mode='grammar_mask': incomplete_case_lookup is the overapproximate lookup
+ accept_token_mask |= self._lookup_table.incomplete_case_lookup(fsm_state)
+ elif len(accept_sequence) == 2:
+ accept_token_mask |= self._lookup_next_tokens_for_fsm_state(fsm_state, accept_sequence[1])
+ elif len(accept_sequence) == 3:
+ # If the DFA state is a final state we can jump to the start of next terminal
+ if self._fsms.is_final(fsm_state):
+ ignore_init_state = self._fsms.initial(accept_sequence[1])
+ accept_token_mask |= self._lookup_next_tokens_for_fsm_state(ignore_init_state, accept_sequence[2])
+ else:
+ raise ValueError(f"Invalid accept sequence: {accept_sequence}")
+ return accept_token_mask
+
+ def get_fsm_states(self, r: ParseResult) -> Iterable[JointFSMState]:
+ """
+ Returns the DFA state for the current partial code
+ """
+ cur_incomplete_string = r.remainder
+ if cur_incomplete_string is None:
+ return []
+
+ cur_fsm_states = self._fsms.compute_fsm_states(cur_incomplete_string)
+ return cur_fsm_states
+
+ def get_accept_mask(
+ self,
+ r: ParseResult,
+ get_list=False
+ ) -> torch.Tensor:
+ """
+ Returns the mask for the acceptable tokens for the current partial code
+
+ Args:
+ r (ParseResult): The parse result
+ get_list (bool, optional): If True, returns the list of tokens instead of the mask. Defaults to False.
+ """
+ cur_incomplete_string = r.remainder
+ assert type(cur_incomplete_string) == bytes
+
+ if cur_incomplete_string is None:
+ return torch.ones(len(self._vocab), dtype=torch.bool)
+
+ cur_fsm_states = self._fsms.compute_fsm_states(cur_incomplete_string)
+ accept_token_mask = self._lookup_next_tokens(
+ cur_fsm_states,
+ r.remainder_state,
+ r.accept_sequences
+ )
+
+ if self.indentation and r.next_ac_indents is not None:
+ indent_ac_token = self._lookup_table.get_indentation_tokens(r.next_ac_indents)
+ accept_token_mask &= indent_ac_token
+
+ if get_list: # This is useful for testing
+ return self._get_tokens_list(accept_token_mask)
+ return accept_token_mask
+
+ def is_valid_prefix(self, r: ParseResult) -> bool:
+ """
+ Check if r.remainder is a valid prefix for accept sequences in r
+ """
+ cur_incomplete_string = r.remainder
+
+ cur_fsm_states = self._fsms.compute_fsm_states(cur_incomplete_string)
+ for accept_sequence in r.accept_sequences:
+ for fsm_state in cur_fsm_states:
+ if fsm_state.terminal == accept_sequence[0]:
+ return True
+ return False
+
+ def _list_to_mask(self, tokens_idx_list) -> torch.Tensor:
+ indices = torch.tensor(tokens_idx_list)
+ tokens_mask = self._lookup_table._get_default_mask()
+ tokens_mask[indices] = 1
+ return tokens_mask
+
+ def _get_tokens_list(self, token_mask) -> Iterable[str]:
+ return [self._vocab[idx.item()] for idx in torch.where(token_mask == True)[0]]
diff --git a/syncode/parsers/incremental_parser.py b/syncode/parsers/incremental_parser.py
index 0bad9ed5..7a36c75e 100644
--- a/syncode/parsers/incremental_parser.py
+++ b/syncode/parsers/incremental_parser.py
@@ -5,12 +5,15 @@
from syncode.parse_result import ParseResult, RemainderState
from syncode.larkm.lexer import Token
from typing import Optional, Any, Tuple, Iterable
+import logging
+logger = logging.getLogger(__name__)
+
class IncrementalParser:
"""
This is the base class for all incremental parsers.
"""
- def __init__(self, base_parser, logger: Optional[common.Logger]=None, ignore_whitespace=False) -> None:
+ def __init__(self, base_parser, ignore_whitespace=False) -> None:
self.cur_pos = 0 # Current cursor position in the lexer tokens list
self.lexer_pos = 0 # Current lexer position in the code
self.dedent_queue: list = []
@@ -19,7 +22,6 @@ def __init__(self, base_parser, logger: Optional[common.Logger]=None, ignore_whi
# Initialize the parser
self.base_parser = base_parser
- self.logger = logger if logger is not None else common.EmptyLogger()
self.interactive = self.base_parser.parse_interactive('')
self.parsed_lexer_tokens: list = []
diff --git a/syncode/parsers/python_parser.py b/syncode/parsers/python_parser.py
index f2a7732e..4d16ee2b 100644
--- a/syncode/parsers/python_parser.py
+++ b/syncode/parsers/python_parser.py
@@ -7,13 +7,15 @@
from syncode.parsers.incremental_parser import IncrementalParser
from syncode.parse_result import IndentationConstraint, ParseResult, RemainderState
from typing import Optional, Iterable
+import logging
+logger = logging.getLogger(__name__)
class PythonIncrementalParser(IncrementalParser):
"""
This class implements an incremental parser for Python code.
"""
- def __init__(self, base_parser, indenter, logger:Optional[common.Logger]=None, partial_code=None,**kwargs):
- super().__init__(base_parser, logger=logger, **kwargs)
+ def __init__(self, base_parser, indenter, partial_code=None,**kwargs):
+ super().__init__(base_parser, **kwargs)
if partial_code is not None: # extract indentation type from partial code
indenter.tab_len = self._get_indentation(partial_code) # NOTE: tab_len is useful when \t and spaces are used for indentation in same code
@@ -146,7 +148,7 @@ def _lex_code(self, code: str) -> Iterable[Token]:
# Perform postlexing indentation
if token.type == indenter.NL_type:
- lexer_tokens += indenter._handle_NL(token, self.logger)
+ lexer_tokens += indenter._handle_NL(token)
else:
lexer_tokens.append(token)
if token.type in indenter.OPEN_PAREN_types:
@@ -174,7 +176,7 @@ class PythonIndenter(Indenter):
DEDENT_type = "_DEDENT"
tab_len = 4
- def _handle_NL(self, token: Token, logger=None) -> Iterator[Token]:
+ def _handle_NL(self, token: Token) -> Iterator[Token]:
'''
This is taken from Lark library and modified to handle the case when there is a LONG_STRING comment in the _NL token.
'''
@@ -186,8 +188,7 @@ def _handle_NL(self, token: Token, logger=None) -> Iterator[Token]:
try:
indent_str = m.group(1).rsplit('\n', 1)[1] # Tabs and spaces
except IndexError:
- if logger is not None:
- logger.log(f'Could not find the indentation for LONG_STRING comment in the token: {token}')
+ logger.error(f'Could not find the indentation for LONG_STRING comment in the token: {token}')
indent_str = ''
indent = indent_str.count(' ') + indent_str.count('\t') * self.tab_len
diff --git a/tests/mask_store/tes b/tests/mask_store/tes
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/mask_store/test_byte_fsm.py b/tests/mask_store/test_byte_fsm.py
new file mode 100644
index 00000000..ec75917a
--- /dev/null
+++ b/tests/mask_store/test_byte_fsm.py
@@ -0,0 +1,178 @@
+import re
+import unittest
+import sys, os
+sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../..')
+from syncode.mask_store.byte_fsm import ByteFSM
+
+class TestByteFSM(unittest.TestCase):
+ """Test suite for the ByteFSM class."""
+
+ def test_basic_patterns(self):
+ """Test basic regex pattern matching with various inputs."""
+ test_patterns = {
+ "abc": [
+ ("abc", True),
+ ("abcd", False),
+ ("ab", False)
+ ],
+ "a[0-9]c": [
+ ("a0c", True),
+ ("a5c", True),
+ ("abc", False),
+ ("a12c", False)
+ ],
+ "a+b*c": [
+ ("ac", True),
+ ("abc", True),
+ ("aac", True),
+ ("aabbc", True),
+ ("bc", False)
+ ],
+ "cat|dog": [
+ ("cat", True),
+ ("dog", True),
+ ("bat", False)
+ ],
+ "hello[0-9]+": [
+ ("hello123", True),
+ ("hello", False)
+ ],
+ "(?:(?:(?:\\/\\/[^\n]*|(\r?\n[\t ]*)+))+|/\\*'\\ \\.\\*\\?\\ '\\*/|;)":
+ [
+ ("//comment", True),
+ # ("/*comment*/", True), TODO: this should be fixed in GO grammar
+ (";", True),
+ ("\n //comment", True),
+ ("/*comment", False),
+ ],
+ "😊": [
+ ("😊", True),
+ ("a", False)
+ ],
+ }
+
+ for pattern, test_cases in test_patterns.items():
+ with self.subTest(pattern=pattern):
+ byte_fsm = ByteFSM(pattern)
+ re_pattern = re.compile(f"^{pattern}$")
+
+ # Log FSM properties for information
+ print(f"FSM states: {byte_fsm.num_states}")
+ print(f"FSM alphabet size: {byte_fsm.alphabet_size}")
+ print(f"FSM transitions: {byte_fsm.num_transitions}")
+
+ for test_str, expected in test_cases:
+ with self.subTest(test_str=test_str):
+ # Check Python's re
+ python_match = bool(re_pattern.match(test_str))
+ self.assertEqual(python_match, expected,
+ f"Python regex gave unexpected result for pattern '{pattern}' and input '{test_str}'")
+
+ # Check ByteFSM
+ byte_match = byte_fsm.accepts(test_str)
+ self.assertEqual(byte_match, expected,
+ f"ByteFSM gave unexpected result for pattern '{pattern}' and input '{test_str}'")
+
+ def test_consume_prefix(self):
+ """Test the consume_prefix functionality for various regex patterns."""
+ prefix_test_cases = [
+ ("abc[0-9]+", [
+ ("abc123def", (True, b"def")),
+ ("abc", (True, b"")), # Live state
+ ("xyz", (False, None)),
+ ("abc123", (True, b"")),
+ ("abc123456xyz", (True, b"xyz")),
+ ("ab", (True, b"")) # Live state
+ ]),
+ ("a+b*c", [
+ ("acdef", (True, b"def")),
+ ("abbcxyz", (True, b"xyz")),
+ ("aaabcdef", (True, b"def")),
+ ("def", (False, None)),
+ ("a", (True, b"")), # Live state
+ ("ab", (True, b"")), # Live state
+ ("aaaabc", (True, b""))
+ ]),
+ ("cat|dog", [
+ ("caterpillar", (True, b"erpillar")),
+ ("doghouse", (True, b"house")),
+ ("catalog", (True, b"alog")),
+ ("ca", (True, b"")), # Live state
+ ("donut", (False, None)),
+ ("mouse", (False, None))
+ ]),
+ ("ab?c", [
+ ("abcdef", (True, b"def")),
+ ("acxyz", (True, b"xyz")),
+ ("abc", (True, b"")),
+ ("ac", (True, b"")),
+ ("abd", (False, None))
+ ]),
+ ("😊+", [
+ ("😊hello", (True, b"hello")),
+ ("😊😊world", (True, b"world")),
+ ("hello😊", (False, None)),
+ ("😊", (True, b""))
+ ]),
+ ("[a-z]+@[a-z]+\\.(com|org)", [
+ ("user@example.com/page", (True, b"/page")),
+ ("admin@site.org?query=1", (True, b"?query=1")),
+ ("user@example.net", (False, None)),
+ ("user@", (True, b"")), # Live state
+ ("invalid", (True, b"")) # Live state for [a-z]+
+ ])
+ ]
+
+ for pattern, test_cases in prefix_test_cases:
+ with self.subTest(pattern=pattern):
+ byte_fsm = ByteFSM(pattern)
+
+ for test_str, expected in test_cases:
+ with self.subTest(test_str=test_str):
+ success, remainder = byte_fsm.consume_prefix(test_str)
+
+ self.assertEqual(success, expected[0],
+ f"Success flag incorrect for pattern '{pattern}' and input '{test_str}'")
+ self.assertEqual(remainder, expected[1],
+ f"Remainder incorrect for pattern '{pattern}' and input '{test_str}'")
+
+
+ def test_identifier_regex(self):
+ """Test that the identifier regex pattern '[a-zA-Z_][a-zA-Z0-9_]*' works with ByteFSM."""
+ # Create a ByteFSM for the identifier pattern
+ identifier_fsm = ByteFSM(r'[a-zA-Z_][a-zA-Z0-9_]*')
+
+ # Test with valid identifiers
+ valid_identifiers = [b'abc', b'x', b'_var', b'abc123', b'ABC_123']
+ for ident in valid_identifiers:
+ with self.subTest(identifier=ident):
+ self.assertTrue(identifier_fsm.accepts(ident), f"Failed to accept valid identifier: {ident}")
+ final_state = identifier_fsm.try_consume_all(ident)
+ self.assertIsNotNone(final_state, f"Failed to match valid identifier: {ident}")
+
+ # Test with invalid identifiers
+ invalid_identifiers = [b'123abc', b' abc', b'']
+ for ident in invalid_identifiers:
+ with self.subTest(identifier=ident):
+ self.assertFalse(identifier_fsm.accepts(ident), f"Incorrectly accepted invalid identifier: {ident}")
+ final_state = identifier_fsm.try_consume_all(ident)
+ self.assertIsNone(final_state, f"Incorrectly matched invalid identifier: {ident}")
+
+ def test_fsm_properties(self):
+ """Test that FSM properties return valid values."""
+ test_patterns = ["abc", "a[0-9]c", "a+b*c", "cat|dog", "😊", "hello[0-9]+"]
+
+ for pattern in test_patterns:
+ with self.subTest(pattern=pattern):
+ byte_fsm = ByteFSM(pattern)
+
+ self.assertGreater(byte_fsm.num_states, 0,
+ f"FSM for '{pattern}' should have at least one state")
+ self.assertGreater(byte_fsm.alphabet_size, 0,
+ f"FSM for '{pattern}' should have a non-empty alphabet")
+ self.assertGreater(byte_fsm.num_transitions, 0,
+ f"FSM for '{pattern}' should have at least one transition")
+
+if __name__ == "__main__":
+ unittest.main()
+
\ No newline at end of file
diff --git a/tests/mask_store/test_byte_tokenizer.py b/tests/mask_store/test_byte_tokenizer.py
new file mode 100644
index 00000000..13df1bb4
--- /dev/null
+++ b/tests/mask_store/test_byte_tokenizer.py
@@ -0,0 +1,347 @@
+import unittest
+from unittest.mock import MagicMock
+import time
+import random, sys, os
+sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../..')
+from syncode.mask_store.byte_tokenizer import ByteTokenizer
+from syncode.mask_store.byte_tokenizer import VocabType, detect_vocab_type, bytes_to_unicode
+
+class TestByteTokenizer(unittest.TestCase):
+ """Test cases for the ByteTokenizer class with different tokenizer types."""
+
+ def create_mock_tokenizer(self, vocab, vocab_type):
+ """Create a mock tokenizer of the specified type with the given vocabulary."""
+ mock_tokenizer = MagicMock()
+ mock_tokenizer.get_vocab.return_value = vocab
+
+ # Set up the appropriate properties based on vocab_type
+ if vocab_type == VocabType.RAW:
+ # For tiktoken-style tokenizers
+ mock_tokenizer.tokenizer = "tiktoken.Encoding"
+ mock_tokenizer.vocab_files_names = {"vocab_file": "tiktoken_vocab.json"}
+ elif vocab_type == VocabType.BYTE_FALLBACK:
+ # For LLaMA-2 style tokenizers
+ mock_tokenizer.tokenizer = "ByteFallbackTokenizer"
+ else: # BYTE_LEVEL
+ # For GPT-2 style tokenizers
+ mock_tokenizer.tokenizer = "ByteLevelTokenizer"
+
+ # Set up decode method for testing
+ def mock_decode(token_ids, **kwargs):
+ result = ""
+ for token_id in token_ids:
+ for token, tid in vocab.items():
+ if tid == token_id:
+ result += token
+ break
+ return result
+
+ mock_tokenizer.decode = mock_decode
+
+ # Set up encode method for testing
+ def mock_encode(text, **kwargs):
+ # Simplified encoding - just matching exact tokens
+ result = []
+
+ if vocab_type == VocabType.RAW:
+ # For RAW tokenizers, we need to encode the text as bytes
+ text = text.encode('utf-8')
+
+ remaining = text
+ while remaining:
+ matched = False
+ for token, token_id in sorted(vocab.items(), key=lambda x: len(x[0]), reverse=True):
+ if remaining.startswith(token):
+ result.append(token_id)
+ remaining = remaining[len(token):]
+ matched = True
+ break
+ if not matched:
+ # Skip one character if no match
+ remaining = remaining[1:]
+ return result
+
+ mock_tokenizer.encode = mock_encode
+ mock_tokenizer.unk_token_id = 0
+
+ return mock_tokenizer
+
+ def test_raw_tokenizer(self):
+ """Test ByteTokenizer with a RAW (tiktoken-style) tokenizer."""
+ # Create mock vocabulary for a raw tokenizer
+ vocab = {
+ b"hello": 1,
+ b"world": 2,
+ b"!": 3,
+ b"\xE4\xBD\xA0": 4, # 你
+ b"\xE5\xA5\xBD": 5, # 好
+ b"\xE5\x90": 6, # first two bytes of 吗
+ b"\x97": 7, # last byte of 吗
+ }
+
+ mock_tokenizer = self.create_mock_tokenizer(vocab, VocabType.RAW)
+ byte_tokenizer = ByteTokenizer(mock_tokenizer, VocabType.RAW)
+
+ # Test encoding
+ input_bytes = b"hello world!"
+ expected_ids = [1, 2, 3] # hello, world, !
+ # Mocking - we'll just check if the encode method was called correctly
+ mock_tokenizer.encode.return_value = expected_ids
+
+ # Test decoding
+ token_ids = [4, 5, 6] # 你, 好, 吗 (first two bytes).
+ mock_tokenizer.decode.return_value = "你好吗?"
+ result = byte_tokenizer.decode(token_ids)
+ self.assertEqual(result, b"\xE4\xBD\xA0\xE5\xA5\xBD\xE5\x90")
+
+ def test_byte_fallback_tokenizer(self):
+ """Test ByteTokenizer with a BYTE_FALLBACK (Llama-2-style) tokenizer."""
+ # Create mock vocabulary for a byte fallback tokenizer
+ vocab = {
+ "hello": 1,
+ "▁world": 2, # Space-prefixed token
+ "<0x21>": 3, # Byte fallback for !
+ "<0xE4>": 4, # First byte of 你 in UTF-8
+ "<0xBD>": 5, # Second byte of 你 in UTF-8
+ "<0xA0>": 6, # Third byte of 你 in UTF-8
+ }
+
+ mock_tokenizer = self.create_mock_tokenizer(vocab, VocabType.BYTE_FALLBACK)
+ byte_tokenizer = ByteTokenizer(mock_tokenizer, VocabType.BYTE_FALLBACK)
+
+ # Test encoding/decoding of byte fallback tokens
+ self.assertEqual(byte_tokenizer.enbyte_fn("<0x21>"), b"!")
+ self.assertEqual(byte_tokenizer.enbyte_fn("▁world"), b" world")
+
+ # Verify byte_vocab mapping
+ self.assertEqual(byte_tokenizer.byte_vocab[b"!"], 3)
+ self.assertEqual(byte_tokenizer.byte_vocab[b" world"], 2)
+
+ def test_byte_level_tokenizer(self):
+ """Test ByteTokenizer with a BYTE_LEVEL (GPT-2-style) tokenizer."""
+ # Create a simplified byte-to-unicode mapping for testing
+ byte_to_unicode = bytes_to_unicode()
+ unicode_to_byte = {v: k for k, v in byte_to_unicode.items()}
+
+ # Create mock vocabulary with encoded characters
+ # 'Ġ' (U+0120) represents space in GPT-2 tokenizer
+ vocab = {
+ "hello": 1,
+ "Ġworld": 2, # Space-prefixed token in byte-level encoding
+ byte_to_unicode[ord("!")]: 3, # Encoded !
+ }
+
+ mock_tokenizer = self.create_mock_tokenizer(vocab, VocabType.BYTE_LEVEL)
+ byte_tokenizer = ByteTokenizer(mock_tokenizer, VocabType.BYTE_LEVEL)
+
+ # Test encoding byte-level tokens
+ # The byte representation of 'Ġ' followed by 'world'
+ self.assertEqual(byte_tokenizer.enbyte_fn("Ġworld")[0], ord(' '))
+
+ # Test that we can decode a sequence
+ token_ids = [1, 2, 3] # hello, Ġworld, !
+ mock_tokenizer.decode.return_value = "hello world!"
+ byte_result = byte_tokenizer.decode(token_ids)
+
+ # The actual bytes might be different due to the encoding,
+ # but decoding to UTF-8 should give us the original text
+ try:
+ text_result = byte_result.decode('utf-8')
+ self.assertIn("hello", text_result)
+ self.assertIn("world", text_result)
+ except UnicodeDecodeError:
+ # If we can't decode, that's also acceptable for this mock test
+ pass
+
+ def test_batched_decoding(self):
+ """Test batched decoding capabilities."""
+ vocab = {
+ b"hello": 1,
+ b"world": 2,
+ b"!": 3,
+ b"": 4, # special token
+ b"": 5, # special token
+ }
+
+ mock_tokenizer = self.create_mock_tokenizer(vocab, VocabType.RAW)
+ mock_tokenizer.all_special_ids = [4, 5] # Mark and as special tokens
+ byte_tokenizer = ByteTokenizer(mock_tokenizer, VocabType.RAW)
+
+ # Test batched decoding
+ token_batches = [
+ [4, 1, 2], # hello world
+ [4, 1, 2, 3] # hello world !
+ ]
+ batch_results = byte_tokenizer.batched_decode(token_batches)
+ self.assertEqual(len(batch_results), 2)
+
+ # Test batched decoding with skip_special_tokens
+ batch_results_skipped = byte_tokenizer.batched_decode(token_batches, skip_special_tokens=True)
+ self.assertEqual(len(batch_results_skipped), 2)
+
+ def test_auto_detection(self):
+ """Test automatic detection of tokenizer type."""
+ # Test RAW detection
+ raw_vocab = {"hello": 1, "world": 2}
+ raw_tokenizer = self.create_mock_tokenizer(raw_vocab, VocabType.RAW)
+ self.assertEqual(detect_vocab_type(raw_tokenizer), VocabType.RAW)
+
+ # Test BYTE_FALLBACK detection
+ fallback_vocab = {"hello": 1, "<0x0A>": 2}
+ fallback_tokenizer = self.create_mock_tokenizer(fallback_vocab, VocabType.BYTE_FALLBACK)
+ self.assertEqual(detect_vocab_type(fallback_tokenizer), VocabType.BYTE_FALLBACK)
+
+ # Test BYTE_LEVEL detection
+ bytelevel_vocab = {"hello": 1, "Ġworld": 2}
+ bytelevel_tokenizer = self.create_mock_tokenizer(bytelevel_vocab, VocabType.BYTE_LEVEL)
+ # Make sure our mock tokenizer correctly returns the vocabulary with the Ġ character
+ self.assertIn("Ġworld", bytelevel_tokenizer.get_vocab())
+ self.assertEqual(detect_vocab_type(bytelevel_tokenizer), VocabType.BYTE_LEVEL)
+
+ def test_decoding_performance(self):
+ """Test basic decoding performance."""
+ # Create a larger vocabulary for more realistic testing
+ vocab = {bytes(f"token{i}".encode('utf-8')): i for i in range(1000)}
+ # Add some special tokens
+ vocab[b""] = 1000
+ vocab[b""] = 1001
+
+ mock_tokenizer = self.create_mock_tokenizer(vocab, VocabType.RAW)
+ mock_tokenizer.all_special_ids = [1000, 1001]
+ byte_tokenizer = ByteTokenizer(mock_tokenizer, VocabType.RAW)
+
+ # Generate random token sequences of different lengths
+ sequence_lengths = [10, 100, 1000, 10000] # Added longer sequence
+ sequences = {}
+
+ for length in sequence_lengths:
+ sequences[length] = [random.randint(1, 999) for _ in range(length)]
+
+ # Test single decode performance
+ for length, sequence in sequences.items():
+ # Warm-up run
+ byte_tokenizer.decode(sequence)
+
+ # Actual timed run
+ start_time = time.time()
+ repetitions = max(1, 1000 // length) # More repetitions for shorter sequences
+ for _ in range(repetitions):
+ byte_tokenizer.decode(sequence)
+ elapsed = time.time() - start_time
+
+ # Calculate tokens per second
+ tokens_per_second = (length * repetitions) / elapsed
+ self.assertIsNotNone(tokens_per_second) # Simple assertion to check execution
+
+ # Test with special token handling
+ special_sequence = sequences[1000].copy()
+ # Insert special tokens randomly
+ for _ in range(50):
+ pos = random.randint(0, len(special_sequence) - 1)
+ special_sequence[pos] = 1000 if random.random() < 0.5 else 1001
+
+ # Warm-up run
+ byte_tokenizer.decode(special_sequence, skip_special_tokens=True)
+
+ # Actual timed run
+ start_time = time.time()
+ for _ in range(10):
+ byte_tokenizer.decode(special_sequence, skip_special_tokens=True)
+ elapsed = time.time() - start_time
+ self.assertGreater(elapsed, 0) # Simple assertion to check execution
+
+ # Test batched decode performance
+ batch_sizes = [10, 50, 100]
+
+ for batch_size in batch_sizes:
+ # Create batch of same-length sequences
+ batch = [sequences[100] for _ in range(batch_size)]
+
+ # Warm-up run
+ byte_tokenizer.batched_decode(batch)
+
+ # Actual timed run
+ start_time = time.time()
+ for _ in range(5): # Run multiple times for more stable measurement
+ byte_tokenizer.batched_decode(batch)
+ elapsed = time.time() - start_time
+
+ # Calculate tokens per second
+ tokens_per_second = (100 * batch_size * 5) / elapsed
+ self.assertGreater(tokens_per_second, 0) # Simple assertion to check execution
+
+ def test_real_tokenizers(self):
+ """Test ByteTokenizer with real HuggingFace tokenizers."""
+ # Skip test if transformers is not available
+ try:
+ import transformers
+ from transformers import AutoTokenizer
+
+ # Test strings with different characteristics
+ test_strings = [
+ "Hello, world!",
+ "This is a test of ByteTokenizer with different languages.",
+ "Let's try some emojis: 🚀🔥🌍",
+ "And some CJK characters: 你好, 안녕하세요, こんにちは"
+ ]
+
+ models = [
+ "google/gemma-2-2b-it",
+ "meta-llama/Llama-3.1-8B-Instruct"
+ ]
+
+ # Try to load at least one model for testing
+ for model_name in models:
+ try:
+ # Load the tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
+ byte_tokenizer = ByteTokenizer(tokenizer)
+
+ # Test at least one string
+ test_str = test_strings[0]
+ token_ids = tokenizer.encode(test_str, add_special_tokens=False)
+ bytes_result = byte_tokenizer.decode(token_ids)
+
+ # Simple assertion that we got some bytes back
+ self.assertIsInstance(bytes_result, bytes)
+
+ # Skip the rest of the test
+ break
+ except Exception:
+ continue
+
+ except (ImportError, ConnectionError):
+ # Skip the test if no tokenizers are available
+ self.skipTest("Transformers library not available or no internet connection")
+
+ def test_roundtrip_encoding_decoding(self):
+ """Test encoding and decoding round-trip."""
+ # Create a simple vocabulary for testing
+ raw_vocab = {
+ b"hello": 1,
+ b" ": 2,
+ b"world": 3,
+ b"!": 4,
+ }
+
+ mock_tokenizer = self.create_mock_tokenizer(raw_vocab, VocabType.RAW)
+ byte_tokenizer = ByteTokenizer(mock_tokenizer, VocabType.RAW)
+
+ # Test string
+ test_str = "hello world!"
+
+ # Encode with the mock tokenizer
+ token_ids = mock_tokenizer.encode(test_str)
+
+ # Decode with ByteTokenizer
+ decoded_bytes = byte_tokenizer.decode(token_ids)
+
+ # Check round-trip
+ try:
+ decoded_str = decoded_bytes.decode('utf-8')
+ self.assertEqual(test_str, decoded_str)
+ except UnicodeDecodeError:
+ self.fail("Unicode decode error in round-trip test")
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/mask_store/test_fsm_set.py b/tests/mask_store/test_fsm_set.py
new file mode 100644
index 00000000..d7d890e8
--- /dev/null
+++ b/tests/mask_store/test_fsm_set.py
@@ -0,0 +1,212 @@
+import unittest
+from typing import Any, Optional, Tuple, Iterable, Dict
+import sys, os
+sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../..')
+from syncode.mask_store.fsm_set import FSMSet, JointFSMState
+
+# Mock classes for testing
+class MockTerminalDef:
+ """Mock version of TerminalDef for testing purposes"""
+ def __init__(self, name, pattern):
+ self.name = name
+ self.pattern = pattern
+
+ def __repr__(self):
+ return f"TerminalDef({self.name}, {self.pattern})"
+
+class MockRegexPattern:
+ """Mock version of RegexPattern for testing purposes"""
+ def __init__(self, pattern):
+ self.pattern = pattern
+
+ def to_regexp(self):
+ return self.pattern
+
+ def __repr__(self):
+ return f"RegexPattern({self.pattern})"
+
+
+class TestFSMSet(unittest.TestCase):
+ def setUp(self):
+ # Create some test terminals with different regex patterns
+ self.terminals = [
+ MockTerminalDef("IDENTIFIER", MockRegexPattern("[a-zA-Z_][a-zA-Z0-9_]*")),
+ MockTerminalDef("NUMBER", MockRegexPattern("[0-9]+")),
+ MockTerminalDef("STRING", MockRegexPattern('"[^"]*"')),
+ MockTerminalDef("WHITESPACE", MockRegexPattern("[ \t\n\r]+")),
+ MockTerminalDef("OPERATOR", MockRegexPattern("[+\\-*/=<>!]")),
+ MockTerminalDef("KEYWORD", MockRegexPattern("(if|else|while|for|return)")),
+ MockTerminalDef("EMOJI", MockRegexPattern("😊+")), # Test UTF-8 handling
+ ]
+
+ # Create DFAs instance
+ self.dfas = FSMSet(self.terminals)
+
+ def test_compute_fsm_states_simple_inputs(self):
+ """Test compute_fsm_states with simple inputs matching specific terminals"""
+ test_cases = [
+ (b"abc123", ["IDENTIFIER"]),
+ (b"123", ["NUMBER"]),
+ (b'"hello"', ["STRING"]),
+ (b" \t\n", ["WHITESPACE"]),
+ (b"+", ["OPERATOR"]),
+ (b"if", ["KEYWORD", "IDENTIFIER"]), # both KEYWORD and IDENTIFIER match 'if'
+ ("😊".encode('utf-8'), ["EMOJI"]),
+ ]
+
+ for input_bytes, expected_terminals in test_cases:
+ states = self.dfas.compute_fsm_states(input_bytes)
+ terminal_names = [state.terminal for state in states]
+
+ # Check that all expected terminals are matched
+ for terminal in expected_terminals:
+ self.assertIn(terminal, terminal_names,
+ f"Expected terminal {terminal} to match input '{input_bytes}'")
+
+ # Check that no unexpected terminals are matched
+ for terminal in terminal_names:
+ self.assertIn(terminal, expected_terminals,
+ f"Unexpected terminal {terminal} matched input '{input_bytes}'")
+
+ def test_compute_fsm_states_state_tracking(self):
+ """Test that compute_fsm_states correctly tracks the state of each FSM"""
+ # Create a mock ByteFSM with known states for testing
+ class MockByteFSM:
+ def __init__(self, states, transitions, finals):
+ self.states = states
+ self.transitions = transitions
+ self.finals = finals
+ self.initial = states[0]
+
+ def get_next_state(self, state, byte_val):
+ key = (state, byte_val)
+ return self.transitions.get(key, None)
+
+ # Create a simple FSM for testing
+ states = [0, 1, 2, 3]
+ transitions = {
+ (0, ord('a')): 1,
+ (1, ord('b')): 2,
+ (2, ord('c')): 3,
+ }
+ finals = {3}
+
+ # Mock the DFAs object
+ original_terminals = self.dfas._terminals_to_byte_fsm.copy()
+ self.dfas._terminals_to_byte_fsm = {
+ "TEST": MockByteFSM(states, transitions, finals)
+ }
+
+ try:
+ # Test with different inputs
+ test_cases = [
+ (b"a", 1), # Should reach state 1
+ (b"ab", 2), # Should reach state 2
+ (b"abc", 3), # Should reach state 3
+ (b"", 0), # Should stay in initial state 0
+ (b"x", None), # Should not match and return no states
+ ]
+
+ for input_bytes, expected_state in test_cases:
+ states = self.dfas.compute_fsm_states(input_bytes)
+
+ if expected_state is None:
+ self.assertEqual(len(states), 0,
+ f"Expected no matches for input '{input_bytes}'")
+ else:
+ self.assertEqual(len(states), 1,
+ f"Expected exactly one match for input '{input_bytes}'")
+ self.assertEqual(states[0].terminal, "TEST",
+ f"Expected 'TEST' terminal for input '{input_bytes}'")
+ self.assertEqual(states[0].state_id, expected_state, # Changed to state_id
+ f"Expected state {expected_state} for input '{input_bytes}' but got {states[0].state_id}")
+ finally:
+ # Restore original terminals
+ self.dfas._terminals_to_byte_fsm = original_terminals
+
+ def test_consume_prefix(self):
+ """Test consume_prefix with various inputs and starting states"""
+ test_cases = [
+ # (terminal_name, input_bytes, expected_result)
+ ("IDENTIFIER", b"abc123 rest", (True, b" rest")),
+ ("IDENTIFIER", b"123", (False, None)), # IDENTIFIER doesn't match digits first
+ ("NUMBER", b"123abc", (True, b"abc")),
+ ("NUMBER", b"abc", (False, None)), # NUMBER doesn't match letters
+ ("STRING", b'"hello" rest', (True, b" rest")),
+ ("STRING", b'hello"', (False, None)), # STRING needs opening quote
+ ("WHITESPACE", b" \t\nrest", (True, b"rest")),
+ ("OPERATOR", b"+rest", (True, b"rest")),
+ ("KEYWORD", b"if(x)", (True, b"(x)")),
+ ("KEYWORD", b"ifdef", (True, b"def")), # Matches 'if' and leaves 'def'
+ ("EMOJI", "😊😊 text".encode('utf-8'), (True, b" text")),
+ ]
+
+ for terminal_name, input_bytes, expected_result in test_cases:
+ initial_state = self.dfas.initial(terminal_name)
+ result = self.dfas.consume_prefix(initial_state, input_bytes)
+ self.assertEqual(result, expected_result,
+ f"Failed for terminal {terminal_name} with input '{input_bytes}'")
+
+ def test_consume_prefix_with_non_initial_states(self):
+ """Test consume_prefix starting from non-initial states"""
+ # This is a more complex test that requires following transitions
+ # For example, for IDENTIFIER after consuming 'a', we'd be in a non-initial state
+
+ # First, manually transition to a non-initial state
+ terminal_name = "IDENTIFIER"
+ byte_fsm = self.dfas._terminals_to_byte_fsm[terminal_name]
+
+ # Get the state after consuming 'a'
+ next_state = byte_fsm.get_next_state(byte_fsm.initial, ord('a'))
+ if next_state is not None:
+ non_initial_state = JointFSMState(terminal_name, next_state)
+
+ # Now test consume_prefix from this state
+ result = self.dfas.consume_prefix(non_initial_state, b"bc123 rest")
+ self.assertEqual(result, (True, b" rest"))
+
+ def test_byte_regex_handling(self):
+ """Test that byte regex patterns are correctly handled"""
+ # This is a more advanced test to ensure UTF-8 encoding is handled correctly
+ emoji_terminal = MockTerminalDef("EMOJI_TEST", MockRegexPattern("😊+"))
+ dfas = FSMSet([emoji_terminal])
+
+ # Test states computation
+ states = dfas.compute_fsm_states("😊😊".encode('utf-8'))
+ self.assertEqual(len(states), 1)
+ self.assertEqual(states[0].terminal, "EMOJI_TEST")
+
+ # Test consume_prefix
+ initial_state = dfas.initial("EMOJI_TEST")
+ result = dfas.consume_prefix(initial_state, "😊😊rest".encode('utf-8'))
+ self.assertEqual(result, (True, b"rest"))
+
+ # Specific test for UTF-8 bug
+ def test_utf8_bug(self):
+ """Test specifically targeting the UTF-8 bug found in the original implementation"""
+ # Create a terminal that matches emoji
+ emoji_terminal = MockTerminalDef("EMOJI", MockRegexPattern("😊+"))
+ dfas = FSMSet([emoji_terminal])
+
+ # Test with the emoji pattern
+ emoji_initial = dfas.initial("EMOJI")
+ result = dfas.consume_prefix(emoji_initial, "😊😊 text".encode('utf-8'))
+ print(f"Emoji test: {result}") # Debug output
+
+ # More debug info
+ emoji_fsm = dfas._terminals_to_byte_fsm["EMOJI"]
+ print(f"ByteFSM states: {len(emoji_fsm.transitions)}")
+ print(f"ByteFSM alphabet size: {emoji_fsm.alphabet_size}")
+ print(f"ByteFSM transitions: {emoji_fsm.num_transitions}")
+
+ # Try using simple patterns too for comparison
+ simple_terminal = MockTerminalDef("SIMPLE", MockRegexPattern("a+"))
+ simple_dfas = FSMSet([simple_terminal])
+ simple_initial = simple_dfas.initial("SIMPLE")
+ simple_result = simple_dfas.consume_prefix(simple_initial, b"aaa text")
+ print(f"Simple test: {simple_result}") # Debug output
+
+# Run the tests
+if __name__ == "__main__":
+ unittest.main(argv=[''], exit=False)
+ print("Done running tests")
diff --git a/tests/mask_store/test_lookup_table.py b/tests/mask_store/test_lookup_table.py
new file mode 100644
index 00000000..581e3458
--- /dev/null
+++ b/tests/mask_store/test_lookup_table.py
@@ -0,0 +1,482 @@
+import os
+import unittest
+import time
+import torch
+import sys
+from typing import List, Dict, Any, Tuple, Iterable
+import copy
+import logging
+logger = logging.getLogger(__name__)
+sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../..')
+
+# Import the actual classes
+from syncode.mask_store.mask_store import JointFSMState
+from syncode.parse_result import IndentationConstraint
+from syncode.mask_store.lookup_table import LookupTable
+
+class TestLookupTable(unittest.TestCase):
+ def setUp(self):
+ # Create a small vocabulary for basic testing
+ self.small_vocab = ["token1", "token2", " ", "\t", " ", "\t\t", " token", "\ttoken", "token ", "token\t"]
+ self.eos_token_id = 0
+ self.special_token_ids = [0, 1]
+
+ # Create some FSM states for testing - using actual JointFSMState as defined
+ self.state1 = JointFSMState("terminal1", 1)
+ self.state2 = JointFSMState("terminal2", 2)
+
+ # Create lookup table instance
+ self.lookup = LookupTable(
+ vocab=self.small_vocab,
+ eos_token_id=self.eos_token_id,
+ special_token_ids=self.special_token_ids,
+ indent=True
+ )
+
+ # For performance tests, create a larger vocabulary
+ self.large_vocab_size = 10000
+ self.large_vocab = [f"token{i}" for i in range(self.large_vocab_size)]
+
+ # Add some whitespace tokens for indentation testing
+ for i in range(50):
+ self.large_vocab.append(" " * i)
+ self.large_vocab.append("\t" * i)
+ self.large_vocab.append(" " * i + "token")
+ self.large_vocab.append("\t" * i + "token")
+
+ # Create more FSM states for performance testing
+ self.many_states = [JointFSMState(f"terminal{i % 100}", i) for i in range(10000)]
+
+ def time_function(self, func_name, func, *args, **kwargs):
+ """Helper to time a function execution and log it"""
+ start_time = time.time()
+ result = func(*args, **kwargs)
+ end_time = time.time()
+ time_taken = end_time - start_time
+ logger.info(f"{func_name} execution time: {time_taken:.6f}s")
+ return result, time_taken
+
+ def test_initialization(self):
+ """Test initialization of the LookupTable"""
+ # Check vocabulary
+ self.assertEqual(self.lookup._vocab, self.small_vocab)
+
+ # Check default mask
+ expected_default_mask = torch.zeros(len(self.small_vocab), dtype=torch.bool)
+ expected_default_mask[1] = 1 # Only special token that isn't EOS
+
+ self.assertTrue(torch.equal(self.lookup._default_mask, expected_default_mask))
+
+ # Check indentation mode
+ self.assertTrue(self.lookup.indent)
+
+ # Check that the whitespace tokens map was created
+ self.assertTrue(isinstance(self.lookup._whitespace_tokens_map, dict))
+
+ # Check that the indentation tokens map was created
+ self.assertTrue(isinstance(self.lookup._indentation_to_tokens_map, dict))
+
+ def test_add_exact_lookup(self):
+ """Test adding tokens to exact lookup"""
+ # Add some tokens
+ self.lookup.add_exact_lookup(self.state1, 2) # Add token at index 2
+ self.lookup.add_exact_lookup(self.state1, 3) # Add token at index 3
+ self.lookup.add_exact_lookup(self.state2, 4) # Add token at index 4
+
+ # Check that the tokens were added to the exact lookup
+ self.assertIn(2, self.lookup._exact_lookup[self.state1])
+ self.assertIn(3, self.lookup._exact_lookup[self.state1])
+ self.assertIn(4, self.lookup._exact_lookup[self.state2])
+
+ # Check that the tokens were not added to the wrong states
+ self.assertNotIn(2, self.lookup._exact_lookup.get(self.state2, []))
+ self.assertNotIn(4, self.lookup._exact_lookup.get(self.state1, []))
+
+ def test_fsm_state_and_next_terminal(self):
+ """Test adding and accessing tokens by FSM state and next terminal"""
+ next_terminal1 = "next_terminal1"
+ next_terminal2 = "next_terminal2"
+
+ # Initialize and add tokens
+ self.lookup._fsm_state_and_next_terminal_to_tokens[(self.state1, next_terminal1)] = []
+ self.lookup.fsm_state_and_next_terminal_to_tokens_add(self.state1, next_terminal1, 2)
+ self.lookup.fsm_state_and_next_terminal_to_tokens_add(self.state1, next_terminal1, 3)
+
+ self.lookup._fsm_state_and_next_terminal_to_tokens[(self.state2, next_terminal2)] = []
+ self.lookup.fsm_state_and_next_terminal_to_tokens_add(self.state2, next_terminal2, 4)
+
+ # Check that the tokens were added correctly
+ self.assertIn(2, self.lookup._fsm_state_and_next_terminal_to_tokens[(self.state1, next_terminal1)])
+ self.assertIn(3, self.lookup._fsm_state_and_next_terminal_to_tokens[(self.state1, next_terminal1)])
+ self.assertIn(4, self.lookup._fsm_state_and_next_terminal_to_tokens[(self.state2, next_terminal2)])
+
+ # Convert to masks
+ self.lookup.convert_lookups_from_list_to_mask()
+
+ # Check that the tokens were converted to masks correctly
+ state1_terminal1_mask = self.lookup.fsm_state_and_next_terminal_to_tokens(self.state1, next_terminal1)
+ self.assertTrue(state1_terminal1_mask[2])
+ self.assertTrue(state1_terminal1_mask[3])
+ self.assertFalse(state1_terminal1_mask[4])
+
+ state2_terminal2_mask = self.lookup.fsm_state_and_next_terminal_to_tokens(self.state2, next_terminal2)
+ self.assertTrue(state2_terminal2_mask[4])
+ self.assertFalse(state2_terminal2_mask[2])
+ self.assertFalse(state2_terminal2_mask[3])
+
+ def test_overapprox_lookup(self):
+ """Test storing and retrieving overapproximate lookups"""
+ # Create masks
+ mask1 = torch.zeros(len(self.small_vocab), dtype=torch.bool)
+ mask1[5] = 1
+ mask1[6] = 1
+
+ mask2 = torch.zeros(len(self.small_vocab), dtype=torch.bool)
+ mask2[7] = 1
+ mask2[8] = 1
+
+ # Store the masks
+ self.lookup.store_overapprox_lookup(self.state1, mask1)
+ self.lookup.store_overapprox_lookup(self.state2, mask2)
+
+ # Check that the masks were stored correctly
+ stored_mask1 = self.lookup.incomplete_case_lookup(self.state1)
+ self.assertTrue(stored_mask1[5])
+ self.assertTrue(stored_mask1[6])
+
+ stored_mask2 = self.lookup.incomplete_case_lookup(self.state2)
+ self.assertTrue(stored_mask2[7])
+ self.assertTrue(stored_mask2[8])
+
+ # Store another mask for state1 and check that it's ORed with the existing mask
+ mask1_additional = torch.zeros(len(self.small_vocab), dtype=torch.bool)
+ mask1_additional[9] = 1
+
+ self.lookup.store_overapprox_lookup(self.state1, mask1_additional)
+
+ updated_mask1 = self.lookup.incomplete_case_lookup(self.state1)
+ self.assertTrue(updated_mask1[5])
+ self.assertTrue(updated_mask1[6])
+ self.assertTrue(updated_mask1[9])
+
+ def test_get_indent_type(self):
+ """Test the _get_indent_type method"""
+ # Test with whitespace-only strings
+ self.assertEqual(self.lookup._get_indent_type(" "), (True, 4)) # 4 spaces
+ self.assertEqual(self.lookup._get_indent_type("\t\t"), (True, 8)) # 2 tabs (4 spaces each)
+ self.assertEqual(self.lookup._get_indent_type("\t "), (True, 6)) # 1 tab (4 spaces) + 2 spaces
+
+ # Test with strings that start with whitespace
+ self.assertEqual(self.lookup._get_indent_type(" token"), (False, 2)) # 2 spaces + token
+ self.assertEqual(self.lookup._get_indent_type("\ttoken"), (False, 4)) # 1 tab + token
+
+ # Test with strings that don't start with whitespace
+ self.assertEqual(self.lookup._get_indent_type("token"), (False, 0))
+ self.assertEqual(self.lookup._get_indent_type("token "), (False, 0))
+
+ def test_get_indentation_tokens(self):
+ """Test the get_indentation_tokens method"""
+ # Create indentation constraints for testing
+ constraint1 = IndentationConstraint(greater_than_indent_val=2)
+ constraint2 = IndentationConstraint(accept_indents=[2, 4])
+
+ # Convert all lookups to masks first
+ self.lookup.convert_lookups_from_list_to_mask()
+
+ # Test with greater_than_indent_val
+ tokens1 = self.lookup.get_indentation_tokens(constraint1, get_list=True)
+
+ # Check that tokens with indentation > 2 are included
+ # This depends on the specific tokens in the vocabulary, so we'll just check that some tokens are returned
+ self.assertTrue(len(tokens1) > 0)
+
+ # Test with accept_indents
+ tokens2 = self.lookup.get_indentation_tokens(constraint2, get_list=True)
+
+ # Check that tokens with indentation in [2, 4] are included
+ self.assertTrue(len(tokens2) > 0)
+
+ # Check that the tokens returned are valid tokens from the vocabulary
+ for token in tokens1:
+ self.assertIn(token, self.small_vocab)
+
+ for token in tokens2:
+ self.assertIn(token, self.small_vocab)
+
+ def test_list_to_mask_conversion(self):
+ """Test conversion from token lists to masks"""
+ # Add some tokens to the exact lookup
+ self.lookup.add_exact_lookup(self.state1, 2)
+ self.lookup.add_exact_lookup(self.state1, 3)
+
+ # Add some tokens to the fsm_state_and_next_terminal_to_tokens
+ next_terminal = "terminal"
+ self.lookup._fsm_state_and_next_terminal_to_tokens[(self.state1, next_terminal)] = []
+ self.lookup.fsm_state_and_next_terminal_to_tokens_add(self.state1, next_terminal, 4)
+ self.lookup.fsm_state_and_next_terminal_to_tokens_add(self.state1, next_terminal, 5)
+
+ # Convert to masks
+ self.lookup.convert_lookups_from_list_to_mask()
+
+ # Check that the tokens were converted to masks correctly
+ exact_mask = self.lookup.complete_case_lookup(self.state1)
+ self.assertTrue(exact_mask[2])
+ self.assertTrue(exact_mask[3])
+ self.assertFalse(exact_mask[4])
+ self.assertFalse(exact_mask[5])
+
+ fsm_mask = self.lookup.fsm_state_and_next_terminal_to_tokens(self.state1, next_terminal)
+ self.assertFalse(fsm_mask[2])
+ self.assertFalse(fsm_mask[3])
+ self.assertTrue(fsm_mask[4])
+ self.assertTrue(fsm_mask[5])
+
+ # Check that the overapprox lookup was also updated
+ overapprox_mask = self.lookup.incomplete_case_lookup(self.state1)
+ self.assertTrue(torch.any(overapprox_mask)) # Should have some tokens set
+
+ def test_complete_workflow(self):
+ """Test a complete workflow to ensure all components work together"""
+ # Add tokens to exact lookup
+ self.lookup.add_exact_lookup(self.state1, 2)
+ self.lookup.add_exact_lookup(self.state1, 3)
+
+ # Add tokens to fsm_state_and_next_terminal_to_tokens
+ next_terminal = "terminal"
+ self.lookup._fsm_state_and_next_terminal_to_tokens[(self.state1, next_terminal)] = []
+ self.lookup.fsm_state_and_next_terminal_to_tokens_add(self.state1, next_terminal, 4)
+
+ # Add a mask to overapprox lookup
+ mask = torch.zeros(len(self.small_vocab), dtype=torch.bool)
+ mask[5] = 1
+ self.lookup.store_overapprox_lookup(self.state1, mask)
+
+ # Convert to masks
+ self.lookup.convert_lookups_from_list_to_mask()
+
+ # Check exact lookup
+ exact_mask = self.lookup.complete_case_lookup(self.state1)
+ self.assertTrue(exact_mask[2])
+ self.assertTrue(exact_mask[3])
+
+ # Check fsm_state_and_next_terminal_to_tokens
+ fsm_mask = self.lookup.fsm_state_and_next_terminal_to_tokens(self.state1, next_terminal)
+ self.assertTrue(fsm_mask[4])
+
+ # Check overapprox lookup
+ overapprox_mask = self.lookup.incomplete_case_lookup(self.state1)
+ self.assertTrue(overapprox_mask[5])
+ self.assertTrue(torch.any(overapprox_mask & fsm_mask)) # Should share some tokens
+
+ def test_performance_initialization(self):
+ """Test the performance of initialization"""
+ def initialize_lookup():
+ return LookupTable(
+ vocab=self.large_vocab,
+ eos_token_id=self.eos_token_id,
+ special_token_ids=self.special_token_ids,
+ indent=True
+ )
+
+ # Run 5 times to get a reliable measurement
+ total_time = 0
+ runs = 5
+ for i in range(runs):
+ _, time_taken = self.time_function(f"Initialization run {i+1}/{runs}", initialize_lookup)
+ total_time += time_taken
+
+ logger.info(f"Average initialization time ({runs} runs): {total_time/runs:.6f}s")
+
+ def test_performance_token_addition(self):
+ """Test the performance of adding tokens - 100x more tokens"""
+ # Create an instance for testing
+ lookup = LookupTable(
+ vocab=self.large_vocab,
+ eos_token_id=self.eos_token_id,
+ special_token_ids=self.special_token_ids,
+ indent=True
+ )
+
+ def add_tokens():
+ # Add 100,000 tokens instead of 1,000
+ for i in range(100000):
+ state = self.many_states[i % 10000]
+ token_id = i % self.large_vocab_size
+ lookup.add_exact_lookup(state, token_id)
+
+ _, time_taken = self.time_function("Adding 100,000 tokens", add_tokens)
+
+ def test_performance_conversion(self):
+ """Test the performance of converting lists to masks - with much more data"""
+ # Create an instance for testing
+ lookup = LookupTable(
+ vocab=self.large_vocab,
+ eos_token_id=self.eos_token_id,
+ special_token_ids=self.special_token_ids,
+ indent=True
+ )
+
+ # Add many more tokens to the lookup
+ logger.info("Preparing for conversion test - adding tokens...")
+ token_count = 100000 # 100x more tokens
+ for i in range(token_count):
+ state = self.many_states[i % 10000]
+ token_id = i % self.large_vocab_size
+ lookup.add_exact_lookup(state, token_id)
+
+ # Also add to fsm_state_and_next_terminal_to_tokens occasionally
+ if i % 5 == 0:
+ next_terminal = f"terminal{i % 1000}" # More varied terminals
+ if (state, next_terminal) not in lookup._fsm_state_and_next_terminal_to_tokens:
+ lookup._fsm_state_and_next_terminal_to_tokens[(state, next_terminal)] = []
+ lookup.fsm_state_and_next_terminal_to_tokens_add(state, next_terminal, token_id)
+
+ logger.info(f"Added {token_count} tokens, now converting to masks...")
+
+ # Time the conversion
+ _, time_taken = self.time_function("Converting large list to masks", lookup.convert_lookups_from_list_to_mask)
+
+ def test_performance_lookup(self):
+ """Test the performance of token lookups - with many more lookups"""
+ # Create and populate an instance for testing
+ lookup = LookupTable(
+ vocab=self.large_vocab,
+ eos_token_id=self.eos_token_id,
+ special_token_ids=self.special_token_ids,
+ indent=True
+ )
+
+ # Add tokens to the lookup - more tokens for a more realistic test
+ logger.info("Preparing for lookup test - adding tokens...")
+ for i in range(10000): # 10x more tokens
+ state = self.many_states[i % 1000] # Use 1000 different states
+ token_id = i % self.large_vocab_size
+ lookup.add_exact_lookup(state, token_id)
+
+ # Convert to masks
+ lookup.convert_lookups_from_list_to_mask()
+ logger.info("Converted to masks, now performing lookups...")
+
+ def lookup_tokens():
+ results = []
+ # Perform many more lookups - 10,000 instead of 100
+ num_lookups = 10000
+ for i in range(num_lookups):
+ state = self.many_states[i % 1000]
+ try:
+ results.append(lookup.complete_case_lookup(state))
+ except KeyError:
+ pass # Some states might not have tokens
+ return results
+
+ _, time_taken = self.time_function(f"Looking up tokens for 10,000 states", lookup_tokens)
+
+ def test_performance_indentation(self):
+ """Test the performance of indentation token lookup - with more constraints"""
+ # Create an instance for testing
+ lookup = LookupTable(
+ vocab=self.large_vocab,
+ eos_token_id=self.eos_token_id,
+ special_token_ids=self.special_token_ids,
+ indent=True
+ )
+
+ # Convert lookups
+ lookup.convert_lookups_from_list_to_mask()
+
+ def get_indentation_tokens():
+ # Create more varied constraints
+ constraints = []
+
+ # Add greater_than_indent_val constraints
+ for i in range(20): # 4x more constraints
+ constraints.append(IndentationConstraint(greater_than_indent_val=i))
+
+ # Add accept_indents constraints with various sizes
+ for i in range(20):
+ # Create accept_indents with different sizes
+ constraints.append(IndentationConstraint(accept_indents=list(range(i, i+3))))
+ constraints.append(IndentationConstraint(accept_indents=list(range(i, i+5))))
+ constraints.append(IndentationConstraint(accept_indents=[i, i+2, i+4, i+6]))
+
+ results = []
+ # Process each constraint multiple times to simulate repeated lookups
+ for _ in range(5): # 5x repetition
+ for constraint in constraints:
+ results.append(lookup.get_indentation_tokens(constraint))
+ return results
+
+ _, time_taken = self.time_function("Indentation token lookup with 100 constraints, repeated 5 times", get_indentation_tokens)
+
+ def test_performance_overall_workflow(self):
+ """Test a complete workflow to measure overall performance - much larger workload"""
+ def workflow():
+ # Initialize
+ lookup = LookupTable(
+ vocab=self.large_vocab,
+ eos_token_id=self.eos_token_id,
+ special_token_ids=self.special_token_ids,
+ indent=True
+ )
+
+ num_tokens = 50000
+ logger.info(f"Adding {num_tokens} tokens to various collections...")
+
+ for i in range(num_tokens):
+ state = self.many_states[i % 10000]
+ token_id = i % self.large_vocab_size
+ lookup.add_exact_lookup(state, token_id)
+
+ # Also add to fsm_state_and_next_terminal_to_tokens
+ if i % 5 == 0:
+ next_terminal = f"terminal{i % 1000}" # More varied terminals
+ if (state, next_terminal) not in lookup._fsm_state_and_next_terminal_to_tokens:
+ lookup._fsm_state_and_next_terminal_to_tokens[(state, next_terminal)] = []
+ lookup.fsm_state_and_next_terminal_to_tokens_add(state, next_terminal, token_id)
+
+ # Also add to overapprox lookup occasionally
+ if i % 7 == 0:
+ mask = torch.zeros(len(self.large_vocab), dtype=torch.bool)
+ mask[token_id] = 1
+ lookup.store_overapprox_lookup(state, mask)
+
+ # Convert lookups
+ logger.info("Converting lookups from lists to masks...")
+ lookup.convert_lookups_from_list_to_mask()
+
+ # Do many more lookups
+ logger.info("Performing lookups...")
+ results = []
+ lookup_count = 1000
+ for i in range(lookup_count):
+ state = self.many_states[i % 10000]
+ try:
+ results.append(lookup.complete_case_lookup(state))
+ except KeyError:
+ pass
+
+ # Fix for the KeyError issue - use try/except for incomplete_case_lookup as well
+ try:
+ results.append(lookup.incomplete_case_lookup(state))
+ except KeyError:
+ # If the state doesn't exist in the _overapprox_lookup, just skip it
+ pass
+
+ # Check indentation with more constraints
+ logger.info("Processing indentation constraints...")
+ for i in range(30):
+ constraint = IndentationConstraint(greater_than_indent_val=i)
+ results.append(lookup.get_indentation_tokens(constraint))
+
+ if i < 25: # Add some accept_indents constraints too
+ constraint = IndentationConstraint(accept_indents=list(range(i, i+5)))
+ results.append(lookup.get_indentation_tokens(constraint))
+
+ logger.info("Workflow complete.")
+ return results
+
+ _, time_taken = self.time_function("Complete large-scale workflow", workflow)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/mask_store/test_mask_store_go.py b/tests/mask_store/test_mask_store_go.py
new file mode 100644
index 00000000..a6364804
--- /dev/null
+++ b/tests/mask_store/test_mask_store_go.py
@@ -0,0 +1,48 @@
+import sys
+import os
+import time
+import unittest
+sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../..')
+import syncode.common as common
+from syncode.parsers.incremental_parser import ParseResult
+from syncode.parse_result import AcceptSequence, RemainderState
+from syncode.mask_store.mask_store import MaskStore
+from syncode.parsers.grammars.grammar import Grammar
+from tests.test_utils import CustomAssertMixin
+
+
+# Initialize these outside the test class if they're shared across tests
+model = 'Qwen/Qwen2.5-1.5B-Instruct'
+tokenizer = common.load_tokenizer(model)
+mask_store = MaskStore.init_mask_store(grammar=Grammar('go'), tokenizer=tokenizer, use_cache=False, mode='grammar_mask')
+
+class TestDFAMask(unittest.TestCase, CustomAssertMixin):
+ def test_dfa_mask(self):
+ r = ParseResult({AcceptSequence(['DECIMAL_LIT', 'PLUS'])}, b'1', RemainderState.MAYBE_COMPLETE)
+ mask_store.get_accept_mask(r, get_list=True)
+ result_list = mask_store.get_accept_mask(r, get_list=True)
+ for token in [' +', ' +=', ' ++']:
+ self.assertInWithLimit(token, result_list, f"{token} not found in result list")
+
+ def test_dfa_mask2(self):
+ r = ParseResult({AcceptSequence(['EOS'])}, b'\n // 1.', RemainderState.MAYBE_COMPLETE)
+ result_list = mask_store.get_accept_mask(r, get_list=True)
+ self.assertTrue(len(result_list) > 32000, "Result list is smaller than expected")
+
+ def test_dfa_mask3(self):
+ r = ParseResult({AcceptSequence(['__ANON_14'])}, b'', RemainderState.COMPLETE)
+ result_list = mask_store.get_accept_mask(r, get_list=True)
+ # Uncomment the following line if you want to assert presence of specific tokens
+ self.assertInWithLimit(":=", result_list, ":= not found in result list")
+
+ def test_dfa_mask4(self):
+ r = ParseResult({AcceptSequence(['__IGNORE_0'])}, b'', RemainderState.COMPLETE)
+ self.assertInWithLimit("\t", mask_store.get_accept_mask(r, get_list=True), "Tab character not found in result list")
+
+ def test_dfa_mask5(self):
+ r = ParseResult({AcceptSequence(['LBRACE', '__IGNORE_0'])}, b'{', RemainderState.MAYBE_COMPLETE)
+ self.assertInWithLimit("\t", mask_store.get_accept_mask(r, get_list=True), "Tab character not found in result list")
+
+ def test_dfa_mask6(self):
+ r = ParseResult({AcceptSequence(['NAME'])}, b'for', RemainderState.MAYBE_COMPLETE)
+ self.assertInWithLimit(" {", mask_store.get_accept_mask(r, get_list=True), "Opening brace not found in result list")
diff --git a/tests/mask_store/test_mask_store_python.py b/tests/mask_store/test_mask_store_python.py
new file mode 100644
index 00000000..7bfc032a
--- /dev/null
+++ b/tests/mask_store/test_mask_store_python.py
@@ -0,0 +1,260 @@
+import unittest
+import sys, os
+sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../..')
+import time
+from tests.test_utils import CustomAssertMixin
+import syncode.common as common
+from syncode.parsers.incremental_parser import ParseResult
+from syncode.parse_result import AcceptSequence, IndentationConstraint, RemainderState
+from syncode.mask_store.mask_store import MaskStore
+from syncode.parsers import create_parser
+from syncode.parsers.grammars.grammar import Grammar
+import logging
+logger = logging.getLogger(__name__)
+
+
+class TestDFAMaskLlama(unittest.TestCase, CustomAssertMixin):
+ model = 'meta-llama/Llama-2-7b-hf'
+ tokenizer = common.load_tokenizer(model)
+ mask_store = MaskStore.init_mask_store(
+ grammar=Grammar('python'),
+ tokenizer=tokenizer,
+ use_cache=True,
+ indent=True,
+ mode="grammar_strict"
+ )
+
+ @classmethod
+ def setUpClass(cls):
+ """Configure logging before any tests run."""
+ super().setUpClass()
+ common.setup_logging()
+
+ def test_strict(self):
+ query_start_time = time.time()
+ r = ParseResult({AcceptSequence(['DEC_NUMBER', 'PLUS'])}, b'1', RemainderState.MAYBE_COMPLETE)
+ self.mask_store.get_accept_mask(r) # Assuming dfa_mask is accessible
+ time_taken_for_mask_query = time.time() - query_start_time
+
+ query_start_time = time.time()
+ r = ParseResult({AcceptSequence(['DEC_NUMBER', 'PLUS'])}, b'1', RemainderState.MAYBE_COMPLETE)
+ self.mask_store.get_accept_mask(r, get_list=True)
+ time_taken_for_list_query = time.time() - query_start_time
+ logger.info(f"Time taken for mask query: {time_taken_for_mask_query} and list query: {time_taken_for_list_query}")
+
+ ac_list = self.mask_store.get_accept_mask(r, get_list=True)
+ self.assertInWithLimit(' +', ac_list)
+ self.assertNotIn(' +=', ac_list) # In strict mode this should not be present
+ self.assertNotIn(' ++', ac_list) # In strict mode this should not be present
+
+
+ def test_dfa_mask2(self):
+ r = ParseResult({AcceptSequence(['NAME'])}, b'\n"""comment"""\n', RemainderState.MAYBE_COMPLETE)
+ ac_list = self.mask_store.get_accept_mask(r, get_list=True)
+ self.assertGreaterWithLimit(len(ac_list), 0, ac_list)
+
+ def test_dfa_mask3(self):
+ r = ParseResult({AcceptSequence(['STRING', 'FOR'])}, b'"Return only negative numbers in the list. Note that this is not the same as the negative of the list. ', RemainderState.MAYBE_COMPLETE)
+ ac_list = self.mask_store.get_accept_mask(r, get_list=True)
+ self.assertGreaterWithLimit(len(ac_list), 0, ac_list)
+
+ def test_dfa_mask4(self):
+ r = ParseResult({AcceptSequence(['NAME', 'LPAR'])}, b'upper', RemainderState.MAYBE_COMPLETE)
+ ac_list = self.mask_store.get_accept_mask(r, get_list=True)
+ self.assertInWithLimit('(', ac_list)
+
+ def test_dfa_mask5(self):
+ s = b'\n\t""" Check if in given list of numbers, are any two numbers closer to each other than\n\tgiven threshold.\n\t>>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n\tFalse\n\t>>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n\tTrue\n\t"""\n'
+ r = ParseResult({AcceptSequence(['_NL', 'NAME'])}, s, RemainderState.MAYBE_COMPLETE)
+ ac_list = self.mask_store.get_accept_mask(r, get_list=True)
+ self.assertTrueWithLimit(all([t in ac_list for t in ['\t', '\n', '""', '#', "''", "'", '"']]), ['\t', '\n', '""', '#', "''", "'", '"'], ac_list)
+
+ def test_dfa_mask6(self):
+ r = ParseResult({AcceptSequence(['DEC_NUMBER', 'COLON'])}, b'2', RemainderState.MAYBE_COMPLETE)
+ ac_list = self.mask_store.get_accept_mask(r, get_list=True)
+ self.assertFalse(any([t in ac_list for t in ['+', '#', '-', '*']]))
+
+ def test_dfa_mask7(self):
+ r = ParseResult({AcceptSequence(['LPAR'])}, b'', RemainderState.COMPLETE)
+ ac_list = self.mask_store.get_accept_mask(r, get_list=True)
+ self.assertEqual(len([ac for ac in ac_list if 'num' in ac]), 0)
+ self.assertGreater(len([ac for ac in ac_list if '(' in ac]), 0)
+
+ def test_dfa_mask8(self):
+ r = ParseResult({AcceptSequence(['NAME', 'LPAR'])}, b'print', RemainderState.MAYBE_COMPLETE)
+ ac_list = self.mask_store.get_accept_mask(r, get_list=True)
+ self.assertNotIn(' num', ac_list)
+ self.assertInWithLimit('num', ac_list)
+ self.assertInWithLimit('(', ac_list)
+
+ def test_special_token(self):
+ r = ParseResult({AcceptSequence(['_NL'])}, b'', RemainderState.COMPLETE)
+ ac_list = self.mask_store.get_accept_mask(r, get_list=True)
+ self.assertInWithLimit('', ac_list) # special token should always be in the list
+
+ def test_eos_token(self):
+ # EOS token should be in the list if $END is in the accept sequence
+ r = ParseResult({AcceptSequence(['$END'])}, b'', RemainderState.COMPLETE)
+ ac_list = self.mask_store.get_accept_mask(r, get_list=True)
+ self.assertInWithLimit('', ac_list)
+
+ # EOS token should be in the list if $END is in the accept sequence
+ r = ParseResult({AcceptSequence(['RPAR', '$END'])}, b')', RemainderState.MAYBE_COMPLETE)
+ ac_list = self.mask_store.get_accept_mask(r, get_list=True)
+ self.assertInWithLimit('', ac_list)
+
+ # EOS token should not be in the list if $END is not in the accept sequence
+ r = ParseResult({AcceptSequence(['NAME'])}, b'', RemainderState.COMPLETE)
+ ac_list = self.mask_store.get_accept_mask(r, get_list=True)
+ self.assertNotIn('', ac_list)
+
+ def test_dfa_mask13(self):
+ r = ParseResult({AcceptSequence(['NAME']), AcceptSequence(['RETURN', 'NAME'])}, b'return', RemainderState.MAYBE_COMPLETE, next_ac_indents=None)
+ ac_list = self.mask_store.get_accept_mask(r, get_list=True)
+ self.assertInWithLimit("ing", ac_list)
+ self.assertInWithLimit(" x", ac_list)
+
+ def test_indent(self):
+ ac_list = self.mask_store._lookup_table.get_indentation_tokens(IndentationConstraint(accept_indents=[1]), get_list=True)
+ self.assertTrue(all(t in ac_list for t in [' int', ' ']))
+ self.assertFalse(' ' in ac_list)
+
+ ac_list = self.mask_store._lookup_table.get_indentation_tokens(IndentationConstraint(accept_indents=[2]), get_list=True)
+ self.assertTrue(all(t in ac_list for t in [' ', ' ']))
+
+ ac_list = self.mask_store._lookup_table.get_indentation_tokens(IndentationConstraint(accept_indents=[4]), get_list=True)
+ self.assertTrue(all(t in ac_list for t in ['\t', ' ', ' ', ' ', ' ', ' ']))
+
+ ac_list = self.mask_store._lookup_table.get_indentation_tokens(IndentationConstraint(greater_than_indent_val=0), get_list=True)
+ self.assertInWithLimit(' int', ac_list)
+
+ ac_list = self.mask_store._lookup_table.get_indentation_tokens(IndentationConstraint(greater_than_indent_val=1), get_list=True)
+ self.assertFalse(' int' in ac_list)
+ self.assertTrue(' ' in ac_list)
+
+ ac_list = self.mask_store._lookup_table.get_indentation_tokens(IndentationConstraint(greater_than_indent_val=3), get_list=True)
+ self.assertIn(' ', ac_list)
+
+ def test_dfa_mask_with_indent(self):
+ r = ParseResult({AcceptSequence(['NAME'])}, b'int', RemainderState.COMPLETE, IndentationConstraint(accept_indents=[0]))
+ ac_list = self.mask_store.get_accept_mask(r, get_list=True)
+ self.assertInWithLimit('int', ac_list)
+
+ r = ParseResult({AcceptSequence(['IF'])}, b'', RemainderState.COMPLETE, IndentationConstraint(accept_indents=[1]))
+ ac_list = self.mask_store.get_accept_mask(r, get_list=True)
+ self.assertInWithLimit(' if', ac_list)
+
+ r = ParseResult({AcceptSequence(['NAME'])}, b'int', RemainderState.COMPLETE, IndentationConstraint(greater_than_indent_val=0))
+ ac_list = self.mask_store.get_accept_mask(r, get_list=True)
+ self.assertInWithLimit(' int', ac_list)
+
+ r = ParseResult({AcceptSequence(['NAME'])}, b'', RemainderState.COMPLETE, IndentationConstraint(greater_than_indent_val=1))
+ ac_list = self.mask_store.get_accept_mask(r, get_list=True)
+ self.assertNotIn(' int', ac_list)
+
+ r = ParseResult({AcceptSequence(['IF'])}, b'', RemainderState.COMPLETE, IndentationConstraint(accept_indents=[0]))
+ ac_list = self.mask_store.get_accept_mask(r, get_list=True)
+ self.assertInWithLimit('if', ac_list)
+
+ r = ParseResult({AcceptSequence(['_NL', 'RETURN'])}, b'\n\t\t', RemainderState.MAYBE_COMPLETE, IndentationConstraint(greater_than_indent_val=-1))
+ ac_list = self.mask_store.get_accept_mask(r, get_list=True)
+ self.assertInWithLimit('return', ac_list)
+
+
+ @unittest.skip("Skipping the correctness comparison test.")
+ def test_indentation(self):
+ from mxeval.data import get_data
+ mbpp = get_data("mbxp", "python")
+ p = create_parser('python')
+ self.assertEqual(p._get_indentation(mbpp['MBPP/1']["prompt"]), 4)
+ self.assertEqual(p._get_indentation(mbpp['MBPP/2']["prompt"]), 2)
+ self.assertEqual(p._get_indentation(mbpp['MBPP/8']["prompt"]), 1)
+
+
+ @unittest.skip("Skipping the correctness comparison test.")
+ def test_simplifications(self):
+ import regex
+ simplifications = Grammar('python').simplifications()
+
+ # COMMENT
+ reg = simplifications['COMMENT']
+ self.assertIsNotNone(regex.match(reg, '# Hello'))
+ self.assertIsNotNone(regex.match(reg, '""" Foo \n Bar """'))
+ self.assertIsNotNone(regex.match(reg, "''' Foo \n Bar '''"))
+
+ # LONG_STRING
+ reg = simplifications['LONG_STRING']
+ self.assertIsNotNone(regex.match(reg, '""" Foo \n Bar """'))
+ self.assertIsNotNone(regex.match(reg, "''' Foo \n Bar '''"))
+ self.assertIsNone(regex.match(reg, '""" Foo \n Bar '))
+ self.assertIsNone(regex.match(reg, "''' Foo \n Bar "))
+
+ # STRING
+ reg = simplifications['STRING']
+ self.assertIsNotNone(regex.match(reg, '"Foo"'))
+ self.assertIsNotNone(regex.match(reg, "'Foo'"))
+ self.assertIsNone(regex.match(reg, '"Foo'))
+ self.assertIsNone(regex.match(reg, "'Foo"))
+
+ # _NL
+ reg = simplifications['_NL']
+ self.assertIsNotNone(regex.match(reg, '\n'))
+ self.assertIsNotNone(regex.match(reg, '\n\n'))
+ self.assertIsNotNone(regex.match(reg, '\n""" Foo \n Bar """'))
+ self.assertIsNotNone(regex.match(reg, '\n# Hello!'))
+
+
+class TestDFAMaskCodegen(unittest.TestCase, CustomAssertMixin):
+ model = 'Salesforce/codegen-350M-multi'
+ tokenizer = common.load_tokenizer(model)
+ mask_store = MaskStore.init_mask_store(
+ grammar=Grammar('python'),
+ tokenizer=tokenizer,
+ use_cache=True,
+ indent=True,
+ mode="grammar_mask"
+ )
+
+ def test_overapprox(self):
+ r = ParseResult({AcceptSequence(['DEC_NUMBER', 'PLUS'])}, b'1', RemainderState.MAYBE_COMPLETE)
+ self.mask_store.get_accept_mask(r, get_list=True)
+ ac_list = self.mask_store.get_accept_mask(r, get_list=True)
+ self.assertInWithLimit(' +', ac_list)
+ self.assertInWithLimit(' +=', ac_list) # In overapprox mode this should be present
+ self.assertInWithLimit(' ++', ac_list) # In overapprox mode this should be present
+
+ def test_dfa_mask10(self):
+ ac_list = self.mask_store.get_accept_mask(ParseResult({AcceptSequence(['STRING'])}, b"'", RemainderState.INCOMPLETE, next_ac_indents=None), get_list=True)
+ self.assertInWithLimit(" '.", ac_list)
+
+ def test_dfa_mask11(self):
+ ac_list = self.mask_store.get_accept_mask(ParseResult({AcceptSequence(['STRING'])}, b"'", RemainderState.INCOMPLETE, next_ac_indents=None), get_list=True)
+ self.assertInWithLimit(" '.", ac_list)
+
+ def test_dfa_mask12(self):
+ r = ParseResult({AcceptSequence(['_NL', 'IF'])}, b'\n\t\t', RemainderState.MAYBE_COMPLETE, next_ac_indents=None)
+ ac_list = self.mask_store.get_accept_mask(r, get_list=True)
+ self.assertInWithLimit("if", ac_list)
+
+ def test_dfa_mask13(self):
+ r = ParseResult({AcceptSequence(['NAME', 'LPAR'])}, b'print', RemainderState.MAYBE_COMPLETE)
+ ac_list = self.mask_store.get_accept_mask(r, get_list=True)
+ self.assertNotIn(' num', ac_list)
+ self.assertInWithLimit('num', ac_list)
+ self.assertInWithLimit('()', ac_list)
+
+ def test_dfa_mask14(self):
+ r = ParseResult({AcceptSequence(['NAME', 'LPAR'])}, b'upper', RemainderState.MAYBE_COMPLETE)
+ ac_list = self.mask_store.get_accept_mask(r, get_list=True)
+ self.assertTrueWithLimit(all([t in ac_list for t in ['()', '(']]), ['()', '('], ac_list)
+
+if __name__ == '__main__':
+ run_codegen, run_llama = True, True
+
+ if run_llama:
+ suite = unittest.TestLoader().loadTestsFromTestCase(TestDFAMaskLlama)
+ unittest.TextTestRunner().run(suite)
+
+ if run_codegen:
+ suite = unittest.TestLoader().loadTestsFromTestCase(TestDFAMaskCodegen)
+ unittest.TextTestRunner().run(suite)
diff --git a/tests/test_dfa_mask_go.py b/tests/test_dfa_mask_go.py
deleted file mode 100644
index 0c4491f4..00000000
--- a/tests/test_dfa_mask_go.py
+++ /dev/null
@@ -1,55 +0,0 @@
-import sys
-import os
-import time
-import unittest
-sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../')
-import syncode.common as common
-from syncode.parsers.incremental_parser import ParseResult
-from syncode.parse_result import AcceptSequence, RemainderState
-from syncode.dfa_mask_store import DFAMaskStore
-from syncode.parsers.grammars.grammar import Grammar
-
-# Initialize these outside the test class if they're shared across tests
-model = 'deepseek-ai/deepseek-coder-6.7b-instruct'
-# model = 'Llama-7b'
-tokenizer = common.load_tokenizer(model)
-dfa_mask = DFAMaskStore.load_dfa_mask_store(grammar=Grammar('go'), tokenizer=tokenizer, use_cache=True, logger=common.EmptyLogger())
-
-class TestDFAMask(unittest.TestCase):
- def test_dfa_mask(self):
- query_start_time = time.time()
- r = ParseResult({AcceptSequence(['DECIMAL_LIT', 'PLUS'])}, '1', RemainderState.MAYBE_COMPLETE)
- dfa_mask.get_accept_mask(r) # This is just to run the function, assuming you're checking time
- # self.assertLess(time.time() - query_start_time, 0.02, "Mask query took too long")
-
- query_start_time = time.time()
- r = ParseResult({AcceptSequence(['DECIMAL_LIT', 'PLUS'])}, '1', RemainderState.MAYBE_COMPLETE)
- dfa_mask.get_accept_mask(r, get_list=True)
- # self.assertLess(time.time() - query_start_time, 10**-4, "List query took too long")
- result_list = dfa_mask.get_accept_mask(r, get_list=True)
- for token in [' +', ' +=', ' ++']:
- self.assertIn(token, result_list, f"{token} not found in result list")
-
- def test_dfa_mask2(self):
- r = ParseResult({AcceptSequence(['EOS'])}, '\n // 1.', RemainderState.MAYBE_COMPLETE)
- result_list = dfa_mask.get_accept_mask(r, get_list=True)
- self.assertTrue(len(result_list) > 32000, "Result list is smaller than expected")
-
- def test_dfa_mask3(self):
- r = ParseResult({AcceptSequence(['__ANON_14'])}, '', RemainderState.COMPLETE)
- result_list = dfa_mask.get_accept_mask(r, get_list=True)
- # Uncomment the following line if you want to assert presence of specific tokens
- # self.assertIn(":=", result_list, ":= not found in result list")
-
- def test_dfa_mask4(self):
- r = ParseResult({AcceptSequence(['__IGNORE_0'])}, '', RemainderState.COMPLETE)
- self.assertIn("\t", dfa_mask.get_accept_mask(r, get_list=True), "Tab character not found in result list")
-
- def test_dfa_mask5(self):
- r = ParseResult({AcceptSequence(['LBRACE', '__IGNORE_0'])}, '{', RemainderState.MAYBE_COMPLETE)
- self.assertIn("\t", dfa_mask.get_accept_mask(r, get_list=True), "Tab character not found in result list")
-
- def test_dfa_mask6(self):
- r = ParseResult({AcceptSequence(['NAME'])}, 'for', RemainderState.MAYBE_COMPLETE)
- self.assertIn(" {", dfa_mask.get_accept_mask(r, get_list=True), "Opening brace not found in result list")
-
diff --git a/tests/test_dfa_mask_python.py b/tests/test_dfa_mask_python.py
deleted file mode 100644
index b59111df..00000000
--- a/tests/test_dfa_mask_python.py
+++ /dev/null
@@ -1,221 +0,0 @@
-import unittest
-import sys, os
-sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../')
-import time
-import syncode.common as common
-from syncode.parsers.incremental_parser import ParseResult
-from syncode.parse_result import AcceptSequence, IndentationConstraint, RemainderState
-from syncode.dfa_mask_store import DFAMaskStore
-from syncode.parsers import create_parser
-from syncode.parsers.grammars.grammar import Grammar
-
-class TestDFAMaskLlama(unittest.TestCase):
-
- model = 'Llama-7b'
- tokenizer = common.load_tokenizer(model)
- dfa_mask = DFAMaskStore.load_dfa_mask_store(grammar=Grammar('python'), tokenizer=tokenizer, use_cache=True, logger=common.EmptyLogger())
-
- def test_dfa_mask(self):
- query_start_time = time.time()
- r = ParseResult({AcceptSequence(['DEC_NUMBER', 'PLUS'])}, '1', RemainderState.MAYBE_COMPLETE)
- self.dfa_mask.get_accept_mask(r) # Assuming dfa_mask is accessible
- time_taken_for_mask_query = time.time() - query_start_time
-
- query_start_time = time.time()
- r = ParseResult({AcceptSequence(['DEC_NUMBER', 'PLUS'])}, '1', RemainderState.MAYBE_COMPLETE)
- self.dfa_mask.get_accept_mask(r, get_list=True)
- time_taken_for_list_query = time.time() - query_start_time
-
- ac_list = self.dfa_mask.get_accept_mask(r, get_list=True)
- self.assertTrue(all(t in ac_list for t in [' +', ' +=', ' ++']))
-
- def test_dfa_mask2(self):
- r = ParseResult({AcceptSequence(['NAME'])}, '\n"""comment"""\n', RemainderState.MAYBE_COMPLETE)
- self.assertGreater(len(self.dfa_mask.get_accept_mask(r, get_list=True)), 0)
-
- def test_dfa_mask3(self):
- r = ParseResult({AcceptSequence(['STRING', 'FOR'])}, '"Return only negative numbers in the list. Note that this is not the same as the negative of the list. ', RemainderState.MAYBE_COMPLETE)
- self.assertGreater(len(self.dfa_mask.get_accept_mask(r, get_list=True)), 0)
-
- def test_dfa_mask4(self):
- r = ParseResult({AcceptSequence(['NAME', 'LPAR'])}, 'upper', RemainderState.MAYBE_COMPLETE)
- ac_list = self.dfa_mask.get_accept_mask(r, get_list=True)
- self.assertTrue(all([t in ac_list for t in ['()', '(']]))
-
- def test_dfa_mask5(self):
- s = '\n\t""" Check if in given list of numbers, are any two numbers closer to each other than\n\tgiven threshold.\n\t>>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n\tFalse\n\t>>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n\tTrue\n\t"""\n'
- r = ParseResult({AcceptSequence(['_NL', 'NAME'])}, s, RemainderState.MAYBE_COMPLETE)
- ac_list = self.dfa_mask.get_accept_mask(r, get_list=True)
- self.assertTrue(all([t in ac_list for t in ['\t', '\n', '""', '#', "''", "'", '"']]))
-
- def test_dfa_mask6(self):
- r = ParseResult({AcceptSequence(['DEC_NUMBER', 'COLON'])}, '2', RemainderState.MAYBE_COMPLETE)
- ac_list = self.dfa_mask.get_accept_mask(r, get_list=True)
- self.assertFalse(any([t in ac_list for t in ['+', '#', '-', '*']]))
-
- def test_dfa_mask7(self):
- r = ParseResult({AcceptSequence(['LPAR'])}, '', RemainderState.COMPLETE)
- ac_list = self.dfa_mask.get_accept_mask(r, get_list=True)
- self.assertEqual(len([ac for ac in ac_list if 'num' in ac]), 0)
- self.assertGreater(len([ac for ac in ac_list if '(' in ac]), 0)
-
- def test_dfa_mask8(self):
- r = ParseResult({AcceptSequence(['NAME', 'LPAR'])}, 'print', RemainderState.MAYBE_COMPLETE)
- ac_list = self.dfa_mask.get_accept_mask(r, get_list=True)
- self.assertNotIn(' num', ac_list)
- self.assertIn('num', ac_list)
- self.assertIn('()', ac_list)
-
- def test_dfa_mask9(self):
- r = ParseResult({AcceptSequence(['_NL'])}, '', RemainderState.COMPLETE)
- ac_list = self.dfa_mask.get_accept_mask(r, get_list=True)
- self.assertIn('', ac_list) # special token should always be in the list
-
- def test_dfa_mask13(self):
- r = ParseResult({AcceptSequence(['NAME']), AcceptSequence(['RETURN', 'NAME'])}, 'return', RemainderState.MAYBE_COMPLETE, next_ac_indents=None)
- ac_list = self.dfa_mask.get_accept_mask(r, get_list=True)
- self.assertIn("ing", ac_list)
- self.assertIn(" x", ac_list)
-
- def test_indent(self):
- ac_list = self.dfa_mask._lookup_table.get_indentation_tokens(IndentationConstraint(accept_indents=[1]), get_list=True)
- self.assertTrue(all(t in ac_list for t in [' int', ' ']))
- self.assertFalse(' ' in ac_list)
-
- ac_list = self.dfa_mask._lookup_table.get_indentation_tokens(IndentationConstraint(accept_indents=[2]), get_list=True)
- self.assertTrue(all(t in ac_list for t in [' ', ' ']))
-
- ac_list = self.dfa_mask._lookup_table.get_indentation_tokens(IndentationConstraint(accept_indents=[4]), get_list=True)
- self.assertTrue(all(t in ac_list for t in ['\t', ' ', ' ', ' ', ' ', ' ']))
-
- ac_list = self.dfa_mask._lookup_table.get_indentation_tokens(IndentationConstraint(greater_than_indent_val=0), get_list=True)
- self.assertIn(' int', ac_list)
-
- ac_list = self.dfa_mask._lookup_table.get_indentation_tokens(IndentationConstraint(greater_than_indent_val=1), get_list=True)
- self.assertFalse(' int' in ac_list)
- self.assertTrue(' ' in ac_list)
-
- ac_list = self.dfa_mask._lookup_table.get_indentation_tokens(IndentationConstraint(greater_than_indent_val=3), get_list=True)
- self.assertIn(' ', ac_list)
-
- def test_dfa_mask_with_indent(self):
- r = ParseResult({AcceptSequence(['NAME'])}, 'int', RemainderState.COMPLETE, IndentationConstraint(accept_indents=[0]))
- ac_list = self.dfa_mask.get_accept_mask(r, get_list=True)
- self.assertIn('int', ac_list)
-
- r = ParseResult({AcceptSequence(['IF'])}, '', RemainderState.COMPLETE, IndentationConstraint(accept_indents=[1]))
- ac_list = self.dfa_mask.get_accept_mask(r, get_list=True)
- self.assertIn(' if', ac_list)
-
- r = ParseResult({AcceptSequence(['NAME'])}, 'int', RemainderState.COMPLETE, IndentationConstraint(greater_than_indent_val=0))
- ac_list = self.dfa_mask.get_accept_mask(r, get_list=True)
- self.assertIn(' int', ac_list)
-
- r = ParseResult({AcceptSequence(['NAME'])}, '', RemainderState.COMPLETE, IndentationConstraint(greater_than_indent_val=1))
- ac_list = self.dfa_mask.get_accept_mask(r, get_list=True)
- self.assertNotIn(' int', ac_list)
-
- r = ParseResult({AcceptSequence(['IF'])}, '', RemainderState.COMPLETE, IndentationConstraint(accept_indents=[0]))
- ac_list = self.dfa_mask.get_accept_mask(r, get_list=True)
- self.assertIn('if', ac_list)
-
- r = ParseResult({AcceptSequence(['_NL', 'RETURN'])}, '\n\t\t', RemainderState.MAYBE_COMPLETE, IndentationConstraint(greater_than_indent_val=-1))
- ac_list = self.dfa_mask.get_accept_mask(r, get_list=True)
- self.assertIn('return', ac_list)
-
-
- @unittest.skip("Skipping the correctness comparison test.")
- def test_indentation(self):
- from mxeval.data import get_data
- mbpp = get_data("mbxp", "python")
- p = create_parser('python')
- self.assertEqual(p._get_indentation(mbpp['MBPP/1']["prompt"]), 4)
- self.assertEqual(p._get_indentation(mbpp['MBPP/2']["prompt"]), 2)
- self.assertEqual(p._get_indentation(mbpp['MBPP/8']["prompt"]), 1)
-
-
- @unittest.skip("Skipping the correctness comparison test.")
- def test_simplifications(self):
- import regex
- simplifications = Grammar('python').simplifications()
-
- # COMMENT
- reg = simplifications['COMMENT']
- self.assertIsNotNone(regex.match(reg, '# Hello'))
- self.assertIsNotNone(regex.match(reg, '""" Foo \n Bar """'))
- self.assertIsNotNone(regex.match(reg, "''' Foo \n Bar '''"))
-
- # LONG_STRING
- reg = simplifications['LONG_STRING']
- self.assertIsNotNone(regex.match(reg, '""" Foo \n Bar """'))
- self.assertIsNotNone(regex.match(reg, "''' Foo \n Bar '''"))
- self.assertIsNone(regex.match(reg, '""" Foo \n Bar '))
- self.assertIsNone(regex.match(reg, "''' Foo \n Bar "))
-
- # STRING
- reg = simplifications['STRING']
- self.assertIsNotNone(regex.match(reg, '"Foo"'))
- self.assertIsNotNone(regex.match(reg, "'Foo'"))
- self.assertIsNone(regex.match(reg, '"Foo'))
- self.assertIsNone(regex.match(reg, "'Foo"))
-
- # _NL
- reg = simplifications['_NL']
- self.assertIsNotNone(regex.match(reg, '\n'))
- self.assertIsNotNone(regex.match(reg, '\n\n'))
- self.assertIsNotNone(regex.match(reg, '\n""" Foo \n Bar """'))
- self.assertIsNotNone(regex.match(reg, '\n# Hello!'))
-
-
-class TestDFAMaskCodegen(unittest.TestCase):
-
- model = 'Salesforce/codegen-350M-multi'
- tokenizer = common.load_tokenizer(model)
- dfa_mask = DFAMaskStore.load_dfa_mask_store(grammar=Grammar('python'), tokenizer=tokenizer, use_cache=True, logger=common.EmptyLogger())
-
- def test_dfa_mask10(self):
- ac_list = self.dfa_mask.get_accept_mask(ParseResult({AcceptSequence(['STRING'])}, "'", RemainderState.INCOMPLETE, next_ac_indents=None), get_list=True)
- self.assertIn(" '.", ac_list)
-
- def test_dfa_mask11(self):
- ac_list = self.dfa_mask.get_accept_mask(ParseResult({AcceptSequence(['STRING'])}, "'", RemainderState.INCOMPLETE, next_ac_indents=None), get_list=True)
- self.assertIn(" '.", ac_list)
-
- def test_dfa_mask12(self):
- r = ParseResult({AcceptSequence(['_NL', 'IF'])}, '\n\t\t', RemainderState.MAYBE_COMPLETE, next_ac_indents=None)
- ac_list = self.dfa_mask.get_accept_mask(r, get_list=True)
- self.assertIn("if", ac_list)
-
-
-
-
-class TestDFAMaskWizard(unittest.TestCase):
-
- model = 'WizardLM/WizardCoder-1B-V1.0'
- tokenizer = common.load_tokenizer(model)
- dfa_mask = DFAMaskStore.load_dfa_mask_store(grammar=Grammar('python'), tokenizer=tokenizer, use_cache=True, logger=common.EmptyLogger())
-
- def test_dfa_mask13(self):
- ac_list = self.dfa_mask.get_accept_mask(ParseResult({AcceptSequence(['STRING'])}, "'", RemainderState.INCOMPLETE, next_ac_indents=None), get_list=True)
- self.assertIn(" '.", ac_list)
-
- def test_dfa_mask14(self):
- ac_list = self.dfa_mask.get_accept_mask(ParseResult({AcceptSequence(['STRING'])}, "'", RemainderState.INCOMPLETE, next_ac_indents=None), get_list=True)
- self.assertIn(" '.", ac_list)
-
-
-if __name__ == '__main__':
-
- run_codegen, run_llama, run_wizard = True, True, True
-
- if run_llama:
- suite = unittest.TestLoader().loadTestsFromTestCase(TestDFAMaskLlama)
- unittest.TextTestRunner().run(suite)
-
- if run_codegen:
- suite = unittest.TestLoader().loadTestsFromTestCase(TestDFAMaskCodegen)
- unittest.TextTestRunner().run(suite)
-
- if run_wizard:
- suite = unittest.TestLoader().loadTestsFromTestCase(TestDFAMaskWizard)
- unittest.TextTestRunner().run(suite)
diff --git a/tests/test_misc.py b/tests/test_misc.py
index 3ed1e021..54ce28db 100644
--- a/tests/test_misc.py
+++ b/tests/test_misc.py
@@ -4,7 +4,7 @@
import torch
-from syncode.dfa_mask_store import DFAMaskStore
+from syncode.mask_store.mask_store import MaskStore
from syncode.grammar_decoder import SyncodeLogitsProcessor
# Adjusting the path so the modules can be imported correctly
@@ -32,8 +32,9 @@ def test_mask_store_misc(self):
tokenizer = common.load_tokenizer(model)
inc_parser = create_parser(grammar)
r = inc_parser.get_acceptable_next_terminals("234 * 327 = 76518")
- dfa_mask = DFAMaskStore.load_dfa_mask_store(grammar=grammar, tokenizer=tokenizer, use_cache=False, logger=common.EmptyLogger())
- mask = dfa_mask.get_accept_mask(r, get_list=True)
+ r.remainder = r.remainder.encode('utf-8')
+ mask_store = MaskStore.init_mask_store(grammar=grammar, tokenizer=tokenizer, use_cache=False)
+ mask = mask_store.get_accept_mask(r, get_list=True)
self.assertNotIn(' (', mask)
@staticmethod
@@ -59,8 +60,9 @@ def test_mask_store_misc2(self):
tokenizer = common.load_tokenizer(model)
inc_parser = create_parser(grammar)
r = inc_parser.get_acceptable_next_terminals("I")
- dfa_mask = DFAMaskStore.load_dfa_mask_store(grammar=grammar, tokenizer=tokenizer, use_cache=False, logger=common.EmptyLogger())
- mask = dfa_mask.get_accept_mask(r, get_list=True)
+ r.remainder = r.remainder.encode('utf-8')
+ mask_store = MaskStore.init_mask_store(grammar=grammar, tokenizer=tokenizer, use_cache=False)
+ mask = mask_store.get_accept_mask(r, get_list=True)
self.assertIn(' have', mask)
def test_mask_store_misc3(self):
@@ -69,8 +71,9 @@ def test_mask_store_misc3(self):
tokenizer = common.load_tokenizer(model)
inc_parser = create_parser(grammar)
r = inc_parser.get_acceptable_next_terminals("I have been working there for 5 years.")
- dfa_mask = DFAMaskStore.load_dfa_mask_store(grammar=grammar, tokenizer=tokenizer, use_cache=False, logger=common.EmptyLogger())
- mask = dfa_mask.get_accept_mask(r, get_list=True)
+ r.remainder = r.remainder.encode('utf-8')
+ mask_store = MaskStore.init_mask_store(grammar=grammar, tokenizer=tokenizer, use_cache=False)
+ mask = mask_store.get_accept_mask(r, get_list=True)
self.assertIn(' I', mask)
def test_grammar_decoder_empty(self):
diff --git a/tests/test_utils.py b/tests/test_utils.py
new file mode 100644
index 00000000..23d58dab
--- /dev/null
+++ b/tests/test_utils.py
@@ -0,0 +1,28 @@
+# Custom assertion methods to handle long lists
+class CustomAssertMixin:
+ def assertInWithLimit(self, member, container, msg=None):
+ """Assert that member is in container, with limited output if container is too long."""
+ if member not in container:
+ if len(container) <= 50:
+ self.assertIn(member, container, msg)
+ else:
+ sample = list(container)[:10]
+ self.fail(f"{repr(member)} not found in list with {len(container)} items. First 10 items: {sample}")
+
+ def assertTrueWithLimit(self, expr, items, container, msg=None):
+ """Assert that expression is true, with limited output for long containers."""
+ if not expr:
+ if len(container) <= 50:
+ self.assertTrue(expr, msg)
+ else:
+ sample = list(container)[:10]
+ self.fail(f"Assertion failed for items {items}. List has {len(container)} items. First 10 items: {sample}")
+
+ def assertGreaterWithLimit(self, a, b, container, msg=None):
+ """Assert a > b with limited output for the container."""
+ if not (a > b):
+ if len(container) <= 50:
+ self.assertGreater(a, b, msg)
+ else:
+ sample = list(container)[:10]
+ self.fail(f"{a} not greater than {b}. List has {len(container)} items. First 10 items: {sample}")
\ No newline at end of file