Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
[build-system]
requires = ["setuptools>=42", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "syncode"
version = "0.4.0"
description = "Grammar-guided code generation tool"
readme = "README.md"
authors = [
{name = "Shubham Ugare", email = "shubhamugare@gmail.com"}
]
license = {text = "MIT"}
classifiers = [
"Programming Language :: Python :: 3",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
dependencies = [
"fire",
"interegular",
"regex==2023.8.8",
"torch",
"tqdm",
"transformers==4.44.0",
"datasets",
"jsonschema",
]

[project.urls]
"Homepage" = "https://github.com/shubhamugare/syncode"
"Bug Tracker" = "https://github.com/shubhamugare/syncode/issues"
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@ regex==2023.8.8
torch
tqdm
transformers==4.44.0
mxeval @ git+https://github.com/shubhamugare/mxeval.git
datasets
jsonschema
20 changes: 14 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,33 @@
with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()

# Read the content of the requirements.txt file
with open('requirements.txt', 'r', encoding='utf-8') as f:
requirements = f.read().splitlines()
# Read the content of the requirements.txt file without mxeval
requirements = [
"fire",
"interegular",
"regex==2023.8.8",
"torch",
"tqdm",
"transformers==4.44.0",
"datasets",
"jsonschema"
]

setuptools.setup(
name="syncode",
version="0.1",
version="0.4.0",
author="Shubham Ugare",
author_email="shubhamugare@gmail.com",
description="This package provides the tool for grammar augmented LLM generation.",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/shubhamugare/syncode",
url="https://github.com/uiuc-focal-lab/syncode",
include_package_data=True,
packages=setuptools.find_packages(),
install_requires=requirements,
classifiers=[
"Programming Language :: Python :: 3",
"Intended Audience :: Science/Research",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
Expand Down
2 changes: 1 addition & 1 deletion syncode/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datasets import load_dataset
from mxeval.data import get_data, get_examples
from syncode.evaluation.mxeval.data import get_data, get_examples

class Dataset:
"""
Expand Down
2 changes: 1 addition & 1 deletion syncode/evaluation/code_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Optional
from syncode import common
from syncode.evaluation.mxeval_evaluation import check_corectness
from mxeval.data import write_jsonl
from syncode.evaluation.mxeval.data import write_jsonl


class CodeEval:
Expand Down
2 changes: 1 addition & 1 deletion syncode/evaluation/fol_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import random
import re
from typing import Optional
from mxeval.data import write_jsonl
from syncode.evaluation.mxeval_evaluation import write_jsonl
from tqdm import tqdm
import signal
from syncode.parsers import create_base_parser
Expand Down
8 changes: 6 additions & 2 deletions syncode/evaluation/json_eval.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from tqdm import tqdm
from typing import Optional
from mxeval.data import write_jsonl
from syncode.evaluation.mxeval.data import write_jsonl
import ast
import json
from jsonschema import validate, ValidationError
Expand All @@ -18,7 +18,8 @@ def run_json_eval(
out_path: Optional[str],
debug_task_id: Optional[int] = None,
logger=common.EmptyLogger(),
prompt_type='original'
prompt_type='original',
num_tasks=None
):
problems = syncode.dataset.problems
if syncode.grammar_decoder is not None:
Expand All @@ -27,6 +28,9 @@ def run_json_eval(

if debug_task_id is not None:
problems = [problems[debug_task_id]]

if num_tasks is not None:
problems = problems[:num_tasks]

samples = []
outputs = []
Expand Down
2 changes: 1 addition & 1 deletion syncode/evaluation/math_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from tqdm import tqdm
from syncode import common
from syncode.evaluation.mxeval_evaluation import compute_pass_at_k
from mxeval.data import write_jsonl
from syncode.evaluation.mxeval_evaluation import write_jsonl


class MathEval:
Expand Down
Empty file.
109 changes: 109 additions & 0 deletions syncode/evaluation/mxeval/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from typing import Iterable, Dict
import gzip
import json
import os


ROOT = os.path.dirname(os.path.abspath(__file__))
MULTILINGUAL_HUMANEVAL_METADATA = os.path.join(ROOT, "data", "multilingual_humaneval", "metadata.json")
with open(MULTILINGUAL_HUMANEVAL_METADATA, "r", encoding="utf-8") as fr:
MULTILINGUAL_HUMANEVAL_METADATA = json.load(fr)
HUMAN_EVAL_PYTHON = os.path.join(ROOT, "data", "multilingual_humaneval", MULTILINGUAL_HUMANEVAL_METADATA["python"])
HUMAN_EVAL = HUMAN_EVAL_PYTHON


def read_problems(evalset_file: str = HUMAN_EVAL_PYTHON) -> Dict[str, Dict]:
return {task["task_id"]: task for task in stream_jsonl(evalset_file)}


def stream_jsonl(filename: str) -> Iterable[Dict]:
"""
Parses each jsonl line and yields it as a dictionary
"""
if filename.endswith(".gz"):
with open(filename, "rb") as gzfp:
with gzip.open(gzfp, 'rt') as fp:
for line in fp:
if any(not x.isspace() for x in line):
yield json.loads(line)
else:
with open(filename, "r") as fp:
for line in fp:
if any(not x.isspace() for x in line):
yield json.loads(line)


def write_jsonl(filename: str, data: Iterable[Dict], append: bool = False):
"""
Writes an iterable of dictionaries to jsonl
"""
if append:
mode = 'ab'
else:
mode = 'wb'
filename = os.path.expanduser(filename)
if filename.endswith(".gz"):
with open(filename, mode) as fp:
with gzip.GzipFile(fileobj=fp, mode='wb') as gzfp:
for x in data:
gzfp.write((json.dumps(x) + "\n").encode('utf-8'))
else:
with open(filename, mode) as fp:
for x in data:
fp.write((json.dumps(x) + "\n").encode('utf-8'))


def get_metadata(dataset, metadata_type="problem"):
assert metadata_type in ["problem", "example"]
assert dataset in ["mbxp", "multi-humaneval", "mathqa-x"], f"Unsupported dataset {dataset}"
dataset_dirmap = {"mbxp": "mbxp",
"multi-humaneval": "multilingual_humaneval",
"mathqa-x": "multilingual_mathqa"}
typemap = {"problem": "metadata.json",
"example": "metadata_examples.json"}
datadir = os.path.join(ROOT, "data", dataset_dirmap[dataset])
path = os.path.join(datadir, typemap[metadata_type])
with open(path, "r") as f:
metadata = json.load(f)
return metadata, datadir


def get_supported_langs(dataset):
metadata, _ = get_metadata(dataset, metadata_type="problem")
return list(metadata.keys())


def get_data(dataset="mbxp", language="python"):
metadata, datadir = get_metadata(dataset, metadata_type="problem")
if language.lower() not in metadata:
raise ValueError(f"Language {language} not found in metadata file")
datafile = metadata[language.lower()]
print(f"Loading {dataset} | language = {language}")
return read_problems(os.path.join(datadir, datafile))


# due to similar format, examples from mbxp are sufficient to be used
# for few-shot prompting in multi-humaneval
def get_examples(dataset="mbxp", language="python", num_examples=None):
assert dataset in ["mbxp"], f"No fewshot examples in dataset {dataset}"
metadata, datadir = get_metadata(dataset=dataset, metadata_type="example")
if language.lower() not in metadata:
raise ValueError(f"Language {language} not found in metadata file")
datafile = metadata[language.lower()]
print(f"Loading examples from {dataset} | language = {language}")
# use streams
if num_examples is None:
# return the entire stream
return stream_jsonl(os.path.join(datadir, datafile))
else:
problems = get_data(dataset=dataset, language=language)
stream = get_examples(dataset=dataset, language=language)
examples = []
for idx, example in enumerate(stream):
if idx == num_examples:
break
task_id = example["task_id"]
prompt = problems[task_id]["prompt"]
example["prompt"] = prompt
examples.append(example)
return examples
Loading