diff --git a/README.md b/README.md index f37b7ce..9cbf02a 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,8 @@ The package doesn't have the dataset, it is stored on our [HuggingFace page](htt ## Latest News 📣 +* [2026/04] New inference framework, see [`inference_function()`](./llmsql/inference/inference_function.py#inference_function) for more. + * [2026/03] Fully functional CLI commands for inference and evaluation. See this [guide](./llmsql/_cli/README.md). * [2026/03] Added support for API inference, for now only for OpenAI-compatable APIs, see [`inference_api()` function](./llmsql/inference/inference_api.py#inference_api) @@ -58,7 +60,11 @@ Modern LLMs are already strong at producing SQL queries without finetuning. We therefore recommend that most users: 1. **Run inference** directly on the full benchmark: - - Use [`llmsql.inference_transformers`](./llmsql/inference/inference_transformers.py) (the function for transformers inference) for generation of SQL predictions with your model. If you want to do vllm based inference, use [`llmsql.inference_vllm`](./llmsql/inference/inference_vllm.py). Works both with HF model id, e.g. `Qwen/Qwen2.5-1.5B-Instruct` and model instance passed directly, e.g. `inference_transformers(model_or_model_name_or_path=model, ...)`. The api inference is also supported, see [`inference_api()`](./llmsql/inference/inference_api.py#inference_api) + - Use one of the supported inference frameworks: + * [`llmsql.inference_transformers`](./llmsql/inference/inference_transformers.py) - the function for transformers inference, for generation of SQL predictions with your model. + * [`llmsql.inference_vllm`](./llmsql/inference/inference_vllm.py) - if you want to do vllm based inference. Works both with HF model id, e.g. `Qwen/Qwen2.5-1.5B-Instruct` and model instance passed directly, e.g. `inference_transformers(model_or_model_name_or_path=model, ...)`. + * [`inference_api()`](./llmsql/inference/inference_api.py#inference_api) - for custom api inference. + * [`inference_function()`](./llmsql/inference/inference_function.py#inference_function) - for passing custom async function for inference. - Evaluate results against the benchmark with the [`llmsql.evaluate`](./llmsql/evaluation/evaluator.py) function. 2. **Optional finetuning**: diff --git a/docs/docs/inference.rst b/docs/docs/inference.rst index 1a33533..524573f 100644 --- a/docs/docs/inference.rst +++ b/docs/docs/inference.rst @@ -20,6 +20,14 @@ Inference API Reference --- + +.. automodule:: llmsql.inference.inference_function + :members: + :undoc-members: + +--- + + .. raw:: html
diff --git a/llmsql/__init__.py b/llmsql/__init__.py index 800a066..a876d29 100644 --- a/llmsql/__init__.py +++ b/llmsql/__init__.py @@ -30,8 +30,18 @@ def __getattr__(name: str): # type: ignore from .inference.inference_api import inference_api return inference_api + elif name == "inference_function": + from .inference.inference_function import inference_function + + return inference_function raise AttributeError(f"module {__name__} has no attribute {name!r}") -__all__ = ["evaluate", "inference_vllm", "inference_transformers", "inference_api"] +__all__ = [ + "evaluate", + "inference_vllm", + "inference_transformers", + "inference_api", + "inference_function", +] diff --git a/llmsql/inference/README.md b/llmsql/inference/README.md index 1aba3f7..276768b 100644 --- a/llmsql/inference/README.md +++ b/llmsql/inference/README.md @@ -5,6 +5,7 @@ LLMSQL provides two inference backends for **Text-to-SQL generation** with large * **Transformers** — runs inference using the standard Hugging Face `transformers` pipeline. * **vLLM** — runs inference using the high-performance [vLLM](https://github.com/vllm-project/vllm) backend. * **API** — runs inference against an OpenAI-compatible Chat Completions API with configurable base URL and rate limiting. +* **Custom Function** — runs inference with your own async callable while preserving LLMSQL prompt building and output format. Both backends load benchmark questions and table schemas, build prompts (with few-shot examples), and generate SQL queries in parallel batches. @@ -94,7 +95,25 @@ results = inference_api( ) ``` +--- + + +### Option 4 — Using your own async inference function + +```python +from llmsql import inference_function + +async def get_answer(input_prompt, **kwargs): + # call your engine/API/router and return a SQL string + return "SELECT 1" +results = inference_function( + inference_function=get_answer, + requests_per_minute=60, + function_kwargs={"temperature": 0.0}, + output_file="test_output_function.jsonl", +) +``` --- ## Command-Line Interface (CLI) diff --git a/llmsql/inference/inference_function.py b/llmsql/inference/inference_function.py new file mode 100644 index 0000000..c12ad26 --- /dev/null +++ b/llmsql/inference/inference_function.py @@ -0,0 +1,186 @@ +""" +LLMSQL Custom Function Inference +================================ + +This module provides ``inference_function()`` for text-to-SQL generation using +an arbitrary user-provided async inference callable. +""" + +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable, Callable +import inspect +import time +from typing import Any, Literal + +from dotenv import load_dotenv +import nest_asyncio +from tqdm.asyncio import tqdm + +from llmsql.config.config import DEFAULT_LLMSQL_VERSION, get_repo_id +from llmsql.loggers.logging_config import log +from llmsql.utils.inference_utils import ( + _maybe_download, + _setup_seed, + resolve_workdir_path, +) +from llmsql.utils.utils import ( + build_all_requests, + choose_prompt_builder, + load_jsonl, + overwrite_jsonl, + save_jsonl_lines, +) + +load_dotenv() + + +class _AsyncRateLimiter: + """Token-bucket style async rate limiter with request-start spacing.""" + + def __init__(self, requests_per_minute: float | None) -> None: + if requests_per_minute is not None and requests_per_minute <= 0: + raise ValueError("requests_per_minute must be > 0 when provided.") + self._interval: float | None = ( + 60.0 / requests_per_minute if requests_per_minute is not None else None + ) + self._next_allowed: float = 0.0 + self._lock = asyncio.Lock() + + async def acquire(self) -> None: + if self._interval is None: + return + + async with self._lock: + now = time.monotonic() + wait = self._next_allowed - now + if wait > 0: + await asyncio.sleep(wait) + self._next_allowed = time.monotonic() + self._interval + + +async def _inference_function_async( + *, + inference_callable: Callable[..., Awaitable[str]], + requests_per_minute: float | None, + function_kwargs: dict[str, Any], + questions: list[dict[str, Any]], + tables: dict[str, Any], + prompt_builder: Any, + output_file: str, +) -> list[dict[str, str]]: + limiter = _AsyncRateLimiter(requests_per_minute) + all_results: list[dict[str, str]] = [] + write_lock = asyncio.Lock() + + prompts = build_all_requests(questions, tables, prompt_builder) + + async def process_question(q: dict[str, Any], prompt: str) -> dict[str, str]: + await limiter.acquire() + + completion = await inference_callable( + prompt, + question=q, + table=tables[q["table_id"]], + **function_kwargs, + ) + + result = { + "question_id": q.get("question_id", q.get("id", "")), + "completion": str(completion), + } + + async with write_lock: + save_jsonl_lines(output_file, [result]) + + return result + + tasks = [process_question(q, p) for q, p in zip(questions, prompts, strict=False)] + for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Generating"): + all_results.append(await coro) + + return all_results + + +def inference_function( + *, + inference_function: Callable[..., Awaitable[str]], + requests_per_minute: float | None = None, + function_kwargs: dict[str, Any] | None = None, + version: Literal["1.0", "2.0"] = DEFAULT_LLMSQL_VERSION, + output_file: str = "llm_sql_predictions.jsonl", + workdir_path: str | None = None, + limit: int | float | None = None, + num_fewshots: int = 5, + seed: int = 42, +) -> list[dict[str, str]]: + """Run SQL generation using a user-provided async callable. + + The callable is awaited as: + ``await inference_function(prompt, question=..., table=..., **function_kwargs)``. + + If your callable needs sampling parameters, pass them through ``function_kwargs``. + """ + _setup_seed(seed=seed) + + if not callable(inference_function): + raise TypeError("`inference_function` must be callable.") + + function_kwargs = function_kwargs or {} + workdir = resolve_workdir_path(workdir_path) + + repo_id = get_repo_id(version) + questions_path = _maybe_download(repo_id, "questions.jsonl", workdir) + tables_path = _maybe_download(repo_id, "tables.jsonl", workdir) + + questions = load_jsonl(questions_path) + tables_list = load_jsonl(tables_path) + tables = {t["table_id"]: t for t in tables_list} + + if limit is not None: + if isinstance(limit, float): + if not (0.0 < limit <= 1.0): + raise ValueError( + f"When a float, `limit` must be between 0.0 and 1.0, got {limit}." + ) + limit = max(1, int(len(questions) * limit)) + if not isinstance(limit, int) or limit < 1: + raise ValueError( + f"`limit` must be a positive integer or a float in (0.0, 1.0], got {limit!r}." + ) + questions = questions[:limit] + + prompt_builder = choose_prompt_builder(num_fewshots) + + overwrite_jsonl(output_file) + + async def _validated_callable(*args: Any, **kwargs: Any) -> str: + out = inference_function(*args, **kwargs) + if inspect.isawaitable(out): + return str(await out) + raise TypeError("`inference_function` must return an awaitable value.") + + coro = _inference_function_async( + inference_callable=_validated_callable, + requests_per_minute=requests_per_minute, + function_kwargs=function_kwargs, + questions=questions, + tables=tables, + prompt_builder=prompt_builder, + output_file=output_file, + ) + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is not None and loop.is_running(): + nest_asyncio.apply(loop) + all_results = loop.run_until_complete(coro) + else: + all_results = asyncio.run(coro) + + log.info(f"Generation completed. {len(all_results)} results saved to {output_file}") + return all_results diff --git a/llmsql/utils/utils.py b/llmsql/utils/utils.py index fc9b1ab..0d62e84 100644 --- a/llmsql/utils/utils.py +++ b/llmsql/utils/utils.py @@ -2,6 +2,8 @@ import json from pathlib import Path +from transformers import AutoTokenizer + from llmsql.loggers.logging_config import log from llmsql.prompts.prompts import ( build_prompt_0shot, @@ -65,7 +67,7 @@ def build_all_requests( questions: list[dict], tables: dict, prompt_builder: Callable[[str, list[str], list[str], list[str | float | int]], str], - tokenizer=None, + tokenizer: AutoTokenizer = None, use_chat_template: bool = True, ) -> list[str]: """ diff --git a/tests/inference/test_inference_function.py b/tests/inference/test_inference_function.py new file mode 100644 index 0000000..1ccf74e --- /dev/null +++ b/tests/inference/test_inference_function.py @@ -0,0 +1,197 @@ +"""Tests for the async inference_function implementation.""" + +from __future__ import annotations + +import asyncio +import json +from pathlib import Path + +import pytest + +from llmsql.inference.inference_function import inference_function + + +def _write_jsonl(path: Path, rows: list[dict]) -> None: + path.write_text("\n".join(json.dumps(r) for r in rows) + "\n") + + +def _make_fixtures(tmp_path: Path) -> tuple[Path, Path, Path]: + questions = [ + {"question_id": "q1", "question": "What is 1+1?", "table_id": "t1"}, + {"question_id": "q2", "question": "What is 2+2?", "table_id": "t1"}, + ] + tables = [{"table_id": "t1", "header": ["col"], "types": ["text"], "rows": [["x"]]}] + qpath = tmp_path / "questions.jsonl" + tpath = tmp_path / "tables.jsonl" + outpath = tmp_path / "out.jsonl" + _write_jsonl(qpath, questions) + _write_jsonl(tpath, tables) + return qpath, tpath, outpath + + +def test_runs_with_custom_async_callable(tmp_path): + _make_fixtures(tmp_path) + + async def fake_infer(prompt, **kwargs): + assert isinstance(prompt, str) + assert kwargs["temperature"] == 0.0 + assert "question" in kwargs + assert "table" in kwargs + return "SELECT 1" + + results = inference_function( + inference_function=fake_infer, + function_kwargs={"temperature": 0.0}, + output_file=str(tmp_path / "out.jsonl"), + workdir_path=str(tmp_path), + ) + + assert len(results) == 2 + assert all(r["completion"] == "SELECT 1" for r in results) + + +def test_limit_works(tmp_path): + _make_fixtures(tmp_path) + + async def fake_infer(prompt, **kwargs): + return "SELECT 1" + + results = inference_function( + inference_function=fake_infer, + output_file=str(tmp_path / "out.jsonl"), + workdir_path=str(tmp_path), + limit=1, + ) + + assert len(results) == 1 + + +def test_rejects_non_async_result(tmp_path): + _make_fixtures(tmp_path) + + def bad_infer(prompt, **kwargs): + return "not-awaitable" + + with pytest.raises(TypeError, match="must return an awaitable"): + inference_function( + inference_function=bad_infer, # type: ignore[arg-type] + output_file=str(tmp_path / "out.jsonl"), + workdir_path=str(tmp_path), + ) + + +def test_rate_limiter_rejects_non_positive_rpm(): + from llmsql.inference.inference_function import _AsyncRateLimiter + + with pytest.raises(ValueError, match="requests_per_minute must be > 0"): + _AsyncRateLimiter(0) + + with pytest.raises(ValueError): + _AsyncRateLimiter(-5) + + +@pytest.mark.asyncio +async def test_rate_limiter_waits(monkeypatch): + from llmsql.inference.inference_function import _AsyncRateLimiter + + limiter = _AsyncRateLimiter(requests_per_minute=60) # 1 request/sec + + sleep_calls = [] + + async def fake_sleep(duration): + sleep_calls.append(duration) + + monkeypatch.setattr(asyncio, "sleep", fake_sleep) + + # First call should not sleep + await limiter.acquire() + # Second call should trigger wait + await limiter.acquire() + + assert len(sleep_calls) == 1 + assert sleep_calls[0] > 0 + + +def test_rejects_non_callable_inference_function(tmp_path): + _make_fixtures(tmp_path) + + with pytest.raises(TypeError, match="must be callable"): + inference_function( + inference_function="not-a-function", # type: ignore + output_file=str(tmp_path / "out.jsonl"), + workdir_path=str(tmp_path), + ) + + +@pytest.mark.parametrize("bad_limit", [0.0, -0.1, 1.5]) +def test_limit_float_out_of_range(tmp_path, bad_limit): + _make_fixtures(tmp_path) + + async def fake_infer(prompt, **kwargs): + return "SELECT 1" + + with pytest.raises(ValueError, match="must be between 0.0 and 1.0"): + inference_function( + inference_function=fake_infer, + output_file=str(tmp_path / "out.jsonl"), + workdir_path=str(tmp_path), + limit=bad_limit, + ) + + +@pytest.mark.parametrize("bad_limit", [0, -1, "foo", None]) +def test_limit_invalid_type_or_value(tmp_path, bad_limit): + _make_fixtures(tmp_path) + + async def fake_infer(prompt, **kwargs): + return "SELECT 1" + + if bad_limit is None: + # None is valid (means no limit), skip + return + + with pytest.raises(ValueError, match="must be a positive integer"): + inference_function( + inference_function=fake_infer, + output_file=str(tmp_path / "out.jsonl"), + workdir_path=str(tmp_path), + limit=bad_limit, # type: ignore[arg-type] + ) + + +def test_runs_inside_existing_event_loop(monkeypatch, tmp_path): + _make_fixtures(tmp_path) + + async def fake_infer(prompt, **kwargs): + return "SELECT 1" + + applied = {"called": False} + + def fake_apply(loop): + applied["called"] = True + + monkeypatch.setattr("nest_asyncio.apply", fake_apply) + + # Create a real loop we control + real_loop = asyncio.new_event_loop() + + class FakeLoop: + def is_running(self): + return True + + def run_until_complete(self, coro): + return real_loop.run_until_complete(coro) + + monkeypatch.setattr(asyncio, "get_running_loop", lambda: FakeLoop()) + + try: + results = inference_function( + inference_function=fake_infer, + output_file=str(tmp_path / "out.jsonl"), + workdir_path=str(tmp_path), + ) + finally: + real_loop.close() + + assert applied["called"] is True + assert len(results) > 0 diff --git a/tests/test_init.py b/tests/test_init.py index 4d8a473..34d59ba 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -26,6 +26,12 @@ def test_lazy_import_inference_transformers(self) -> None: assert inference_transformers is not None + def test_lazy_import_inference_function(self) -> None: + """Test that inference_function can be imported via lazy loading.""" + from llmsql import inference_function + + assert inference_function is not None + def test_invalid_attribute_raises_error(self) -> None: """Test that accessing invalid attribute raises AttributeError.""" import llmsql