-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathupdate_embeddings_command.py
More file actions
203 lines (170 loc) · 7.61 KB
/
update_embeddings_command.py
File metadata and controls
203 lines (170 loc) · 7.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import os
import pandas as pd
import tiktoken
from bs4 import BeautifulSoup, Comment
from discord.ext import commands
from util import num_tokens, GPT_MODEL, EMBEDDING_MODEL, SAVE_PATH
class UpdateEmbeddingsCommand(commands.Cog):
def __init__(self, bot):
self.bot = bot
@commands.command()
@commands.guild_only()
@commands.has_permissions(administrator=True)
async def update_embeddings(self, ctx):
await ctx.send("Beginning embeddings update process.")
sections = await sectionize_articles()
strings = []
for section in sections:
strings.extend(split_strings_from_subsection(section))
await ctx.send("Wiki articles have been split into strings.")
embeddings = []
BATCH_SIZE = 2048
for batch_start in range(0, len(sections), BATCH_SIZE):
batch_end = batch_start + BATCH_SIZE
# batch = sections[batch_start:batch_end]
response = await self.bot.openai_client.embeddings.create(input=strings, model=EMBEDDING_MODEL)
for i, be in enumerate(response.data):
assert i == be.index # double check embeddings are in same order as input
batch_embeddings = [e.embedding for e in response.data]
embeddings.extend(batch_embeddings)
await ctx.send("Embeddings have been creating. Now saving to data frame.")
df = pd.DataFrame({"text": strings, "embedding": embeddings})
df.to_csv(SAVE_PATH, index=False)
await ctx.send("Data frame save complete. Finished!")
@update_embeddings.error
async def on_embeddings_error(self, ctx, error):
await ctx.send(f"Error: {error}")
# See https://help.openai.com/en/articles/6643167-how-to-use-openai-api-for-q-a-and-chatbot-apps
async def process_headings(soup, page_title: str) -> list[tuple[list[str], str]]:
"""Splits article into sections and subsections using its headings"""
headings = ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']
elements = soup.find_all(headings)
sections = []
path = [f"{page_title} - "] # Keeps track of the current path of headings
for element in elements:
# Determine the level of the current heading
current_level = int(element.name[1])
# Update the path to reflect the current heading's level
path = path[:current_level - 1] + [element.get_text(strip=True)]
# Gather all sibling elements until the next heading
content = []
for sibling in element.next_siblings:
if sibling.name in headings and int(sibling.name[1]) <= current_level:
break # Stop at a heading of equal or higher level
if sibling.name is not None and sibling.name != 'img':
content.append(sibling.get_text(strip=True, separator=" "))
# Add the current section to the list
sections.append((path, ' '.join(content).replace("\n", " ")))
return sections
async def sectionize(directory) -> list[tuple[list[str], str]]:
"""Sectionizes the page at the given directory"""
with open(directory, 'r', encoding="utf8") as file:
# Load the HTML file
html_content = file.read()
# Parse the HTML content
soup = BeautifulSoup(html_content, 'html.parser')
# Find page title
title = None
for comment in soup.find_all(string=lambda text: isinstance(text, Comment)):
if 'title:' in comment:
# Split the comment into lines
lines = comment.split('\n')
# Find the line with the title
for line in lines:
if line.strip().startswith('title:'):
# Extract the text after "title:"
title = line.split('title:', 1)[1].strip()
break
if title is None:
return []
processed_sections = await process_headings(soup, title)
return processed_sections
async def sectionize_articles() -> list[tuple[list[str], str]]:
sections = []
for root, dirs, files in os.walk("wiki"):
for file in files:
if file.endswith('.html'):
file_path = os.path.join(root, file)
page_sections = await sectionize(file_path)
sections.extend(page_sections)
return sections
def halved_by_delimiter(string: str, delimiter: str = "\n") -> list[str, str]:
"""Split a string in two, on a delimiter, trying to balance tokens on each side."""
chunks = string.split(delimiter)
if len(chunks) == 1:
return [string, ""] # no delimiter found
elif len(chunks) == 2:
return chunks # no need to search for halfway point
else:
total_tokens = num_tokens(string)
halfway = total_tokens // 2
best_diff = halfway
for i, chunk in enumerate(chunks):
left = delimiter.join(chunks[: i + 1])
left_tokens = num_tokens(left)
diff = abs(halfway - left_tokens)
if diff >= best_diff:
break
else:
best_diff = diff
left = delimiter.join(chunks[:i])
right = delimiter.join(chunks[i:])
return [left, right]
def truncated_string(
string: str,
model: str,
max_tokens: int,
print_warning: bool = True,
) -> str:
"""Truncate a string to a maximum number of tokens."""
encoding = tiktoken.encoding_for_model(model)
encoded_string = encoding.encode(string)
truncated_string = encoding.decode(encoded_string[:max_tokens])
if print_warning and len(encoded_string) > max_tokens:
print(f"Warning: Truncated string from {len(encoded_string)} tokens to {max_tokens} tokens.")
return truncated_string
# See https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_Wikipedia_articles_for_search.ipynb
def split_strings_from_subsection(
subsection: tuple[list[str], str],
max_tokens: int = 1000,
model: str = GPT_MODEL,
max_recursion: int = 5,
) -> list[str]:
"""
Split a subsection into a list of subsections, each with no more than max_tokens.
Each subsection is a tuple of parent titles [H1, H2, ...] and text (str).
"""
titles, text = subsection
string = "\n\n".join(titles + [text])
num_tokens_in_string = num_tokens(string)
# if length is fine, return string
if num_tokens_in_string <= max_tokens:
return [string]
# if recursion hasn't found a split after X iterations, just truncate
elif max_recursion == 0:
return [truncated_string(string, model=model, max_tokens=max_tokens)]
# otherwise, split in half and recurse
else:
titles, text = subsection
for delimiter in ["\n\n", "\n", ". "]:
left, right = halved_by_delimiter(text, delimiter=delimiter)
if left == "" or right == "":
# if either half is empty, retry with a more fine-grained delimiter
continue
else:
# recurse on each half
results = []
for half in [left, right]:
half_subsection = (titles, half)
half_strings = split_strings_from_subsection(
half_subsection,
max_tokens=max_tokens,
model=model,
max_recursion=max_recursion - 1,
)
results.extend(half_strings)
return results
# otherwise no split was found, so just truncate (should be very rare)
return [truncated_string(string, model=model, max_tokens=max_tokens)]
async def setup(bot):
await bot.add_cog(UpdateEmbeddingsCommand(bot))