diff --git a/predictionguard/src/chat.py b/predictionguard/src/chat.py index 12532f8..332c91e 100644 --- a/predictionguard/src/chat.py +++ b/predictionguard/src/chat.py @@ -8,6 +8,7 @@ import urllib.request import urllib.parse import uuid +from warnings import warn from ..version import __version__ @@ -47,7 +48,7 @@ class Chat: ] result = client.chat.completions.create( - model="Hermes-2-Pro-Llama-3-8B", messages=messages, max_tokens=500 + model="Hermes-2-Pro-Llama-3-8B", messages=messages, max_completion_tokens=500 ) print(json.dumps(result, sort_keys=True, indent=4, separators=(",", ": "))) @@ -71,7 +72,8 @@ def create( messages: Union[str, List[Dict[str, Any]]], input: Optional[Dict[str, Any]] = None, output: Optional[Dict[str, Any]] = None, - max_tokens: Optional[int] = 100, + max_completion_tokens: Optional[int] = 100, + max_tokens: Optional[int] = None, temperature: Optional[float] = 1.0, top_p: Optional[float] = 0.99, top_k: Optional[float] = 50, @@ -84,7 +86,7 @@ def create( :param messages: The content of the call, an array of dictionaries containing a role and content. :param input: A dictionary containing the PII and injection arguments. :param output: A dictionary containing the consistency, factuality, and toxicity arguments. - :param max_tokens: The maximum amount of tokens the model should return. + :param max_completion_tokens: The maximum amount of tokens the model should return. :param temperature: The consistency of the model responses to the same prompt. The higher the more consistent. :param top_p: The sampling for the model to use. :param top_k: The Top-K sampling for the model to use. @@ -92,6 +94,15 @@ def create( :return: A dictionary containing the chat response. """ + # Handling max_tokens and returning deprecation message + if max_tokens is not None: + max_completion_tokens = max_tokens + warn(""" + The max_tokens argument is deprecated. + Please use max_completion_tokens instead. + """, DeprecationWarning, stacklevel=2 + ) + # Create a list of tuples, each containing all the parameters for # a call to _generate_chat args = ( @@ -99,7 +110,7 @@ def create( messages, input, output, - max_tokens, + max_completion_tokens, temperature, top_p, top_k, @@ -117,7 +128,7 @@ def _generate_chat( messages, input, output, - max_tokens, + max_completion_tokens, temperature, top_p, top_k, @@ -142,7 +153,7 @@ def return_dict(url, headers, payload): ) else: # Check if there is a json body in the response. Read that in, - # print out the error field in the json body, and raise an exception. + # then print out the error field in the json body, and raise an exception. err = "" try: err = response.json()["error"] @@ -246,7 +257,7 @@ def stream_generator(url, headers, payload, stream): payload_dict = { "model": model, "messages": messages, - "max_tokens": max_tokens, + "max_completion_tokens": max_completion_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k, @@ -292,4 +303,4 @@ def list_models(self, capability: Optional[str] = "chat-completion") -> List[str for model in response.json()["data"]: response_list.append(model["id"]) - return response_list \ No newline at end of file + return response_list diff --git a/predictionguard/src/completions.py b/predictionguard/src/completions.py index 8a48f1d..845baa1 100644 --- a/predictionguard/src/completions.py +++ b/predictionguard/src/completions.py @@ -2,6 +2,7 @@ import requests from typing import Any, Dict, List, Optional, Union +from warnings import warn from ..version import __version__ @@ -21,7 +22,8 @@ def create( prompt: Union[str, List[str]], input: Optional[Dict[str, Any]] = None, output: Optional[Dict[str, Any]] = None, - max_tokens: Optional[int] = 100, + max_completion_tokens: Optional[int] = 100, + max_tokens: Optional[int] = None, temperature: Optional[float] = 1.0, top_p: Optional[float] = 0.99, top_k: Optional[int] = 50 @@ -33,16 +35,25 @@ def create( :param prompt: The prompt(s) to generate completions for. :param input: A dictionary containing the PII and injection arguments. :param output: A dictionary containing the consistency, factuality, and toxicity arguments. - :param max_tokens: The maximum number of tokens to generate in the completion(s). + :param max_completion_tokens: The maximum number of tokens to generate in the completion(s). :param temperature: The sampling temperature to use. :param top_p: The nucleus sampling probability to use. :param top_k: The Top-K sampling for the model to use. :return: A dictionary containing the completion response. """ + # Handling max_tokens and returning deprecation message + if max_tokens is not None: + max_completion_tokens = max_tokens + warn(""" + The max_tokens argument is deprecated. + Please use max_completion_tokens instead. + """, DeprecationWarning, stacklevel=2 + ) + # Create a list of tuples, each containing all the parameters for # a call to _generate_completion - args = (model, prompt, input, output, max_tokens, temperature, top_p, top_k) + args = (model, prompt, input, output, max_completion_tokens, temperature, top_p, top_k) # Run _generate_completion choices = self._generate_completion(*args) @@ -51,7 +62,7 @@ def create( def _generate_completion( self, model, prompt, - input, output, max_tokens, + input, output, max_completion_tokens, temperature, top_p, top_k ): """ @@ -68,7 +79,7 @@ def _generate_completion( payload_dict = { "model": model, "prompt": prompt, - "max_tokens": max_tokens, + "max_completion_tokens": max_completion_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k diff --git a/predictionguard/version.py b/predictionguard/version.py index 33c738e..9c132bc 100644 --- a/predictionguard/version.py +++ b/predictionguard/version.py @@ -1,2 +1,2 @@ # Setting the package version -__version__ = "2.7.0" +__version__ = "2.7.1"