From aa4b0786f7c0f2d98df83130f28d18ee216abd9a Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 12 May 2025 16:32:42 -0700 Subject: [PATCH] Support reasoning tokens, saving/loading chat checkpoints, and adaptive max tokens --- chat_client.py | 116 +++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 98 insertions(+), 18 deletions(-) diff --git a/chat_client.py b/chat_client.py index 2443096..f9bff56 100644 --- a/chat_client.py +++ b/chat_client.py @@ -17,6 +17,7 @@ import argparse import atexit +import json import os import readline @@ -27,6 +28,12 @@ def chat_loop(model: str, url: str, args): conversation = [] client = openai.OpenAI(api_key="None", base_url=url) temperature = openai.NOT_GIVEN + total_tokens = None + if args.total_tokens is not None: + total_tokens = args.total_tokens + system_prompt = args.system_prompt + if system_prompt is not None: + conversation.append({"role": "system", "content": system_prompt}) print( "Type a message to start the chat.", @@ -36,6 +43,11 @@ def chat_loop(model: str, url: str, args): f"Commands are stored in history file {args.history}.", ) + if args.load: + with open(args.load, "r") as fp: + conversation = json.load(fp) + print(f"Loaded history from {args.load}") + try: readline.read_history_file(args.history) readline.set_history_length(1000) @@ -43,11 +55,11 @@ def chat_loop(model: str, url: str, args): pass atexit.register(readline.write_history_file, args.history) - + tokens_so_far = 0 try: while True: + skip = False message = input("> ") - # Commands if message.strip().startswith("/"): command = message[1:].strip() @@ -59,7 +71,11 @@ def chat_loop(model: str, url: str, args): " /clear: Clear the chat context\n", " /temp : Set the temperature for the model\n", " /tokens : Set the maximal number of tokens to use per response\n", - " /cq : Prefix with custom prompt given by --custom-prompt ", + " /totaltokens : Set the maximal number of tokens to use in total\n", + " /cq : Prefix with custom prompt given by --custom-prompt \n", + " /save [file.json]: Save current chat history (Default: chat_history.json)\n", + " /load : Load chat history from the given file\n", + " /tool : Responds as \"tool\" instead of \"user\"\n", ) continue elif command == "exit": @@ -67,6 +83,7 @@ def chat_loop(model: str, url: str, args): elif command == "clear": print("[Chat context cleared]") conversation = [] + tokens_so_far = 0 continue elif command.startswith("temp "): try: @@ -75,9 +92,7 @@ def chat_loop(model: str, url: str, args): raise ValueError print(f"[Temperature set to {temperature}]") except ValueError: - print( - "[Invalid temperature. Should be a positive number less than 1]" - ) + print("[Invalid temperature. Should be a positive number less than 1]") continue elif command.startswith("tokens "): try: @@ -87,33 +102,75 @@ def chat_loop(model: str, url: str, args): print(f"[Tokens set to {tokens}]") args.max_tokens = tokens except ValueError: - print( - "[Invalid number of tokens. Should be a positive number less than 131,000]" - ) + print("[Invalid number of tokens. Should be a positive number less than 131,000]") + continue + elif command.startswith("totaltokens "): + try: + tokens = int(command.split(" ")[1]) + if tokens < 0 or tokens > 131000: + raise ValueError + print(f"[Total Tokens set to {tokens}]") + total_tokens = tokens + except ValueError: + print("[Invalid number of tokens. Should be a positive number less than 131,000]") + continue + elif command.startswith("sofar"): + print("[Tokens in conversation:", tokens_so_far, "]") continue elif command.startswith("cq "): if args.custom_prompt is None: - print( - "[Error: a custom prompt has not been provided, use --custom-prompt]" - ) + print("[Error: a custom prompt has not been provided, use --custom-prompt]") continue message = command[len("cq ") :] new_message = open(args.custom_prompt, "r").read() new_message += message + "]" message = new_message + elif command.startswith("save"): + filename = None + if command == "save": + filename = "chat_history.json" + elif command.startswith("save "): + filename = command[len("save ") :] + if not filename: + print("[Error: invalid syntax for /save]") + continue + with open(filename, "w") as fp: + json.dump(conversation, fp) + print(f"[Saved history to {filename}]") + continue + elif command.startswith("load "): + filename = command[len("load ") :] + if not os.path.exists(filename): + print(f'Cannot load "{filename}", file does not exist') + continue + with open(filename, "r") as fp: + conversation = json.load(fp) + print(f"[Loaded history from {filename}]") + continue + elif command.startswith("tool "): + response = command[len("tool ") :] + conversation.append({"role": "tool", "content": response}) + skip = True else: print(f"Invalid command '{command}'") continue - conversation.append({"role": "user", "content": message}) + if not skip: + conversation.append({"role": "user", "content": message}) try: + max_tokens = args.max_tokens + if total_tokens is not None: + max_tokens = total_tokens - tokens_so_far + + print(conversation) chat_completion = client.chat.completions.create( model=model, messages=conversation, stream=not args.no_stream, temperature=temperature, - max_tokens=args.max_tokens, + max_tokens=max_tokens, + # tools=tools, ) if args.no_stream: @@ -122,15 +179,34 @@ def chat_loop(model: str, url: str, args): print(response) conversation.append({"role": "assistant", "content": response}) else: + response_to_save = "" full_response = "" + in_reasoning = False for chunk in chat_completion: if chunk.choices[0].delta.content is not None: - full_response += chunk.choices[0].delta.content + if chunk.choices[0].delta.content == "": + print("\n==========================REASONING===========================") + in_reasoning = True + continue + if chunk.choices[0].delta.content == "": + print("\n==========================END REASONING===========================") + in_reasoning = False + continue + + if not in_reasoning: + response_to_save += chunk.choices[0].delta.content + full_response += chunk.choices[0].delta.content + tokens_so_far += 1 + else: + full_response += chunk.choices[0].delta.content print(chunk.choices[0].delta.content, end="", flush=True) + if args.stop_after and response_to_save.endswith(args.stop_after): + print("\n[Stopping after seeing the stop pattern]") + break print() full_response += "\n" - response_message = {"role": "assistant", "content": full_response} + response_message = {"role": "assistant", "content": response_to_save} conversation.append(response_message) except KeyboardInterrupt: # Catch ctrl-C if not args.no_stream: @@ -145,11 +221,15 @@ def main(): parser.add_argument("--model", type=str, default="LLLama") parser.add_argument("--url", type=str, default="http://localhost:8123") parser.add_argument("--max-tokens", type=int, default=1024) + parser.add_argument("--total-tokens", type=int, default=None) parser.add_argument("--custom-prompt", type=str, default=None) + parser.add_argument("--system-prompt", type=str, default=None) + parser.add_argument("--history", action="store", type=str, default=".chat-client-history") + parser.add_argument("--no-stream", action="store_true") parser.add_argument( - "--history", action="store", type=str, default=".chat-client-history" + "--load", type=str, default=None, help="Continue from a file containing an existing chat history" ) - parser.add_argument("--no-stream", action="store_true") + parser.add_argument("--stop-after", type=str, default=None, help="Checkpoint chat when seeing the given pattern") args = parser.parse_args() chat_loop(args.model, args.url, args)