diff --git a/src/index.ts b/src/index.ts index 2dcf064..57d85e4 100644 --- a/src/index.ts +++ b/src/index.ts @@ -3,6 +3,8 @@ import { Hono } from "hono"; import { stream } from "hono/streaming"; import { Octokit } from "@octokit/core"; import * as duckdb from 'duckdb'; +import crypto from 'crypto'; +import path from 'path'; import { createAckEvent, createDoneEvent, @@ -13,12 +15,21 @@ import { verifyAndParseRequest, } from "@copilot-extensions/preview-sdk"; -// Initialize DuckDB -const db = new duckdb.Database(':memory:'); // In-memory database -const connection = db.connect(); +function getUserDBPath(username) { + const hash = crypto + .createHash('sha256') + .update(username) + .digest('hex') + .slice(0, 16); + return path.join('/tmp', `duckdb_${hash}.db`); +} + +// Initialize DuckDB - these will be set per user session +let db; +let connection; // Helper function to execute SQL queries -async function executeQuery(query: string): Promise { +async function executeQuery(query) { return new Promise((resolve, reject) => { connection.all(query, (err, result) => { if (err) reject(err); @@ -28,7 +39,7 @@ async function executeQuery(query: string): Promise { } // Results with printTable(json); -async function executeQueryPretty(query: string): Promise { +async function executeQueryPretty(query) { return new Promise((resolve, reject) => { connection.all(query, (err, result) => { if (err) { @@ -48,7 +59,7 @@ async function executeQueryPretty(query: string): Promise { } // Helper function to execute SQL queries -async function executeQueryTable(query: string): Promise { +async function executeQueryTable(query) { return new Promise((resolve, reject) => { connection.all(query, (err, result) => { if (err) { @@ -74,7 +85,7 @@ async function executeQueryTable(query: string): Promise { } // Dummy helper to filter out non-queries -function containsSQLQuery(message: string): boolean { +function containsSQLQuery(message) { const duckdbPattern = /\b(SELECT|INSERT|UPDATE|DELETE|CREATE|DROP|ALTER|COPY|ATTACH|FROM|WHERE|GROUP BY|ORDER BY|LIMIT|READ_CSV|READ_PARQUET|READ_JSON_AUTO|UNNEST|PRAGMA|EXPLAIN|DESCRIBE|SHOW|SET|WITH|CASE|JOIN|TABLE)\b/i; return duckdbPattern.test(message.toUpperCase()); } @@ -127,6 +138,15 @@ app.post("/", async (c) => { stream.write(createAckEvent()); const octokit = new Octokit({ auth: tokenForUser }); const user = await octokit.request("GET /user"); + + // Initialize database for this user if not already done + if (!db) { + const dbPath = getUserDBPath(user.data.login); + console.log(`Initializing database at ${dbPath}`); + db = new duckdb.Database(dbPath); + connection = db.connect(); + } + const userPrompt = getUserMessage(payload); // Check if the message contains a SQL query @@ -134,7 +154,7 @@ app.post("/", async (c) => { console.log(user.data.login, userPrompt); try { const resultChunks = await executeQueryTable(userPrompt); - console.log('Query Output:',resultChunks.join()); + console.log('Query Output:', resultChunks.join()); for (const chunk of resultChunks) { stream.write(createTextEvent(chunk)); } @@ -154,10 +174,9 @@ app.post("/", async (c) => { stream.write(createTextEvent(chunk)); } } catch (error) { - stream.write(createTextEvent(`Oops! ${error}`)); + stream.write(createTextEvent(`Oops! ${error}`)); } } - } } else { // Handle non-SQL messages using the normal prompt flow @@ -166,8 +185,6 @@ app.post("/", async (c) => { token: tokenForUser, }); console.log('LLM Output:', message.content); - // If everything fails, return whatever the response was - // stream.write(createTextEvent(`This doesn't look like DuckDB SQL.\n`)); stream.write(createTextEvent(message.content)); }