From 2eb8ee9d888ca6a3beb2325880f3c5b4a8229e97 Mon Sep 17 00:00:00 2001 From: jmansdorfer Date: Wed, 11 Dec 2024 13:34:56 -0500 Subject: [PATCH 1/2] adding max_completion_tokens and deprecation message --- predictionguard/src/chat.py | 23 +++++++++++++++++------ predictionguard/src/completions.py | 21 ++++++++++++++++----- predictionguard/version.py | 2 +- 3 files changed, 34 insertions(+), 12 deletions(-) diff --git a/predictionguard/src/chat.py b/predictionguard/src/chat.py index 12532f8..f2309f4 100644 --- a/predictionguard/src/chat.py +++ b/predictionguard/src/chat.py @@ -8,6 +8,7 @@ import urllib.request import urllib.parse import uuid +from warnings import warn from ..version import __version__ @@ -47,7 +48,7 @@ class Chat: ] result = client.chat.completions.create( - model="Hermes-2-Pro-Llama-3-8B", messages=messages, max_tokens=500 + model="Hermes-2-Pro-Llama-3-8B", messages=messages, max_completion_tokens=500 ) print(json.dumps(result, sort_keys=True, indent=4, separators=(",", ": "))) @@ -71,7 +72,8 @@ def create( messages: Union[str, List[Dict[str, Any]]], input: Optional[Dict[str, Any]] = None, output: Optional[Dict[str, Any]] = None, - max_tokens: Optional[int] = 100, + max_completion_tokens: Optional[int] = 100, + max_tokens: Optional[int] = None, temperature: Optional[float] = 1.0, top_p: Optional[float] = 0.99, top_k: Optional[float] = 50, @@ -84,7 +86,7 @@ def create( :param messages: The content of the call, an array of dictionaries containing a role and content. :param input: A dictionary containing the PII and injection arguments. :param output: A dictionary containing the consistency, factuality, and toxicity arguments. - :param max_tokens: The maximum amount of tokens the model should return. + :param max_completion_tokens: The maximum amount of tokens the model should return. :param temperature: The consistency of the model responses to the same prompt. The higher the more consistent. :param top_p: The sampling for the model to use. :param top_k: The Top-K sampling for the model to use. @@ -92,6 +94,15 @@ def create( :return: A dictionary containing the chat response. """ + # Handling max_tokens and returning deprecation message + if max_tokens is not None: + max_completion_tokens = max_tokens + warn(""" + The max_tokens argument is deprecated. + Please use max_completion_tokens instead. + """, DeprecationWarning, stacklevel=2 + ) + # Create a list of tuples, each containing all the parameters for # a call to _generate_chat args = ( @@ -99,7 +110,7 @@ def create( messages, input, output, - max_tokens, + max_completion_tokens, temperature, top_p, top_k, @@ -117,7 +128,7 @@ def _generate_chat( messages, input, output, - max_tokens, + max_completion_tokens, temperature, top_p, top_k, @@ -246,7 +257,7 @@ def stream_generator(url, headers, payload, stream): payload_dict = { "model": model, "messages": messages, - "max_tokens": max_tokens, + "max_completion_tokens": max_completion_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k, diff --git a/predictionguard/src/completions.py b/predictionguard/src/completions.py index 8a48f1d..845baa1 100644 --- a/predictionguard/src/completions.py +++ b/predictionguard/src/completions.py @@ -2,6 +2,7 @@ import requests from typing import Any, Dict, List, Optional, Union +from warnings import warn from ..version import __version__ @@ -21,7 +22,8 @@ def create( prompt: Union[str, List[str]], input: Optional[Dict[str, Any]] = None, output: Optional[Dict[str, Any]] = None, - max_tokens: Optional[int] = 100, + max_completion_tokens: Optional[int] = 100, + max_tokens: Optional[int] = None, temperature: Optional[float] = 1.0, top_p: Optional[float] = 0.99, top_k: Optional[int] = 50 @@ -33,16 +35,25 @@ def create( :param prompt: The prompt(s) to generate completions for. :param input: A dictionary containing the PII and injection arguments. :param output: A dictionary containing the consistency, factuality, and toxicity arguments. - :param max_tokens: The maximum number of tokens to generate in the completion(s). + :param max_completion_tokens: The maximum number of tokens to generate in the completion(s). :param temperature: The sampling temperature to use. :param top_p: The nucleus sampling probability to use. :param top_k: The Top-K sampling for the model to use. :return: A dictionary containing the completion response. """ + # Handling max_tokens and returning deprecation message + if max_tokens is not None: + max_completion_tokens = max_tokens + warn(""" + The max_tokens argument is deprecated. + Please use max_completion_tokens instead. + """, DeprecationWarning, stacklevel=2 + ) + # Create a list of tuples, each containing all the parameters for # a call to _generate_completion - args = (model, prompt, input, output, max_tokens, temperature, top_p, top_k) + args = (model, prompt, input, output, max_completion_tokens, temperature, top_p, top_k) # Run _generate_completion choices = self._generate_completion(*args) @@ -51,7 +62,7 @@ def create( def _generate_completion( self, model, prompt, - input, output, max_tokens, + input, output, max_completion_tokens, temperature, top_p, top_k ): """ @@ -68,7 +79,7 @@ def _generate_completion( payload_dict = { "model": model, "prompt": prompt, - "max_tokens": max_tokens, + "max_completion_tokens": max_completion_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k diff --git a/predictionguard/version.py b/predictionguard/version.py index 33c738e..9c132bc 100644 --- a/predictionguard/version.py +++ b/predictionguard/version.py @@ -1,2 +1,2 @@ # Setting the package version -__version__ = "2.7.0" +__version__ = "2.7.1" From 9cb6c6da0477f97fc5dd8a7dcfea19f27183e075 Mon Sep 17 00:00:00 2001 From: Jacob Mansdorfer <90076431+jmansdorfer@users.noreply.github.com> Date: Mon, 13 Jan 2025 09:48:40 -0500 Subject: [PATCH 2/2] Update chat.py --- predictionguard/src/chat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/predictionguard/src/chat.py b/predictionguard/src/chat.py index f2309f4..332c91e 100644 --- a/predictionguard/src/chat.py +++ b/predictionguard/src/chat.py @@ -153,7 +153,7 @@ def return_dict(url, headers, payload): ) else: # Check if there is a json body in the response. Read that in, - # print out the error field in the json body, and raise an exception. + # then print out the error field in the json body, and raise an exception. err = "" try: err = response.json()["error"] @@ -303,4 +303,4 @@ def list_models(self, capability: Optional[str] = "chat-completion") -> List[str for model in response.json()["data"]: response_list.append(model["id"]) - return response_list \ No newline at end of file + return response_list