diff --git a/README.md b/README.md
index f37b7ce..9cbf02a 100644
--- a/README.md
+++ b/README.md
@@ -42,6 +42,8 @@ The package doesn't have the dataset, it is stored on our [HuggingFace page](htt
## Latest News 📣
+* [2026/04] New inference framework, see [`inference_function()`](./llmsql/inference/inference_function.py#inference_function) for more.
+
* [2026/03] Fully functional CLI commands for inference and evaluation. See this [guide](./llmsql/_cli/README.md).
* [2026/03] Added support for API inference, for now only for OpenAI-compatable APIs, see [`inference_api()` function](./llmsql/inference/inference_api.py#inference_api)
@@ -58,7 +60,11 @@ Modern LLMs are already strong at producing SQL queries without finetuning.
We therefore recommend that most users:
1. **Run inference** directly on the full benchmark:
- - Use [`llmsql.inference_transformers`](./llmsql/inference/inference_transformers.py) (the function for transformers inference) for generation of SQL predictions with your model. If you want to do vllm based inference, use [`llmsql.inference_vllm`](./llmsql/inference/inference_vllm.py). Works both with HF model id, e.g. `Qwen/Qwen2.5-1.5B-Instruct` and model instance passed directly, e.g. `inference_transformers(model_or_model_name_or_path=model, ...)`. The api inference is also supported, see [`inference_api()`](./llmsql/inference/inference_api.py#inference_api)
+ - Use one of the supported inference frameworks:
+ * [`llmsql.inference_transformers`](./llmsql/inference/inference_transformers.py) - the function for transformers inference, for generation of SQL predictions with your model.
+ * [`llmsql.inference_vllm`](./llmsql/inference/inference_vllm.py) - if you want to do vllm based inference. Works both with HF model id, e.g. `Qwen/Qwen2.5-1.5B-Instruct` and model instance passed directly, e.g. `inference_transformers(model_or_model_name_or_path=model, ...)`.
+ * [`inference_api()`](./llmsql/inference/inference_api.py#inference_api) - for custom api inference.
+ * [`inference_function()`](./llmsql/inference/inference_function.py#inference_function) - for passing custom async function for inference.
- Evaluate results against the benchmark with the [`llmsql.evaluate`](./llmsql/evaluation/evaluator.py) function.
2. **Optional finetuning**:
diff --git a/docs/docs/inference.rst b/docs/docs/inference.rst
index 1a33533..524573f 100644
--- a/docs/docs/inference.rst
+++ b/docs/docs/inference.rst
@@ -20,6 +20,14 @@ Inference API Reference
---
+
+.. automodule:: llmsql.inference.inference_function
+ :members:
+ :undoc-members:
+
+---
+
+
.. raw:: html
diff --git a/llmsql/__init__.py b/llmsql/__init__.py
index 800a066..a876d29 100644
--- a/llmsql/__init__.py
+++ b/llmsql/__init__.py
@@ -30,8 +30,18 @@ def __getattr__(name: str): # type: ignore
from .inference.inference_api import inference_api
return inference_api
+ elif name == "inference_function":
+ from .inference.inference_function import inference_function
+
+ return inference_function
raise AttributeError(f"module {__name__} has no attribute {name!r}")
-__all__ = ["evaluate", "inference_vllm", "inference_transformers", "inference_api"]
+__all__ = [
+ "evaluate",
+ "inference_vllm",
+ "inference_transformers",
+ "inference_api",
+ "inference_function",
+]
diff --git a/llmsql/inference/README.md b/llmsql/inference/README.md
index 1aba3f7..276768b 100644
--- a/llmsql/inference/README.md
+++ b/llmsql/inference/README.md
@@ -5,6 +5,7 @@ LLMSQL provides two inference backends for **Text-to-SQL generation** with large
* **Transformers** — runs inference using the standard Hugging Face `transformers` pipeline.
* **vLLM** — runs inference using the high-performance [vLLM](https://github.com/vllm-project/vllm) backend.
* **API** — runs inference against an OpenAI-compatible Chat Completions API with configurable base URL and rate limiting.
+* **Custom Function** — runs inference with your own async callable while preserving LLMSQL prompt building and output format.
Both backends load benchmark questions and table schemas, build prompts (with few-shot examples), and generate SQL queries in parallel batches.
@@ -94,7 +95,25 @@ results = inference_api(
)
```
+---
+
+
+### Option 4 — Using your own async inference function
+
+```python
+from llmsql import inference_function
+
+async def get_answer(input_prompt, **kwargs):
+ # call your engine/API/router and return a SQL string
+ return "SELECT 1"
+results = inference_function(
+ inference_function=get_answer,
+ requests_per_minute=60,
+ function_kwargs={"temperature": 0.0},
+ output_file="test_output_function.jsonl",
+)
+```
---
## Command-Line Interface (CLI)
diff --git a/llmsql/inference/inference_function.py b/llmsql/inference/inference_function.py
new file mode 100644
index 0000000..c12ad26
--- /dev/null
+++ b/llmsql/inference/inference_function.py
@@ -0,0 +1,186 @@
+"""
+LLMSQL Custom Function Inference
+================================
+
+This module provides ``inference_function()`` for text-to-SQL generation using
+an arbitrary user-provided async inference callable.
+"""
+
+from __future__ import annotations
+
+import asyncio
+from collections.abc import Awaitable, Callable
+import inspect
+import time
+from typing import Any, Literal
+
+from dotenv import load_dotenv
+import nest_asyncio
+from tqdm.asyncio import tqdm
+
+from llmsql.config.config import DEFAULT_LLMSQL_VERSION, get_repo_id
+from llmsql.loggers.logging_config import log
+from llmsql.utils.inference_utils import (
+ _maybe_download,
+ _setup_seed,
+ resolve_workdir_path,
+)
+from llmsql.utils.utils import (
+ build_all_requests,
+ choose_prompt_builder,
+ load_jsonl,
+ overwrite_jsonl,
+ save_jsonl_lines,
+)
+
+load_dotenv()
+
+
+class _AsyncRateLimiter:
+ """Token-bucket style async rate limiter with request-start spacing."""
+
+ def __init__(self, requests_per_minute: float | None) -> None:
+ if requests_per_minute is not None and requests_per_minute <= 0:
+ raise ValueError("requests_per_minute must be > 0 when provided.")
+ self._interval: float | None = (
+ 60.0 / requests_per_minute if requests_per_minute is not None else None
+ )
+ self._next_allowed: float = 0.0
+ self._lock = asyncio.Lock()
+
+ async def acquire(self) -> None:
+ if self._interval is None:
+ return
+
+ async with self._lock:
+ now = time.monotonic()
+ wait = self._next_allowed - now
+ if wait > 0:
+ await asyncio.sleep(wait)
+ self._next_allowed = time.monotonic() + self._interval
+
+
+async def _inference_function_async(
+ *,
+ inference_callable: Callable[..., Awaitable[str]],
+ requests_per_minute: float | None,
+ function_kwargs: dict[str, Any],
+ questions: list[dict[str, Any]],
+ tables: dict[str, Any],
+ prompt_builder: Any,
+ output_file: str,
+) -> list[dict[str, str]]:
+ limiter = _AsyncRateLimiter(requests_per_minute)
+ all_results: list[dict[str, str]] = []
+ write_lock = asyncio.Lock()
+
+ prompts = build_all_requests(questions, tables, prompt_builder)
+
+ async def process_question(q: dict[str, Any], prompt: str) -> dict[str, str]:
+ await limiter.acquire()
+
+ completion = await inference_callable(
+ prompt,
+ question=q,
+ table=tables[q["table_id"]],
+ **function_kwargs,
+ )
+
+ result = {
+ "question_id": q.get("question_id", q.get("id", "")),
+ "completion": str(completion),
+ }
+
+ async with write_lock:
+ save_jsonl_lines(output_file, [result])
+
+ return result
+
+ tasks = [process_question(q, p) for q, p in zip(questions, prompts, strict=False)]
+ for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Generating"):
+ all_results.append(await coro)
+
+ return all_results
+
+
+def inference_function(
+ *,
+ inference_function: Callable[..., Awaitable[str]],
+ requests_per_minute: float | None = None,
+ function_kwargs: dict[str, Any] | None = None,
+ version: Literal["1.0", "2.0"] = DEFAULT_LLMSQL_VERSION,
+ output_file: str = "llm_sql_predictions.jsonl",
+ workdir_path: str | None = None,
+ limit: int | float | None = None,
+ num_fewshots: int = 5,
+ seed: int = 42,
+) -> list[dict[str, str]]:
+ """Run SQL generation using a user-provided async callable.
+
+ The callable is awaited as:
+ ``await inference_function(prompt, question=..., table=..., **function_kwargs)``.
+
+ If your callable needs sampling parameters, pass them through ``function_kwargs``.
+ """
+ _setup_seed(seed=seed)
+
+ if not callable(inference_function):
+ raise TypeError("`inference_function` must be callable.")
+
+ function_kwargs = function_kwargs or {}
+ workdir = resolve_workdir_path(workdir_path)
+
+ repo_id = get_repo_id(version)
+ questions_path = _maybe_download(repo_id, "questions.jsonl", workdir)
+ tables_path = _maybe_download(repo_id, "tables.jsonl", workdir)
+
+ questions = load_jsonl(questions_path)
+ tables_list = load_jsonl(tables_path)
+ tables = {t["table_id"]: t for t in tables_list}
+
+ if limit is not None:
+ if isinstance(limit, float):
+ if not (0.0 < limit <= 1.0):
+ raise ValueError(
+ f"When a float, `limit` must be between 0.0 and 1.0, got {limit}."
+ )
+ limit = max(1, int(len(questions) * limit))
+ if not isinstance(limit, int) or limit < 1:
+ raise ValueError(
+ f"`limit` must be a positive integer or a float in (0.0, 1.0], got {limit!r}."
+ )
+ questions = questions[:limit]
+
+ prompt_builder = choose_prompt_builder(num_fewshots)
+
+ overwrite_jsonl(output_file)
+
+ async def _validated_callable(*args: Any, **kwargs: Any) -> str:
+ out = inference_function(*args, **kwargs)
+ if inspect.isawaitable(out):
+ return str(await out)
+ raise TypeError("`inference_function` must return an awaitable value.")
+
+ coro = _inference_function_async(
+ inference_callable=_validated_callable,
+ requests_per_minute=requests_per_minute,
+ function_kwargs=function_kwargs,
+ questions=questions,
+ tables=tables,
+ prompt_builder=prompt_builder,
+ output_file=output_file,
+ )
+
+ try:
+ loop = asyncio.get_running_loop()
+ except RuntimeError:
+ loop = None
+
+ if loop is not None and loop.is_running():
+ nest_asyncio.apply(loop)
+ all_results = loop.run_until_complete(coro)
+ else:
+ all_results = asyncio.run(coro)
+
+ log.info(f"Generation completed. {len(all_results)} results saved to {output_file}")
+ return all_results
diff --git a/llmsql/utils/utils.py b/llmsql/utils/utils.py
index fc9b1ab..0d62e84 100644
--- a/llmsql/utils/utils.py
+++ b/llmsql/utils/utils.py
@@ -2,6 +2,8 @@
import json
from pathlib import Path
+from transformers import AutoTokenizer
+
from llmsql.loggers.logging_config import log
from llmsql.prompts.prompts import (
build_prompt_0shot,
@@ -65,7 +67,7 @@ def build_all_requests(
questions: list[dict],
tables: dict,
prompt_builder: Callable[[str, list[str], list[str], list[str | float | int]], str],
- tokenizer=None,
+ tokenizer: AutoTokenizer = None,
use_chat_template: bool = True,
) -> list[str]:
"""
diff --git a/tests/inference/test_inference_function.py b/tests/inference/test_inference_function.py
new file mode 100644
index 0000000..1ccf74e
--- /dev/null
+++ b/tests/inference/test_inference_function.py
@@ -0,0 +1,197 @@
+"""Tests for the async inference_function implementation."""
+
+from __future__ import annotations
+
+import asyncio
+import json
+from pathlib import Path
+
+import pytest
+
+from llmsql.inference.inference_function import inference_function
+
+
+def _write_jsonl(path: Path, rows: list[dict]) -> None:
+ path.write_text("\n".join(json.dumps(r) for r in rows) + "\n")
+
+
+def _make_fixtures(tmp_path: Path) -> tuple[Path, Path, Path]:
+ questions = [
+ {"question_id": "q1", "question": "What is 1+1?", "table_id": "t1"},
+ {"question_id": "q2", "question": "What is 2+2?", "table_id": "t1"},
+ ]
+ tables = [{"table_id": "t1", "header": ["col"], "types": ["text"], "rows": [["x"]]}]
+ qpath = tmp_path / "questions.jsonl"
+ tpath = tmp_path / "tables.jsonl"
+ outpath = tmp_path / "out.jsonl"
+ _write_jsonl(qpath, questions)
+ _write_jsonl(tpath, tables)
+ return qpath, tpath, outpath
+
+
+def test_runs_with_custom_async_callable(tmp_path):
+ _make_fixtures(tmp_path)
+
+ async def fake_infer(prompt, **kwargs):
+ assert isinstance(prompt, str)
+ assert kwargs["temperature"] == 0.0
+ assert "question" in kwargs
+ assert "table" in kwargs
+ return "SELECT 1"
+
+ results = inference_function(
+ inference_function=fake_infer,
+ function_kwargs={"temperature": 0.0},
+ output_file=str(tmp_path / "out.jsonl"),
+ workdir_path=str(tmp_path),
+ )
+
+ assert len(results) == 2
+ assert all(r["completion"] == "SELECT 1" for r in results)
+
+
+def test_limit_works(tmp_path):
+ _make_fixtures(tmp_path)
+
+ async def fake_infer(prompt, **kwargs):
+ return "SELECT 1"
+
+ results = inference_function(
+ inference_function=fake_infer,
+ output_file=str(tmp_path / "out.jsonl"),
+ workdir_path=str(tmp_path),
+ limit=1,
+ )
+
+ assert len(results) == 1
+
+
+def test_rejects_non_async_result(tmp_path):
+ _make_fixtures(tmp_path)
+
+ def bad_infer(prompt, **kwargs):
+ return "not-awaitable"
+
+ with pytest.raises(TypeError, match="must return an awaitable"):
+ inference_function(
+ inference_function=bad_infer, # type: ignore[arg-type]
+ output_file=str(tmp_path / "out.jsonl"),
+ workdir_path=str(tmp_path),
+ )
+
+
+def test_rate_limiter_rejects_non_positive_rpm():
+ from llmsql.inference.inference_function import _AsyncRateLimiter
+
+ with pytest.raises(ValueError, match="requests_per_minute must be > 0"):
+ _AsyncRateLimiter(0)
+
+ with pytest.raises(ValueError):
+ _AsyncRateLimiter(-5)
+
+
+@pytest.mark.asyncio
+async def test_rate_limiter_waits(monkeypatch):
+ from llmsql.inference.inference_function import _AsyncRateLimiter
+
+ limiter = _AsyncRateLimiter(requests_per_minute=60) # 1 request/sec
+
+ sleep_calls = []
+
+ async def fake_sleep(duration):
+ sleep_calls.append(duration)
+
+ monkeypatch.setattr(asyncio, "sleep", fake_sleep)
+
+ # First call should not sleep
+ await limiter.acquire()
+ # Second call should trigger wait
+ await limiter.acquire()
+
+ assert len(sleep_calls) == 1
+ assert sleep_calls[0] > 0
+
+
+def test_rejects_non_callable_inference_function(tmp_path):
+ _make_fixtures(tmp_path)
+
+ with pytest.raises(TypeError, match="must be callable"):
+ inference_function(
+ inference_function="not-a-function", # type: ignore
+ output_file=str(tmp_path / "out.jsonl"),
+ workdir_path=str(tmp_path),
+ )
+
+
+@pytest.mark.parametrize("bad_limit", [0.0, -0.1, 1.5])
+def test_limit_float_out_of_range(tmp_path, bad_limit):
+ _make_fixtures(tmp_path)
+
+ async def fake_infer(prompt, **kwargs):
+ return "SELECT 1"
+
+ with pytest.raises(ValueError, match="must be between 0.0 and 1.0"):
+ inference_function(
+ inference_function=fake_infer,
+ output_file=str(tmp_path / "out.jsonl"),
+ workdir_path=str(tmp_path),
+ limit=bad_limit,
+ )
+
+
+@pytest.mark.parametrize("bad_limit", [0, -1, "foo", None])
+def test_limit_invalid_type_or_value(tmp_path, bad_limit):
+ _make_fixtures(tmp_path)
+
+ async def fake_infer(prompt, **kwargs):
+ return "SELECT 1"
+
+ if bad_limit is None:
+ # None is valid (means no limit), skip
+ return
+
+ with pytest.raises(ValueError, match="must be a positive integer"):
+ inference_function(
+ inference_function=fake_infer,
+ output_file=str(tmp_path / "out.jsonl"),
+ workdir_path=str(tmp_path),
+ limit=bad_limit, # type: ignore[arg-type]
+ )
+
+
+def test_runs_inside_existing_event_loop(monkeypatch, tmp_path):
+ _make_fixtures(tmp_path)
+
+ async def fake_infer(prompt, **kwargs):
+ return "SELECT 1"
+
+ applied = {"called": False}
+
+ def fake_apply(loop):
+ applied["called"] = True
+
+ monkeypatch.setattr("nest_asyncio.apply", fake_apply)
+
+ # Create a real loop we control
+ real_loop = asyncio.new_event_loop()
+
+ class FakeLoop:
+ def is_running(self):
+ return True
+
+ def run_until_complete(self, coro):
+ return real_loop.run_until_complete(coro)
+
+ monkeypatch.setattr(asyncio, "get_running_loop", lambda: FakeLoop())
+
+ try:
+ results = inference_function(
+ inference_function=fake_infer,
+ output_file=str(tmp_path / "out.jsonl"),
+ workdir_path=str(tmp_path),
+ )
+ finally:
+ real_loop.close()
+
+ assert applied["called"] is True
+ assert len(results) > 0
diff --git a/tests/test_init.py b/tests/test_init.py
index 4d8a473..34d59ba 100644
--- a/tests/test_init.py
+++ b/tests/test_init.py
@@ -26,6 +26,12 @@ def test_lazy_import_inference_transformers(self) -> None:
assert inference_transformers is not None
+ def test_lazy_import_inference_function(self) -> None:
+ """Test that inference_function can be imported via lazy loading."""
+ from llmsql import inference_function
+
+ assert inference_function is not None
+
def test_invalid_attribute_raises_error(self) -> None:
"""Test that accessing invalid attribute raises AttributeError."""
import llmsql