diff --git a/predictionguard/src/audio.py b/predictionguard/src/audio.py index a7348c3..dce5c46 100644 --- a/predictionguard/src/audio.py +++ b/predictionguard/src/audio.py @@ -42,25 +42,31 @@ def __init__(self, api_key, url): def create( self, model: str, - file: str + file: str, + language: Optional[str] = "auto", + temperature: Optional[float] = 0.0, + prompt: Optional[str] = "", ) -> Dict[str, Any]: """ Creates a audio transcription request to the Prediction Guard /audio/transcriptions API :param model: The model to use :param file: Audio file to be transcribed + :param language: The language of the audio file + :param temperature: The temperature parameter for model transcription + :param prompt: A prompt to assist in transcription styling :result: A dictionary containing the transcribed text. """ # Create a list of tuples, each containing all the parameters for # a call to _transcribe_audio - args = (model, file) + args = (model, file, language, temperature, prompt) # Run _transcribe_audio choices = self._transcribe_audio(*args) return choices - def _transcribe_audio(self, model, file): + def _transcribe_audio(self, model, file, language, temperature, prompt): """ Function to transcribe an audio file. """ @@ -72,7 +78,12 @@ def _transcribe_audio(self, model, file): with open(file, "rb") as audio_file: files = {"file": (file, audio_file, "audio/wav")} - data = {"model": model} + data = { + "model": model, + "language": language, + "temperature": temperature, + "prompt": prompt, + } response = requests.request( "POST", self.url + "/audio/transcriptions", headers=headers, files=files, data=data diff --git a/predictionguard/src/chat.py b/predictionguard/src/chat.py index 753fbe2..02cf127 100644 --- a/predictionguard/src/chat.py +++ b/predictionguard/src/chat.py @@ -95,7 +95,7 @@ def create( str, Dict[ str, Dict[str, str] ] - ]] = "none", + ]] = None, tools: Optional[List[Dict[str, Union[str, Dict[str, str]]]]] = None, top_p: Optional[float] = 0.99, top_k: Optional[float] = 50, @@ -296,22 +296,40 @@ def stream_generator(url, headers, payload, stream): elif entry["type"] == "text": continue - payload_dict = { - "model": model, - "messages": messages, - "frequency_penalty": frequency_penalty, - "logit_bias": logit_bias, - "max_completion_tokens": max_completion_tokens, - "parallel_tool_calls": parallel_tool_calls, - "presence_penalty": presence_penalty, - "stop": stop, - "stream": stream, - "temperature": temperature, - "tool_choice": tool_choice, - "tools": tools, - "top_p": top_p, - "top_k": top_k, - } + # TODO: Remove `tool_choice` check when null value available in API + if tool_choice is None: + payload_dict = { + "model": model, + "messages": messages, + "frequency_penalty": frequency_penalty, + "logit_bias": logit_bias, + "max_completion_tokens": max_completion_tokens, + "parallel_tool_calls": parallel_tool_calls, + "presence_penalty": presence_penalty, + "stop": stop, + "stream": stream, + "temperature": temperature, + "tools": tools, + "top_p": top_p, + "top_k": top_k, + } + else: + payload_dict = { + "model": model, + "messages": messages, + "frequency_penalty": frequency_penalty, + "logit_bias": logit_bias, + "max_completion_tokens": max_completion_tokens, + "parallel_tool_calls": parallel_tool_calls, + "presence_penalty": presence_penalty, + "stop": stop, + "stream": stream, + "temperature": temperature, + "tool_choice": tool_choice, + "tools": tools, + "top_p": top_p, + "top_k": top_k, + } if input: payload_dict["input"] = input diff --git a/predictionguard/version.py b/predictionguard/version.py index c77fef4..5c1c3b7 100644 --- a/predictionguard/version.py +++ b/predictionguard/version.py @@ -1,2 +1,2 @@ # Setting the package version -__version__ = "2.8.0" +__version__ = "2.8.1"