Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions src/annotation_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,23 @@ def add_citations(self):
f"Adding Citations to Study Parameters using OneShotCitations with model {self.citation_model}"
)
for field_name in self.study_parameters.__class__.model_fields:
if (
field_name != "additional_resource_links"
): # Skip non-ParameterWithCitations field
if field_name != "additional_resource_links":
param_content = getattr(self.study_parameters, field_name)
if hasattr(param_content, "content"):
citations = (
self.one_shot_citations.get_study_parameter_citations(
field_name,
param_content.content,
model=self.citation_model,
)

if field_name in ["participant_info", "study_design", "study_results"]:
if hasattr(param_content, "items"):
for item in param_content.items:
citations = self.one_shot_citations.get_study_parameter_item_citations(
field_name,
item.content,
model=self.citation_model,
)
item.citations = citations
elif hasattr(param_content, "content"):
citations = self.one_shot_citations.get_study_parameter_citations(
field_name,
param_content.content,
model=self.citation_model,
)
param_content.citations = citations
else:
Expand Down
74 changes: 57 additions & 17 deletions src/citations/line_citation_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,10 +607,58 @@ def _get_top_citations_for_parameter(

return top_sentences

def _get_top_citations_for_parameter_item(
self, item_content: str, parameter_type: str, top_k: int = 2
) -> List[str]:
"""
Find the top K most relevant sentences for a specific study parameter item.

Args:
item_content: The content of the specific item to find citations for
parameter_type: The type of parameter (participant_info, study_design, etc.)
top_k: Number of top sentences to return

Returns:
List of top relevant sentences
"""
candidate_sentences = self.sentences

logger.info(
f"Scoring all {len(candidate_sentences)} sentences for {parameter_type} item"
)

sentence_scores = []
for sentence in candidate_sentences:
score = self._score_sentence_for_study_param(
sentence, item_content, parameter_type
)
sentence_scores.append((sentence, score))

sentence_scores.sort(key=lambda x: x[1], reverse=True)
candidate_sentences = [item[0] for item in sentence_scores[: top_k * 3]]
filtered_sentences = self._remove_duplicates(candidate_sentences)

top_sentences = filtered_sentences[:top_k]

if len(top_sentences) < top_k:
remaining_needed = top_k - len(top_sentences)
for sentence, score in sentence_scores[top_k * 3 :]:
is_duplicate = any(
self._is_duplicate_citation(sentence, existing)
for existing in top_sentences
)
if not is_duplicate:
top_sentences.append(sentence)
remaining_needed -= 1
if remaining_needed == 0:
break

return top_sentences

def add_citations_to_study_parameters(self, study_parameters):
"""
Add citations to study parameters by finding relevant sentences for each parameter.
Modifies the parameters to include citations nested within each parameter key.
Now handles item-level citations for participant_info, study_design, and study_results.

Args:
study_parameters: StudyParameters object
Expand All @@ -620,10 +668,16 @@ def add_citations_to_study_parameters(self, study_parameters):
"""
logger.info("Adding citations to study parameters")

# Create a new study parameters object with citations
updated_params = study_parameters.model_copy(deep=True)

# Add citations nested within each parameter
for field_name in ["participant_info", "study_design", "study_results"]:
param_obj = getattr(updated_params, field_name)
if hasattr(param_obj, 'items'):
for item in param_obj.items:
item.citations = self._get_top_citations_for_parameter_item(
item.content, field_name
)

updated_params.summary.citations = self._get_top_citations_for_parameter(
study_parameters.summary.content, "summary"
)
Expand All @@ -632,20 +686,6 @@ def add_citations_to_study_parameters(self, study_parameters):
study_parameters.study_type.content, "study_type"
)

updated_params.participant_info.citations = (
self._get_top_citations_for_parameter(
study_parameters.participant_info.content, "participant_info"
)
)

updated_params.study_design.citations = self._get_top_citations_for_parameter(
study_parameters.study_design.content, "study_design"
)

updated_params.study_results.citations = self._get_top_citations_for_parameter(
study_parameters.study_results.content, "study_results"
)

updated_params.allele_frequency.citations = (
self._get_top_citations_for_parameter(
study_parameters.allele_frequency.content, "allele_frequency"
Expand Down
45 changes: 45 additions & 0 deletions src/citations/one_shot_citations.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,51 @@ def get_study_parameter_citations(
logger.error(f"Error getting citations for parameter: {e}")
return []

def get_study_parameter_item_citations(
self, parameter_type: str, item_content: str, model: str = "openai/gpt-4.1"
) -> List[str]:
"""
Get citations for a single study parameter item using the whole article text.

Args:
parameter_type: The type of parameter (participant_info, study_design, etc.)
item_content: The content of the specific item to find citations for
model: The language model to use for citation generation

Returns:
List of top 2 most relevant sentences for this specific item
"""
prompt = f"""
Parameter Type: {parameter_type}
Specific Item Content: {item_content}

From the following article text, find the top 2 sentences from the article that are most relevant to and support this specific item content.
Article text:
"{self.article_text}"

If a table provides the support warranting of being in the top 2, return the table header (## Table X: ..., etc.) as your sentence.
Output the exact sentences from the article text in a numbered list with each sentence on a new line. No other text.
Keep in mind that headings are text/numbers preceded by hash symbols (#) and should not be included in citations unless referencing the table. Only include content sentences.
"""

try:
completion_kwargs = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.1,
}

response = completion(**completion_kwargs)
response_text = response.choices[0].message.content.strip()

citations = self._parse_citation_list(response_text)
logger.info(f"Found {len(citations)} citations for {parameter_type} item")
return citations[:2] # Return top 2 for individual items

except Exception as e:
logger.error(f"Error getting citations for parameter item: {e}")
return []

def _parse_citation_list(self, response_text: str) -> List[str]:
"""
Parse the citation list from the model response.
Expand Down
85 changes: 65 additions & 20 deletions src/study_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,19 @@ class ParameterWithCitations(BaseModel):
citations: Optional[List[str]] = None


class ParameterItemWithCitations(BaseModel):
"""Model for a single parameter item with its own citations"""

content: str
citations: Optional[List[str]] = None


class ParameterWithItemCitations(BaseModel):
"""Model for a parameter containing multiple items, each with their own citations"""

items: List[ParameterItemWithCitations]


def parse_bullets_to_list(text: str) -> List[str]:
"""Parse bulleted text into a list of strings."""
if not text or not text.strip():
Expand Down Expand Up @@ -43,9 +56,9 @@ def parse_bullets_to_list(text: str) -> List[str]:
class StudyParameters(BaseModel):
summary: ParameterWithCitations
study_type: ParameterWithCitations
participant_info: ParameterWithCitations
study_design: ParameterWithCitations
study_results: ParameterWithCitations
participant_info: ParameterWithItemCitations
study_design: ParameterWithItemCitations
study_results: ParameterWithItemCitations
allele_frequency: ParameterWithCitations
additional_resource_links: List[str]

Expand Down Expand Up @@ -147,14 +160,25 @@ def generate_all_parameters(self) -> StudyParameters:
"""Generate all study parameters using separate questions."""
logger.info(f"Extracting study parameters for {self.pmcid}")

participant_items = [
ParameterItemWithCitations(content=item)
for item in self.get_participant_info()
]
study_design_items = [
ParameterItemWithCitations(content=item)
for item in self.get_study_design()
]
study_results_items = [
ParameterItemWithCitations(content=item)
for item in self.get_study_results()
]

return StudyParameters(
summary=ParameterWithCitations(content=self.get_summary()),
study_type=ParameterWithCitations(content=self.get_study_type()),
participant_info=ParameterWithCitations(
content=self.get_participant_info()
),
study_design=ParameterWithCitations(content=self.get_study_design()),
study_results=ParameterWithCitations(content=self.get_study_results()),
participant_info=ParameterWithItemCitations(items=participant_items),
study_design=ParameterWithItemCitations(items=study_design_items),
study_results=ParameterWithItemCitations(items=study_results_items),
allele_frequency=ParameterWithCitations(
content=self.get_allele_frequency()
),
Expand Down Expand Up @@ -190,25 +214,46 @@ def test_study_parameters():
print(f" {study_parameters.study_type.content}")

print(f"\n👥 PARTICIPANT INFO:")
if isinstance(study_parameters.participant_info.content, list):
for i, item in enumerate(study_parameters.participant_info.content, 1):
print(f" • {item}")
if hasattr(study_parameters.participant_info, 'items'):
for i, item in enumerate(study_parameters.participant_info.items, 1):
print(f" • {item.content}")
if item.citations:
for j, citation in enumerate(item.citations, 1):
print(f" Citation {j}: {citation[:100]}...")
else:
print(f" {study_parameters.participant_info.content}")
if isinstance(study_parameters.participant_info.content, list):
for i, item in enumerate(study_parameters.participant_info.content, 1):
print(f" • {item}")
else:
print(f" {study_parameters.participant_info.content}")

print(f"\n🔬 STUDY DESIGN:")
if isinstance(study_parameters.study_design.content, list):
for i, item in enumerate(study_parameters.study_design.content, 1):
print(f" • {item}")
if hasattr(study_parameters.study_design, 'items'):
for i, item in enumerate(study_parameters.study_design.items, 1):
print(f" • {item.content}")
if item.citations:
for j, citation in enumerate(item.citations, 1):
print(f" Citation {j}: {citation[:100]}...")
else:
print(f" {study_parameters.study_design.content}")
if isinstance(study_parameters.study_design.content, list):
for i, item in enumerate(study_parameters.study_design.content, 1):
print(f" • {item}")
else:
print(f" {study_parameters.study_design.content}")

print(f"\n📊 STUDY RESULTS:")
if isinstance(study_parameters.study_results.content, list):
for i, item in enumerate(study_parameters.study_results.content, 1):
print(f" • {item}")
if hasattr(study_parameters.study_results, 'items'):
for i, item in enumerate(study_parameters.study_results.items, 1):
print(f" • {item.content}")
if item.citations:
for j, citation in enumerate(item.citations, 1):
print(f" Citation {j}: {citation[:100]}...")
else:
print(f" {study_parameters.study_results.content}")
if isinstance(study_parameters.study_results.content, list):
for i, item in enumerate(study_parameters.study_results.content, 1):
print(f" • {item}")
else:
print(f" {study_parameters.study_results.content}")

print(f"\n🧬 ALLELE FREQUENCY:")
if isinstance(study_parameters.allele_frequency.content, list):
Expand Down
Loading