Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
- name: Cache DFA mask store
- name: Cache mask store
uses: actions/cache@v3
with:
path: /home/runner/work/syncode/syncode/cache/mask_stores/
Expand All @@ -39,3 +39,7 @@ jobs:
python3 -m unittest tests.test_language_model
python3 -m unittest tests.test_lr_parser
python3 -m unittest tests.test_syncode
python3 -m unittest tests.mask_store.test_byte_fsm
python3 -m unittest tests.mask_store.test_fsm_set
python3 -m unittest tests.mask_store.test_byte_tokenizer
python3 -m unittest tests.mask_store.test_lookup_table
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ syncode/core/__pycache__
.vscode/
tmp*
cache/
.ipynb_checkpoints/
.ipynb_checkpoints/
*.prof
3 changes: 1 addition & 2 deletions notebooks/tests/builtin_grammar.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -17,7 +17,6 @@
"\n",
"device = 'cuda'\n",
"model_name = \"meta-llama/Llama-3.2-1B-Instruct\"\n",
"# model_name = \"meta-llama/Llama-3.1-8B-Instruct\"\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True).eval().to(device)\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True)"
Expand Down
3 changes: 1 addition & 2 deletions notebooks/tests/lexer_ambiguity.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -36,7 +36,6 @@
"source": [
"from syncode.infer import Syncode\n",
"\n",
"# Load the unconstrained original model\n",
"model_name = \"microsoft/Phi-3-mini-4k-instruct\"\n",
"\n",
"trying = \"\"\" \n",
Expand Down
122 changes: 122 additions & 0 deletions notebooks/tests/non_ascii.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/shubham/codex/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 10.73it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Creating DFA mask store for PreTrainedTokenizerFast and custom, may take more than 10 minutes. Caching at /home/shubham/syncode/cache/mask_stores/PreTrainedTokenizerFast/grammar_strict_4470738745_128000.pkl.\n",
"Ignore whitespace tokens is False\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 16/16 [00:03<00:00, 4.32it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Time taken to create mask store: 4.161165714263916 seconds\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"from syncode.infer import Syncode\n",
"\n",
"grammar = r\"\"\"\n",
" start: \"∀∃∀∃∀\" \n",
" \"\"\"\n",
"\n",
"syn_llm = Syncode(model=\"meta-llama/Llama-3.1-8B-Instruct\", grammar=grammar, new_mask_store=True)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
"Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n",
"The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Syncode augmented LLM output:\n",
"∀∃∀∃∀\n",
"\n"
]
},
{
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
"\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
"\u001b[1;31mClick <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. \n",
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
]
}
],
"source": [
"p = \"You are an expert in writing print something random.\"\n",
" \n",
"output = syn_llm.infer(p)[0]\n",
"print(f\"Syncode augmented LLM output:\\n{output}\\n\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "codex",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
3 changes: 3 additions & 0 deletions syncode/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from syncode.infer import Syncode
from grammar_decoder import SyncodeLogitsProcessor
from parsers.grammars import Grammar
import common

common.setup_logging()
59 changes: 42 additions & 17 deletions syncode/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
import os
import sys
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

Expand All @@ -8,23 +10,6 @@
SYNCODE_CACHE = os.environ['SYNCODE_CACHE'] if 'SYNCODE_CACHE' in os.environ else 'cache/'
HF_ACCESS_TOKEN = os.environ['HF_ACCESS_TOKEN'] if 'HF_ACCESS_TOKEN' in os.environ else None

def get_vocab_from_tokenizer(tokenizer):
# self.vocab is a list of readable token strings (e.g., ' hello' and '\n')
# sorted by their token IDs (so self.vocab[0] is the first token, etc).
vocab = [v for k, v in
sorted([(t_id, tokenizer.decode([t_id]))
for _, t_id in tokenizer.get_vocab().items()])]

# HACK: Is there a better way to know if a token has a prefix space?
if 'Llama' in tokenizer.__class__.__name__:
for i in range(len(vocab)):
t = vocab[i]
if 2*len(t) != len(tokenizer.decode([i, i], add_special_tokens=False)):
vocab[i] = ' ' + t
if t == '':
vocab[i] = ' '

return vocab

def load_model(model_name, device, quantize):
if model_name == 'test':
Expand Down Expand Up @@ -53,6 +38,46 @@ def get_output_path(model_name, grammar, dataset, num_samples, mode):
os.makedirs(out_dir, exist_ok=True)
return out_dir,out_path

# This is the setup for Python logging
def setup_logging(level=None):
"""
Configure the root logger for both application and test usage.

This function is safe to call multiple times - it will only configure
logging once to avoid duplicate handlers.

Args:
level: Override the logging level. If None, uses the LOG_LEVEL
environment variable or defaults to INFO.

Returns:
The root logger
"""
# Determine the logging level
if level is None:
# Get level from environment or default to INFO
level_name = os.environ.get('LOG_LEVEL', 'INFO')
level = getattr(logging, level_name.upper(), logging.INFO)

# Get the root logger
root_logger = logging.getLogger()

# Clear any existing handlers to avoid duplicates
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)

# Set the logging level
root_logger.setLevel(level)

# Create a stdout handler
handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter('[%(asctime)s-%(name)s] - %(message)s')
handler.setFormatter(formatter)
root_logger.addHandler(handler)

return root_logger


class Logger:
"""
Logger class for logging the output of the model
Expand Down
Loading