diff --git a/pageindex/config.yaml b/pageindex/config.yaml index fd73e3a..e2771a9 100644 --- a/pageindex/config.yaml +++ b/pageindex/config.yaml @@ -1,3 +1,4 @@ +provider: "openai" # "openai" or "gemini" model: "gpt-4o-2024-11-20" toc_check_page_num: 20 max_page_num_each_node: 10 diff --git a/pageindex/page_index.py b/pageindex/page_index.py index 882fb5d..82d560b 100644 --- a/pageindex/page_index.py +++ b/pageindex/page_index.py @@ -155,7 +155,7 @@ def check_if_toc_transformation_is_complete(content, toc, model=None): prompt = prompt + '\n Raw Table of contents:\n' + content + '\n Cleaned Table of contents:\n' + toc response = ChatGPT_API(model=model, prompt=prompt) json_content = extract_json(response) - return json_content['completed'] + return json_content.get('completed', 'no') def extract_toc_content(content, model=None): prompt = f""" @@ -289,7 +289,13 @@ def toc_transformer(toc_content, model=None): Directly return the final JSON structure, do not output anything else. """ prompt = init_prompt + '\n Given table of contents\n:' + toc_content - last_complete, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt) + + response_schema = None + if GOOGLE_GENAI_AVAILABLE and LLM_PROVIDER == "gemini": + from pageindex.utils import TocStructure + response_schema = TocStructure + + last_complete, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt, response_schema=response_schema) if_complete = check_if_toc_transformation_is_complete(toc_content, last_complete, model) if if_complete == "yes" and finish_reason == "finished": last_complete = extract_json(last_complete) @@ -313,7 +319,7 @@ def toc_transformer(toc_content, model=None): Please continue the json structure, directly output the remaining part of the json structure.""" - new_complete, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt) + new_complete, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt, response_schema=response_schema) if new_complete.startswith('```json'): new_complete = get_json_content(new_complete) @@ -496,7 +502,7 @@ def remove_first_physical_index_section(text): return text ### add verify completeness -def generate_toc_continue(toc_content, part, model="gpt-4o-2024-11-20"): +def generate_toc_continue(toc_content, part, model=None): print('start generate_toc_continue') prompt = """ You are an expert in extracting hierarchical tree structure. @@ -729,7 +735,7 @@ def check_toc(page_list, opt=None): ################### fix incorrect toc ######################################################### -def single_toc_item_index_fixer(section_title, content, model="gpt-4o-2024-11-20"): +def single_toc_item_index_fixer(section_title, content, model=None): tob_extractor_prompt = """ You are given a section title and several pages of a document, your job is to find the physical index of the start page of the section in the partial document. @@ -1066,7 +1072,7 @@ def page_index_main(doc, opt=None): raise ValueError("Unsupported input type. Expected a PDF file path or BytesIO object.") print('Parsing PDF...') - page_list = get_page_tokens(doc) + page_list = get_page_tokens(doc, model=opt.model) logger.info({'total_page_number': len(page_list)}) logger.info({'total_token': sum([page[1] for page in page_list])}) diff --git a/pageindex/utils.py b/pageindex/utils.py index dc7acd8..79ad4b9 100644 --- a/pageindex/utils.py +++ b/pageindex/utils.py @@ -18,91 +18,287 @@ from types import SimpleNamespace as config CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY") +GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") +LLM_PROVIDER = os.getenv("LLM_PROVIDER", "openai") + +try: + from google import genai as google_genai + from pydantic import BaseModel, Field + GOOGLE_GENAI_AVAILABLE = True +except ImportError: + GOOGLE_GENAI_AVAILABLE = False + +class TocItem(BaseModel): + structure: str | None = Field(description="Structure index like '1', '1.1', '1.2' or None") + title: str = Field(description="Title of the section") + page: str | None = Field(description="Page number or None") + +class TocStructure(BaseModel): + table_of_contents: list[TocItem] = Field(description="List of table of contents items") + +class LLMProvider: + """Abstraction layer for different LLM providers""" + + def __init__(self, provider_name=None, model=None, api_key=None): + self.provider = provider_name or LLM_PROVIDER + self.model = model or self._get_default_model() + self.api_key = api_key or self._get_api_key() + self.client = self._initialize_client() + + def _get_default_model(self): + if self.provider == "openai": + return "gpt-4o-2024-11-20" + elif self.provider == "gemini": + return "gemini-2.5-flash-lite" + else: + raise ValueError(f"Unsupported provider: {self.provider}") + + def _get_api_key(self): + if self.provider == "openai": + return CHATGPT_API_KEY + elif self.provider == "gemini": + return GEMINI_API_KEY + else: + raise ValueError(f"Unsupported provider: {self.provider}") + + def _initialize_client(self): + if self.provider == "openai": + return openai.OpenAI(api_key=self.api_key) + elif self.provider == "gemini": + if not GOOGLE_GENAI_AVAILABLE: + raise ImportError("google-genai not installed. Install with: pip install google-genai") + return google_genai.Client(api_key=self.api_key) + else: + raise ValueError(f"Unsupported provider: {self.provider}") + + def generate_content(self, contents, config=None, response_schema=None, **kwargs): + """Unified interface for generating content across providers""" + if self.provider == "openai": + return self._openai_generate_content(contents, config, **kwargs) + elif self.provider == "gemini": + return self._gemini_generate_content(contents, config, response_schema, **kwargs) + else: + raise ValueError(f"Unsupported provider: {self.provider}") + + def count_tokens(self, contents): + """Unified interface for counting tokens across providers""" + if self.provider == "openai": + return self._openai_count_tokens(contents) + elif self.provider == "gemini": + return self._gemini_count_tokens(contents) + else: + raise ValueError(f"Unsupported provider: {self.provider}") + + def _openai_generate_content(self, contents, config=None, **kwargs): + """OpenAI content generation""" + messages = [{"role": "user", "content": contents}] + request_config = { + "model": self.model, + "messages": messages, + "temperature": 0, + } + + if config: + if hasattr(config, 'temperature') and config.temperature is not None: + request_config["temperature"] = config.temperature + if hasattr(config, 'max_output_tokens') and config.max_output_tokens is not None: + request_config["max_tokens"] = config.max_output_tokens + if hasattr(config, 'top_p') and config.top_p is not None: + request_config["top_p"] = config.top_p + if hasattr(config, 'stop_sequences') and config.stop_sequences: + request_config["stop"] = config.stop_sequences + + response = self.client.chat.completions.create(**request_config) + return type('Response', (), { + 'text': response.choices[0].message.content, + 'usage_metadata': type('Usage', (), { + 'total_token_count': response.usage.total_tokens if response.usage else None + })() + })() + + def _gemini_generate_content(self, contents, config=None, response_schema=None, **kwargs): + """Gemini content generation""" + request_config = { + "model": self.model, + "contents": contents + } + + # Handle config for structured output + if config: + if response_schema is not None: + # For structured output, merge response_schema into config + config_dict = config if isinstance(config, dict) else {} + if isinstance(config, dict): + config_dict.update(config) + config_dict["response_schema"] = response_schema + request_config["config"] = config_dict + else: + request_config["config"] = config + elif response_schema is not None: + # Only response_schema provided + request_config["config"] = {"response_schema": response_schema} + + response = self.client.models.generate_content(**request_config) + return response + + def _openai_count_tokens(self, contents): + """OpenAI token counting using tiktoken""" + try: + import tiktoken + enc = tiktoken.encoding_for_model(self.model or "gpt-4o-2024-11-20") + return len(enc.encode(contents)) + except Exception: + return len(contents) // 4 + + def _gemini_count_tokens(self, contents): + """Gemini token counting using official API""" + try: + response = self.client.models.count_tokens( + model=self.model, + contents=contents + ) + return response.total_tokens + except Exception as e: + return len(contents) // 4 def count_tokens(text, model=None): if not text: return 0 - enc = tiktoken.encoding_for_model(model) - tokens = enc.encode(text) - return len(tokens) -def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): + if model is None: + model = "gpt-4o-2024-11-20" if LLM_PROVIDER == "openai" else "gemini-2.5-flash-lite" + + try: + provider = LLMProvider(provider_name=LLM_PROVIDER, model=model) + return provider.count_tokens(text) + except Exception: + return len(text) // 4 + +def ChatGPT_API_with_finish_reason(model, prompt, api_key=None, chat_history=None, provider=None, response_schema=None): + provider = provider or LLM_PROVIDER max_retries = 10 - client = openai.OpenAI(api_key=api_key) + for i in range(max_retries): try: - if chat_history: - messages = chat_history - messages.append({"role": "user", "content": prompt}) - else: - messages = [{"role": "user", "content": prompt}] - - response = client.chat.completions.create( - model=model, - messages=messages, - temperature=0, - ) - if response.choices[0].finish_reason == "length": - return response.choices[0].message.content, "max_output_reached" - else: - return response.choices[0].message.content, "finished" + llm_provider = LLMProvider(provider_name=provider, model=model, api_key=api_key) + + if provider == "openai": + contents = prompt + if chat_history: + contents = "" + for msg in chat_history: + if msg["role"] == "system": + contents += f"System: {msg['content']}\n" + elif msg["role"] == "user": + contents += f"User: {msg['content']}\n" + elif msg["role"] == "assistant": + contents += f"Assistant: {msg['content']}\n" + contents += f"User: {prompt}\n" + + response = llm_provider.generate_content(contents) + return response.text, "finished" + + elif provider == "gemini": + contents = prompt + if chat_history: + contents = "" + for msg in chat_history: + contents += f"{msg['role']}: {msg['content']}\n" + contents += f"user: {prompt}\n" + + config = None + if response_schema is not None: + config = { + "response_mime_type": "application/json", + "response_schema": response_schema + } + + response = llm_provider.generate_content(contents, config=config, response_schema=response_schema) + return response.text, "finished" except Exception as e: print('************* Retrying *************') logging.error(f"Error: {e}") if i < max_retries - 1: - time.sleep(1) # Wait for 1秒 before retrying + time.sleep(1) else: logging.error('Max retries reached for prompt: ' + prompt) return "Error" -def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): +def ChatGPT_API(model, prompt, api_key=None, chat_history=None, provider=None): + provider = provider or LLM_PROVIDER max_retries = 10 - client = openai.OpenAI(api_key=api_key) + for i in range(max_retries): try: - if chat_history: - messages = chat_history - messages.append({"role": "user", "content": prompt}) - else: - messages = [{"role": "user", "content": prompt}] - - response = client.chat.completions.create( - model=model, - messages=messages, - temperature=0, - ) - - return response.choices[0].message.content + llm_provider = LLMProvider(provider_name=provider, model=model, api_key=api_key) + + if provider == "openai": + contents = prompt + if chat_history: + contents = "" + for msg in chat_history: + if msg["role"] == "system": + contents += f"System: {msg['content']}\n" + elif msg["role"] == "user": + contents += f"User: {msg['content']}\n" + elif msg["role"] == "assistant": + contents += f"Assistant: {msg['content']}\n" + contents += f"User: {prompt}\n" + + response = llm_provider.generate_content(contents) + return response.text + + elif provider == "gemini": + contents = prompt + if chat_history: + contents = "" + for msg in chat_history: + contents += f"{msg['role']}: {msg['content']}\n" + contents += f"user: {prompt}\n" + + response = llm_provider.generate_content(contents) + return response.text + except Exception as e: print('************* Retrying *************') logging.error(f"Error: {e}") if i < max_retries - 1: - time.sleep(1) # Wait for 1秒 before retrying + time.sleep(1) else: logging.error('Max retries reached for prompt: ' + prompt) return "Error" -async def ChatGPT_API_async(model, prompt, api_key=CHATGPT_API_KEY): +async def ChatGPT_API_async(model, prompt, api_key=None, provider=None): + provider = provider or LLM_PROVIDER max_retries = 10 - messages = [{"role": "user", "content": prompt}] + for i in range(max_retries): try: - async with openai.AsyncOpenAI(api_key=api_key) as client: - response = await client.chat.completions.create( - model=model, - messages=messages, - temperature=0, - ) - return response.choices[0].message.content + if provider == "openai": + api_key = api_key or CHATGPT_API_KEY + messages = [{"role": "user", "content": prompt}] + async with openai.AsyncOpenAI(api_key=api_key) as client: + response = await client.chat.completions.create( + model=model, + messages=messages, + temperature=0, + ) + return response.choices[0].message.content + + elif provider == "gemini": + llm_provider = LLMProvider(provider_name=provider, model=model, api_key=api_key) + response = llm_provider.generate_content(prompt) + return response.text + except Exception as e: print('************* Retrying *************') logging.error(f"Error: {e}") if i < max_retries - 1: - await asyncio.sleep(1) # Wait for 1s before retrying + await asyncio.sleep(1) else: logging.error('Max retries reached for prompt: ' + prompt) return "Error" @@ -410,15 +606,14 @@ def add_preface_if_needed(data): -def get_page_tokens(pdf_path, model="gpt-4o-2024-11-20", pdf_parser="PyPDF2"): - enc = tiktoken.encoding_for_model(model) +def get_page_tokens(pdf_path, model=None, pdf_parser="PyPDF2"): if pdf_parser == "PyPDF2": pdf_reader = PyPDF2.PdfReader(pdf_path) page_list = [] for page_num in range(len(pdf_reader.pages)): page = pdf_reader.pages[page_num] page_text = page.extract_text() - token_length = len(enc.encode(page_text)) + token_length = count_tokens(page_text, model) page_list.append((page_text, token_length)) return page_list elif pdf_parser == "PyMuPDF": @@ -430,7 +625,7 @@ def get_page_tokens(pdf_path, model="gpt-4o-2024-11-20", pdf_parser="PyPDF2"): page_list = [] for page in doc: page_text = page.get_text() - token_length = len(enc.encode(page_text)) + token_length = count_tokens(page_text, model) page_list.append((page_text, token_length)) return page_list else: diff --git a/run_pageindex.py b/run_pageindex.py index 1070245..13273da 100644 --- a/run_pageindex.py +++ b/run_pageindex.py @@ -10,7 +10,8 @@ parser.add_argument('--pdf_path', type=str, help='Path to the PDF file') parser.add_argument('--md_path', type=str, help='Path to the Markdown file') - parser.add_argument('--model', type=str, default='gpt-4o-2024-11-20', help='Model to use') + parser.add_argument('--provider', type=str, default='openai', choices=['openai', 'gemini'], help='LLM provider to use') + parser.add_argument('--model', type=str, default=None, help='Model to use (defaults based on provider)') parser.add_argument('--toc-check-pages', type=int, default=20, help='Number of pages to check for table of contents (PDF only)') @@ -51,8 +52,12 @@ raise ValueError(f"PDF file not found: {args.pdf_path}") # Process PDF file + if args.model is None: + args.model = "gpt-4o-2024-11-20" if args.provider == "openai" else "gemini-2.5-flash-lite" + # Configure options opt = config( + provider=args.provider, model=args.model, toc_check_page_num=args.toc_check_pages, max_page_num_each_node=args.max_pages_per_node, @@ -95,8 +100,12 @@ from pageindex.utils import ConfigLoader config_loader = ConfigLoader() + if args.model is None: + args.model = "gpt-4o-2024-11-20" if args.provider == "openai" else "gemini-2.5-flash-lite" + # Create options dict with user args user_opt = { + 'provider': args.provider, 'model': args.model, 'if_add_node_summary': args.if_add_node_summary, 'if_add_doc_description': args.if_add_doc_description,