Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions notebooks/exploration.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,6 @@
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import os\n",
"from tqdm import tqdm\n",
"import pickle\n",
"from loguru import logger\n",
"import json"
]
},
{
"cell_type": "code",
"execution_count": 3,
Expand All @@ -44,12 +30,28 @@
}
],
"source": [
"import os\n",
"\n",
"# Change path to project root\n",
"if os.getcwd().endswith(\"notebooks\"):\n",
" os.chdir(os.path.dirname(os.getcwd()))\n",
"print(os.getcwd())"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import os\n",
"from tqdm import tqdm\n",
"import pickle\n",
"from loguru import logger\n",
"import json"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
139 changes: 139 additions & 0 deletions notebooks/fuser_nb.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n",
"/Users/shloknatarajan/stanford/research/daneshjou/AutoGKB\n"
]
}
],
"source": [
"# Notebook Setup\n",
"# Run this cell: \n",
"# The lines below will instruct jupyter to reload imported modules before \n",
"# executing code cells. This enables you to quickly iterate and test revisions\n",
"# to your code without having to restart the kernel and reload all of your \n",
"# modules each time you make a code change in a separate python file.\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import os\n",
"\n",
"# Change path to project root\n",
"if os.getcwd().endswith(\"notebooks\"):\n",
" os.chdir(os.path.dirname(os.getcwd()))\n",
"print(os.getcwd())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Testing / Debugging the Fuser Module\n",
"Goal is to have generators output many possible samples (JSON) and the Fusers are used to either: \n",
"- Merge into one response\n",
"- Merge into a reasonable set of responses (mostly deduplication and outlier/weirdness removing)\n",
"- Merge into a set that is majority vote / somewhat rule inclined"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"from src.inference import Generator, Fuser\n",
"from pydantic import BaseModel\n",
"from typing import List"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"class StockPrice(BaseModel):\n",
" ticker: str\n",
" price: float\n",
"\n",
"class StockPriceList(BaseModel):\n",
" stock_prices: List[StockPrice]"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Generating 10 Responses: 0%| | 0/10 [00:00<?, ?it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Generating 10 Responses: 100%|██████████| 10/10 [00:19<00:00, 1.92s/it]\n"
]
}
],
"source": [
"generator = Generator(model=\"gpt-4o\", samples=10)\n",
"responses = generator.generate(input_prompt=\"Give me a set of 5 stocks and an estimate of their exact share price in dollars.\", response_format=StockPriceList)"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
"fuser = Fuser(model=\"gpt-4o\", temperature=0.1)\n",
"fused_responses_10 = fuser.generate(input_prompt=responses,response_format=StockPriceList)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fused_responses_10"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "default",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
20 changes: 12 additions & 8 deletions src/components/all_associations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from src.inference import Generator
from src.inference import Generator, Fuser
from src.variants import QuotedStr
from src.prompts import GeneratorPrompt, ArticlePrompt
from src.utils import get_article_text
Expand Down Expand Up @@ -85,12 +85,14 @@ def get_all_associations(article_text: str) -> List[Dict]:
),
output_format_structure=VariantAssociationList,
).get_hydrated_prompt()
generator = Generator(model="gpt-4o")
response = generator.generate(prompt)
if isinstance(response, dict):
response = VariantAssociationList(**response)
return response.association_list
return response
generator = Generator(model="gpt-4o", samples=2)
responses = generator.generate(prompt)
logger.info(f"Fusing {len(responses)} Responses")

fuser = Fuser(model="gpt-4o", temperature=0.1)
fused_response = fuser.generate(responses, response_format=VariantAssociationList)

return fused_response.association_list


def test_all_associations():
Expand All @@ -105,7 +107,9 @@ def test_all_associations():
file_path = f"data/extractions/all_associations/{pmcid}.jsonl"
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "w") as f:
json.dump(associations, f, indent=4)
json.dump(
[assoc.model_dump(mode="json") for assoc in associations], f, indent=4
)
logger.info(f"Saved to file {file_path}")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def run(self, save_path: str = "data/extractions"):
os.makedirs(os.path.dirname(file_path), exist_ok=True)
try:
with open(file_path, "w") as f:
json.dump(final_structure, f, indent=4)
json.dump(final_structure, f, indent=4, default=lambda obj: obj.model_dump() if hasattr(obj, 'model_dump') else str(obj))
logger.info(f"Saved annotations to {file_path}")
except Exception as e:
logger.error(f"Error saving annotations: {e}")
Expand Down
76 changes: 41 additions & 35 deletions src/inference.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
from loguru import logger
import litellm
from typing import List, Optional, Union
Expand All @@ -6,6 +7,8 @@
from abc import ABC, abstractmethod
from src.prompts import HydratedPrompt
import json
from src.utils import parse_structured_response
from tqdm import tqdm

load_dotenv()

Expand Down Expand Up @@ -83,10 +86,13 @@ class Generator(LLMInterface):

debug_mode = False

def __init__(self, model: str = "gpt-4o-mini", temperature: float = 0.1):
def __init__(
self, model: str = "gpt-4o-mini", temperature: float = 0.1, samples: int = 1
):
super().__init__(model, temperature)
if self.debug_mode:
litellm.set_verbose = True
self.samples = samples

def _generate_single(
self,
Expand Down Expand Up @@ -131,28 +137,20 @@ def _generate_single(
logger.error(f"Error generating response: {e}")
raise e
response_content = response.choices[0].message.content
if isinstance(response_content, str) and response_format is not None:
try:
response_content = json.loads(response_content)
except:
logger.warning(
f"Response content was not a valid JSON string. Returning string"
)
return response_content
return parse_structured_response(response_content, response_format)

def generate(
self,
input_prompt: str,
system_prompt: Optional[str] = None,
temperature: Optional[float] = None,
response_format: Optional[BaseModel] = None,
samples: Optional[int] = 1,
) -> LMResponse:
"""
Generate a response from the LLM.
"""
responses = []
for _ in range(samples):
for _ in tqdm(range(self.samples), desc=f"Generating {self.samples} Responses"):
response = self._generate_single(
input_prompt=input_prompt,
system_prompt=system_prompt,
Expand All @@ -163,7 +161,7 @@ def generate(
if len(responses) == 1:
return responses[0]

return responses
return parse_structured_response(responses, response_format)


class Parser(LLMInterface):
Expand Down Expand Up @@ -208,10 +206,15 @@ def generate(
except Exception as e:
logger.error(f"Error generating response: {e}")
raise e
return response.choices[0].message.content
raw_response = response.choices[0].message.content
return parse_structured_response(raw_response, response_format)


class Fuser(LLMInterface):
"""
Fuser Class
Used to fuse multiple responses into a final set of responses, removing duplicates and unreasonable responses
"""

debug_mode = False

Expand All @@ -220,37 +223,40 @@ def __init__(self, model: str = "gpt-4o-mini", temperature: float = 0.1):
if self.debug_mode:
litellm.set_verbose = True

self.system_prompt = (
"You are a helpful assistant who fuses multiple responses into a comprehensive final response. You will "
"be given a list of responses and you will merge the responses into a final set of responses while removing "
"duplicates, responses that are extremely similar, and responses that are not reasonable."
)

def generate(
self,
input_prompt: str,
input_prompt: str | List[str],
system_prompt: Optional[str] = None,
temperature: Optional[float] = None,
response_format: Optional[BaseModel] = None,
) -> LMResponse:
temp = temperature if temperature is not None else self.temperature
# Check if system prompt is provided
if system_prompt is not None and system_prompt != "":
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": input_prompt},
]
else:
logger.warning("")
messages = [
{
"role": "system",
"content": "You are a helpful assistant who fuses multiple responses into a comprehensive final response",
},
{"role": "user", "content": input_prompt},
]
self.system_prompt = system_prompt
messages = [
{
"role": "system",
"content": self.system_prompt,
},
{"role": "user", "content": f"Here are the responses: {input_prompt}"},
]
try:
response = litellm.completion(
model=self.model,
messages=messages,
response_format=response_format,
temperature=temp,
)
completion_kwargs = {
"model": self.model,
"messages": messages,
"temperature": temp,
}
if response_format is not None:
completion_kwargs["response_format"] = response_format
response = litellm.completion(**completion_kwargs)
except Exception as e:
logger.error(f"Error generating response: {e}")
raise e
return response.choices[0].message.content
raw_response = response.choices[0].message.content
return parse_structured_response(raw_response, response_format)
Loading
Loading