Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions predictionguard/src/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,25 +42,31 @@ def __init__(self, api_key, url):
def create(
self,
model: str,
file: str
file: str,
language: Optional[str] = "auto",
temperature: Optional[float] = 0.0,
prompt: Optional[str] = "",
) -> Dict[str, Any]:
"""
Creates a audio transcription request to the Prediction Guard /audio/transcriptions API

:param model: The model to use
:param file: Audio file to be transcribed
:param language: The language of the audio file
:param temperature: The temperature parameter for model transcription
:param prompt: A prompt to assist in transcription styling
:result: A dictionary containing the transcribed text.
"""

# Create a list of tuples, each containing all the parameters for
# a call to _transcribe_audio
args = (model, file)
args = (model, file, language, temperature, prompt)

# Run _transcribe_audio
choices = self._transcribe_audio(*args)
return choices

def _transcribe_audio(self, model, file):
def _transcribe_audio(self, model, file, language, temperature, prompt):
"""
Function to transcribe an audio file.
"""
Expand All @@ -72,7 +78,12 @@ def _transcribe_audio(self, model, file):

with open(file, "rb") as audio_file:
files = {"file": (file, audio_file, "audio/wav")}
data = {"model": model}
data = {
"model": model,
"language": language,
"temperature": temperature,
"prompt": prompt,
}

response = requests.request(
"POST", self.url + "/audio/transcriptions", headers=headers, files=files, data=data
Expand Down
52 changes: 35 additions & 17 deletions predictionguard/src/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def create(
str, Dict[
str, Dict[str, str]
]
]] = "none",
]] = None,
tools: Optional[List[Dict[str, Union[str, Dict[str, str]]]]] = None,
top_p: Optional[float] = 0.99,
top_k: Optional[float] = 50,
Expand Down Expand Up @@ -296,22 +296,40 @@ def stream_generator(url, headers, payload, stream):
elif entry["type"] == "text":
continue

payload_dict = {
"model": model,
"messages": messages,
"frequency_penalty": frequency_penalty,
"logit_bias": logit_bias,
"max_completion_tokens": max_completion_tokens,
"parallel_tool_calls": parallel_tool_calls,
"presence_penalty": presence_penalty,
"stop": stop,
"stream": stream,
"temperature": temperature,
"tool_choice": tool_choice,
"tools": tools,
"top_p": top_p,
"top_k": top_k,
}
# TODO: Remove `tool_choice` check when null value available in API
if tool_choice is None:
payload_dict = {
"model": model,
"messages": messages,
"frequency_penalty": frequency_penalty,
"logit_bias": logit_bias,
"max_completion_tokens": max_completion_tokens,
"parallel_tool_calls": parallel_tool_calls,
"presence_penalty": presence_penalty,
"stop": stop,
"stream": stream,
"temperature": temperature,
"tools": tools,
"top_p": top_p,
"top_k": top_k,
}
else:
payload_dict = {
"model": model,
"messages": messages,
"frequency_penalty": frequency_penalty,
"logit_bias": logit_bias,
"max_completion_tokens": max_completion_tokens,
"parallel_tool_calls": parallel_tool_calls,
"presence_penalty": presence_penalty,
"stop": stop,
"stream": stream,
"temperature": temperature,
"tool_choice": tool_choice,
"tools": tools,
"top_p": top_p,
"top_k": top_k,
}

if input:
payload_dict["input"] = input
Expand Down
2 changes: 1 addition & 1 deletion predictionguard/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Setting the package version
__version__ = "2.8.0"
__version__ = "2.8.1"
Loading