Skip to content

Llm rag integration #156

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion tests/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@

CONFIG_DIR = os.path.expanduser('~/.config/yt-fts')


@pytest.fixture(scope="session", autouse=True)
def cleanup_after_tests():
yield
if os.path.exists(CONFIG_DIR):
shutil.rmtree(CONFIG_DIR)
if os.path.exists(f"{CONFIG_DIR}_backup"):
shutil.move(f"{CONFIG_DIR}_backup", CONFIG_DIR)


@pytest.fixture
def runner():
return CliRunner()
Expand All @@ -15,7 +25,13 @@ def runner():
def reset_testing_env():
if os.path.exists(CONFIG_DIR):
if os.environ.get('YT_FTS_TEST_RESET', 'true').lower() == 'true':

if os.path.exists(CONFIG_DIR):
if not os.path.exists(f"{CONFIG_DIR}_backup"):
shutil.copytree(CONFIG_DIR, f"{CONFIG_DIR}_backup")

shutil.rmtree(CONFIG_DIR)

else:
print('running tests with existing db')

Expand Down Expand Up @@ -100,6 +116,7 @@ def test_playlist_download(runner, capsys):
assert subtitle_count >= 20970, f"Expected 20970 subtitles, but got {subtitle_count}"




if __name__ == "__main__":
pytest.main([__file__])

230 changes: 230 additions & 0 deletions yt_fts/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
from rich.prompt import Prompt
from rich.text import Text
import textwrap
import sys
import traceback
from openai import OpenAI
from .db_utils import (
get_channel_id_from_input,
get_channel_name_from_video_id,
get_title_from_db
)
from .get_embeddings import get_embedding
from .utils import time_to_secs
from .config import get_chroma_client

class LLMHandler:
def __init__(self, openai_api_key: str, channel: str):
self.openai_client = OpenAI(api_key=openai_api_key)
self.channel_id = get_channel_id_from_input(channel)
self.chroma_client = get_chroma_client()
self.console = Console()
self.max_width = 80

def init_llm(self, prompt: str):
messages = self.start_llm(prompt)
self.display_message(messages[-1]["content"], "assistant")

while True:
user_input = Prompt.ask("> ")
if user_input.lower() == "exit":
self.console.print("Goodbye!", style="bold red")
sys.exit(0)
messages.append({"role": "user", "content": user_input})
messages = self.continue_llm(messages)
self.display_message(messages[-1]["content"], "assistant")

def display_message(self, content: str, role: str):
if role == "assistant":
wrapped_content = self.wrap_text(content)
md = Markdown(wrapped_content)
# self.console.print(Panel(md, expand=False, border_style="green"))
self.console.print(md)
else:
wrapped_content = self.wrap_text(content)
self.console.print(Text(wrapped_content, style="bold blue"))

def wrap_text(self, text: str) -> str:
lines = text.split('\n')
wrapped_lines = []

for line in lines:
# If the line is a code block, don't wrap it
if line.strip().startswith('```') or line.strip().startswith('`'):
wrapped_lines.append(line)
else:
# Wrap the line
wrapped = textwrap.wrap(line, width=self.max_width, break_long_words=False, replace_whitespace=False)
wrapped_lines.extend(wrapped)


# Join the wrapped lines back together
return " \n".join(wrapped_lines)


def start_llm(self, prompt: str) -> list:
try:
context = self.create_context(prompt)
user_str = f"Context: {context}\n\n---\n\nQuestion: {prompt}\nAnswer:"
system_prompt = """
Answer the question based on the context below, The context are
subtitles and timestamped links from videos related to the question.
In your answer, provide the link to the video where the answer can
be found. and if the question can't be answered based on the context,
say \"I don't know\" AND ONLY I don't know\n\n
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_str},
]

response_text = self.get_completion(messages)

if "i don't know" in response_text.lower():
expanded_query = self.get_expand_context_query(messages)
expanded_context = self.create_context(expanded_query)
messages.append({
"role": "user",
"content": f"Okay here is some more context:\n---\n\n{expanded_context}\n\n---"
})
response_text = self.get_completion(messages)

messages.append({
"role": "assistant",
"content": response_text
})
return messages

except Exception as e:
self.display_error(e)

def continue_llm(self, messages: list) -> list:
try:
response_text = self.get_completion(messages)

if "i don't know" in response_text.lower():
expanded_query = self.get_expand_context_query(messages)
self.console.print(f"[italic]Expanding context with query: {expanded_query}[/italic]")
expanded_context = self.create_context(expanded_query)
messages.append({
"role": "user",
"content": f"Okay here is some more context:\n---\n\n{expanded_context}\n\n---"
})
response_text = self.get_completion(messages)

messages.append({
"role": "assistant",
"content": response_text
})
return messages

except Exception as e:
self.display_error(e)

def display_error(self, error: Exception):
self.console.print(Panel(str(error), title="Error", border_style="red"))
traceback.print_exc()
sys.exit(1)

def create_context(self, text: str) -> str:
collection = self.chroma_client.get_collection(name="subEmbeddings")
search_embedding = get_embedding(text, "text-embedding-ada-002", self.openai_client)
scope_options = {"channel_id": self.channel_id}

chroma_res = collection.query(
query_embeddings=[search_embedding],
n_results=10,
where=scope_options,
)

documents = chroma_res["documents"][0]
metadata = chroma_res["metadatas"][0]
distances = chroma_res["distances"][0]

res = []
for i in range(len(documents)):
text = documents[i]
video_id = metadata[i]["video_id"]
start_time = metadata[i]["start_time"]
link = f"https://youtu.be/{video_id}?t={time_to_secs(start_time)}"
channel_name = get_channel_name_from_video_id(video_id)
channel_id = metadata[i]["channel_id"]
title = get_title_from_db(video_id)

match = {
"distance": distances[i],
"channel_name": channel_name,
"channel_id": channel_id,
"video_title": title,
"subs": text,
"start_time": start_time,
"video_id": video_id,
"link": link,
}
res.append(match)

return self.format_context(res)

def get_expand_context_query(self, messages: list) -> str:
try:
system_prompt = """
Your task is to generate a question to input into a vector search
engine of youtube subitles to find strings that can answer the question
asked in the previous message.
"""
formatted_context = self.format_message_history_context(messages)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": formatted_context},
]

return self.get_completion(messages)

except Exception as e:
self.display_error(e)

def get_completion(self, messages: list) -> str:
try:
response = self.openai_client.chat.completions.create(
model="gpt-4",
messages=messages,
temperature=0,
max_tokens=2000,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
stop=None,
)
return response.choices[0].message.content

except Exception as e:
self.display_error(e)

@staticmethod
def format_message_history_context(messages: list) -> str:
formatted_context = ""
for message in messages:
if message["role"] == "system":
continue
role = message["role"]
content = message["content"]
formatted_context += f"{role}: {content}\n"
return formatted_context

@staticmethod
def format_context(chroma_res: list) -> str:
formatted_context = ""
for obj in chroma_res:
tmp = f"""
Video Title: {obj["video_title"]}
Text: {obj["subs"]}
Time: {obj["start_time"]}
Similarity: {obj["distance"]}
Link: {obj["link"]}
-------------------------
"""
formatted_context += tmp
return formatted_context
37 changes: 37 additions & 0 deletions yt_fts/yt_fts.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,43 @@ def embeddings(channel, openai_api_key, interval=10):
sys.exit(0)


@cli.command(
name="llm",
help="""
Interactive LLM chat bot RAG bot, needs to be run on a channel with
Embeddings.
"""
)
@click.argument("prompt", required=True)
@click.option("-c",
"--channel",
default=None,
required=True,
help="The name or id of the channel to generate embeddings for")
@click.option("--openai-api-key",
default=None,
help="OpenAI API key. If not provided, the script will attempt to read it from"
" the OPENAI_API_KEY environment variable.")
def llm(prompt, channel, openai_api_key=None):
from yt_fts.llm import LLMHandler

if openai_api_key is None:
openai_api_key = os.environ.get("OPENAI_API_KEY")

if openai_api_key is None:
console.print("""
[bold][red]Error:[/red][/bold] OPENAI_API_KEY environment variable not set, Run:

export OPENAI_API_KEY=<your_key> to set the key
""")
sys.exit(1)

llm_handler = LLMHandler(openai_api_key, channel)
llm_handler.init_llm(prompt)

sys.exit(0)


@cli.command(
help="""
Show config settings
Expand Down
Loading