From adc15524eef14ac3b4ed125bb0ef3b73ab25de1e Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sat, 16 Aug 2025 05:18:39 +0000 Subject: [PATCH] feat: implement item-level citations for participant_info, study_design, and study_results - Add ParameterItemWithCitations and ParameterWithItemCitations Pydantic models - Update StudyParameters to use item-level citations for target sections - Add get_study_parameter_item_citations method to OneShotCitations - Add _get_top_citations_for_parameter_item method to line_citation_generator - Update annotation pipeline to handle item-level citation generation - Maintain backward compatibility for other sections with section-level citations Co-Authored-By: Shlok Natarajan --- src/annotation_pipeline.py | 26 +++++--- src/citations/line_citation_generator.py | 74 ++++++++++++++++----- src/citations/one_shot_citations.py | 45 +++++++++++++ src/study_parameters.py | 85 ++++++++++++++++++------ 4 files changed, 183 insertions(+), 47 deletions(-) diff --git a/src/annotation_pipeline.py b/src/annotation_pipeline.py index 9543f61..befda54 100644 --- a/src/annotation_pipeline.py +++ b/src/annotation_pipeline.py @@ -51,17 +51,23 @@ def add_citations(self): f"Adding Citations to Study Parameters using OneShotCitations with model {self.citation_model}" ) for field_name in self.study_parameters.__class__.model_fields: - if ( - field_name != "additional_resource_links" - ): # Skip non-ParameterWithCitations field + if field_name != "additional_resource_links": param_content = getattr(self.study_parameters, field_name) - if hasattr(param_content, "content"): - citations = ( - self.one_shot_citations.get_study_parameter_citations( - field_name, - param_content.content, - model=self.citation_model, - ) + + if field_name in ["participant_info", "study_design", "study_results"]: + if hasattr(param_content, "items"): + for item in param_content.items: + citations = self.one_shot_citations.get_study_parameter_item_citations( + field_name, + item.content, + model=self.citation_model, + ) + item.citations = citations + elif hasattr(param_content, "content"): + citations = self.one_shot_citations.get_study_parameter_citations( + field_name, + param_content.content, + model=self.citation_model, ) param_content.citations = citations else: diff --git a/src/citations/line_citation_generator.py b/src/citations/line_citation_generator.py index 2926f60..03fc154 100644 --- a/src/citations/line_citation_generator.py +++ b/src/citations/line_citation_generator.py @@ -607,10 +607,58 @@ def _get_top_citations_for_parameter( return top_sentences + def _get_top_citations_for_parameter_item( + self, item_content: str, parameter_type: str, top_k: int = 2 + ) -> List[str]: + """ + Find the top K most relevant sentences for a specific study parameter item. + + Args: + item_content: The content of the specific item to find citations for + parameter_type: The type of parameter (participant_info, study_design, etc.) + top_k: Number of top sentences to return + + Returns: + List of top relevant sentences + """ + candidate_sentences = self.sentences + + logger.info( + f"Scoring all {len(candidate_sentences)} sentences for {parameter_type} item" + ) + + sentence_scores = [] + for sentence in candidate_sentences: + score = self._score_sentence_for_study_param( + sentence, item_content, parameter_type + ) + sentence_scores.append((sentence, score)) + + sentence_scores.sort(key=lambda x: x[1], reverse=True) + candidate_sentences = [item[0] for item in sentence_scores[: top_k * 3]] + filtered_sentences = self._remove_duplicates(candidate_sentences) + + top_sentences = filtered_sentences[:top_k] + + if len(top_sentences) < top_k: + remaining_needed = top_k - len(top_sentences) + for sentence, score in sentence_scores[top_k * 3 :]: + is_duplicate = any( + self._is_duplicate_citation(sentence, existing) + for existing in top_sentences + ) + if not is_duplicate: + top_sentences.append(sentence) + remaining_needed -= 1 + if remaining_needed == 0: + break + + return top_sentences + def add_citations_to_study_parameters(self, study_parameters): """ Add citations to study parameters by finding relevant sentences for each parameter. - Modifies the parameters to include citations nested within each parameter key. + Now handles item-level citations for participant_info, study_design, and study_results. Args: study_parameters: StudyParameters object @@ -620,10 +668,16 @@ def add_citations_to_study_parameters(self, study_parameters): """ logger.info("Adding citations to study parameters") - # Create a new study parameters object with citations updated_params = study_parameters.model_copy(deep=True) - # Add citations nested within each parameter + for field_name in ["participant_info", "study_design", "study_results"]: + param_obj = getattr(updated_params, field_name) + if hasattr(param_obj, 'items'): + for item in param_obj.items: + item.citations = self._get_top_citations_for_parameter_item( + item.content, field_name + ) + updated_params.summary.citations = self._get_top_citations_for_parameter( study_parameters.summary.content, "summary" ) @@ -632,20 +686,6 @@ def add_citations_to_study_parameters(self, study_parameters): study_parameters.study_type.content, "study_type" ) - updated_params.participant_info.citations = ( - self._get_top_citations_for_parameter( - study_parameters.participant_info.content, "participant_info" - ) - ) - - updated_params.study_design.citations = self._get_top_citations_for_parameter( - study_parameters.study_design.content, "study_design" - ) - - updated_params.study_results.citations = self._get_top_citations_for_parameter( - study_parameters.study_results.content, "study_results" - ) - updated_params.allele_frequency.citations = ( self._get_top_citations_for_parameter( study_parameters.allele_frequency.content, "allele_frequency" diff --git a/src/citations/one_shot_citations.py b/src/citations/one_shot_citations.py index d8ae1eb..06cb6bb 100644 --- a/src/citations/one_shot_citations.py +++ b/src/citations/one_shot_citations.py @@ -186,6 +186,51 @@ def get_study_parameter_citations( logger.error(f"Error getting citations for parameter: {e}") return [] + def get_study_parameter_item_citations( + self, parameter_type: str, item_content: str, model: str = "openai/gpt-4.1" + ) -> List[str]: + """ + Get citations for a single study parameter item using the whole article text. + + Args: + parameter_type: The type of parameter (participant_info, study_design, etc.) + item_content: The content of the specific item to find citations for + model: The language model to use for citation generation + + Returns: + List of top 2 most relevant sentences for this specific item + """ + prompt = f""" +Parameter Type: {parameter_type} +Specific Item Content: {item_content} + +From the following article text, find the top 2 sentences from the article that are most relevant to and support this specific item content. +Article text: +"{self.article_text}" + +If a table provides the support warranting of being in the top 2, return the table header (## Table X: ..., etc.) as your sentence. +Output the exact sentences from the article text in a numbered list with each sentence on a new line. No other text. +Keep in mind that headings are text/numbers preceded by hash symbols (#) and should not be included in citations unless referencing the table. Only include content sentences. +""" + + try: + completion_kwargs = { + "model": model, + "messages": [{"role": "user", "content": prompt}], + "temperature": 0.1, + } + + response = completion(**completion_kwargs) + response_text = response.choices[0].message.content.strip() + + citations = self._parse_citation_list(response_text) + logger.info(f"Found {len(citations)} citations for {parameter_type} item") + return citations[:2] # Return top 2 for individual items + + except Exception as e: + logger.error(f"Error getting citations for parameter item: {e}") + return [] + def _parse_citation_list(self, response_text: str) -> List[str]: """ Parse the citation list from the model response. diff --git a/src/study_parameters.py b/src/study_parameters.py index 7733707..10de3da 100644 --- a/src/study_parameters.py +++ b/src/study_parameters.py @@ -14,6 +14,19 @@ class ParameterWithCitations(BaseModel): citations: Optional[List[str]] = None +class ParameterItemWithCitations(BaseModel): + """Model for a single parameter item with its own citations""" + + content: str + citations: Optional[List[str]] = None + + +class ParameterWithItemCitations(BaseModel): + """Model for a parameter containing multiple items, each with their own citations""" + + items: List[ParameterItemWithCitations] + + def parse_bullets_to_list(text: str) -> List[str]: """Parse bulleted text into a list of strings.""" if not text or not text.strip(): @@ -43,9 +56,9 @@ def parse_bullets_to_list(text: str) -> List[str]: class StudyParameters(BaseModel): summary: ParameterWithCitations study_type: ParameterWithCitations - participant_info: ParameterWithCitations - study_design: ParameterWithCitations - study_results: ParameterWithCitations + participant_info: ParameterWithItemCitations + study_design: ParameterWithItemCitations + study_results: ParameterWithItemCitations allele_frequency: ParameterWithCitations additional_resource_links: List[str] @@ -147,14 +160,25 @@ def generate_all_parameters(self) -> StudyParameters: """Generate all study parameters using separate questions.""" logger.info(f"Extracting study parameters for {self.pmcid}") + participant_items = [ + ParameterItemWithCitations(content=item) + for item in self.get_participant_info() + ] + study_design_items = [ + ParameterItemWithCitations(content=item) + for item in self.get_study_design() + ] + study_results_items = [ + ParameterItemWithCitations(content=item) + for item in self.get_study_results() + ] + return StudyParameters( summary=ParameterWithCitations(content=self.get_summary()), study_type=ParameterWithCitations(content=self.get_study_type()), - participant_info=ParameterWithCitations( - content=self.get_participant_info() - ), - study_design=ParameterWithCitations(content=self.get_study_design()), - study_results=ParameterWithCitations(content=self.get_study_results()), + participant_info=ParameterWithItemCitations(items=participant_items), + study_design=ParameterWithItemCitations(items=study_design_items), + study_results=ParameterWithItemCitations(items=study_results_items), allele_frequency=ParameterWithCitations( content=self.get_allele_frequency() ), @@ -190,25 +214,46 @@ def test_study_parameters(): print(f" {study_parameters.study_type.content}") print(f"\n👥 PARTICIPANT INFO:") - if isinstance(study_parameters.participant_info.content, list): - for i, item in enumerate(study_parameters.participant_info.content, 1): - print(f" • {item}") + if hasattr(study_parameters.participant_info, 'items'): + for i, item in enumerate(study_parameters.participant_info.items, 1): + print(f" • {item.content}") + if item.citations: + for j, citation in enumerate(item.citations, 1): + print(f" Citation {j}: {citation[:100]}...") else: - print(f" {study_parameters.participant_info.content}") + if isinstance(study_parameters.participant_info.content, list): + for i, item in enumerate(study_parameters.participant_info.content, 1): + print(f" • {item}") + else: + print(f" {study_parameters.participant_info.content}") print(f"\n🔬 STUDY DESIGN:") - if isinstance(study_parameters.study_design.content, list): - for i, item in enumerate(study_parameters.study_design.content, 1): - print(f" • {item}") + if hasattr(study_parameters.study_design, 'items'): + for i, item in enumerate(study_parameters.study_design.items, 1): + print(f" • {item.content}") + if item.citations: + for j, citation in enumerate(item.citations, 1): + print(f" Citation {j}: {citation[:100]}...") else: - print(f" {study_parameters.study_design.content}") + if isinstance(study_parameters.study_design.content, list): + for i, item in enumerate(study_parameters.study_design.content, 1): + print(f" • {item}") + else: + print(f" {study_parameters.study_design.content}") print(f"\n📊 STUDY RESULTS:") - if isinstance(study_parameters.study_results.content, list): - for i, item in enumerate(study_parameters.study_results.content, 1): - print(f" • {item}") + if hasattr(study_parameters.study_results, 'items'): + for i, item in enumerate(study_parameters.study_results.items, 1): + print(f" • {item.content}") + if item.citations: + for j, citation in enumerate(item.citations, 1): + print(f" Citation {j}: {citation[:100]}...") else: - print(f" {study_parameters.study_results.content}") + if isinstance(study_parameters.study_results.content, list): + for i, item in enumerate(study_parameters.study_results.content, 1): + print(f" • {item}") + else: + print(f" {study_parameters.study_results.content}") print(f"\n🧬 ALLELE FREQUENCY:") if isinstance(study_parameters.allele_frequency.content, list):