diff --git a/.gitignore b/.gitignore index 289adb5..f957775 100644 --- a/.gitignore +++ b/.gitignore @@ -13,7 +13,8 @@ test.md test.txt *.log debug_transcription.wav +*.pkl # VSCode .vscode/ -node_modules/ \ No newline at end of file +node_modules/ diff --git a/sandbox/t10_state_loop.py b/sandbox/t10_state_loop.py new file mode 100644 index 0000000..517d898 --- /dev/null +++ b/sandbox/t10_state_loop.py @@ -0,0 +1,197 @@ +''' + +Just getting the basic components of the State Loop Working + +''' + +import openai +import yaml +import os +import re + +with open(os.path.expanduser('~/.uniteai.yml'), 'r') as ymlfile: + cfg = yaml.safe_load(ymlfile) +openai.api_key = cfg['openai']['api_key'] + +COMPLETION_ENGINES = [ + "text-davinci-003", + "text-davinci-002", + "ada", + "babbage", + "curie", + "davinci", +] + +CHAT_ENGINES = [ + "gpt-3.5-turbo", + "gpt-4", +] + +ENGINE = 'gpt-3.5-turbo' +# ENGINE = 'gpt-4' + +def openai_autocomplete(engine, text, max_length): + ''' NON-Streaming responses from OpenAI's API.''' + if engine in COMPLETION_ENGINES: + response = openai.Completion.create( + engine=engine, + prompt=text, + max_tokens=max_length, + stream=False + ) + return response + elif engine in CHAT_ENGINES: + response = openai.ChatCompletion.create( + model=engine, + messages=[{"role": "user", "content": text}], + stream=False + ) + return response['choices'][0]['message']['content'] + + +def find_tag(tag: str, doc_lines: [str]): + ''' Find index of first element that contains `tag`. ''' + ix = 0 + for ix, line in enumerate(doc_lines): + match = re.search(tag, line) + if match: + return ix, match.start(), match.end() + return None + + +def find_block(start_tag, end_tag, doc): + '''Fine the indices of a start/end-tagged block.''' + if doc is None: + return None, None + doc_lines = doc.split('\n') + s = find_tag(start_tag, doc_lines) + e = find_tag(end_tag, doc_lines) + return s, e + + +def extract_block(start, end, doc): + '''Extract block of text between `start` and `end` tag.''' + if doc is None: + return None + doc_lines = doc.split('\n') + if start is None or end is None: + return None + if start[0] > end[0] or (start[0] == end[0] and start[2] > end[1]): + return None + if start[0] == end[0]: + return [doc_lines[start[0]][start[2]: end[1]]] + else: + block = [doc_lines[start[0]][start[2]:]] # portion of start line + block.extend(doc_lines[start[0]+1:end[0]]) # all of middle lines + block.append(doc_lines[end[0]][:end[1]]) # portion of end line + return '\n'.join(block) + + +def start_tag(x): + return f'<{x}_TAG>' + + +def end_tag(x): + return f'' + + +def get_block(tag, doc): + s1, s2 = find_block(start_tag(tag), end_tag(tag), doc) + return extract_block(s1, s2, doc) + + +STATE = 'STATE' +NEW_STATE = 'NEW_STATE' +REQUEST = 'REQUEST' +RESPONSE = 'RESPONSE' +UPDATES_NEEDED = 'UPDATES_NEEDED' + +state = ''' +players: + josh: + items: + location: + kirtley: + items: + location: + +quests: + +obstacles: + +enemies: +''' + +def get_response(request, + running_resp, + state, + prefix=None, + suffix=None): + nl = '\n\n' # can't do newlines inside f-exprs + prompt = f''' +{prefix + nl if prefix else ''}You must assume the role of a finite state machine, but using only natural language. + +You will be given state, and a request. + +You must return a response, and a new state. + +Please format your response like: + +{start_tag(RESPONSE)} +your response +{end_tag(RESPONSE)} + +{start_tag(UPDATES_NEEDED)} +updates that you'll need to apply to the new state +{end_tag(UPDATES_NEEDED)} + +{start_tag(NEW_STATE)} +the new state +{end_tag(NEW_STATE)} + +Here is the current state: + +{start_tag(STATE)} +{state} +{end_tag(STATE)} + +Here is a transcript of your responses so far: +{running_resp} + +Here is the current request: + {request}{nl + suffix if suffix else ''} +'''.strip() + + return openai_autocomplete(ENGINE, prompt, max_length=200) + +prefix = 'You will be a Dungeon Master, and you will keep notes via a natural language-based state machine. Keep notes on: items, players, quests, etc.' +suffix = 'Remember, keep responses brief, invent interesting quests and obstacles, and make sure the state is always accurate.' + +print('Welcome!') + +running_resp = '' +while True: + request = input('Your Command:') + x = get_response(request, + running_resp=running_resp, + state=state, + prefix=prefix, + suffix=suffix) + + # Try extracting new_state + new_state = get_block(NEW_STATE, x) + if new_state is None: + new_state = get_block(STATE, x) + + # Try extracting response + resp = get_block(RESPONSE, x) + if resp is None: + resp = '' + print(f'INVALID RESPONSE: \n{x}') + continue + + if new_state is not None and resp is not None: + state = new_state + print(f'STATE: {state}') + running_resp = f'{running_resp.strip()}\n\n{resp.strip()}' + print(f'RESPONSE: {running_resp}') diff --git a/sandbox/t11_document_chat.py b/sandbox/t11_document_chat.py new file mode 100644 index 0000000..b8a1244 --- /dev/null +++ b/sandbox/t11_document_chat.py @@ -0,0 +1,443 @@ +''' + +Reading in and indexing documents. + +pip install pypdf +pip install InstructorEmbedding +pip install sentence-transformers + +''' + +import os +from InstructorEmbedding import INSTRUCTOR +import numpy as np +from sklearn.metrics.pairwise import cosine_similarity +from pypdf import PdfReader +from tqdm import tqdm +from typing import List, Dict +import numpy as np +from tqdm import tqdm +from scipy.signal import savgol_filter +import pickle +from dataclasses import dataclass + +@dataclass +class Meta: + name: str + path: str + window_size: int + stride: int + percentile: int + text: str + embeddings: List[np.ndarray] + query_instruction: str + embed_instruction: str + denoise_window_size: int + denoise_poly_order: int + +################################################## +# Load Model + +try: + already_loaded +except: + model = INSTRUCTOR('hkunlp/instructor-base') + already_loaded = True + + + +def embed(xs: List[str]): + ''' Build sentence embeddings for each sentence in `xs` ''' + return model.encode(xs) + + +################################################## +# PDF + +def load_doc(meta: Meta): + ''' Mutate `meta` to include `text` ''' + if meta.text is None: + path = meta.path + path = os.path.expanduser(path) + _, ext = os.path.splitext(path) + if ext == '.pdf': + reader = PdfReader(path) + text = '' + for page in reader.pages: + text += '\n' + page.extract_text() + meta.text = text + else: + with open(path, 'r') as f: + meta.text = f.read() + +def get_file_name(path): + full_name = os.path.basename(path) + name, ext = os.path.splitext(full_name) + return name + + +def load_pkl(pdf_key): + path = f'{pdf_key}.pkl' + if os.path.exists(path): + with open(path, 'rb') as f: + xs = pickle.load(f) + return xs + return None + + +def save_pkl(pdf_key, xs): + path = f'{pdf_key}.pkl' + with open(path, 'wb') as f: + pickle.dump(xs, f) + print(f'Saved: {path}') + + +def load_embeddings(meta): + ''' Mutate `meta` to include embeddings. ''' + pdf_key = meta.name + document = meta.text + window_size = meta.window_size + stride = meta.stride + embed_instruction = meta.embed_instruction + + # Try to load embeddings from disk + embeddings = load_pkl(pdf_key) + if embeddings is not None: + print(f'Loaded {pdf_key} embeddings') + meta.embeddings = embeddings + else: # If not found, then calculate + print(f'Preparing embeddings for {pdf_key}') + embeddings = [] + # Loop through the document with given stride + offsets = list( + enumerate(range(0, len(document) - window_size + 1, stride))) + for emb_i, doc_i in tqdm(offsets): + # Extract the chunk from document + chunk = document[doc_i:doc_i+window_size] + # Embed the chunk + chunk_e = embed([[embed_instruction, chunk]]) + embeddings.append(chunk_e) + + meta.embeddings = embeddings + save_pkl(pdf_key, embeddings) + + +def similar_tokens(query: str, meta: Meta) -> List[float]: + '''Compare a `query` to a strided window over a `document`.''' + embeddings = meta.embeddings + pdf_key = meta.name + document = meta.text + window_size = meta.window_size + stride = meta.stride + query_instruction = meta.query_instruction + # Initialize a numpy array for storing similarities and overlaps + document_length = len(embeddings) * stride + window_size - 1 # Derive the document length from embeddings + similarities = np.zeros(document_length, dtype=float) + overlaps = np.zeros(document_length, dtype=float) + + query_e = embed([[query_instruction, query]]) + + # Loop through the document with given stride + offsets = list(range(0, document_length - window_size + 1, stride)) + for chunk_e, doc_i in tqdm(zip(embeddings, offsets)): + sim = cosine_similarity(query_e, chunk_e)[0][0] + + # Update the similarities and overlaps array + for j in range(doc_i, doc_i + window_size): + similarities[j] += sim + overlaps[j] += 1 + + # Average the similarities with the number of overlaps + similarities /= np.where(overlaps != 0, overlaps, 1) + return similarities + + +def find_spans(arr, threshold=0.5): + ''' ''' + # Create an array that is 1 where arr is above threshold, and padded with 0s at the edges + is_over_threshold = np.concatenate(([0], np.greater(arr, threshold), [0])) + + # Find the indices of rising and falling edges + diffs = np.diff(is_over_threshold) + starts = np.where(diffs > 0)[0] + ends = np.where(diffs < 0)[0] + return list(zip(starts, ends - 1)) + + +def tune_percentile(xs, percentile): + ''' 0-out all elements below percentile. Essentially, this will leave some + `1-percentile` percentage of the document highlighted. ''' + xs = np.copy(xs) # don't mutate original + p = np.percentile(xs, percentile) + xs[xs < p] *= 0 + return xs + + +def segments(similarities, document, threshold=0.0): + out = '' + last_thresh = False # for finding edge + + text = '' + sims = [] + out = [] # [(text, sims), ...] + for sim, char in zip(similarities, document): + super_thresh = sim > threshold + # no longer a super_thresh run + if last_thresh and not super_thresh: + out.append((text, np.array(sims))) + text = '' + sims = [] + + # is a super_thresh run + if super_thresh: + text += char + sims.append(sim) + last_thresh = super_thresh + if len(text) > 0: + out.append((text, np.array(sims))) + + return out + + +def rank(segments, rank_fn): + '''Sort segments according to an aggregate function of their scores.''' + scores = [] + for text, sims in segments: + scores.append(rank_fn(sims)) + out = [] + for score, (text, sims) in reversed(sorted(zip(scores, segments))): + out.append(text) + return out + + +def denoise_similarities(similarities, window_size=2000, poly_order=2): + ''' Apply Savitzky-Golay filter to smooth out the similarity scores. ''' + return savgol_filter(similarities, window_size, poly_order) + + +def top_segments(query, doc_name, top_n, visualize=False): + meta = cache[doc_name] + document = meta.text + denoise_window_size = meta.denoise_window_size + denoise_poly_order = meta.denoise_poly_order + percentile = meta.percentile + similarities = similar_tokens(query, meta) + + # remove outlier at end + last_edge = int(len(similarities) * 0.01) + similarities[-last_edge:] = similarities[-last_edge] + + # Denoise salience scores + # similarities = tune_percentile(similarities, percentile) + d_similarities = denoise_similarities(similarities, + denoise_window_size, + denoise_poly_order) + d_similarities -= d_similarities.min() # normalize + d_similarities /= d_similarities.max() + d_similarities = tune_percentile(d_similarities, percentile) + + segs = segments(d_similarities, document) + ranked_segments = rank(segs, np.mean)[:top_n] + + if visualize: + import matplotlib.pyplot as plt + plt.plot(similarities) + plt.plot(d_similarities) + plt.show() + + return ranked_segments, d_similarities + + +################################################## +# Visualization + +import webbrowser +from html import escape +import os +from typing import List, Tuple +import numpy as np + + +def hex_to_rgb(hex_color: str): + """ Converts a hexadecimal color string to an RGB tuple. """ + hex_color = hex_color.lstrip('#') + return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4)) + + +def interpolate_color(similarity: float, color): + """ Scales an RGB color tuple by a similarity factor. """ + rgb_color = hex_to_rgb(color) + return tuple(similarity * channel for channel in rgb_color) + + +def add_color(a, b): + """ Adds two RGB color tuples element-wise. """ + return tuple(ai + bi for ai, bi in zip(a, b)) + + +def rgb_to_hex(rgb): + """ Converts an RGB color tuple to a hexadecimal color string. """ + r,g,b = rgb + rgb = (int(r), int(g), int(b)) + return '#%02x%02x%02x' % rgb + + +def colorize_text(text: str, queries_and_colors_and_similarity: List[Tuple[str, str, np.ndarray]]) -> str: + """ Colorizes text based on similarity scores for different queries. """ + # Initialize an empty list for the HTML parts + html_parts = [] + + # Get the color names and similarity scores + color_names = [item[1] for item in queries_and_colors_and_similarity] + similarities_tr = zip(*[item[2] for item in queries_and_colors_and_similarity]) # transposed + + # Loop through the text with the corresponding similarity + for similarities, char in zip(similarities_tr, text): + # Initialize color as black + color = (0, 0, 0) + for name, sim in zip(color_names, similarities): + color = add_color(color, interpolate_color(sim, name)) + hex_color = rgb_to_hex(color) + + if char == '\n': + html_parts.append('
') + continue + char = escape(char) + html_parts.append(f'{char}') + + return ''.join(html_parts) + + +def create_colorful_html(queries_and_colors_and_similarity: List[Tuple[str, str, np.ndarray]], + document: str) -> None: + """ Creates an HTML page with colorized query and document text. """ + # Colorize the queries and document + queries_html = [colorize_text(query, [(query, color, np.ones(len(query)))]) + for query, color, self_similarities + in queries_and_colors_and_similarity] + document_html = colorize_text(document, queries_and_colors_and_similarity) + + # Combine the HTML strings for the queries and document + html = ''' + + + +''' + for i, query_html in enumerate(queries_html): + html += f'

QUERY {i+1}:

\n{query_html}\n

\n' + html += '
\n

\n

DOCUMENT:

\n' + document_html + + # Write the HTML string to a temporary file + with open('temp.html', 'w') as f: + f.write(html) + + # Open the HTML file in the default web browser + webbrowser.open('file://' + os.path.realpath('temp.html')) + + +################################################## +# Ranked Segments + +try: + cache_is_loaded +except: + print('Populating cache') + cache = { + 'bitcoin': Meta( + name='bitcoin', + path='~/Documents/misc/bitcoin.pdf', + window_size=300, + stride=100, + text=None, + embeddings=None, + query_instruction='Represent the Science question for retrieving supporting documents: ', + embed_instruction='Represent the Science document for retrieval: ', + denoise_window_size=2000, + denoise_poly_order=2, + percentile=80, + ), + 'idaho': Meta( + name='idaho', + path='~/Documents/misc/land use and development code.pdf', + window_size=300, + stride=100, + text=None, + embeddings=None, + query_instruction='Represent the wikipedia question for retrieving supporting documents: ', + embed_instruction='Represent the wikipedia document for retrieval: ', + denoise_window_size=5000, + denoise_poly_order=2, + percentile=80, + ), + '2001_positive': Meta( + name='2001_positive', + path='./2001_positive.md', + window_size=500, + stride=25, + text=None, + embeddings=None, + query_instruction='Represent the book review question for retrieving supporting documents: ', + embed_instruction='Represent the book review document for retrieval: ', + denoise_window_size=250, + denoise_poly_order=3, + percentile=80, + ), + '2001_negative': Meta( + name='2001_negative', + path='./2001_negative.md', + window_size=500, + stride=25, + text=None, + embeddings=None, + query_instruction='Represent the book review question for retrieving supporting documents: ', + embed_instruction='Represent the book review document for retrieval: ', + denoise_window_size=250, + denoise_poly_order=3, + percentile=80, + ), + } + for k, m in cache.items(): + load_doc(m) + load_embeddings(m) + cache_is_loaded = True + +doc_name = 'bitcoin' +# query = 'whats in it for participants to the blockchain?' +# query = 'how does this protect my anonymity?' +# query = 'im concerned my hdd isnt big enough' +# query = 'who contributed to this paper?' +# query = 'what is the transaction size limit?' +query = 'what game theory problem is it trying to solve?' + +# doc_name = 'idaho' +# # query = 'how close can my silver mine be to a farm?' +# # query = 'how do houses on the lake need to be addressed? marine addressing.' +# # query = 'How can I rezone my property? Rezoning.' +# # query = 'What signs can I put on my property?' +# query = 'How does this document define "sign"?' + +ranked_segments, sims = top_segments(query, doc_name, top_n=3, visualize=True) + +for x in reversed(ranked_segments): + print('\n\nXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX\n\n') + print(x) + + +# ################################################## +# # Colorize document + +# doc_name = '2001_positive' +# queries = [ +# ('The reviewer loves the book.', '00cc00'), +# ('The reviewer hates the book.', 'cc0000'), +# ] +# query_tups = [] +# for (q, c) in queries: +# _, sims = top_segments(q, doc_name, top_n=0) +# query_tups.append((q, c, sims)) + +# create_colorful_html(query_tups, cache[doc_name].text) diff --git a/test/test_common.py b/test/test_common.py index 3e114b9..bb06ff3 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -4,7 +4,7 @@ ''' -from uniteai.common import insert_text_at, find_pattern_in_document +from uniteai.common import insert_text_at, find_block, extract_block import pytest @@ -49,15 +49,27 @@ def test_insert_text_at(): assert str(e.value) == "Column number out of range" -def test_find_pattern_in_document(): - document = ''' -Hello, world! -Regex is fun. -I like programming in Python. -''' +def test_extract_block(): + doc = "This is the first document line.\nThe second line is start_tag and also contains some more text.\nThis is the third line between the tags.\nThis is the fourth line.\nThe fifth line is end_tag and also contains some more text.\nThis is the last document line." + start_tag = "start_tag" + end_tag = "end_tag" + + # Getting the start and end line and column tuples + start, end = find_block(start_tag, end_tag, doc) + + # Expecting three lines as output - the line containing start tag, the line between the tags and the line containing the end tag + expected_output = " and also contains some more text.\nThis is the third line between the tags.\nThis is the fourth line.\nThe fifth line is " + + assert extract_block(start, end, doc) == expected_output, "Test case 1 failed!" + + # Test case when no tag is there in document + start_tag = "no_tag" + end_tag = "no_tag" + start, end = find_block(start_tag, end_tag, doc) + + # Expecting None since no tag is there in document + expected_output = None + + assert extract_block(start, end, doc) == expected_output, "Test case 2 failed!" - assert find_pattern_in_document(document, "o") == [(1, 4, 5), (1, 8, 9), - (3, 9, 10), (3, 26, 27)] - assert find_pattern_in_document(document, "P...on") == [(3, 22, 28)] - assert find_pattern_in_document(document, "Java") == [] - assert find_pattern_in_document('', "o") == [] + print('All test cases passed!') diff --git a/uniteai/common.py b/uniteai/common.py index a76e579..5d28c1b 100644 --- a/uniteai/common.py +++ b/uniteai/common.py @@ -105,20 +105,20 @@ def find_block(start_tag, end_tag, doc): return s, e -def find_pattern_in_document( - document: str, - pattern: str) -> List[Tuple[int, int, int]]: - '''Return (line, start_col, end_col) for each match. Regex cannot span - newlines.''' - result = [] - compiled_pattern = re.compile(pattern) - - for line_number, line in enumerate(document.split('\n')): - for match in compiled_pattern.finditer(line): - start, end = match.span() - result.append((line_number, start, end)) - - return result +def extract_block(start, end, doc): + '''Extract block of text between `start` and `end` tag.''' + doc_lines = doc.split('\n') + if start is None or end is None: + return None + if start[0] > end[0] or (start[0] == end[0] and start[2] > end[1]): + return None + if start[0] == end[0]: + return [doc_lines[start[0]][start[2]: end[1]]] + else: + block = [doc_lines[start[0]][start[2]:]] # portion of start line + block.extend(doc_lines[start[0]+1:end[0]]) # all of middle lines + block.append(doc_lines[end[0]][:end[1]]) # portion of end line + return '\n'.join(block) ################################################## diff --git a/uniteai/contrib/state_loop.py b/uniteai/contrib/state_loop.py new file mode 100644 index 0000000..7fa9ca6 --- /dev/null +++ b/uniteai/contrib/state_loop.py @@ -0,0 +1,353 @@ +''' + +TODO: As of now, this is little more than a copy-pasted example bot. + +A Language-model-based State Machine + +State + Action -> State' + Response + +''' + +import re +from lsprotocol.types import ( + CodeAction, + CodeActionKind, + CodeActionParams, + Command, + Range, + TextDocumentIdentifier, + WorkspaceEdit, +) +from concurrent.futures import ThreadPoolExecutor +import openai +from threading import Event +from thespian.actors import Actor +import argparse +import logging + +from uniteai.edit import init_block, cleanup_block, BlockJob +from uniteai.common import extract_range, find_block, mk_logger, get_nested +from uniteai.server import Server + + +################################################## +# + +def parse_state_and_response(string): + state_pattern = r"START_STATE:\n(.*?)\nEND_STATE:" + response_pattern = r"START_RESPONSE:\n(.*?)\nEND_RESPONSE:" + + state = re.search(state_pattern, string, re.DOTALL) + response = re.search(response_pattern, string, re.DOTALL) + + if state: + state = state.group(1) + else: + state = "" + + if response: + response = response.group(1) + else: + response = "" + + return state, response + +################################################## +# StateLoop + +STATE_START_TAG = ':STATE_START_TAG:' +STATE_END_TAG = ':STATE_END_TAG:' +NAME = 'state_loop' +log = mk_logger(NAME, logging.DEBUG) + +class StateLoopActor(Actor): + def __init__(self): + log.debug('ACTOR INIT') + self.is_running = False + self.executor = ThreadPoolExecutor(max_workers=3) + self.current_future = None + self.should_stop = Event() + self.tags = [START_TAG, END_TAG] + + def receiveMessage(self, msg, sender): + command = msg.get('command') + doc = msg.get('doc') + edits = msg.get('edits') + log.debug(f''' +%%%%%%%%%% +ACTOR RECV: {msg["command"]} +ACTOR STATE: +is_running: {self.is_running} +should_stop: {self.should_stop.is_set()} +current_future: {self.current_future} + +EDITS STATE: +job_thread alive: {edits.job_thread.is_alive() if edits and edits.job_thread else "NOT STARTED"} +%%%%%%%%%%''') + if command == 'start': + uri = msg.get('uri') + range = msg.get('range') + prompt = msg.get('prompt') + engine = msg.get('engine') + max_length = msg.get('max_length') + edits = msg.get('edits') + + # check if block already exists + start_ixs, end_ixs = find_block(START_TAG, + END_TAG, + doc) + + if not (start_ixs and end_ixs): + init_block(NAME, self.tags, uri, range, edits) + + self.start(uri, range, prompt, engine, max_length, edits) + + elif command == 'stop': + self.stop() + + def init_state_block(): + pass + + def + + + def start(self, uri, range, prompt, engine, max_length, edits): + if self.is_running: + log.info('WARN: ON_START_BUT_RUNNING') + return + log.debug('ACTOR START') + + self.is_running = True + self.should_stop.clear() + + def f(uri_, prompt_, engine_, max_length_, should_stop_, edits_): + ''' Compose the streaming fn with some cleanup. ''' + openai_stream_fn(uri_, prompt_, engine_, max_length_, + should_stop_, edits_) + + # Cleanup + log.debug('CLEANING UP') + cleanup_block(NAME, self.tags, uri_, edits_) + self.is_running = False + self.current_future = None + self.should_stop.clear() + + self.current_future = self.executor.submit( + f, uri, prompt, engine, max_length, self.should_stop, edits + ) + log.debug('START CAN RETURN') + + def stop(self): + log.debug('ACTOR STOP') + if not self.is_running: + log.info('WARN: ON_STOP_BUT_STOPPED') + + self.should_stop.set() + + if self.current_future: + self.current_future.result() # block, wait to finish + self.current_future = None + log.debug('FINALLY STOPPED') + + +################################################## +# StateLoop + +COMPLETION_ENGINES = [ + "text-davinci-003", + "text-davinci-002", + "ada", + "babbage", + "curie", + "davinci", +] + +CHAT_ENGINES = [ + "gpt-3.5-turbo", + "gpt-3.5-turbo-0613", + "gpt-4", +] + + +def openai_autocomplete(engine, text, max_length): + ''' Stream responses from StateLoop's API as a generator. ''' + if engine in COMPLETION_ENGINES: + response = openai.Completion.create( + engine=engine, + prompt=text, + max_tokens=max_length, + stream=True + ) + for message in response: + generated_text = message['choices'][0]['text'] + yield generated_text + elif engine in CHAT_ENGINES: + response = openai.ChatCompletion.create( + model=engine, + messages=[{"role": "user", "content": text}], + stream=True + ) + for message in response: + # different json structure than completion endpoint + delta = message['choices'][0]['delta'] + if 'content' in delta: + generated_text = delta['content'] + yield generated_text + +def openai_stream_fn(uri, prompt, engine, max_length, stop_event, edits): + log.debug(f'START: OPENAI_STREAM_FN, max_length={max_length}') + try: + # Stream the results to LSP Client + running_text = '' + for new_text in openai_autocomplete(engine, prompt, max_length): + # For breaking out early + if stop_event.is_set(): + log.debug('STREAM_FN received STOP EVENT') + break + log.debug(f'NEW: {new_text}') + # ignore empty strings + if len(new_text) == 0: + continue + + running_text += new_text + job = BlockJob( + uri=uri, + start_tag=START_TAG, + end_tag=END_TAG, + text=f'\n{running_text}\n', + strict=False, + ) + edits.add_job(NAME, job) + + # Streaming is done, and those added jobs were all non-strict. Let's + # make sure to have one final strict job. Streaming jobs are ok to be + # dropped, but we need to make sure it does finalize, eg before a + # strict delete-tags job is added. + job = BlockJob( + uri=uri, + start_tag=START_TAG, + end_tag=END_TAG, + text=f'\n{running_text}\n', + strict=True, + ) + edits.add_job(NAME, job) + log.debug('STREAM COMPLETE') + except Exception as e: + log.error(f'Error: StateLoop, {e}') + + +def code_action_gpt(engine, max_length, params: CodeActionParams): + '''Trigger a GPT Autocompletion response. A code action calls a command, + which is set up below to `tell` the actor to start streaming a response.''' + text_document = params.text_document + range = params.range + return CodeAction( + title='StateLoop GPT', + kind=CodeActionKind.Refactor, + command=Command( + title='StateLoop GPT', + command='command.openaiAutocompleteStream', + # Note: these arguments get jsonified, not passed as python objs + arguments=[text_document, range, engine, max_length] + ) + ) + + +def code_action_chat_gpt(engine, max_length, params: CodeActionParams): + '''Trigger a ChatGPT response. A code action calls a command, which is set + up below to `tell` the actor to start streaming a response. ''' + text_document = params.text_document + range = params.range + return CodeAction( + title='StateLoop ChatGPT', + kind=CodeActionKind.Refactor, + command=Command( + title='StateLoop ChatGPT', + command='command.openaiAutocompleteStream', + # Note: these arguments get jsonified, not passed as python objs + arguments=[text_document, range, engine, max_length] + ) + ) + + +################################################## +# Setup + +def configure(config_yaml): + parser = argparse.ArgumentParser() + parser.add_argument('--openai_completion_engine', default=get_nested(config_yaml, ['openai', 'completion_engine'])) + parser.add_argument('--openai_chat_engine', default=get_nested(config_yaml, ['openai', 'chat_engine'])) + parser.add_argument('--openai_max_length', default=get_nested(config_yaml, ['openai', 'max_length'])) + parser.add_argument('--openai_api_key', default=get_nested(config_yaml, ['openai', 'api_key'])) + + # bc this is only concerned with openai params, do not error if extra params + # are sent via cli. + args, _ = parser.parse_known_args() + return args + + + + +def initialize(config, server): + # Config + openai_chat_engine = config.openai_chat_engine + openai_completion_engine = config.openai_completion_engine + openai_max_length = config.openai_max_length + openai.api_key = config.openai_api_key # make library aware of api key + + # Actor + server.add_actor(NAME, StateLoopActor) + + # CodeActions + server.add_code_action( + lambda params: + code_action_gpt(openai_completion_engine, openai_max_length, params)) + server.add_code_action( + lambda params: + code_action_chat_gpt(openai_chat_engine, openai_max_length, params)) + + # Modify Server + @server.thread() + @server.command('command.openaiAutocompleteStream') + def openai_autocomplete_stream(ls: Server, args): + if len(args) != 4: + log.error(f'command.openaiAutocompleteStream: Wrong arguments, received: {args}') + text_document = ls.converter.structure(args[0], TextDocumentIdentifier) + range = ls.converter.structure(args[1], Range) + uri = text_document.uri + doc = ls.workspace.get_document(uri) + doc_source = doc.source + + # Determine engine, by checking for sentinel values to allow LSP client + # to defer arguments to server's configuration. + if args[2] == FROM_CONFIG_CHAT: + engine = openai_chat_engine + elif args[2] == FROM_CONFIG_COMPLETION: + engine = openai_completion_engine + else: + engine = args[2] + + # Max Length + if args[3] == FROM_CONFIG: + max_length = openai_max_length + else: + max_length = args[3] + + # Extract the highlighted region + prompt = extract_range(doc_source, range) + + # Send a message to start the stream + actor_args = { + 'command': 'start', + 'uri': uri, + 'range': range, + 'prompt': prompt, + 'engine': engine, + 'max_length': max_length, + 'edits': ls.edits, + 'doc': doc_source, + } + ls.tell_actor(NAME, actor_args) + + # Return null-edit immediately (the rest will stream) + return WorkspaceEdit()