diff --git a/llmsql/inference/inference_api.py b/llmsql/inference/inference_api.py index c01f0f9..d9dfabb 100644 --- a/llmsql/inference/inference_api.py +++ b/llmsql/inference/inference_api.py @@ -27,6 +27,7 @@ from llmsql.loggers.logging_config import log from llmsql.utils.inference_utils import _maybe_download, _setup_seed from llmsql.utils.utils import ( + build_all_requests, choose_prompt_builder, load_jsonl, overwrite_jsonl, @@ -114,13 +115,10 @@ async def _inference_api_async( async with aiohttp.ClientSession(headers=headers) as session: - async def process_question(q: dict[str, Any]) -> dict[str, str]: - tbl = tables[q["table_id"]] - example_row = tbl["rows"][0] if tbl["rows"] else [] - prompt = prompt_builder( - q["question"], tbl["header"], tbl["types"], example_row - ) + # Pre-build all prompts using the shared function + prompts = build_all_requests(questions, tables, prompt_builder) + async def process_question(q: dict[str, Any], prompt: str) -> dict[str, str]: payload = { "model": model_name, "messages": [ @@ -152,7 +150,7 @@ async def process_question(q: dict[str, Any]) -> dict[str, str]: return result - tasks = [process_question(q) for q in questions] + tasks = [process_question(q, p) for q, p in zip(questions, prompts)] for coro in tqdm( asyncio.as_completed(tasks), total=len(tasks), diff --git a/llmsql/inference/inference_vllm.py b/llmsql/inference/inference_vllm.py index f661c6b..a6c00cf 100644 --- a/llmsql/inference/inference_vllm.py +++ b/llmsql/inference/inference_vllm.py @@ -56,6 +56,7 @@ from llmsql.loggers.logging_config import log from llmsql.utils.inference_utils import _maybe_download, _setup_seed from llmsql.utils.utils import ( + build_all_requests, choose_prompt_builder, load_jsonl, overwrite_jsonl, @@ -201,37 +202,27 @@ def inference_vllm( sampling_params = SamplingParams(**sampling_params_args) + # --- build all requests --- + prompts = build_all_requests( + questions, + tables, + prompt_builder, + tokenizer=tokenizer if use_chat_template else None, + use_chat_template=bool(use_chat_template), + ) + # --- main inference loop --- all_results: list[dict[str, str]] = [] total = len(questions) for batch_start in tqdm(range(0, total, batch_size), desc="Generating"): - batch = questions[batch_start : batch_start + batch_size] - - prompts = [] - for q in batch: - tbl = tables[q["table_id"]] - example_row = tbl["rows"][0] if tbl["rows"] else [] - - raw_text = prompt_builder( - q["question"], tbl["header"], tbl["types"], example_row - ) - - if use_chat_template: - messages = [{"role": "user", "content": raw_text}] - - final_prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - else: - final_prompt = raw_text - - prompts.append(final_prompt) + batch_prompts = prompts[batch_start : batch_start + batch_size] + batch_questions = questions[batch_start : batch_start + batch_size] - outputs = llm.generate(prompts, sampling_params) + outputs = llm.generate(batch_prompts, sampling_params) batch_results: list[dict[str, str]] = [] - for q, out in zip(batch, outputs, strict=False): + for q, out in zip(batch_questions, outputs, strict=False): text = out.outputs[0].text batch_results.append( { diff --git a/llmsql/utils/utils.py b/llmsql/utils/utils.py index d1f248d..fc9b1ab 100644 --- a/llmsql/utils/utils.py +++ b/llmsql/utils/utils.py @@ -59,3 +59,45 @@ def choose_prompt_builder( if shots == 5: return build_prompt_5shot raise ValueError("shots must be one of {0, 1, 5}") + + +def build_all_requests( + questions: list[dict], + tables: dict, + prompt_builder: Callable[[str, list[str], list[str], list[str | float | int]], str], + tokenizer=None, + use_chat_template: bool = True, +) -> list[str]: + """ + Build all prompts from questions and tables. + + Args: + questions: List of question dicts with 'question' and 'table_id' keys. + tables: Dict mapping table_id to table metadata (with 'header', 'types', 'rows'). + prompt_builder: Function to build raw prompt text. + tokenizer: Optional tokenizer with apply_chat_template method. + use_chat_template: Whether to apply chat template (if tokenizer provided). + + Returns: + List of final prompts (with chat template applied if requested). + """ + prompts = [] + for q in questions: + tbl = tables[q["table_id"]] + example_row = tbl["rows"][0] if tbl["rows"] else [] + + raw_text = prompt_builder( + q["question"], tbl["header"], tbl["types"], example_row + ) + + if tokenizer and use_chat_template: + messages = [{"role": "user", "content": raw_text}] + final_prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + else: + final_prompt = raw_text + + prompts.append(final_prompt) + + return prompts