From f2d64d2613ed8942cdc3a3235ec1d41671e3a611 Mon Sep 17 00:00:00 2001 From: Shlok Natarajan Date: Wed, 2 Jul 2025 16:54:09 +0200 Subject: [PATCH 1/4] feat: better structured outputs --- notebooks/exploration.ipynb | 30 +++--- notebooks/fuser_nb.ipynb | 159 +++++++++++++++++++++++++++++ src/components/all_associations.py | 16 +-- src/inference.py | 53 ++++++---- src/utils.py | 30 ++++++ 5 files changed, 246 insertions(+), 42 deletions(-) create mode 100644 notebooks/fuser_nb.ipynb diff --git a/notebooks/exploration.ipynb b/notebooks/exploration.ipynb index 4f52ae8..5bd78f2 100644 --- a/notebooks/exploration.ipynb +++ b/notebooks/exploration.ipynb @@ -16,20 +16,6 @@ "%autoreload 2" ] }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd\n", - "import os\n", - "from tqdm import tqdm\n", - "import pickle\n", - "from loguru import logger\n", - "import json" - ] - }, { "cell_type": "code", "execution_count": 3, @@ -44,12 +30,28 @@ } ], "source": [ + "import os\n", + "\n", "# Change path to project root\n", "if os.getcwd().endswith(\"notebooks\"):\n", " os.chdir(os.path.dirname(os.getcwd()))\n", "print(os.getcwd())" ] }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import os\n", + "from tqdm import tqdm\n", + "import pickle\n", + "from loguru import logger\n", + "import json" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/notebooks/fuser_nb.ipynb b/notebooks/fuser_nb.ipynb new file mode 100644 index 0000000..35e96db --- /dev/null +++ b/notebooks/fuser_nb.ipynb @@ -0,0 +1,159 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n", + "/Users/shloknatarajan/stanford/research/daneshjou/AutoGKB\n" + ] + } + ], + "source": [ + "# Notebook Setup\n", + "# Run this cell: \n", + "# The lines below will instruct jupyter to reload imported modules before \n", + "# executing code cells. This enables you to quickly iterate and test revisions\n", + "# to your code without having to restart the kernel and reload all of your \n", + "# modules each time you make a code change in a separate python file.\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import os\n", + "\n", + "# Change path to project root\n", + "if os.getcwd().endswith(\"notebooks\"):\n", + " os.chdir(os.path.dirname(os.getcwd()))\n", + "print(os.getcwd())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Testing / Debugging the Fuser Module\n", + "Goal is to have generators output many possible samples (JSON) and the Fusers are used to either: \n", + "- Merge into one response\n", + "- Merge into a reasonable set of responses (mostly deduplication and outlier/weirdness removing)\n", + "- Merge into a set that is majority vote / somewhat rule inclined" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [], + "source": [ + "from src.inference import Generator, Fuser\n", + "from pydantic import BaseModel\n", + "from typing import List" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [], + "source": [ + "class StockPrice(BaseModel):\n", + " ticker: str\n", + " price: float\n", + "\n", + "class StockPriceList(BaseModel):\n", + " stock_prices: List[StockPrice]" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating 10 Responses: 0%| | 0/10 [00:00 List[Dict]: output_format_structure=VariantAssociationList, ).get_hydrated_prompt() generator = Generator(model="gpt-4o") - response = generator.generate(prompt) - if isinstance(response, dict): - response = VariantAssociationList(**response) - return response.association_list - return response + responses = generator.generate(prompt, samples=10) + + fuser = Fuser(model="gpt-4o", temperature=0.1) + fused_response = fuser.generate(responses, response_format=VariantAssociationList) + + if isinstance(fused_response, dict): + fused_response = VariantAssociationList(**fused_response) + return fused_response.association_list + return fused_response def test_all_associations(): diff --git a/src/inference.py b/src/inference.py index a95ab74..7038df0 100644 --- a/src/inference.py +++ b/src/inference.py @@ -1,3 +1,4 @@ +import enum from loguru import logger import litellm from typing import List, Optional, Union @@ -6,6 +7,8 @@ from abc import ABC, abstractmethod from src.prompts import HydratedPrompt import json +from src.utils import parse_structured_response +from tqdm import tqdm load_dotenv() @@ -83,11 +86,12 @@ class Generator(LLMInterface): debug_mode = False - def __init__(self, model: str = "gpt-4o-mini", temperature: float = 0.1): + def __init__(self, model: str = "gpt-4o-mini", temperature: float = 0.1, samples: int = 1): super().__init__(model, temperature) if self.debug_mode: litellm.set_verbose = True - + self.samples = samples + def _generate_single( self, input_prompt: str | HydratedPrompt, @@ -146,13 +150,12 @@ def generate( system_prompt: Optional[str] = None, temperature: Optional[float] = None, response_format: Optional[BaseModel] = None, - samples: Optional[int] = 1, ) -> LMResponse: """ Generate a response from the LLM. """ responses = [] - for _ in range(samples): + for _ in tqdm(range(self.samples), desc=f"Generating {self.samples} Responses"): response = self._generate_single( input_prompt=input_prompt, system_prompt=system_prompt, @@ -163,7 +166,7 @@ def generate( if len(responses) == 1: return responses[0] - return responses + return parse_structured_response(responses, response_format) class Parser(LLMInterface): @@ -208,10 +211,15 @@ def generate( except Exception as e: logger.error(f"Error generating response: {e}") raise e - return response.choices[0].message.content + raw_response = response.choices[0].message.content + return parse_structured_response(raw_response, response_format) class Fuser(LLMInterface): + """ + Fuser Class + Used to fuse multiple responses into a final set of responses, removing duplicates and unreasonable responses + """ debug_mode = False @@ -220,29 +228,29 @@ def __init__(self, model: str = "gpt-4o-mini", temperature: float = 0.1): if self.debug_mode: litellm.set_verbose = True + self.system_prompt = ( + "You are a helpful assistant who fuses multiple responses into a comprehensive final response. You will " + "be given a list of responses and you will merge the responses into a final set of responses while removing " + "duplicates, responses that are extremely similar, and responses that are not reasonable." + ) + def generate( self, - input_prompt: str, + input_prompt: str | List[str], system_prompt: Optional[str] = None, temperature: Optional[float] = None, response_format: Optional[BaseModel] = None, ) -> LMResponse: temp = temperature if temperature is not None else self.temperature - # Check if system prompt is provided if system_prompt is not None and system_prompt != "": - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": input_prompt}, - ] - else: - logger.warning("") - messages = [ - { - "role": "system", - "content": "You are a helpful assistant who fuses multiple responses into a comprehensive final response", - }, - {"role": "user", "content": input_prompt}, - ] + self.system_prompt = system_prompt + messages = [ + { + "role": "system", + "content": self.system_prompt, + }, + {"role": "user", "content": f"Here are the responses: {input_prompt}"}, + ] try: response = litellm.completion( model=self.model, @@ -253,4 +261,5 @@ def generate( except Exception as e: logger.error(f"Error generating response: {e}") raise e - return response.choices[0].message.content + raw_response = response.choices[0].message.content + return parse_structured_response(raw_response, response_format) \ No newline at end of file diff --git a/src/utils.py b/src/utils.py index 7844266..c519435 100644 --- a/src/utils.py +++ b/src/utils.py @@ -4,6 +4,7 @@ from typing import List, Optional from termcolor import colored from src.article_parser import MarkdownParser +from pydantic import BaseModel, ValidationError _true_variant_cache: Optional[dict] = None @@ -129,3 +130,32 @@ def get_title(markdown_text: str): # remove the # from the title title = title.replace("# ", "") return title + + +def parse_structured_response(raw_response: str | List[str], response_format: BaseModel): + """ + Parse a raw response into a Pydantic model. + """ + + if isinstance(raw_response, list): + try: + parsed_items = [] + for item in raw_response: + if isinstance(item, dict): + # If item is already a dict, validate it directly + parsed_items.append(response_format.model_validate(item)) + elif isinstance(item, str): + # If item is a string, parse as JSON + parsed_items.append(response_format.model_validate_json(item)) + else: + # Convert to JSON string then parse + parsed_items.append(response_format.model_validate_json(json.dumps(item))) + return parsed_items + except ValidationError as e: + logger.error(f"Error parsing response list: {e}. Returning raw response list.") + return raw_response + try: + return response_format.model_validate_json(raw_response) + except ValidationError as e: + logger.error(f"Error parsing response: {e}. Returning raw response.") + return raw_response \ No newline at end of file From 232e41238241f77e7dacf0282d17355b16325774 Mon Sep 17 00:00:00 2001 From: Shlok Natarajan Date: Wed, 2 Jul 2025 17:19:43 +0200 Subject: [PATCH 2/4] feat: working fuser in pipeline + better type serialization --- notebooks/fuser_nb.ipynb | 20 ---------------- src/components/all_associations.py | 12 ++++------ ...ons_pipeline.py => annotation_pipeline.py} | 0 src/inference.py | 23 ++++++++----------- src/utils.py | 4 +++- 5 files changed, 17 insertions(+), 42 deletions(-) rename src/components/{annotations_pipeline.py => annotation_pipeline.py} (100%) diff --git a/notebooks/fuser_nb.ipynb b/notebooks/fuser_nb.ipynb index 35e96db..7cea17f 100644 --- a/notebooks/fuser_nb.ipynb +++ b/notebooks/fuser_nb.ipynb @@ -105,26 +105,6 @@ "fused_responses_10 = fuser.generate(input_prompt=responses,response_format=StockPriceList)" ] }, - { - "cell_type": "code", - "execution_count": 61, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "StockPriceList(stock_prices=[StockPrice(ticker='AAPL', price=175.23), StockPrice(ticker='GOOGL', price=2835.67), StockPrice(ticker='AMZN', price=3456.78), StockPrice(ticker='MSFT', price=299.12), StockPrice(ticker='TSLA', price=759.34)])" - ] - }, - "execution_count": 61, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "fused_responses_10" - ] - }, { "cell_type": "code", "execution_count": null, diff --git a/src/components/all_associations.py b/src/components/all_associations.py index 380d99c..0dab99a 100644 --- a/src/components/all_associations.py +++ b/src/components/all_associations.py @@ -85,16 +85,14 @@ def get_all_associations(article_text: str) -> List[Dict]: ), output_format_structure=VariantAssociationList, ).get_hydrated_prompt() - generator = Generator(model="gpt-4o") - responses = generator.generate(prompt, samples=10) + generator = Generator(model="gpt-4o", samples=2) + responses = generator.generate(prompt) + logger.info(f"Fusing {len(responses)} Responses") fuser = Fuser(model="gpt-4o", temperature=0.1) fused_response = fuser.generate(responses, response_format=VariantAssociationList) - if isinstance(fused_response, dict): - fused_response = VariantAssociationList(**fused_response) - return fused_response.association_list - return fused_response + return fused_response.association_list def test_all_associations(): @@ -109,7 +107,7 @@ def test_all_associations(): file_path = f"data/extractions/all_associations/{pmcid}.jsonl" os.makedirs(os.path.dirname(file_path), exist_ok=True) with open(file_path, "w") as f: - json.dump(associations, f, indent=4) + json.dump([assoc.model_dump(mode='json') for assoc in associations], f, indent=4) logger.info(f"Saved to file {file_path}") diff --git a/src/components/annotations_pipeline.py b/src/components/annotation_pipeline.py similarity index 100% rename from src/components/annotations_pipeline.py rename to src/components/annotation_pipeline.py diff --git a/src/inference.py b/src/inference.py index 7038df0..72b3baf 100644 --- a/src/inference.py +++ b/src/inference.py @@ -135,14 +135,7 @@ def _generate_single( logger.error(f"Error generating response: {e}") raise e response_content = response.choices[0].message.content - if isinstance(response_content, str) and response_format is not None: - try: - response_content = json.loads(response_content) - except: - logger.warning( - f"Response content was not a valid JSON string. Returning string" - ) - return response_content + return parse_structured_response(response_content, response_format) def generate( self, @@ -252,12 +245,14 @@ def generate( {"role": "user", "content": f"Here are the responses: {input_prompt}"}, ] try: - response = litellm.completion( - model=self.model, - messages=messages, - response_format=response_format, - temperature=temp, - ) + completion_kwargs = { + "model": self.model, + "messages": messages, + "temperature": temp, + } + if response_format is not None: + completion_kwargs["response_format"] = response_format + response = litellm.completion(**completion_kwargs) except Exception as e: logger.error(f"Error generating response: {e}") raise e diff --git a/src/utils.py b/src/utils.py index c519435..362d6e8 100644 --- a/src/utils.py +++ b/src/utils.py @@ -132,10 +132,12 @@ def get_title(markdown_text: str): return title -def parse_structured_response(raw_response: str | List[str], response_format: BaseModel): +def parse_structured_response(raw_response: str | List[str], response_format: Optional[BaseModel]): """ Parse a raw response into a Pydantic model. """ + if response_format is None: + return raw_response if isinstance(raw_response, list): try: From b6339342226b7693a65f69b125ea62859e0e3781 Mon Sep 17 00:00:00 2001 From: Shlok Natarajan Date: Wed, 2 Jul 2025 17:20:09 +0200 Subject: [PATCH 3/4] chore: black format --- src/components/all_associations.py | 4 +++- src/inference.py | 8 +++++--- src/utils.py | 16 +++++++++++----- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/components/all_associations.py b/src/components/all_associations.py index 0dab99a..3c3533e 100644 --- a/src/components/all_associations.py +++ b/src/components/all_associations.py @@ -107,7 +107,9 @@ def test_all_associations(): file_path = f"data/extractions/all_associations/{pmcid}.jsonl" os.makedirs(os.path.dirname(file_path), exist_ok=True) with open(file_path, "w") as f: - json.dump([assoc.model_dump(mode='json') for assoc in associations], f, indent=4) + json.dump( + [assoc.model_dump(mode="json") for assoc in associations], f, indent=4 + ) logger.info(f"Saved to file {file_path}") diff --git a/src/inference.py b/src/inference.py index 72b3baf..6eda5c9 100644 --- a/src/inference.py +++ b/src/inference.py @@ -86,12 +86,14 @@ class Generator(LLMInterface): debug_mode = False - def __init__(self, model: str = "gpt-4o-mini", temperature: float = 0.1, samples: int = 1): + def __init__( + self, model: str = "gpt-4o-mini", temperature: float = 0.1, samples: int = 1 + ): super().__init__(model, temperature) if self.debug_mode: litellm.set_verbose = True self.samples = samples - + def _generate_single( self, input_prompt: str | HydratedPrompt, @@ -257,4 +259,4 @@ def generate( logger.error(f"Error generating response: {e}") raise e raw_response = response.choices[0].message.content - return parse_structured_response(raw_response, response_format) \ No newline at end of file + return parse_structured_response(raw_response, response_format) diff --git a/src/utils.py b/src/utils.py index 362d6e8..7c6db9b 100644 --- a/src/utils.py +++ b/src/utils.py @@ -132,13 +132,15 @@ def get_title(markdown_text: str): return title -def parse_structured_response(raw_response: str | List[str], response_format: Optional[BaseModel]): +def parse_structured_response( + raw_response: str | List[str], response_format: Optional[BaseModel] +): """ Parse a raw response into a Pydantic model. """ if response_format is None: return raw_response - + if isinstance(raw_response, list): try: parsed_items = [] @@ -151,13 +153,17 @@ def parse_structured_response(raw_response: str | List[str], response_format: Op parsed_items.append(response_format.model_validate_json(item)) else: # Convert to JSON string then parse - parsed_items.append(response_format.model_validate_json(json.dumps(item))) + parsed_items.append( + response_format.model_validate_json(json.dumps(item)) + ) return parsed_items except ValidationError as e: - logger.error(f"Error parsing response list: {e}. Returning raw response list.") + logger.error( + f"Error parsing response list: {e}. Returning raw response list." + ) return raw_response try: return response_format.model_validate_json(raw_response) except ValidationError as e: logger.error(f"Error parsing response: {e}. Returning raw response.") - return raw_response \ No newline at end of file + return raw_response From 953d77c61c676656b46672efa6af66ccbed12351 Mon Sep 17 00:00:00 2001 From: Shlok Natarajan Date: Wed, 2 Jul 2025 17:26:47 +0200 Subject: [PATCH 4/4] feat: working pipeline run --- src/components/annotation_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/components/annotation_pipeline.py b/src/components/annotation_pipeline.py index 8e6ba13..b7bfae8 100644 --- a/src/components/annotation_pipeline.py +++ b/src/components/annotation_pipeline.py @@ -70,7 +70,7 @@ def run(self, save_path: str = "data/extractions"): os.makedirs(os.path.dirname(file_path), exist_ok=True) try: with open(file_path, "w") as f: - json.dump(final_structure, f, indent=4) + json.dump(final_structure, f, indent=4, default=lambda obj: obj.model_dump() if hasattr(obj, 'model_dump') else str(obj)) logger.info(f"Saved annotations to {file_path}") except Exception as e: logger.error(f"Error saving annotations: {e}")