diff --git a/database/migrations/2024_01_01_000012_replace_sql_results_with_queries_on_messages.php b/database/migrations/2024_01_01_000012_replace_sql_results_with_queries_on_messages.php new file mode 100644 index 0000000..8bdaa1d --- /dev/null +++ b/database/migrations/2024_01_01_000012_replace_sql_results_with_queries_on_messages.php @@ -0,0 +1,36 @@ +getConnection())->table('sql_agent_messages', function (Blueprint $table) { + $table->dropColumn(['sql', 'results']); + }); + + Schema::connection($this->getConnection())->table('sql_agent_messages', function (Blueprint $table) { + $table->json('queries')->nullable()->after('content'); + }); + } + + public function down(): void + { + Schema::connection($this->getConnection())->table('sql_agent_messages', function (Blueprint $table) { + $table->dropColumn('queries'); + }); + + Schema::connection($this->getConnection())->table('sql_agent_messages', function (Blueprint $table) { + $table->text('sql')->nullable()->after('content'); + $table->json('results')->nullable()->after('sql'); + }); + } + + public function getConnection(): ?string + { + return config('sql-agent.database.storage_connection'); + } +}; diff --git a/docs/src/content/docs/guides/web-interface.md b/docs/src/content/docs/guides/web-interface.md index 9a95bf7..61d5611 100644 --- a/docs/src/content/docs/guides/web-interface.md +++ b/docs/src/content/docs/guides/web-interface.md @@ -1,6 +1,6 @@ --- title: Web Interface -description: Livewire chat UI, streaming, debug mode, and conversation exports. +description: Livewire chat UI, streaming, debug mode, and result exports. sidebar: order: 5 --- @@ -66,16 +66,9 @@ You may use the Livewire components directly in your own Blade templates: Displays a searchable list of previous conversations for the current user. -## Exporting Conversations +## Exporting Results -Conversations can be exported as JSON or CSV via dedicated routes: - -| Route | Named Route | Description | -|-------|-------------|-------------| -| `GET /sql-agent/export/{conversation}/json` | `sql-agent.export.json` | Download as JSON | -| `GET /sql-agent/export/{conversation}/csv` | `sql-agent.export.csv` | Download as CSV | - -These routes share the same middleware as the rest of the UI. +Each result table in the chat interface includes **CSV** and **JSON** export buttons in the header bar. Clicking a button downloads the full result set (all rows, not just the current page) directly from the browser — no server round-trip required. ## Streaming (SSE) @@ -86,7 +79,7 @@ The chat interface uses Server-Sent Events for real-time streaming. The streamin | `conversation` | `{"id": 123}` | Sent first with the conversation ID | | `thinking` | `{"thinking": "..."}` | LLM reasoning chunks (when thinking mode is enabled) | | `content` | `{"text": "..."}` | Response text chunks | -| `done` | `{"sql": "...", "hasResults": true, "resultCount": 5}` | Sent when streaming completes | +| `done` | `{"queryCount": 2}` | Sent when streaming completes | | `error` | `{"message": "..."}` | Sent if an error occurs | ## Debug Mode diff --git a/docs/src/content/docs/reference/api.md b/docs/src/content/docs/reference/api.md index 533d396..20816c4 100644 --- a/docs/src/content/docs/reference/api.md +++ b/docs/src/content/docs/reference/api.md @@ -137,7 +137,7 @@ Cache and thought token fields are `null` when the provider does not support the When using the web interface or the SSE streaming endpoint, usage data is included in the `done` event: ```json -{"event": "done", "data": {"sql": "...", "hasResults": true, "resultCount": 5, "usage": {"prompt_tokens": 1234, "completion_tokens": 567, ...}}} +{"event": "done", "data": {"queryCount": 2, "usage": {"prompt_tokens": 1234, "completion_tokens": 567, ...}}} ``` ### Stored Messages diff --git a/resources/views/components/message.blade.php b/resources/views/components/message.blade.php index e0a05ad..b0b4b24 100644 --- a/resources/views/components/message.blade.php +++ b/resources/views/components/message.blade.php @@ -1,9 +1,9 @@ @props([ 'role' => 'user', 'content' => '', - 'sql' => null, - 'results' => null, + 'queries' => null, 'metadata' => null, + 'messageId' => null, 'isStreaming' => false, ]) @@ -14,6 +14,9 @@ $hasPrompt = $debugEnabled && isset($metadata['prompt']); $usage = $metadata['usage'] ?? null; $truncated = $metadata['truncated'] ?? false; + $hasQueries = !empty($queries); + $queryCount = $hasQueries ? count($queries) : 0; + $isSingleQuery = $queryCount === 1; @endphp
@@ -39,8 +42,141 @@
@endif - @if($isAssistant && ($sql || $results || $hasPrompt || $usage)) -
+ @if($isAssistant && ($hasQueries || $hasPrompt || $usage)) +
null); + throw new Error(errorData?.message || `HTTP ${response.status}`); + } + this.queryResults[index] = await response.json(); + this.queryPages[index] = 0; + } catch (e) { + this.queryErrors[index] = e.message; + } finally { + this.loadingQuery = null; + } + }, + handleAutoExecute(event) { + if (event.detail.messageId !== this.messageId) return; + const count = {{ $queryCount }}; + for (let i = 0; i < count; i++) { + this.executeQuery(i); + } + if (count > 1) { + this.showQueries = true; + } + }, + downloadFile(content, filename, type) { + const blob = new Blob([content], { type }); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = filename; + a.click(); + URL.revokeObjectURL(url); + }, + exportCsv(index) { + const rows = this.queryResults[index].rows; + if (!rows?.length) return; + const q = v => String.fromCharCode(34) + String(v).replaceAll(String.fromCharCode(34), String.fromCharCode(34,34)) + String.fromCharCode(34); + const cols = Object.keys(rows[0]); + const lines = [cols.map(q).join(',')]; + for (const row of rows) { + lines.push(cols.map(col => { + const val = row[col]; + if (val === null) return ''; + return q(typeof val === 'object' ? JSON.stringify(val) : val); + }).join(',')); + } + this.downloadFile(lines.join('\n'), 'results.csv', 'text/csv'); + }, + exportJson(index) { + const rows = this.queryResults[index].rows; + if (!rows?.length) return; + this.downloadFile(JSON.stringify(rows, null, 2), 'results.json', 'application/json'); + } + }" @auto-execute-queries.window="handleAutoExecute($event)" class="mt-3">
@if($usage) @@ -51,27 +187,52 @@ @endif - @if($sql) + @if($hasQueries && $isSingleQuery) - @endif - - @if($results && count($results) > 0) + + @elseif($hasQueries) + @endif @@ -83,23 +244,277 @@ class="inline-flex items-center gap-1.5 text-xs px-3 py-1.5 rounded-lg bg-amber- - + @endif
- @if($sql) -
- + {{-- Single query panels --}} + @if($hasQueries && $isSingleQuery) +
+
+ + {{-- Query error --}} + + @endif - @if($results && count($results) > 0) -
- + {{-- Multiple queries --}} + @if($hasQueries && !$isSingleQuery) +
+ @foreach($queries as $index => $query) +
+
+ + Query {{ $index + 1 }} + @if($query['connection'] ?? null) + + + + + {{ $query['connection'] }} + + @endif + +
+ + +
+
+ +
+ +
+ + {{-- Query error --}} + + + {{-- Query results rendered via Alpine --}} + +
+ @endforeach
@endif + {{-- Single query results (rendered via Alpine for dynamic data) --}} + @if($hasQueries && $isSingleQuery) + + @endif + @if($hasPrompt)
diff --git a/resources/views/livewire/chat-component.blade.php b/resources/views/livewire/chat-component.blade.php index b440b06..b3804f8 100644 --- a/resources/views/livewire/chat-component.blade.php +++ b/resources/views/livewire/chat-component.blade.php @@ -47,52 +47,6 @@ class="p-2.5 rounded-lg border border-gray-200 dark:border-gray-600 hover:bg-gra - {{-- Export Menu --}} - @if($conversationId) - - @endif
@@ -162,9 +116,9 @@ class="group w-full p-4 text-left bg-white dark:bg-gray-800 hover:bg-gray-50 dar @endforeach @@ -285,6 +239,7 @@ function chatStream() { isFinishing: false, // True while waiting for Livewire refresh streamedContent: '', pendingUserMessage: '', + pendingMessageId: null, conversationId: @json($conversationId), abortController: null, @@ -388,6 +343,12 @@ function chatStream() { } } } + + // Process any remaining data left in the buffer after stream ends + if (buffer.startsWith('data: ')) { + const data = JSON.parse(buffer.slice(6)); + this.handleEvent(data); + } } catch (error) { // Check if this was a user-initiated cancellation if (error.name === 'AbortError') { @@ -419,6 +380,17 @@ function chatStream() { this.isFinishing = false; this.streamedContent = ''; this.abortController = null; + + // Auto-execute queries on the newly rendered message + if (this.pendingMessageId) { + const messageId = this.pendingMessageId; + this.pendingMessageId = null; + this.$nextTick(() => { + window.dispatchEvent(new CustomEvent('auto-execute-queries', { + detail: { messageId: messageId } + })); + }); + } }, cancelStream() { @@ -449,11 +421,16 @@ function chatStream() { // Error event this.streamedContent = 'Error: ' + data.message; this.renderContent(); - } else if (data.truncated) { - // Done event with truncation — model hit max_tokens - this.streamedContent += '\n\n> **Warning:** The response was cut short because the model reached its token limit. You can increase `SQL_AGENT_LLM_MAX_TOKENS` in your configuration.'; - this.renderContent(); - this.scrollToBottom(); + } else if (data.queryCount !== undefined) { + // Done event + if (data.queryCount > 0 && data.messageId) { + this.pendingMessageId = data.messageId; + } + if (data.truncated) { + this.streamedContent += '\n\n> **Warning:** The response was cut short because the model reached its token limit. You can increase `SQL_AGENT_LLM_MAX_TOKENS` in your configuration.'; + this.renderContent(); + this.scrollToBottom(); + } } }, diff --git a/routes/web.php b/routes/web.php index 091bc8c..533a903 100644 --- a/routes/web.php +++ b/routes/web.php @@ -1,7 +1,7 @@ name('stream'); - // Export endpoints - Route::get('/export/{conversation}/json', [ExportController::class, 'json'])->name('export.json'); - Route::get('/export/{conversation}/csv', [ExportController::class, 'csv'])->name('export.csv'); + // On-demand query execution + Route::post('/query/execute', QueryController::class)->name('query.execute'); }); } diff --git a/src/Agent/SqlAgent.php b/src/Agent/SqlAgent.php index cef0706..c58f242 100644 --- a/src/Agent/SqlAgent.php +++ b/src/Agent/SqlAgent.php @@ -30,6 +30,8 @@ class SqlAgent implements Agent protected ?array $lastResults = null; + protected array $allQueries = []; + protected array $iterations = []; protected ?string $currentQuestion = null; @@ -185,6 +187,11 @@ public function getLastResults(): ?array return $this->lastResults; } + public function getAllQueries(): array + { + return $this->allQueries; + } + public function getIterations(): array { return $this->iterations; @@ -231,6 +238,7 @@ protected function syncFromRunSqlTool(array $tools): void if ($tool instanceof RunSqlTool) { $this->lastSql = $tool->lastSql; $this->lastResults = $tool->lastResults; + $this->allQueries = $tool->executedQueries; return; } @@ -319,6 +327,7 @@ protected function reset(): void { $this->lastSql = null; $this->lastResults = null; + $this->allQueries = []; $this->iterations = []; $this->currentQuestion = null; $this->lastPrompt = null; @@ -329,6 +338,7 @@ protected function reset(): void if ($tool instanceof RunSqlTool) { $tool->lastSql = null; $tool->lastResults = null; + $tool->executedQueries = []; } } } diff --git a/src/Http/Actions/ExportConversationCsv.php b/src/Http/Actions/ExportConversationCsv.php deleted file mode 100644 index 9284057..0000000 --- a/src/Http/Actions/ExportConversationCsv.php +++ /dev/null @@ -1,50 +0,0 @@ -id, - now()->format('Y-m-d-His') - ); - - return response()->streamDownload(function () use ($conversation) { - $handle = fopen('php://output', 'w'); - - // Write CSV header - fputcsv($handle, [ - 'Message ID', - 'Role', - 'Content', - 'SQL', - 'Result Count', - 'Created At', - ]); - - // Write messages - foreach ($conversation->messages as $message) { - fputcsv($handle, [ - $message->id, - $message->role->value, - $message->content, - $message->sql ?? '', - $message->results ? count($message->results) : 0, - $message->created_at->toIso8601String(), - ]); - } - - fclose($handle); - }, $filename, [ - 'Content-Type' => 'text/csv', - ]); - } -} diff --git a/src/Http/Actions/ExportConversationJson.php b/src/Http/Actions/ExportConversationJson.php deleted file mode 100644 index 6517c51..0000000 --- a/src/Http/Actions/ExportConversationJson.php +++ /dev/null @@ -1,47 +0,0 @@ - $conversation->id, - 'title' => $conversation->title, - 'connection' => $conversation->getAttribute('connection'), - 'created_at' => $conversation->created_at->toIso8601String(), - 'updated_at' => $conversation->updated_at->toIso8601String(), - 'messages' => $conversation->messages->map(function ($message) { - return [ - 'id' => $message->id, - 'role' => $message->role->value, - 'content' => $message->content, - 'sql' => $message->sql, - 'results' => $message->results, - 'created_at' => $message->created_at->toIso8601String(), - ]; - })->toArray(), - ]; - - $filename = sprintf( - 'conversation-%d-%s.json', - $conversation->id, - now()->format('Y-m-d-His') - ); - - return response( - json_encode($data, JSON_PRETTY_PRINT | JSON_UNESCAPED_UNICODE), - 200, - [ - 'Content-Type' => 'application/json', - 'Content-Disposition' => sprintf('attachment; filename="%s"', $filename), - ] - ); - } -} diff --git a/src/Http/Actions/StreamAgentResponse.php b/src/Http/Actions/StreamAgentResponse.php index 27933d1..c94ea29 100644 --- a/src/Http/Actions/StreamAgentResponse.php +++ b/src/Http/Actions/StreamAgentResponse.php @@ -85,8 +85,7 @@ protected function persistAndFinish( ?array $usage = null, bool $truncated = false, ): void { - $lastSql = $this->agent->getLastSql(); - $lastResults = $this->agent->getLastResults(); + $allQueries = $this->agent->getAllQueries(); $metadata = []; if ($debugEnabled) { @@ -107,19 +106,17 @@ protected function persistAndFinish( $metadata['truncated'] = true; } - $this->conversationService->addMessage( + $message = $this->conversationService->addMessage( $conversationId, MessageRole::Assistant, $fullContent, - $lastSql, - $lastResults, + ! empty($allQueries) ? $allQueries : null, $metadata ?: null, ); $donePayload = [ - 'sql' => $lastSql, - 'hasResults' => ! empty($lastResults), - 'resultCount' => $lastResults ? count($lastResults) : 0, + 'queryCount' => count($allQueries), + 'messageId' => $message->id, ]; if ($usage !== null) { $donePayload['usage'] = $usage; diff --git a/src/Http/Controllers/ExportController.php b/src/Http/Controllers/ExportController.php deleted file mode 100644 index 1fb4065..0000000 --- a/src/Http/Controllers/ExportController.php +++ /dev/null @@ -1,38 +0,0 @@ -findForCurrentUserWithMessages($conversation); - - if (! $conv) { - abort(404, 'Conversation not found'); - } - - return $action($conv); - } - - public function csv(Request $request, int $conversation, ConversationService $conversationService, ExportConversationCsv $action): StreamedResponse - { - $conv = $conversationService->findForCurrentUserWithMessages($conversation); - - if (! $conv) { - abort(404, 'Conversation not found'); - } - - return $action($conv); - } -} diff --git a/src/Http/Controllers/QueryController.php b/src/Http/Controllers/QueryController.php new file mode 100644 index 0000000..cdc1145 --- /dev/null +++ b/src/Http/Controllers/QueryController.php @@ -0,0 +1,72 @@ +getMessageId()); + + // Verify the message's conversation belongs to the current user + $conversation = $conversationService->findForCurrentUser($message->conversation_id); + if (! $conversation) { + return response()->json(['message' => 'Message not found.'], 404); + } + + $queries = $message->getQueries(); + $queryIndex = $request->getQueryIndex(); + + if (! isset($queries[$queryIndex])) { + return response()->json(['message' => 'Query index out of range.'], 422); + } + + $query = $queries[$queryIndex]; + $sql = trim($query['sql']); + $connectionName = $query['connection'] ?? null; + + try { + $sqlValidator->validate($sql, $connectionName); + } catch (RuntimeException $e) { + return response()->json(['message' => $e->getMessage()], 422); + } + + $resolvedConnection = $connectionRegistry->resolveConnection($connectionName); + $maxRows = config('sql-agent.sql.max_rows'); + + try { + $results = DB::connection($resolvedConnection)->select($sql); + } catch (Throwable $e) { + return response()->json(['message' => $e->getMessage()], 422); + } + + $rows = array_map(fn ($row) => (array) $row, $results); + + $totalRows = count($rows); + $rows = array_slice($rows, 0, $maxRows); + + return response()->json([ + 'rows' => $rows, + 'row_count' => count($rows), + 'total_rows' => $totalRows, + 'truncated' => $totalRows > $maxRows, + ]); + } +} diff --git a/src/Http/Requests/ExecuteQueryRequest.php b/src/Http/Requests/ExecuteQueryRequest.php new file mode 100644 index 0000000..3b42a8d --- /dev/null +++ b/src/Http/Requests/ExecuteQueryRequest.php @@ -0,0 +1,36 @@ + + */ + public function rules(): array + { + return [ + 'message_id' => 'required|integer|exists:sql_agent_messages,id', + 'query_index' => 'required|integer|min:0', + ]; + } + + public function getMessageId(): int + { + return (int) $this->input('message_id'); + } + + public function getQueryIndex(): int + { + return (int) $this->input('query_index'); + } +} diff --git a/src/Models/Message.php b/src/Models/Message.php index e6cacb5..c82ea21 100644 --- a/src/Models/Message.php +++ b/src/Models/Message.php @@ -14,8 +14,7 @@ * @property int $conversation_id * @property MessageRole $role * @property string $content - * @property string|null $sql - * @property array|null $results + * @property array|null $queries * @property array|null $metadata * @property array|null $usage * @property Carbon $created_at @@ -31,8 +30,7 @@ class Message extends Model 'conversation_id', 'role', 'content', - 'sql', - 'results', + 'queries', 'metadata', ]; @@ -40,7 +38,7 @@ protected function casts(): array { return [ 'role' => MessageRole::class, - 'results' => 'array', + 'queries' => 'array', 'metadata' => 'array', ]; } @@ -65,9 +63,9 @@ public function scopeFromAssistant($query) return $query->ofRole(MessageRole::Assistant); } - public function scopeWithSql($query) + public function scopeWithQueries($query) { - return $query->whereNotNull('sql'); + return $query->whereNotNull('queries'); } public function isFromUser(): bool @@ -90,19 +88,14 @@ public function isTool(): bool return $this->role === MessageRole::Tool; } - public function hasSql(): bool + public function hasQueries(): bool { - return ! empty($this->sql); + return ! empty($this->queries); } - public function hasResults(): bool + public function getQueries(): array { - return ! empty($this->results); - } - - public function getResultCount(): int - { - return count($this->results ?? []); + return $this->queries ?? []; } public function getToolName(): ?string diff --git a/src/Services/ConversationService.php b/src/Services/ConversationService.php index 55aeed5..d1af993 100644 --- a/src/Services/ConversationService.php +++ b/src/Services/ConversationService.php @@ -85,16 +85,14 @@ public function addMessage( int $conversationId, MessageRole $role, string $content, - ?string $sql = null, - ?array $results = null, + ?array $queries = null, ?array $metadata = null, ): Message { return Message::create([ 'conversation_id' => $conversationId, 'role' => $role, 'content' => $content, - 'sql' => $sql, - 'results' => $results, + 'queries' => $queries, 'metadata' => $metadata, ]); } diff --git a/src/Services/SqlValidator.php b/src/Services/SqlValidator.php new file mode 100644 index 0000000..add8ad4 --- /dev/null +++ b/src/Services/SqlValidator.php @@ -0,0 +1,86 @@ + 1) { + throw new RuntimeException('Multiple SQL statements are not allowed.'); + } + + $this->validateTableAccess($withoutStrings, $connectionName); + } + + protected function validateTableAccess(string $sql, ?string $connectionName = null): void + { + $tables = $this->extractTableNames($sql); + + foreach ($tables as $table) { + if (! $this->tableAccessControl->isTableAllowed($table, $connectionName)) { + throw new RuntimeException( + "Access denied: table '{$table}' is restricted and cannot be queried." + ); + } + } + } + + /** + * @return array + */ + protected function extractTableNames(string $sql): array + { + $tables = []; + + $pattern = '/\b(?:FROM|JOIN|INTO|UPDATE)\s+([`\[\"]?)(\w+(?:\.\w+)?)\1/i'; + if (preg_match_all($pattern, $sql, $matches)) { + foreach ($matches[2] as $match) { + $parts = explode('.', $match); + $tables[] = end($parts); + } + } + + return array_unique($tables); + } +} diff --git a/src/Tools/RunSqlTool.php b/src/Tools/RunSqlTool.php index 0674819..9e16746 100644 --- a/src/Tools/RunSqlTool.php +++ b/src/Tools/RunSqlTool.php @@ -7,7 +7,7 @@ use Illuminate\Support\Facades\DB; use Knobik\SqlAgent\Events\SqlErrorOccurred; use Knobik\SqlAgent\Services\ConnectionRegistry; -use Knobik\SqlAgent\Services\TableAccessControl; +use Knobik\SqlAgent\Services\SqlValidator; use Prism\Prism\Tool; use RuntimeException; use Throwable; @@ -16,7 +16,7 @@ class RunSqlTool extends Tool { protected ?string $question = null; - protected TableAccessControl $tableAccessControl; + protected SqlValidator $sqlValidator; protected ConnectionRegistry $connectionRegistry; @@ -24,9 +24,11 @@ class RunSqlTool extends Tool public ?array $lastResults = null; + public array $executedQueries = []; + public function __construct() { - $this->tableAccessControl = app(TableAccessControl::class); + $this->sqlValidator = app(SqlValidator::class); $this->connectionRegistry = app(ConnectionRegistry::class); $allowed = implode(', ', config('sql-agent.sql.allowed_statements')); @@ -78,6 +80,10 @@ public function __invoke(string $sql, ?string $connection = null): string $this->lastSql = $sql; $this->lastResults = $rows; + $this->executedQueries[] = [ + 'sql' => $sql, + 'connection' => $connection, + ]; return json_encode([ 'rows' => $rows, @@ -106,81 +112,6 @@ protected function resolveConnection(?string $logicalName): ?string protected function validateSql(string $sql, ?string $connectionName = null): void { - $sqlUpper = strtoupper(trim($sql)); - - $allowedStatements = config('sql-agent.sql.allowed_statements'); - $startsWithAllowed = false; - - foreach ($allowedStatements as $statement) { - if (str_starts_with($sqlUpper, $statement)) { - $startsWithAllowed = true; - break; - } - } - - if (! $startsWithAllowed) { - throw new RuntimeException( - 'Only '.implode(' and ', $allowedStatements).' statements are allowed.' - ); - } - - $forbiddenKeywords = config('sql-agent.sql.forbidden_keywords'); - - foreach ($forbiddenKeywords as $keyword) { - $pattern = '/\b'.preg_quote($keyword, '/').'\b/i'; - if (preg_match($pattern, $sql)) { - throw new RuntimeException( - "Forbidden SQL keyword detected: {$keyword}. This query cannot be executed." - ); - } - } - - $withoutStrings = preg_replace("/'[^']*'/", '', $sql); - $withoutStrings = preg_replace('/"[^"]*"/', '', $withoutStrings); - - if (substr_count($withoutStrings, ';') > 1) { - throw new RuntimeException('Multiple SQL statements are not allowed.'); - } - - $this->validateTableAccess($withoutStrings, $connectionName); - } - - /** - * Extract table names from SQL and validate access. - */ - protected function validateTableAccess(string $sql, ?string $connectionName = null): void - { - $tables = $this->extractTableNames($sql); - - foreach ($tables as $table) { - if (! $this->tableAccessControl->isTableAllowed($table, $connectionName)) { - throw new RuntimeException( - "Access denied: table '{$table}' is restricted and cannot be queried." - ); - } - } - } - - /** - * Extract table names from SQL (best-effort regex). - * - * @return array - */ - protected function extractTableNames(string $sql): array - { - $tables = []; - - // Match FROM table, JOIN table, INTO table, UPDATE table patterns - // Handles optional schema prefix (schema.table) and backtick/bracket quoting - $pattern = '/\b(?:FROM|JOIN|INTO|UPDATE)\s+([`\[\"]?)(\w+(?:\.\w+)?)\1/i'; - if (preg_match_all($pattern, $sql, $matches)) { - foreach ($matches[2] as $match) { - // Strip schema prefix if present - $parts = explode('.', $match); - $tables[] = end($parts); - } - } - - return array_unique($tables); + $this->sqlValidator->validate($sql, $connectionName); } } diff --git a/tests/Feature/Http/QueryControllerTest.php b/tests/Feature/Http/QueryControllerTest.php new file mode 100644 index 0000000..fbad9c2 --- /dev/null +++ b/tests/Feature/Http/QueryControllerTest.php @@ -0,0 +1,149 @@ +artisan('migrate'); + config()->set('sql-agent.user.enabled', true); + + DB::statement('CREATE TABLE IF NOT EXISTS test_users (id INTEGER PRIMARY KEY, name TEXT, email TEXT)'); + DB::table('test_users')->insert([ + ['id' => 1, 'name' => 'John Doe', 'email' => 'john@example.com'], + ['id' => 2, 'name' => 'Jane Smith', 'email' => 'jane@example.com'], + ]); + + $this->user = Helpers::createAuthenticatedUser(); + + $conversation = Conversation::create(['user_id' => $this->user->id]); + $this->message = Message::create([ + 'conversation_id' => $conversation->id, + 'role' => MessageRole::Assistant, + 'content' => 'Here are the users.', + 'queries' => [ + ['sql' => 'SELECT * FROM test_users', 'connection' => null], + ['sql' => 'SELECT name FROM test_users WHERE id = 1', 'connection' => null], + ], + ]); +}); + +afterEach(function () { + DB::statement('DROP TABLE IF EXISTS test_users'); +}); + +describe('QueryController', function () { + it('executes a valid query by message id and query index', function () { + $response = $this->actingAs($this->user) + ->postJson(route('sql-agent.query.execute'), [ + 'message_id' => $this->message->id, + 'query_index' => 0, + ]); + + $response->assertOk(); + $response->assertJsonStructure([ + 'rows', + 'row_count', + 'total_rows', + 'truncated', + ]); + $response->assertJson([ + 'row_count' => 2, + 'total_rows' => 2, + 'truncated' => false, + ]); + }); + + it('executes a different query index', function () { + $response = $this->actingAs($this->user) + ->postJson(route('sql-agent.query.execute'), [ + 'message_id' => $this->message->id, + 'query_index' => 1, + ]); + + $response->assertOk(); + $response->assertJson([ + 'row_count' => 1, + 'total_rows' => 1, + ]); + }); + + it('rejects out-of-range query index', function () { + $response = $this->actingAs($this->user) + ->postJson(route('sql-agent.query.execute'), [ + 'message_id' => $this->message->id, + 'query_index' => 99, + ]); + + $response->assertStatus(422); + $response->assertJson(['message' => 'Query index out of range.']); + }); + + it('requires message_id and query_index parameters', function () { + $response = $this->actingAs($this->user) + ->postJson(route('sql-agent.query.execute'), []); + + $response->assertStatus(422); + $response->assertJsonValidationErrors(['message_id', 'query_index']); + }); + + it('rejects non-existent message id', function () { + $response = $this->actingAs($this->user) + ->postJson(route('sql-agent.query.execute'), [ + 'message_id' => 99999, + 'query_index' => 0, + ]); + + $response->assertStatus(422); + $response->assertJsonValidationErrors('message_id'); + }); + + it('returns structured error for queries that fail at execution', function () { + $conversation = Conversation::create(['user_id' => $this->user->id]); + $message = Message::create([ + 'conversation_id' => $conversation->id, + 'role' => MessageRole::Assistant, + 'content' => 'Query result.', + 'queries' => [ + ['sql' => 'SELECT * FROM nonexistent_table_xyz', 'connection' => null], + ], + ]); + + $response = $this->actingAs($this->user) + ->postJson(route('sql-agent.query.execute'), [ + 'message_id' => $message->id, + 'query_index' => 0, + ]); + + $response->assertStatus(422); + $response->assertJsonStructure(['message']); + }); + + it('rejects access to another users message', function () { + $otherUser = Helpers::createAuthenticatedUser(); + $otherConversation = Conversation::create(['user_id' => $otherUser->id]); + $otherMessage = Message::create([ + 'conversation_id' => $otherConversation->id, + 'role' => MessageRole::Assistant, + 'content' => 'Secret data.', + 'queries' => [ + ['sql' => 'SELECT * FROM test_users', 'connection' => null], + ], + ]); + + // Try to access other user's message as $this->user + $response = $this->actingAs($this->user) + ->postJson(route('sql-agent.query.execute'), [ + 'message_id' => $otherMessage->id, + 'query_index' => 0, + ]); + + $response->assertStatus(404); + $response->assertJson(['message' => 'Message not found.']); + }); +}); diff --git a/tests/Feature/Livewire/ExportControllerTest.php b/tests/Feature/Livewire/ExportControllerTest.php deleted file mode 100644 index ec53574..0000000 --- a/tests/Feature/Livewire/ExportControllerTest.php +++ /dev/null @@ -1,111 +0,0 @@ -set('sql-agent.user.enabled', true); -}); - -it('exports conversation as JSON', function () { - $user = Helpers::createAuthenticatedUser(); - - $conversation = Conversation::create([ - 'user_id' => $user->id, - 'title' => 'Test Conversation', - 'connection' => 'sqlite', - ]); - - Message::create([ - 'conversation_id' => $conversation->id, - 'role' => MessageRole::User, - 'content' => 'Hello world', - ]); - - Message::create([ - 'conversation_id' => $conversation->id, - 'role' => MessageRole::Assistant, - 'content' => 'Hi there!', - 'sql' => 'SELECT * FROM users', - 'results' => [['id' => 1, 'name' => 'John']], - ]); - - $response = $this->actingAs($user) - ->get(route('sql-agent.export.json', $conversation->id)); - - $response->assertStatus(200); - $response->assertHeader('Content-Type', 'application/json'); - - $data = json_decode($response->getContent(), true); - - expect($data['id'])->toBe($conversation->id); - expect($data['title'])->toBe('Test Conversation'); - expect($data['messages'])->toHaveCount(2); - expect($data['messages'][0]['role'])->toBe('user'); - expect($data['messages'][0]['content'])->toBe('Hello world'); - expect($data['messages'][1]['role'])->toBe('assistant'); - expect($data['messages'][1]['sql'])->toBe('SELECT * FROM users'); -}); - -it('exports conversation as CSV', function () { - $user = Helpers::createAuthenticatedUser(); - - $conversation = Conversation::create([ - 'user_id' => $user->id, - 'title' => 'Test Conversation', - 'connection' => 'sqlite', - ]); - - Message::create([ - 'conversation_id' => $conversation->id, - 'role' => MessageRole::User, - 'content' => 'Hello world', - ]); - - Message::create([ - 'conversation_id' => $conversation->id, - 'role' => MessageRole::Assistant, - 'content' => 'Response', - 'sql' => 'SELECT 1', - 'results' => [['value' => 1]], - ]); - - $response = $this->actingAs($user) - ->get(route('sql-agent.export.csv', $conversation->id)); - - $response->assertStatus(200); - // Charset case varies between PHP/Laravel versions (UTF-8 vs utf-8) - expect($response->headers->get('Content-Type'))->toMatch('/^text\/csv; charset=utf-8$/i'); -}); - -it('returns 404 for non-existent conversation', function () { - $user = Helpers::createAuthenticatedUser(); - - $response = $this->actingAs($user) - ->get(route('sql-agent.export.json', 99999)); - - $response->assertStatus(404); -}); - -it('returns 404 when accessing another users conversation', function () { - $user1 = Helpers::createAuthenticatedUser(); - $user2 = Helpers::createAuthenticatedUser(['email' => 'user2@example.com']); - - $conversation = Conversation::create([ - 'user_id' => $user2->id, - 'title' => 'Other User Conversation', - 'connection' => 'sqlite', - ]); - - $response = $this->actingAs($user1) - ->get(route('sql-agent.export.json', $conversation->id)); - - $response->assertStatus(404); -}); diff --git a/tests/Unit/ConversationServiceTest.php b/tests/Unit/ConversationServiceTest.php index 7cc1f7a..3b3b417 100644 --- a/tests/Unit/ConversationServiceTest.php +++ b/tests/Unit/ConversationServiceTest.php @@ -53,21 +53,23 @@ expect($message->conversation_id)->toBe($conversation->id); }); - test('creates an assistant message with sql and results', function () { + test('creates an assistant message with queries', function () { $conversation = Conversation::create(['connection' => 'mysql']); $service = app(ConversationService::class); + $queries = [ + ['sql' => 'SELECT * FROM users', 'connection' => null], + ]; + $message = $service->addMessage( $conversation->id, MessageRole::Assistant, 'Here are the results', - 'SELECT * FROM users', - [['id' => 1, 'name' => 'John']], + $queries, ['thinking' => 'some thoughts'], ); - expect($message->sql)->toBe('SELECT * FROM users'); - expect($message->results)->toBe([['id' => 1, 'name' => 'John']]); + expect($message->queries)->toBe($queries); expect($message->metadata)->toBe(['thinking' => 'some thoughts']); }); }); diff --git a/tests/Unit/Models/MessageTest.php b/tests/Unit/Models/MessageTest.php index 787597e..9495823 100644 --- a/tests/Unit/Models/MessageTest.php +++ b/tests/Unit/Models/MessageTest.php @@ -24,19 +24,34 @@ expect($message->isFromUser())->toBeTrue(); }); - it('can have sql and results', function () { + it('can have queries', function () { $conversation = Conversation::create([]); + $queries = [ + ['sql' => 'SELECT * FROM users', 'connection' => null], + ['sql' => 'SELECT count(*) FROM orders', 'connection' => 'analytics'], + ]; $message = Message::create([ 'conversation_id' => $conversation->id, 'role' => MessageRole::Assistant, 'content' => 'Here are the results', - 'sql' => 'SELECT * FROM users', - 'results' => [['id' => 1, 'name' => 'John']], + 'queries' => $queries, ]); - expect($message->hasSql())->toBeTrue(); - expect($message->hasResults())->toBeTrue(); - expect($message->getResultCount())->toBe(1); + expect($message->hasQueries())->toBeTrue(); + expect($message->getQueries())->toHaveCount(2); + expect($message->getQueries()[0]['sql'])->toBe('SELECT * FROM users'); + }); + + it('returns empty queries when null', function () { + $conversation = Conversation::create([]); + $message = Message::create([ + 'conversation_id' => $conversation->id, + 'role' => MessageRole::Assistant, + 'content' => 'No queries here', + ]); + + expect($message->hasQueries())->toBeFalse(); + expect($message->getQueries())->toBeEmpty(); }); it('scopes by role', function () { diff --git a/tests/Unit/Services/SqlValidatorTest.php b/tests/Unit/Services/SqlValidatorTest.php new file mode 100644 index 0000000..7ab7a7f --- /dev/null +++ b/tests/Unit/Services/SqlValidatorTest.php @@ -0,0 +1,50 @@ +validate('SELECT * FROM users'); + + expect(true)->toBeTrue(); + }); + + it('allows WITH (CTE) statements', function () { + $validator = app(SqlValidator::class); + + $validator->validate('WITH cte AS (SELECT 1) SELECT * FROM cte'); + + expect(true)->toBeTrue(); + }); + + it('rejects INSERT statements', function () { + $validator = app(SqlValidator::class); + + expect(fn () => $validator->validate("INSERT INTO users (name) VALUES ('Test')")) + ->toThrow(RuntimeException::class, 'Only'); + }); + + it('rejects DROP statements', function () { + $validator = app(SqlValidator::class); + + expect(fn () => $validator->validate('DROP TABLE users')) + ->toThrow(RuntimeException::class, 'Only'); + }); + + it('rejects multiple statements', function () { + $validator = app(SqlValidator::class); + + expect(fn () => $validator->validate('SELECT 1; DELETE FROM users')) + ->toThrow(RuntimeException::class); + }); + + it('rejects forbidden keywords', function () { + $validator = app(SqlValidator::class); + + expect(fn () => $validator->validate('SELECT * FROM users; DELETE FROM users')) + ->toThrow(RuntimeException::class); + }); +}); diff --git a/tests/Unit/Tools/RunSqlToolTest.php b/tests/Unit/Tools/RunSqlToolTest.php index a16c4ad..f401bfb 100644 --- a/tests/Unit/Tools/RunSqlToolTest.php +++ b/tests/Unit/Tools/RunSqlToolTest.php @@ -101,6 +101,18 @@ expect($tool->requiredParameters())->toContain('sql'); }); + it('accumulates executedQueries across calls', function () { + $tool = new RunSqlTool; + + $tool('SELECT * FROM test_users'); + $tool('SELECT name FROM test_users WHERE id = 1'); + + expect($tool->executedQueries)->toHaveCount(2); + expect($tool->executedQueries[0]['sql'])->toBe('SELECT * FROM test_users'); + expect($tool->executedQueries[0]['connection'])->toBeNull(); + expect($tool->executedQueries[1]['sql'])->toBe('SELECT name FROM test_users WHERE id = 1'); + }); + it('can set and get question', function () { $tool = new RunSqlTool;