44import openai
55from logging import INFO , Logger
66
7+ from typing import List
8+
79from langchain_anthropic import ChatAnthropic
810from langchain_community .chat_models .deepinfra import ChatDeepInfraException
911from langchain_core .messages import HumanMessage
1719
1820
1921async def invoke_model (prompt , model , semaphore ):
22+ """
23+ Asynchronously invoke a language model with retry logic.
24+
25+ Args:
26+ prompt (str): The input prompt for the model.
27+ model: The language model to invoke.
28+ semaphore (asyncio.Semaphore): Semaphore to limit concurrent calls.
29+
30+ Returns:
31+ str: The model's response content.
32+
33+ Raises:
34+ ChatDeepInfraException: If all retry attempts fail.
35+ """
2036 async with semaphore :
2137 max_retries = 100
2238 delay = 3
@@ -33,7 +49,30 @@ async def invoke_model(prompt, model, semaphore):
3349
3450
3551class APILLM :
52+ """
53+ A class to interface with various language models through their respective APIs.
54+
55+ This class supports Claude (Anthropic), GPT (OpenAI), and LLaMA (DeepInfra) models.
56+ It handles API key management, model initialization, and provides methods for
57+ both synchronous and asynchronous inference.
58+
59+ Attributes:
60+ model: The initialized language model instance.
61+
62+ Methods:
63+ get_response: Synchronously get responses for a list of prompts.
64+ _get_response: Asynchronously get responses for a list of prompts.
65+ """
3666 def __init__ (self , model_id : str ):
67+ """
68+ Initialize the APILLM with a specific model.
69+
70+ Args:
71+ model_id (str): Identifier for the model to use.
72+
73+ Raises:
74+ ValueError: If an unknown model identifier is provided.
75+ """
3776 if "claude" in model_id :
3877 ANTHROPIC_API_KEY = open ("anthropictoken.txt" , "r" ).read ()
3978 self .model = ChatAnthropic (model = model_id , api_key = ANTHROPIC_API_KEY )
@@ -46,7 +85,21 @@ def __init__(self, model_id: str):
4685 else :
4786 raise ValueError (f"Unknown model: { model_id } " )
4887
49- def get_response (self , prompts : list [str ]) -> list [str ]:
88+ def get_response (self , prompts : List [str ]) -> List [str ]:
89+ """
90+ Synchronously get responses for a list of prompts.
91+
92+ This method includes retry logic for handling connection errors and rate limits.
93+
94+ Args:
95+ prompts (list[str]): List of input prompts.
96+
97+ Returns:
98+ list[str]: List of model responses.
99+
100+ Raises:
101+ requests.exceptions.ConnectionError: If max retries are exceeded.
102+ """
50103 max_retries = 100
51104 delay = 3
52105 attempts = 0
@@ -74,6 +127,18 @@ def get_response(self, prompts: list[str]) -> list[str]:
74127 async def _get_response (
75128 self , prompts : list [str ], max_concurrent_calls = 200
76129 ) -> list [str ]: # TODO change name of method
130+ """
131+ Asynchronously get responses for a list of prompts.
132+
133+ This method uses a semaphore to limit the number of concurrent API calls.
134+
135+ Args:
136+ prompts (list[str]): List of input prompts.
137+ max_concurrent_calls (int): Maximum number of concurrent API calls allowed.
138+
139+ Returns:
140+ list[str]: List of model responses.
141+ """
77142 semaphore = asyncio .Semaphore (max_concurrent_calls ) # Limit the number of concurrent calls
78143 tasks = []
79144
0 commit comments