Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions predictionguard/src/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import urllib.request
import urllib.parse
import uuid
from warnings import warn

from ..version import __version__

Expand Down Expand Up @@ -47,7 +48,7 @@ class Chat:
]

result = client.chat.completions.create(
model="Hermes-2-Pro-Llama-3-8B", messages=messages, max_tokens=500
model="Hermes-2-Pro-Llama-3-8B", messages=messages, max_completion_tokens=500
)

print(json.dumps(result, sort_keys=True, indent=4, separators=(",", ": ")))
Expand All @@ -71,7 +72,8 @@ def create(
messages: Union[str, List[Dict[str, Any]]],
input: Optional[Dict[str, Any]] = None,
output: Optional[Dict[str, Any]] = None,
max_tokens: Optional[int] = 100,
max_completion_tokens: Optional[int] = 100,
max_tokens: Optional[int] = None,
temperature: Optional[float] = 1.0,
top_p: Optional[float] = 0.99,
top_k: Optional[float] = 50,
Expand All @@ -84,22 +86,31 @@ def create(
:param messages: The content of the call, an array of dictionaries containing a role and content.
:param input: A dictionary containing the PII and injection arguments.
:param output: A dictionary containing the consistency, factuality, and toxicity arguments.
:param max_tokens: The maximum amount of tokens the model should return.
:param max_completion_tokens: The maximum amount of tokens the model should return.
:param temperature: The consistency of the model responses to the same prompt. The higher the more consistent.
:param top_p: The sampling for the model to use.
:param top_k: The Top-K sampling for the model to use.
:param stream: Option to stream the API response
:return: A dictionary containing the chat response.
"""

# Handling max_tokens and returning deprecation message
if max_tokens is not None:
max_completion_tokens = max_tokens
warn("""
The max_tokens argument is deprecated.
Please use max_completion_tokens instead.
""", DeprecationWarning, stacklevel=2
)

# Create a list of tuples, each containing all the parameters for
# a call to _generate_chat
args = (
model,
messages,
input,
output,
max_tokens,
max_completion_tokens,
temperature,
top_p,
top_k,
Expand All @@ -117,7 +128,7 @@ def _generate_chat(
messages,
input,
output,
max_tokens,
max_completion_tokens,
temperature,
top_p,
top_k,
Expand All @@ -142,7 +153,7 @@ def return_dict(url, headers, payload):
)
else:
# Check if there is a json body in the response. Read that in,
# print out the error field in the json body, and raise an exception.
# then print out the error field in the json body, and raise an exception.
err = ""
try:
err = response.json()["error"]
Expand Down Expand Up @@ -246,7 +257,7 @@ def stream_generator(url, headers, payload, stream):
payload_dict = {
"model": model,
"messages": messages,
"max_tokens": max_tokens,
"max_completion_tokens": max_completion_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
Expand Down Expand Up @@ -292,4 +303,4 @@ def list_models(self, capability: Optional[str] = "chat-completion") -> List[str
for model in response.json()["data"]:
response_list.append(model["id"])

return response_list
return response_list
21 changes: 16 additions & 5 deletions predictionguard/src/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import requests
from typing import Any, Dict, List, Optional, Union
from warnings import warn

from ..version import __version__

Expand All @@ -21,7 +22,8 @@ def create(
prompt: Union[str, List[str]],
input: Optional[Dict[str, Any]] = None,
output: Optional[Dict[str, Any]] = None,
max_tokens: Optional[int] = 100,
max_completion_tokens: Optional[int] = 100,
max_tokens: Optional[int] = None,
temperature: Optional[float] = 1.0,
top_p: Optional[float] = 0.99,
top_k: Optional[int] = 50
Expand All @@ -33,16 +35,25 @@ def create(
:param prompt: The prompt(s) to generate completions for.
:param input: A dictionary containing the PII and injection arguments.
:param output: A dictionary containing the consistency, factuality, and toxicity arguments.
:param max_tokens: The maximum number of tokens to generate in the completion(s).
:param max_completion_tokens: The maximum number of tokens to generate in the completion(s).
:param temperature: The sampling temperature to use.
:param top_p: The nucleus sampling probability to use.
:param top_k: The Top-K sampling for the model to use.
:return: A dictionary containing the completion response.
"""

# Handling max_tokens and returning deprecation message
if max_tokens is not None:
max_completion_tokens = max_tokens
warn("""
The max_tokens argument is deprecated.
Please use max_completion_tokens instead.
""", DeprecationWarning, stacklevel=2
)

# Create a list of tuples, each containing all the parameters for
# a call to _generate_completion
args = (model, prompt, input, output, max_tokens, temperature, top_p, top_k)
args = (model, prompt, input, output, max_completion_tokens, temperature, top_p, top_k)

# Run _generate_completion
choices = self._generate_completion(*args)
Expand All @@ -51,7 +62,7 @@ def create(

def _generate_completion(
self, model, prompt,
input, output, max_tokens,
input, output, max_completion_tokens,
temperature, top_p, top_k
):
"""
Expand All @@ -68,7 +79,7 @@ def _generate_completion(
payload_dict = {
"model": model,
"prompt": prompt,
"max_tokens": max_tokens,
"max_completion_tokens": max_completion_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k
Expand Down
2 changes: 1 addition & 1 deletion predictionguard/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Setting the package version
__version__ = "2.7.0"
__version__ = "2.7.1"
Loading